LCOV - code coverage report
Current view: top level - core - ActionWithVector.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 144 163 88.3 %
Date: 2025-12-04 11:19:34 Functions: 14 16 87.5 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2011-2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "ActionWithVector.h"
      23             : #include "ActionWithMatrix.h"
      24             : #include "PlumedMain.h"
      25             : #include "ActionSet.h"
      26             : #include "tools/OpenMP.h"
      27             : #include "tools/Communicator.h"
      28             : 
      29             : namespace PLMD {
      30             : 
      31        4097 : void ActionWithVector::registerKeywords( Keywords& keys ) {
      32        4097 :   Action::registerKeywords( keys );
      33        4097 :   ActionAtomistic::registerKeywords( keys );
      34        4097 :   ActionWithValue::registerKeywords( keys );
      35        4097 :   keys.remove("NUMERICAL_DERIVATIVES");
      36        4097 :   ActionWithArguments::registerKeywords( keys );
      37        4097 : }
      38             : 
      39        3268 : ActionWithVector::ActionWithVector(const ActionOptions&ao):
      40             :   Action(ao),
      41             :   ActionAtomistic(ao),
      42             :   ActionWithValue(ao),
      43        3268 :   ActionWithArguments(ao) {
      44        3268 :   if( !keywords.exists("MASKED_INPUT_ALLOWED") ) {
      45        3407 :     for(unsigned i=0; i<getNumberOfArguments(); ++i) {
      46        2006 :       ActionWithVector* av = dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
      47        2006 :       if( av && av->getNumberOfMasks()>=0 ) {
      48          57 :         nmask=0;
      49             :       }
      50             :     }
      51             :   }
      52             : 
      53        3268 :   if( keywords.exists("MASK") ) {
      54             :     std::vector<Value*> mask;
      55        4930 :     parseArgumentList("MASK",mask);
      56        2465 :     if( mask.size()>0 ) {
      57         197 :       if( nmask>=0 && getNumberOfArguments()==1 ) {
      58          45 :         ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(0)->getPntrToAction() );
      59          45 :         plumed_massert( av, "input should be a vector from ActionWithVector" );
      60          45 :         unsigned j=0, nargs = av->getNumberOfArguments();
      61          92 :         for(unsigned i=nargs-av->nmask; i<nargs; ++i) {
      62          47 :           if( av->getPntrToArgument(i)!=mask[j] ) {
      63           0 :             error("the masks in subsequent actions do not match");
      64             :           }
      65          47 :           j++;
      66             :         }
      67             :       }
      68         197 :       if( getNumberOfArguments()>0 && getName().find("EVALUATE_FUNCTION_FROM_GRID")==std::string::npos && getPntrToArgument(0)->hasDerivatives() ) {
      69           0 :         error("input for mask should be vector or matrix");
      70         197 :       } else if( mask[0]->getRank()==2 ) {
      71          86 :         if( mask.size()>1 ) {
      72           0 :           error("MASK should only have one argument");
      73             :         }
      74          86 :         log.printf("  only computing elements of matrix that correspond to non-zero elements of matrix %s \n", mask[0]->getName().c_str() );
      75         111 :       } else if( mask[0]->getRank()==1 ) {
      76         111 :         log.printf("  only computing elements of vector that correspond to non-zero elements of vectors %s", mask[0]->getName().c_str() );
      77         115 :         for(unsigned i=1; i<mask.size(); ++i) {
      78           4 :           if( mask[i]->getRank()!=1 ) {
      79           0 :             log.printf("\n");
      80           0 :             error("input to mask should be vector");
      81             :           }
      82           4 :           log.printf(", %s", mask[i]->getName().c_str() );
      83             :         }
      84         111 :         log.printf("\n");
      85             :       }
      86         197 :       std::vector<Value*> allargs( getArguments() );
      87         197 :       nmask=mask.size();
      88         398 :       for(unsigned i=0; i<mask.size(); ++i) {
      89         201 :         allargs.push_back( mask[i] );
      90             :       }
      91         197 :       requestArguments( allargs );
      92             :     }
      93             :   }
      94        3268 : }
      95             : 
      96      232110 : void ActionWithVector::lockRequests() {
      97             :   ActionAtomistic::lockRequests();
      98             :   ActionWithArguments::lockRequests();
      99      232110 : }
     100             : 
     101      232110 : void ActionWithVector::unlockRequests() {
     102             :   ActionAtomistic::unlockRequests();
     103             :   ActionWithArguments::unlockRequests();
     104      232110 : }
     105             : 
     106           0 : void ActionWithVector::calculateNumericalDerivatives(ActionWithValue* av) {
     107           0 :   plumed_merror("cannot calculate numerical derivative for action " + getName() + " with label " + getLabel() );
     108             : }
     109             : 
     110      233257 : void ActionWithVector::prepare() {
     111      233257 :   active_tasks.resize(0);
     112      233257 : }
     113             : 
     114    17161021 : int ActionWithVector::checkTaskIsActive( const unsigned& itask ) const {
     115    17161021 :   unsigned nargs = getNumberOfArguments();
     116    17161021 :   if( nargs==0 ) {
     117             :     return 1;
     118    14333993 :   } else if( nmask>0 ) {
     119     6160405 :     for(unsigned j=nargs-nmask; j<nargs; ++j) {
     120             :       Value* myarg = getPntrToArgument(j);
     121     3247814 :       if( myarg->getRank()==1 && !myarg->hasDerivatives() ) {
     122     2985644 :         if( fabs(myarg->get(itask))>0.0 ) {
     123             :           return 1;
     124             :         }
     125      262170 :       } else if( myarg->getRank()==2 && !myarg->hasDerivatives() ) {
     126      262170 :         if( myarg->getRowLength(itask)>0 ) {
     127             :           return 1;
     128             :         }
     129             :       } else {
     130           0 :         plumed_merror("only matrices and vectors should be used as masks");
     131             :       }
     132             :     }
     133             :   } else {
     134    12692349 :     for(unsigned i=0; i<nargs; ++i) {
     135             :       Value* myarg = getPntrToArgument(i);
     136    12425337 :       if( !myarg->isDerivativeZeroWhenValueIsZero() ) {
     137             :         return 1;
     138             :       }
     139             : 
     140     7726431 :       if( myarg->getRank()==0 ) {
     141             :         return 1;
     142     7726431 :       } else if( myarg->getRank()==1 && !myarg->hasDerivatives() ) {
     143     6971650 :         if( fabs(myarg->get(itask))>0.0 ) {
     144             :           return 1;
     145             :         }
     146      754781 :       } else if( myarg->getRank()==2 && !myarg->hasDerivatives() ) {
     147        2877 :         const unsigned ncol = myarg->getRowLength(itask);
     148        2877 :         const unsigned base = itask*myarg->getNumberOfColumns();
     149        2919 :         for(unsigned k=0; k<ncol; ++k) {
     150        2919 :           if( fabs(myarg->get(base+k,false))>0.0 ) {
     151             :             return 1;
     152             :           }
     153             :         }
     154             :       } else if( myarg->getRank()>0 ) {
     155             :         return 1;
     156             :       } else {
     157             :         plumed_merror("should not be in action " + getName() );
     158             :       }
     159             :     }
     160             :   }
     161             :   return -1;
     162             : }
     163             : 
     164      404844 : std::vector<unsigned>& ActionWithVector::getListOfActiveTasks( ActionWithVector* action ) {
     165      404844 :   if( active_tasks.size()>0 ) {
     166      195598 :     return active_tasks;
     167             :   }
     168      209246 :   unsigned ntasks=0;
     169      209246 :   getNumberOfTasks( ntasks );
     170             : 
     171      209246 :   active_tasks.resize(0);
     172      209246 :   active_tasks.reserve(ntasks);
     173    20290708 :   for(unsigned i=0; i<ntasks; ++i) {
     174    20081462 :     if( checkTaskIsActive(i)>0 ) {
     175             : //no resize are triggered, since we have reserved the number of tasks
     176    15695071 :       active_tasks.push_back(i);
     177             :     }
     178             :   }
     179      209246 :   return active_tasks;
     180             : }
     181             : 
     182       41265 : void ActionWithVector::getInputData( std::vector<double>& inputdata ) const {
     183             :   plumed_dbg_assert( getNumberOfAtoms()==0 );
     184       41265 :   unsigned nargs = getNumberOfArguments();
     185       41265 :   unsigned nmasks=getNumberOfMasks();
     186             :   // getNumberOfMasks(); returns nmask, that it is an int
     187             :   // nmasks cant be <0 (it is unsigned), so I check nmask for that
     188       41265 :   if( nargs>=nmasks && nmask>0 ) {
     189         118 :     nargs = nargs - nmasks;
     190             :   }
     191             : 
     192             :   std::size_t total_args = 0;
     193      127378 :   for(unsigned i=0; i<nargs; ++i) {
     194       86113 :     total_args += getPntrToArgument(i)->getNumberOfStoredValues();
     195             :   }
     196             : 
     197       41265 :   if( inputdata.size()!=total_args ) {
     198         884 :     inputdata.resize( total_args );
     199             :   }
     200             : 
     201             :   total_args = 0;
     202      127378 :   for(unsigned i=0; i<nargs; ++i) {
     203             :     Value* myarg = getPntrToArgument(i);
     204       86113 :     total_args+= myarg->assignValues(View{&inputdata[total_args],inputdata.size()-total_args});
     205             :   }
     206       41265 : }
     207             : 
     208      191930 : void ActionWithVector::transferStashToValues( const std::vector<unsigned>& partialTaskList, const std::vector<double>& stash ) {
     209      191930 :   unsigned ntask = partialTaskList.size();
     210      191930 :   unsigned ncomponents = getNumberOfComponents();
     211      409527 :   for(unsigned i=0; i<ncomponents; ++i) {
     212      217597 :     Value* myval = copyOutput(i);
     213   442468770 :     for(unsigned j=0; j<ntask; ++j) {
     214   442251173 :       myval->set( partialTaskList[j], stash[partialTaskList[j]*ncomponents+i] );
     215             :     }
     216             :   }
     217      191930 : }
     218             : 
     219      173703 : void ActionWithVector::transferForcesToStash( const std::vector<unsigned>& partialTaskList, std::vector<double>& stash ) const {
     220      173703 :   unsigned ntask = partialTaskList.size();
     221      173703 :   unsigned ncomponents = getNumberOfComponents();
     222      363505 :   for(unsigned i=0; i<ncomponents; ++i) {
     223      189802 :     auto myval = getConstPntrToComponent(i);
     224    85934848 :     for(unsigned j=0; j<ntask; ++j) {
     225    85745046 :       stash[partialTaskList[j]*ncomponents+i] = myval->getForce( partialTaskList[j] );
     226             :     }
     227             :   }
     228      173703 : }
     229             : 
     230      239890 : void ActionWithVector::getNumberOfTasks( unsigned& ntasks ) {
     231      239890 :   if( ntasks==0 ) {
     232      239890 :     if( getNumberOfArguments()==1 && getNumberOfComponents()==1 && getPntrToComponent(0)->getRank()==0 ) {
     233           0 :       if( !getPntrToArgument(0)->hasDerivatives() && getPntrToArgument(0)->getRank()==2 ) {
     234           0 :         ntasks = getPntrToArgument(0)->getShape()[0];
     235             :       } else {
     236           0 :         ntasks = getPntrToArgument(0)->getNumberOfValues();
     237             :       }
     238             :     } else {
     239      239890 :       plumed_assert( getNumberOfComponents()>0 && getPntrToComponent(0)->getRank()>0 );
     240      239890 :       if( getPntrToComponent(0)->hasDerivatives() ) {
     241        9173 :         ntasks = getPntrToComponent(0)->getNumberOfValues();
     242             :       } else {
     243      230717 :         ntasks = getPntrToComponent(0)->getShape()[0];
     244             :       }
     245             :     }
     246             :   }
     247      509546 :   for(unsigned i=0; i<getNumberOfComponents(); ++i) {
     248      269656 :     if( getPntrToComponent(i)->getRank()==0 ) {
     249           0 :       if( getNumberOfArguments()!=1 ) {
     250           0 :         error("mismatched numbers of tasks in streamed quantities");
     251             :       }
     252           0 :       if( getPntrToArgument(0)->hasDerivatives() && ntasks!=getPntrToArgument(0)->getNumberOfValues() ) {
     253           0 :         error("mismatched numbers of tasks in streamed quantities");
     254           0 :       } else if ( !getPntrToArgument(0)->hasDerivatives() && ntasks!=getPntrToArgument(0)->getShape()[0] ) {
     255           0 :         error("mismatched numbers of tasks in streamed quantities");
     256             :       }
     257      269656 :     } else if( getPntrToComponent(i)->hasDerivatives() && ntasks!=getPntrToComponent(i)->getNumberOfValues() ) {
     258           0 :       error("mismatched numbers of tasks in streamed quantities");
     259      269656 :     } else if( !getPntrToComponent(i)->hasDerivatives() && ntasks!=getPntrToComponent(i)->getShape()[0] ) {
     260           0 :       error("mismatched numbers of tasks in streamed quantities");
     261             :     }
     262             :   }
     263      239890 : }
     264             : 
     265      254397 : unsigned ActionWithVector::getNumberOfForceDerivatives() const {
     266             :   unsigned nforces=0;
     267      254397 :   unsigned nargs = getNumberOfArguments();
     268      254397 :   unsigned  nmasks = getNumberOfMasks();
     269      254397 :   if( nargs>=nmasks && nmasks>0 ) {
     270        1965 :     nargs = nargs - nmasks;
     271             :   }
     272      254397 :   if( getNumberOfAtoms()>0 ) {
     273       53303 :     nforces += 3*getNumberOfAtoms() + 9;
     274             :   }
     275      600226 :   for(unsigned i=0; i<nargs; ++i) {
     276      345829 :     nforces += getPntrToArgument(i)->getNumberOfStoredValues();
     277             :   }
     278      254397 :   return nforces;
     279             : }
     280             : 
     281      217788 : bool ActionWithVector::checkForForces() {
     282      217788 :   if( getPntrToComponent(0)->getRank()==0 ) {
     283        6492 :     return ActionWithValue::checkForForces();
     284             :   }
     285             : 
     286             :   // Check if there are any forces
     287             :   bool hasforce=false;
     288      247247 :   for(unsigned i=0; i<getNumberOfComponents(); ++i) {
     289      222372 :     if( getConstPntrToComponent(i)->getRank()>0 && getConstPntrToComponent(i)->forcesWereAdded() ) {
     290             :       hasforce=true;
     291             :       break;
     292             :     }
     293             :   }
     294      211296 :   if( !hasforce ) {
     295             :     return false;
     296             :   }
     297      186421 :   applyNonZeroRankForces( forcesForApply );
     298      186421 :   return true;
     299             : }
     300             : 
     301      217788 : void ActionWithVector::apply() {
     302      217788 :   unsigned nf =  getNumberOfForceDerivatives();
     303      217788 :   if( forcesForApply.size()!=nf ) {
     304        3052 :     forcesForApply.resize( nf, 0 );
     305             :   }
     306             : 
     307      217788 :   if( !checkForForces() ) {
     308       26594 :     return;
     309             :   }
     310             :   // Find the top of the chain and add forces
     311      191194 :   unsigned ind=0;
     312      191194 :   addForcesOnArguments( 0, forcesForApply, ind );
     313      191194 :   setForcesOnAtoms( forcesForApply, ind );
     314             : }
     315             : 
     316             : }

Generated by: LCOV version 1.16