Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2012-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "CLTool.h"
23 : #include "core/CLToolRegister.h"
24 : #include "tools/Tools.h"
25 : #include "config/Config.h"
26 : #include "core/PlumedMain.h"
27 : #include "core/ActionSet.h"
28 : #include "core/ActionRegister.h"
29 : #include "core/ActionShortcut.h"
30 : #include "core/ActionToPutData.h"
31 : #include "core/ActionWithVirtualAtom.h"
32 : #include "core/ActionWithVector.h"
33 : #include <cstdio>
34 : #include <string>
35 : #include <iostream>
36 :
37 : namespace PLMD {
38 : namespace cltools {
39 :
40 : //+PLUMEDOC TOOLS show_graph
41 : /*
42 : show_graph is a tool that takes a plumed input and generates a graph showing how
43 : data flows through the action set involved.
44 :
45 : If this tool is invoked without the --force keyword then the way data is passed through the code during the forward pass
46 : through the action is shown.
47 :
48 : When the --force keyword is used then the way forces are passed from biases through actions is shown.
49 :
50 : \par Examples
51 :
52 : The following generates the mermaid file for the input in plumed.dat
53 : \verbatim
54 : plumed show_graph --plumed plumed.dat
55 : \endverbatim
56 :
57 : */
58 : //+ENDPLUMEDOC
59 :
60 : class ShowGraph :
61 : public CLTool {
62 : public:
63 : static void registerKeywords( Keywords& keys );
64 : explicit ShowGraph(const CLToolOptions& co );
65 : int main(FILE* in, FILE*out,Communicator& pc);
66 4 : std::string description()const {
67 4 : return "generate a graph showing how data flows through a PLUMED action set";
68 : }
69 : std::string getLabel(const Action* a, const bool& amp=false);
70 : std::string getLabel(const std::string& s, const bool& amp=false );
71 : void printStyle( const unsigned& linkcount, const Value* v, OFile& ofile );
72 : void printArgumentConnections( const ActionWithArguments* a, unsigned& linkcount, const bool& force, OFile& ofile );
73 : void printAtomConnections( const ActionAtomistic* a, unsigned& linkcount, const bool& force, OFile& ofile );
74 : void drawActionWithVectorNode( OFile& ofile, PlumedMain& p, Action* ag, const std::vector<std::string>& mychain, std::vector<bool>& printed );
75 : };
76 :
77 16338 : PLUMED_REGISTER_CLTOOL(ShowGraph,"show_graph")
78 :
79 5442 : void ShowGraph::registerKeywords( Keywords& keys ) {
80 5442 : CLTool::registerKeywords( keys );
81 10884 : keys.add("compulsory","--plumed","plumed.dat","the plumed input that we are generating the graph for");
82 10884 : keys.add("compulsory","--out","graph.md","the dot file containing the graph that has been generated");
83 10884 : keys.addFlag("--force",false,"print a graph that shows how forces are passed through the actions");
84 5442 : }
85 :
86 12 : ShowGraph::ShowGraph(const CLToolOptions& co ):
87 12 : CLTool(co) {
88 12 : inputdata=commandline;
89 12 : }
90 :
91 377 : std::string ShowGraph::getLabel(const Action* a, const bool& amp) {
92 377 : return getLabel( a->getLabel(), amp );
93 : }
94 :
95 453 : std::string ShowGraph::getLabel( const std::string& s, const bool& amp ) {
96 453 : if( s.find("@")==std::string::npos ) {
97 405 : return s;
98 : }
99 48 : std::size_t p=s.find_first_of("@");
100 48 : if( amp ) {
101 30 : return "#64;" + s.substr(p+1);
102 : }
103 33 : return s.substr(p+1);
104 : }
105 :
106 85 : void ShowGraph::printStyle( const unsigned& linkcount, const Value* v, OFile& ofile ) {
107 85 : if( v->getRank()>0 && v->hasDerivatives() ) {
108 0 : ofile.printf("linkStyle %d stroke:green,color:green;\n", linkcount);
109 85 : } else if( v->getRank()==1 ) {
110 33 : ofile.printf("linkStyle %d stroke:blue,color:blue;\n", linkcount);
111 52 : } else if ( v->getRank()==2 ) {
112 30 : ofile.printf("linkStyle %d stroke:red,color:red;\n", linkcount);
113 : }
114 85 : }
115 :
116 63 : void ShowGraph::printArgumentConnections( const ActionWithArguments* a, unsigned& linkcount, const bool& force, OFile& ofile ) {
117 63 : if( !a ) {
118 : return;
119 : }
120 101 : for(const auto & v : a->getArguments() ) {
121 55 : if( force && v->forcesWereAdded() ) {
122 28 : ofile.printf("%s -- %s --> %s\n", getLabel(a).c_str(), v->getName().c_str(), getLabel(v->getPntrToAction()).c_str() );
123 14 : printStyle( linkcount, v, ofile );
124 14 : linkcount++;
125 41 : } else if( !force ) {
126 66 : ofile.printf("%s -- %s --> %s\n", getLabel(v->getPntrToAction()).c_str(),v->getName().c_str(),getLabel(a).c_str() );
127 33 : printStyle( linkcount, v, ofile );
128 33 : linkcount++;
129 : }
130 : }
131 : }
132 :
133 55 : void ShowGraph::printAtomConnections( const ActionAtomistic* a, unsigned& linkcount, const bool& force, OFile& ofile ) {
134 55 : if( !a ) {
135 : return;
136 : }
137 179 : for(const auto & d : a->getDependencies() ) {
138 138 : ActionToPutData* dp=dynamic_cast<ActionToPutData*>(d);
139 138 : if( dp && dp->getLabel()=="posx" ) {
140 18 : if( force && (dp->copyOutput(0))->forcesWereAdded() ) {
141 8 : ofile.printf("%s --> MD\n", getLabel(a).c_str() );
142 8 : ofile.printf("linkStyle %d stroke:violet,color:violet;\n", linkcount);
143 8 : linkcount++;
144 : } else {
145 10 : ofile.printf("MD --> %s\n", getLabel(a).c_str() );
146 10 : ofile.printf("linkStyle %d stroke:violet,color:violet;\n", linkcount);
147 10 : linkcount++;
148 : }
149 120 : } else if( dp && dp->getLabel()!="posy" && dp->getLabel()!="posz" && dp->getLabel()!="Masses" && dp->getLabel()!="Charges" ) {
150 21 : if( force && (dp->copyOutput(0))->forcesWereAdded() ) {
151 18 : ofile.printf("%s -- %s --> %s\n",getLabel(a).c_str(), getLabel(d).c_str(), getLabel(d).c_str() );
152 9 : printStyle( linkcount, dp->copyOutput(0), ofile );
153 9 : linkcount++;
154 : } else {
155 24 : ofile.printf("%s -- %s --> %s\n", getLabel(d).c_str(),getLabel(d).c_str(),getLabel(a).c_str() );
156 12 : printStyle( linkcount, dp->copyOutput(0), ofile );
157 12 : linkcount++;
158 : }
159 21 : continue;
160 : }
161 117 : ActionWithVirtualAtom* dv=dynamic_cast<ActionWithVirtualAtom*>(d);
162 117 : if( dv ) {
163 4 : if( force && (dv->copyOutput(0))->forcesWereAdded() ) {
164 2 : ofile.printf("%s -- %s --> %s\n", getLabel(a).c_str(),getLabel(d).c_str(),getLabel(d).c_str() );
165 1 : ofile.printf("linkStyle %d stroke:violet,color:violet;\n", linkcount);
166 1 : linkcount++;
167 : } else {
168 6 : ofile.printf("%s -- %s --> %s\n", getLabel(d).c_str(),getLabel(d).c_str(),getLabel(a).c_str() );
169 3 : ofile.printf("linkStyle %d stroke:violet,color:violet;\n", linkcount);
170 3 : linkcount++;
171 : }
172 : }
173 : }
174 : }
175 :
176 30 : void ShowGraph::drawActionWithVectorNode( OFile& ofile, PlumedMain& p, Action* ag, const std::vector<std::string>& mychain, std::vector<bool>& printed ) {
177 30 : ActionWithVector* agg=dynamic_cast<ActionWithVector*>(ag);
178 : std::vector<std::string> matchain;
179 30 : agg->getAllActionLabelsInMatrixChain( matchain );
180 30 : if( matchain.size()>0 ) {
181 16 : ofile.printf("subgraph sub%s_mat [%s]\n",getLabel(agg).c_str(), getLabel(agg).c_str());
182 24 : for(unsigned j=0; j<matchain.size(); ++j ) {
183 16 : Action* agm=p.getActionSet().selectWithLabel<Action*>(matchain[j]);
184 60 : for(unsigned k=0; k<mychain.size(); ++k ) {
185 60 : if( mychain[k]==matchain[j] ) {
186 : printed[k]=true;
187 16 : break;
188 : }
189 : }
190 32 : ofile.printf("%s([\"label=%s \n %s \n\"])\n", getLabel(matchain[j]).c_str(), getLabel(matchain[j],true).c_str(), agm->writeInGraph().c_str() );
191 : }
192 8 : ofile.printf("end\n");
193 16 : ofile.printf("style sub%s_mat fill:lightblue\n",getLabel(ag).c_str());
194 : } else {
195 44 : ofile.printf("%s([\"label=%s \n %s \n\"])\n", getLabel(ag->getLabel()).c_str(), getLabel(ag->getLabel(),true).c_str(), ag->writeInGraph().c_str() );
196 : }
197 30 : }
198 :
199 8 : int ShowGraph::main(FILE* in, FILE*out,Communicator& pc) {
200 :
201 : std::string inpt;
202 16 : parse("--plumed",inpt);
203 : std::string outp;
204 8 : parse("--out",outp);
205 : bool forces;
206 8 : parseFlag("--force",forces);
207 :
208 : // Create a plumed main object and initilize
209 8 : PlumedMain p;
210 8 : int rr=sizeof(double);
211 8 : p.cmd("setRealPrecision",&rr);
212 8 : double lunit=1.0;
213 8 : p.cmd("setMDLengthUnits",&lunit);
214 8 : double cunit=1.0;
215 8 : p.cmd("setMDChargeUnits",&cunit);
216 8 : double munit=1.0;
217 8 : p.cmd("setMDMassUnits",&munit);
218 8 : p.cmd("setPlumedDat",inpt.c_str());
219 8 : p.cmd("setLog",out);
220 8 : int natoms=1000000;
221 8 : p.cmd("setNatoms",&natoms);
222 8 : p.cmd("init");
223 :
224 8 : unsigned linkcount=0;
225 8 : OFile ofile;
226 8 : ofile.open(outp);
227 8 : if( forces ) {
228 : unsigned step=1;
229 4 : p.cmd("setStep",step);
230 4 : p.cmd("prepareCalc");
231 4 : ofile.printf("flowchart BT \n");
232 : std::vector<std::string> drawn_nodes;
233 : std::set<std::string> atom_force_set;
234 103 : for(auto pp=p.getActionSet().rbegin(); pp!=p.getActionSet().rend(); ++pp) {
235 : const auto & a(pp->get());
236 534 : if( a->getName()=="DOMAIN_DECOMPOSITION" || a->getLabel()=="posx" || a->getLabel()=="posy" || a->getLabel()=="posz" || a->getLabel()=="Masses" || a->getLabel()=="Charges" ) {
237 24 : continue;
238 : }
239 :
240 75 : if(a->isActive()) {
241 44 : ActionToPutData* ap=dynamic_cast<ActionToPutData*>(a);
242 44 : if( ap ) {
243 8 : ofile.printf("%s(\"label=%s \n %s \n\")\n", getLabel(a).c_str(), getLabel(a,true).c_str(), a->writeInGraph().c_str() );
244 4 : continue;
245 : }
246 40 : ActionWithValue* av=dynamic_cast<ActionWithValue*>(a);
247 40 : if( !av ) {
248 5 : continue ;
249 : }
250 : // Now apply the force if there is one
251 35 : a->apply();
252 : bool hasforce=false;
253 67 : for(int i=0; i<av->getNumberOfComponents(); ++i) {
254 42 : if( (av->copyOutput(i))->forcesWereAdded() ) {
255 : hasforce=true;
256 : break;
257 : }
258 : }
259 : //Check if there are forces here
260 35 : ActionWithArguments* aaa=dynamic_cast<ActionWithArguments*>(a);
261 35 : if( aaa ) {
262 46 : for(const auto & v : aaa->getArguments() ) {
263 30 : if( v->forcesWereAdded() ) {
264 : hasforce=true;
265 : break;
266 : }
267 : }
268 : }
269 35 : if( !hasforce ) {
270 14 : continue;
271 : }
272 21 : ActionWithVector* avec=dynamic_cast<ActionWithVector*>(a);
273 21 : if( avec ) {
274 8 : ActionWithVector* head=avec->getFirstActionInChain();
275 : std::vector<std::string> mychain;
276 8 : head->getAllActionLabelsInChain( mychain );
277 8 : std::vector<bool> printed(mychain.size(),false);
278 16 : ofile.printf("subgraph sub%s [%s]\n",getLabel(head).c_str(),getLabel(head).c_str());
279 70 : for(unsigned i=0; i<mychain.size(); ++i) {
280 : bool drawn=false;
281 314 : for(unsigned j=0; j<drawn_nodes.size(); ++j ) {
282 294 : if( drawn_nodes[j]==mychain[i] ) {
283 : drawn=true;
284 : break;
285 : }
286 : }
287 62 : if( drawn ) {
288 42 : continue;
289 : }
290 20 : ActionWithVector* ag=p.getActionSet().selectWithLabel<ActionWithVector*>(mychain[i]);
291 20 : plumed_assert( ag );
292 20 : drawn_nodes.push_back( mychain[i] );
293 20 : if( !printed[i] ) {
294 16 : drawActionWithVectorNode( ofile, p, ag, mychain, printed );
295 : printed[i]=true;
296 : }
297 41 : for(const auto & v : ag->getArguments() ) {
298 : bool chain_conn=false;
299 109 : for(unsigned j=0; j<mychain.size(); ++j) {
300 105 : if( (v->getPntrToAction())->getLabel()==mychain[j] ) {
301 : chain_conn=true;
302 : break;
303 : }
304 : }
305 21 : if( !chain_conn ) {
306 4 : continue;
307 : }
308 34 : ofile.printf("%s -. %s .-> %s\n", getLabel(v->getPntrToAction()).c_str(),v->getName().c_str(),getLabel(ag).c_str() );
309 17 : printStyle( linkcount, v, ofile );
310 17 : linkcount++;
311 : }
312 : }
313 8 : ofile.printf("end\n");
314 8 : if( avec!=head ) {
315 70 : for(unsigned i=0; i<mychain.size(); ++i) {
316 62 : ActionWithVector* c = p.getActionSet().selectWithLabel<ActionWithVector*>( mychain[i] );
317 62 : plumed_assert(c);
318 62 : if( c->getNumberOfAtoms()>0 || c->hasStoredArguments() ) {
319 60 : for(unsigned j=0; j<avec->getNumberOfComponents(); ++j ) {
320 30 : if( avec->copyOutput(j)->getRank()>0 ) {
321 20 : continue;
322 : }
323 20 : ofile.printf("%s == %s ==> %s\n", getLabel(avec).c_str(), avec->copyOutput(j)->getName().c_str(), getLabel(c).c_str() );
324 10 : linkcount++;
325 : }
326 30 : if( c->getNumberOfAtoms()>0 ) {
327 16 : atom_force_set.insert( c->getLabel() );
328 : }
329 : }
330 : }
331 : }
332 8 : } else {
333 : // Print out the node if we have force on it
334 26 : ofile.printf("%s([\"label=%s \n %s \n\"])\n", getLabel(a).c_str(), getLabel(a,true).c_str(), a->writeInGraph().c_str() );
335 : }
336 : // Check where this force is being added
337 21 : printArgumentConnections( aaa, linkcount, true, ofile );
338 : }
339 : }
340 : // Now draw connections from action atomistic to relevant actions
341 4 : std::vector<ActionAtomistic*> all_atoms = p.getActionSet().select<ActionAtomistic*>();
342 33 : for(const auto & at : all_atoms ) {
343 29 : ActionWithValue* av=dynamic_cast<ActionWithValue*>(at);
344 : bool hasforce=false;
345 29 : if( av ) {
346 44 : for(unsigned i=0; i<av->getNumberOfComponents(); ++i ) {
347 26 : if( av->copyOutput(i)->forcesWereAdded() ) {
348 8 : printAtomConnections( at, linkcount, true, ofile );
349 8 : atom_force_set.erase( av->getLabel() );
350 : break;
351 : }
352 : }
353 : }
354 : }
355 9 : for(const auto & l : atom_force_set ) {
356 5 : ActionAtomistic* at = p.getActionSet().selectWithLabel<ActionAtomistic*>(l);
357 5 : plumed_assert(at);
358 5 : printAtomConnections( at, linkcount, true, ofile );
359 : }
360 4 : ofile.printf("MD(positions from MD)\n");
361 : return 0;
362 4 : }
363 :
364 4 : ofile.printf("flowchart TB \n");
365 4 : ofile.printf("MD(positions from MD)\n");
366 98 : for(const auto & aa : p.getActionSet() ) {
367 : Action* a(aa.get());
368 504 : if( a->getName()=="DOMAIN_DECOMPOSITION" || a->getLabel()=="posx" || a->getLabel()=="posy" || a->getLabel()=="posz" || a->getLabel()=="Masses" || a->getLabel()=="Charges" ) {
369 24 : continue;
370 : }
371 70 : ActionToPutData* ap=dynamic_cast<ActionToPutData*>(a);
372 70 : if( ap ) {
373 8 : ofile.printf("%s(\"label=%s \n %s \n\")\n", getLabel(a).c_str(), getLabel(a,true).c_str(), a->writeInGraph().c_str() );
374 4 : continue;
375 : }
376 66 : ActionShortcut* as=dynamic_cast<ActionShortcut*>(a);
377 66 : if( as ) {
378 24 : continue ;
379 : }
380 42 : ActionWithValue* av=dynamic_cast<ActionWithValue*>(a);
381 42 : ActionWithArguments* aaa=dynamic_cast<ActionWithArguments*>(a);
382 42 : ActionAtomistic* at=dynamic_cast<ActionAtomistic*>(a);
383 42 : ActionWithVector* avec=dynamic_cast<ActionWithVector*>(a);
384 : // Print out the connections between nodes
385 42 : printAtomConnections( at, linkcount, false, ofile );
386 42 : printArgumentConnections( aaa, linkcount, false, ofile );
387 : // Print out the nodes
388 42 : if( avec && !avec->actionInChain() ) {
389 6 : ofile.printf("subgraph sub%s [%s]\n",getLabel(a).c_str(),getLabel(a).c_str());
390 : std::vector<std::string> mychain;
391 3 : avec->getAllActionLabelsInChain( mychain );
392 3 : std::vector<bool> printed(mychain.size(),false);
393 21 : for(unsigned i=0; i<mychain.size(); ++i) {
394 18 : Action* ag=p.getActionSet().selectWithLabel<Action*>(mychain[i]);
395 18 : if( !printed[i] ) {
396 14 : drawActionWithVectorNode( ofile, p, ag, mychain, printed );
397 : printed[i]=true;
398 : }
399 : }
400 3 : ofile.printf("end\n");
401 42 : } else if( !av ) {
402 22 : ofile.printf("%s(\"label=%s \n %s \n\")\n", getLabel(a).c_str(), getLabel(a,true).c_str(), a->writeInGraph().c_str() );
403 28 : } else if( !avec ) {
404 26 : ofile.printf("%s([\"label=%s \n %s \n\"])\n", getLabel(a).c_str(), getLabel(a,true).c_str(), a->writeInGraph().c_str() );
405 : }
406 : }
407 4 : ofile.close();
408 :
409 : return 0;
410 8 : }
411 :
412 : } // End of namespace
413 : }
|