LCOV - code coverage report
Current view: top level - cltools - ShowGraph.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 217 218 99.5 %
Date: 2025-11-25 13:55:50 Functions: 13 13 100.0 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2012-2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "CLTool.h"
      23             : #include "core/CLToolRegister.h"
      24             : #include "tools/Tools.h"
      25             : #include "config/Config.h"
      26             : #include "core/PlumedMain.h"
      27             : #include "core/ActionSet.h"
      28             : #include "core/ActionRegister.h"
      29             : #include "core/ActionShortcut.h"
      30             : #include "core/ActionToPutData.h"
      31             : #include "core/ActionWithVirtualAtom.h"
      32             : #include "core/ActionWithVector.h"
      33             : #include <cstdio>
      34             : #include <string>
      35             : #include <iostream>
      36             : 
      37             : namespace PLMD {
      38             : namespace cltools {
      39             : 
      40             : //+PLUMEDOC TOOLS show_graph
      41             : /*
      42             : show_graph is a tool that takes a plumed input and generates a graph showing how
      43             : data flows through the action set involved.
      44             : 
      45             : If this tool is invoked without the --force keyword then the way data is passed through the code during the forward pass
      46             : through the action is shown.
      47             : 
      48             : When the --force keyword is used then the way forces are passed from biases through actions is shown.
      49             : 
      50             : \par Examples
      51             : 
      52             : The following generates the mermaid file for the input in plumed.dat
      53             : \verbatim
      54             : plumed show_graph --plumed plumed.dat
      55             : \endverbatim
      56             : 
      57             : */
      58             : //+ENDPLUMEDOC
      59             : 
      60             : class ShowGraph :
      61             :   public CLTool {
      62             : public:
      63             :   static void registerKeywords( Keywords& keys );
      64             :   explicit ShowGraph(const CLToolOptions& co );
      65             :   int main(FILE* in, FILE*out,Communicator& pc);
      66           4 :   std::string description()const {
      67           4 :     return "generate a graph showing how data flows through a PLUMED action set";
      68             :   }
      69             :   std::string getLabel(const Action* a, const bool& amp=false);
      70             :   std::string getLabel(const std::string& s, const bool& amp=false );
      71             :   void printStyle( const unsigned& linkcount, const Value* v, OFile& ofile );
      72             :   void printArgumentConnections( const ActionWithArguments* a, unsigned& linkcount, const bool& force, OFile& ofile );
      73             :   void printAtomConnections( const ActionAtomistic* a, unsigned& linkcount, const bool& force, OFile& ofile );
      74             :   void drawActionWithVectorNode( OFile& ofile, PlumedMain& p, Action* ag, const std::vector<std::string>& mychain, std::vector<bool>& printed );
      75             : };
      76             : 
      77       16338 : PLUMED_REGISTER_CLTOOL(ShowGraph,"show_graph")
      78             : 
      79        5442 : void ShowGraph::registerKeywords( Keywords& keys ) {
      80        5442 :   CLTool::registerKeywords( keys );
      81       10884 :   keys.add("compulsory","--plumed","plumed.dat","the plumed input that we are generating the graph for");
      82       10884 :   keys.add("compulsory","--out","graph.md","the dot file containing the graph that has been generated");
      83       10884 :   keys.addFlag("--force",false,"print a graph that shows how forces are passed through the actions");
      84        5442 : }
      85             : 
      86          12 : ShowGraph::ShowGraph(const CLToolOptions& co ):
      87          12 :   CLTool(co) {
      88          12 :   inputdata=commandline;
      89          12 : }
      90             : 
      91         377 : std::string ShowGraph::getLabel(const Action* a, const bool& amp) {
      92         377 :   return getLabel( a->getLabel(), amp );
      93             : }
      94             : 
      95         453 : std::string ShowGraph::getLabel( const std::string& s, const bool& amp ) {
      96         453 :   if( s.find("@")==std::string::npos ) {
      97         405 :     return s;
      98             :   }
      99          48 :   std::size_t p=s.find_first_of("@");
     100          48 :   if( amp ) {
     101          30 :     return "#64;" + s.substr(p+1);
     102             :   }
     103          33 :   return s.substr(p+1);
     104             : }
     105             : 
     106          85 : void ShowGraph::printStyle( const unsigned& linkcount, const Value* v, OFile& ofile ) {
     107          85 :   if( v->getRank()>0 && v->hasDerivatives() ) {
     108           0 :     ofile.printf("linkStyle %d stroke:green,color:green;\n", linkcount);
     109          85 :   } else if( v->getRank()==1 ) {
     110          33 :     ofile.printf("linkStyle %d stroke:blue,color:blue;\n", linkcount);
     111          52 :   } else if ( v->getRank()==2 ) {
     112          30 :     ofile.printf("linkStyle %d stroke:red,color:red;\n", linkcount);
     113             :   }
     114          85 : }
     115             : 
     116          63 : void ShowGraph::printArgumentConnections( const ActionWithArguments* a, unsigned& linkcount, const bool& force, OFile& ofile ) {
     117          63 :   if( !a ) {
     118             :     return;
     119             :   }
     120         101 :   for(const auto & v : a->getArguments() ) {
     121          55 :     if( force && v->forcesWereAdded() ) {
     122          28 :       ofile.printf("%s -- %s --> %s\n", getLabel(a).c_str(), v->getName().c_str(), getLabel(v->getPntrToAction()).c_str() );
     123          14 :       printStyle( linkcount, v, ofile );
     124          14 :       linkcount++;
     125          41 :     } else if( !force ) {
     126          66 :       ofile.printf("%s -- %s --> %s\n", getLabel(v->getPntrToAction()).c_str(),v->getName().c_str(),getLabel(a).c_str() );
     127          33 :       printStyle( linkcount, v, ofile );
     128          33 :       linkcount++;
     129             :     }
     130             :   }
     131             : }
     132             : 
     133          55 : void ShowGraph::printAtomConnections( const ActionAtomistic* a, unsigned& linkcount, const bool& force, OFile& ofile ) {
     134          55 :   if( !a ) {
     135             :     return;
     136             :   }
     137         179 :   for(const auto & d : a->getDependencies() ) {
     138         138 :     ActionToPutData* dp=dynamic_cast<ActionToPutData*>(d);
     139         138 :     if( dp && dp->getLabel()=="posx" ) {
     140          18 :       if( force && (dp->copyOutput(0))->forcesWereAdded() ) {
     141           8 :         ofile.printf("%s --> MD\n", getLabel(a).c_str() );
     142           8 :         ofile.printf("linkStyle %d stroke:violet,color:violet;\n", linkcount);
     143           8 :         linkcount++;
     144             :       } else {
     145          10 :         ofile.printf("MD --> %s\n", getLabel(a).c_str() );
     146          10 :         ofile.printf("linkStyle %d stroke:violet,color:violet;\n", linkcount);
     147          10 :         linkcount++;
     148             :       }
     149         120 :     } else if( dp && dp->getLabel()!="posy" && dp->getLabel()!="posz" && dp->getLabel()!="Masses" && dp->getLabel()!="Charges" ) {
     150          21 :       if( force && (dp->copyOutput(0))->forcesWereAdded() ) {
     151          18 :         ofile.printf("%s -- %s --> %s\n",getLabel(a).c_str(), getLabel(d).c_str(), getLabel(d).c_str() );
     152           9 :         printStyle( linkcount, dp->copyOutput(0), ofile );
     153           9 :         linkcount++;
     154             :       } else {
     155          24 :         ofile.printf("%s -- %s --> %s\n", getLabel(d).c_str(),getLabel(d).c_str(),getLabel(a).c_str() );
     156          12 :         printStyle( linkcount, dp->copyOutput(0), ofile );
     157          12 :         linkcount++;
     158             :       }
     159          21 :       continue;
     160             :     }
     161         117 :     ActionWithVirtualAtom* dv=dynamic_cast<ActionWithVirtualAtom*>(d);
     162         117 :     if( dv ) {
     163           4 :       if( force && (dv->copyOutput(0))->forcesWereAdded() ) {
     164           2 :         ofile.printf("%s -- %s --> %s\n", getLabel(a).c_str(),getLabel(d).c_str(),getLabel(d).c_str() );
     165           1 :         ofile.printf("linkStyle %d stroke:violet,color:violet;\n", linkcount);
     166           1 :         linkcount++;
     167             :       } else {
     168           6 :         ofile.printf("%s -- %s --> %s\n", getLabel(d).c_str(),getLabel(d).c_str(),getLabel(a).c_str() );
     169           3 :         ofile.printf("linkStyle %d stroke:violet,color:violet;\n", linkcount);
     170           3 :         linkcount++;
     171             :       }
     172             :     }
     173             :   }
     174             : }
     175             : 
     176          30 : void ShowGraph::drawActionWithVectorNode( OFile& ofile, PlumedMain& p, Action* ag, const std::vector<std::string>& mychain, std::vector<bool>& printed ) {
     177          30 :   ActionWithVector* agg=dynamic_cast<ActionWithVector*>(ag);
     178             :   std::vector<std::string> matchain;
     179          30 :   agg->getAllActionLabelsInMatrixChain( matchain );
     180          30 :   if( matchain.size()>0 ) {
     181          16 :     ofile.printf("subgraph sub%s_mat [%s]\n",getLabel(agg).c_str(), getLabel(agg).c_str());
     182          24 :     for(unsigned j=0; j<matchain.size(); ++j ) {
     183          16 :       Action* agm=p.getActionSet().selectWithLabel<Action*>(matchain[j]);
     184          60 :       for(unsigned k=0; k<mychain.size(); ++k ) {
     185          60 :         if( mychain[k]==matchain[j] ) {
     186             :           printed[k]=true;
     187          16 :           break;
     188             :         }
     189             :       }
     190          32 :       ofile.printf("%s([\"label=%s \n %s \n\"])\n", getLabel(matchain[j]).c_str(), getLabel(matchain[j],true).c_str(), agm->writeInGraph().c_str() );
     191             :     }
     192           8 :     ofile.printf("end\n");
     193          16 :     ofile.printf("style sub%s_mat fill:lightblue\n",getLabel(ag).c_str());
     194             :   } else {
     195          44 :     ofile.printf("%s([\"label=%s \n %s \n\"])\n", getLabel(ag->getLabel()).c_str(), getLabel(ag->getLabel(),true).c_str(), ag->writeInGraph().c_str() );
     196             :   }
     197          30 : }
     198             : 
     199           8 : int ShowGraph::main(FILE* in, FILE*out,Communicator& pc) {
     200             : 
     201             :   std::string inpt;
     202          16 :   parse("--plumed",inpt);
     203             :   std::string outp;
     204           8 :   parse("--out",outp);
     205             :   bool forces;
     206           8 :   parseFlag("--force",forces);
     207             : 
     208             :   // Create a plumed main object and initilize
     209           8 :   PlumedMain p;
     210           8 :   int rr=sizeof(double);
     211           8 :   p.cmd("setRealPrecision",&rr);
     212           8 :   double lunit=1.0;
     213           8 :   p.cmd("setMDLengthUnits",&lunit);
     214           8 :   double cunit=1.0;
     215           8 :   p.cmd("setMDChargeUnits",&cunit);
     216           8 :   double munit=1.0;
     217           8 :   p.cmd("setMDMassUnits",&munit);
     218           8 :   p.cmd("setPlumedDat",inpt.c_str());
     219           8 :   p.cmd("setLog",out);
     220           8 :   int natoms=1000000;
     221           8 :   p.cmd("setNatoms",&natoms);
     222           8 :   p.cmd("init");
     223             : 
     224           8 :   unsigned linkcount=0;
     225           8 :   OFile ofile;
     226           8 :   ofile.open(outp);
     227           8 :   if( forces ) {
     228             :     unsigned step=1;
     229           4 :     p.cmd("setStep",step);
     230           4 :     p.cmd("prepareCalc");
     231           4 :     ofile.printf("flowchart BT \n");
     232             :     std::vector<std::string> drawn_nodes;
     233             :     std::set<std::string> atom_force_set;
     234         103 :     for(auto pp=p.getActionSet().rbegin(); pp!=p.getActionSet().rend(); ++pp) {
     235             :       const auto & a(pp->get());
     236         534 :       if( a->getName()=="DOMAIN_DECOMPOSITION" || a->getLabel()=="posx" || a->getLabel()=="posy" || a->getLabel()=="posz" || a->getLabel()=="Masses" || a->getLabel()=="Charges" ) {
     237          24 :         continue;
     238             :       }
     239             : 
     240          75 :       if(a->isActive()) {
     241          44 :         ActionToPutData* ap=dynamic_cast<ActionToPutData*>(a);
     242          44 :         if( ap ) {
     243           8 :           ofile.printf("%s(\"label=%s \n %s \n\")\n", getLabel(a).c_str(), getLabel(a,true).c_str(), a->writeInGraph().c_str() );
     244           4 :           continue;
     245             :         }
     246          40 :         ActionWithValue* av=dynamic_cast<ActionWithValue*>(a);
     247          40 :         if( !av ) {
     248           5 :           continue ;
     249             :         }
     250             :         // Now apply the force if there is one
     251          35 :         a->apply();
     252             :         bool hasforce=false;
     253          67 :         for(int i=0; i<av->getNumberOfComponents(); ++i) {
     254          42 :           if( (av->copyOutput(i))->forcesWereAdded() ) {
     255             :             hasforce=true;
     256             :             break;
     257             :           }
     258             :         }
     259             :         //Check if there are forces here
     260          35 :         ActionWithArguments* aaa=dynamic_cast<ActionWithArguments*>(a);
     261          35 :         if( aaa ) {
     262          46 :           for(const auto & v : aaa->getArguments() ) {
     263          30 :             if( v->forcesWereAdded() ) {
     264             :               hasforce=true;
     265             :               break;
     266             :             }
     267             :           }
     268             :         }
     269          35 :         if( !hasforce ) {
     270          14 :           continue;
     271             :         }
     272          21 :         ActionWithVector* avec=dynamic_cast<ActionWithVector*>(a);
     273          21 :         if( avec ) {
     274           8 :           ActionWithVector* head=avec->getFirstActionInChain();
     275             :           std::vector<std::string> mychain;
     276           8 :           head->getAllActionLabelsInChain( mychain );
     277           8 :           std::vector<bool> printed(mychain.size(),false);
     278          16 :           ofile.printf("subgraph sub%s [%s]\n",getLabel(head).c_str(),getLabel(head).c_str());
     279          70 :           for(unsigned i=0; i<mychain.size(); ++i) {
     280             :             bool drawn=false;
     281         314 :             for(unsigned j=0; j<drawn_nodes.size(); ++j ) {
     282         294 :               if( drawn_nodes[j]==mychain[i] ) {
     283             :                 drawn=true;
     284             :                 break;
     285             :               }
     286             :             }
     287          62 :             if( drawn ) {
     288          42 :               continue;
     289             :             }
     290          20 :             ActionWithVector* ag=p.getActionSet().selectWithLabel<ActionWithVector*>(mychain[i]);
     291          20 :             plumed_assert( ag );
     292          20 :             drawn_nodes.push_back( mychain[i] );
     293          20 :             if( !printed[i] ) {
     294          16 :               drawActionWithVectorNode( ofile, p, ag, mychain, printed );
     295             :               printed[i]=true;
     296             :             }
     297          41 :             for(const auto & v : ag->getArguments() ) {
     298             :               bool chain_conn=false;
     299         109 :               for(unsigned j=0; j<mychain.size(); ++j) {
     300         105 :                 if( (v->getPntrToAction())->getLabel()==mychain[j] ) {
     301             :                   chain_conn=true;
     302             :                   break;
     303             :                 }
     304             :               }
     305          21 :               if( !chain_conn ) {
     306           4 :                 continue;
     307             :               }
     308          34 :               ofile.printf("%s -. %s .-> %s\n", getLabel(v->getPntrToAction()).c_str(),v->getName().c_str(),getLabel(ag).c_str() );
     309          17 :               printStyle( linkcount, v, ofile );
     310          17 :               linkcount++;
     311             :             }
     312             :           }
     313           8 :           ofile.printf("end\n");
     314           8 :           if( avec!=head ) {
     315          70 :             for(unsigned i=0; i<mychain.size(); ++i) {
     316          62 :               ActionWithVector* c = p.getActionSet().selectWithLabel<ActionWithVector*>( mychain[i] );
     317          62 :               plumed_assert(c);
     318          62 :               if( c->getNumberOfAtoms()>0 || c->hasStoredArguments() ) {
     319          60 :                 for(unsigned j=0; j<avec->getNumberOfComponents(); ++j ) {
     320          30 :                   if( avec->copyOutput(j)->getRank()>0 ) {
     321          20 :                     continue;
     322             :                   }
     323          20 :                   ofile.printf("%s == %s ==> %s\n", getLabel(avec).c_str(), avec->copyOutput(j)->getName().c_str(), getLabel(c).c_str() );
     324          10 :                   linkcount++;
     325             :                 }
     326          30 :                 if( c->getNumberOfAtoms()>0 ) {
     327          16 :                   atom_force_set.insert( c->getLabel() );
     328             :                 }
     329             :               }
     330             :             }
     331             :           }
     332           8 :         } else {
     333             :           // Print out the node if we have force on it
     334          26 :           ofile.printf("%s([\"label=%s \n %s \n\"])\n", getLabel(a).c_str(), getLabel(a,true).c_str(), a->writeInGraph().c_str() );
     335             :         }
     336             :         // Check where this force is being added
     337          21 :         printArgumentConnections( aaa, linkcount, true, ofile );
     338             :       }
     339             :     }
     340             :     // Now draw connections from action atomistic to relevant actions
     341           4 :     std::vector<ActionAtomistic*> all_atoms = p.getActionSet().select<ActionAtomistic*>();
     342          33 :     for(const auto & at : all_atoms ) {
     343          29 :       ActionWithValue* av=dynamic_cast<ActionWithValue*>(at);
     344             :       bool hasforce=false;
     345          29 :       if( av ) {
     346          44 :         for(unsigned i=0; i<av->getNumberOfComponents(); ++i ) {
     347          26 :           if( av->copyOutput(i)->forcesWereAdded() ) {
     348           8 :             printAtomConnections( at, linkcount, true, ofile );
     349           8 :             atom_force_set.erase( av->getLabel() );
     350             :             break;
     351             :           }
     352             :         }
     353             :       }
     354             :     }
     355           9 :     for(const auto & l : atom_force_set ) {
     356           5 :       ActionAtomistic* at = p.getActionSet().selectWithLabel<ActionAtomistic*>(l);
     357           5 :       plumed_assert(at);
     358           5 :       printAtomConnections( at, linkcount, true, ofile );
     359             :     }
     360           4 :     ofile.printf("MD(positions from MD)\n");
     361             :     return 0;
     362           4 :   }
     363             : 
     364           4 :   ofile.printf("flowchart TB \n");
     365           4 :   ofile.printf("MD(positions from MD)\n");
     366          98 :   for(const auto & aa : p.getActionSet() ) {
     367             :     Action* a(aa.get());
     368         504 :     if( a->getName()=="DOMAIN_DECOMPOSITION" || a->getLabel()=="posx" || a->getLabel()=="posy" || a->getLabel()=="posz" || a->getLabel()=="Masses" || a->getLabel()=="Charges" ) {
     369          24 :       continue;
     370             :     }
     371          70 :     ActionToPutData* ap=dynamic_cast<ActionToPutData*>(a);
     372          70 :     if( ap ) {
     373           8 :       ofile.printf("%s(\"label=%s \n %s \n\")\n", getLabel(a).c_str(), getLabel(a,true).c_str(), a->writeInGraph().c_str() );
     374           4 :       continue;
     375             :     }
     376          66 :     ActionShortcut* as=dynamic_cast<ActionShortcut*>(a);
     377          66 :     if( as ) {
     378          24 :       continue ;
     379             :     }
     380          42 :     ActionWithValue* av=dynamic_cast<ActionWithValue*>(a);
     381          42 :     ActionWithArguments* aaa=dynamic_cast<ActionWithArguments*>(a);
     382          42 :     ActionAtomistic* at=dynamic_cast<ActionAtomistic*>(a);
     383          42 :     ActionWithVector* avec=dynamic_cast<ActionWithVector*>(a);
     384             :     // Print out the connections between nodes
     385          42 :     printAtomConnections( at, linkcount, false, ofile );
     386          42 :     printArgumentConnections( aaa, linkcount, false, ofile );
     387             :     // Print out the nodes
     388          42 :     if( avec && !avec->actionInChain() ) {
     389           6 :       ofile.printf("subgraph sub%s [%s]\n",getLabel(a).c_str(),getLabel(a).c_str());
     390             :       std::vector<std::string> mychain;
     391           3 :       avec->getAllActionLabelsInChain( mychain );
     392           3 :       std::vector<bool> printed(mychain.size(),false);
     393          21 :       for(unsigned i=0; i<mychain.size(); ++i) {
     394          18 :         Action* ag=p.getActionSet().selectWithLabel<Action*>(mychain[i]);
     395          18 :         if( !printed[i] ) {
     396          14 :           drawActionWithVectorNode( ofile, p, ag, mychain, printed );
     397             :           printed[i]=true;
     398             :         }
     399             :       }
     400           3 :       ofile.printf("end\n");
     401          42 :     } else if( !av ) {
     402          22 :       ofile.printf("%s(\"label=%s \n %s \n\")\n", getLabel(a).c_str(), getLabel(a,true).c_str(), a->writeInGraph().c_str() );
     403          28 :     } else if( !avec ) {
     404          26 :       ofile.printf("%s([\"label=%s \n %s \n\"])\n", getLabel(a).c_str(), getLabel(a,true).c_str(), a->writeInGraph().c_str() );
     405             :     }
     406             :   }
     407           4 :   ofile.close();
     408             : 
     409             :   return 0;
     410           8 : }
     411             : 
     412             : } // End of namespace
     413             : }

Generated by: LCOV version 1.16