LCOV - code coverage report
Current view: top level - function - FunctionOfMatrix.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 244 268 91.0 %
Date: 2026-03-30 11:13:23 Functions: 85 108 78.7 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2011-2020 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #ifndef __PLUMED_function_FunctionOfMatrix_h
      23             : #define __PLUMED_function_FunctionOfMatrix_h
      24             : 
      25             : #include "core/ActionWithMatrix.h"
      26             : #include "FunctionOfVector.h"
      27             : #include "Sum.h"
      28             : #include "tools/Matrix.h"
      29             : 
      30             : namespace PLMD {
      31             : namespace function {
      32             : 
      33             : template <class T>
      34             : class FunctionOfMatrix : public ActionWithMatrix {
      35             : private:
      36             : /// Is this the first step of the calculation
      37             :   bool firststep;
      38             : /// The function that is being computed
      39             :   T myfunc;
      40             : /// The number of derivatives for this action
      41             :   unsigned nderivatives;
      42             : /// A vector that tells us if we have stored the input value
      43             :   std::vector<bool> stored_arguments;
      44             : /// Switch off updating the arguments for this action
      45             :   std::vector<bool> update_arguments;
      46             : /// The list of actiosn in this chain
      47             :   std::vector<std::string> actionsLabelsInChain;
      48             : /// Get the shape of the output matrix
      49             :   std::vector<unsigned> getValueShapeFromArguments();
      50             : public:
      51             :   static void registerKeywords(Keywords&);
      52             :   explicit FunctionOfMatrix(const ActionOptions&);
      53             : /// Get the label to write in the graph
      54           0 :   std::string writeInGraph() const override {
      55           0 :     return myfunc.getGraphInfo( getName() );
      56             :   }
      57             : /// Make sure the derivatives are turned on
      58             :   void turnOnDerivatives() override;
      59             : /// Get the number of derivatives for this action
      60             :   unsigned getNumberOfDerivatives() override ;
      61             : /// Resize the matrices
      62             :   void prepare() override ;
      63             : /// This gets the number of columns
      64             :   unsigned getNumberOfColumns() const override ;
      65             : /// This checks for tasks in the parent class
      66             : //  void buildTaskListFromArgumentRequests( const unsigned& ntasks, bool& reduce, std::set<AtomNumber>& otasks ) override ;
      67             : /// This ensures that we create some bookeeping stuff during the first step
      68             :   void setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol, unsigned& nbookeeping ) override ;
      69             : /// This sets up for the task
      70             :   void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const ;
      71             : /// Calculate the full matrix
      72             :   void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const override ;
      73             : /// This updates the indices for the matrix
      74             : //  void updateCentralMatrixIndex( const unsigned& ind, const std::vector<unsigned>& indices, MultiValue& myvals ) const override ;
      75             :   void runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const override ;
      76             : };
      77             : 
      78             : template <class T>
      79        1037 : void FunctionOfMatrix<T>::registerKeywords(Keywords& keys ) {
      80        1037 :   ActionWithMatrix::registerKeywords(keys);
      81        1037 :   keys.use("ARG");
      82        1037 :   std::string name = keys.getDisplayName();
      83        1037 :   std::size_t und=name.find("_MATRIX");
      84        1037 :   keys.setDisplayName( name.substr(0,und) );
      85        2074 :   keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log");
      86        2074 :   keys.reserve("compulsory","PERIODIC","if the output of your function is periodic then you should specify the periodicity of the function.  If the output is not periodic you must state this using PERIODIC=NO");
      87         749 :   T tfunc;
      88        1037 :   tfunc.registerKeywords( keys );
      89        2074 :   if( keys.getDisplayName()=="SUM" ) {
      90         180 :     keys.setValueDescription("the sum of all the elements in the input matrix");
      91        1894 :   } else if( keys.getDisplayName()=="HIGHEST" ) {
      92           0 :     keys.setValueDescription("the largest element of the input matrix");
      93        1894 :   } else if( keys.getDisplayName()=="LOWEST" ) {
      94           0 :     keys.setValueDescription("the smallest element in the input matrix");
      95        1894 :   } else if( keys.outputComponentExists(".#!value") ) {
      96        1700 :     keys.setValueDescription("the matrix obtained by doing an element-wise application of " + keys.getOutputComponentDescription(".#!value") + " to the input matrix");
      97             :   }
      98        1931 : }
      99             : 
     100             : template <class T>
     101         497 : FunctionOfMatrix<T>::FunctionOfMatrix(const ActionOptions&ao):
     102             :   Action(ao),
     103             :   ActionWithMatrix(ao),
     104         497 :   firststep(true) {
     105         453 :   if( myfunc.getArgStart()>0 ) {
     106             :     error("this has not beeen implemented -- if you are interested email gareth.tribello@gmail.com");
     107             :   }
     108             :   // Get the shape of the output
     109         497 :   std::vector<unsigned> shape( getValueShapeFromArguments() );
     110             :   // Check if the output matrix is symmetric
     111         497 :   bool symmetric=true;
     112             :   unsigned argstart=myfunc.getArgStart();
     113        1512 :   for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     114        1015 :     if( getPntrToArgument(i)->getRank()==2 ) {
     115         950 :       if( !getPntrToArgument(i)->isSymmetric() ) {
     116         833 :         symmetric=false;
     117             :       }
     118             :     }
     119             :   }
     120             :   // Read the input and do some checks
     121         497 :   myfunc.read( this );
     122             :   // Setup to do this in chain if possible
     123             :   if( myfunc.doWithTasks() ) {
     124         497 :     done_in_chain=true;
     125             :   }
     126             :   // Check we are not calculating a sum
     127          43 :   if( myfunc.zeroRank() ) {
     128          43 :     shape.resize(0);
     129             :   }
     130             :   // Get the names of the components
     131         497 :   std::vector<std::string> components( keywords.getOutputComponents() );
     132             :   // Create the values to hold the output
     133          42 :   std::vector<std::string> str_ind( myfunc.getComponentsPerLabel() );
     134        1038 :   for(unsigned i=0; i<components.size(); ++i) {
     135          84 :     if( str_ind.size()>0 ) {
     136          84 :       std::string compstr = components[i];
     137          84 :       if( components[i]==".#!value" ) {
     138             :         compstr = "";
     139             :       }
     140         760 :       for(unsigned j=0; j<str_ind.size(); ++j) {
     141             :         if( myfunc.zeroRank() ) {
     142             :           addComponentWithDerivatives( compstr + str_ind[j], shape );
     143             :         } else {
     144        1352 :           addComponent( compstr + str_ind[j], shape );
     145         676 :           getPntrToComponent(i*str_ind.size()+j)->setSymmetric( symmetric );
     146             :         }
     147             :       }
     148          43 :     } else if( components[i]==".#!value" && myfunc.zeroRank() ) {
     149          43 :       addValueWithDerivatives( shape );
     150         414 :     } else if( components[i]==".#!value" ) {
     151         410 :       addValue( shape );
     152         410 :       getPntrToComponent(0)->setSymmetric( symmetric );
     153           4 :     } else if( components[i].find_first_of("_")!=std::string::npos ) {
     154           0 :       if( getNumberOfArguments()-argstart==1 ) {
     155           0 :         addValue( shape );
     156           0 :         getPntrToComponent(0)->setSymmetric( symmetric );
     157             :       } else {
     158           0 :         for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
     159           0 :           addComponent( getPntrToArgument(j)->getName() + components[i], shape );
     160           0 :           getPntrToComponent(i*(getNumberOfArguments()-argstart)+j-argstart)->setSymmetric( symmetric );
     161             :         }
     162             :       }
     163             :     } else {
     164           4 :       addComponent( components[i], shape );
     165           4 :       getPntrToComponent(i)->setSymmetric( symmetric );
     166             :     }
     167             :   }
     168             :   // Check if this can be sped up
     169         370 :   if( myfunc.getDerivativeZeroIfValueIsZero() )  {
     170         174 :     for(int i=0; i<getNumberOfComponents(); ++i) {
     171          87 :       getPntrToComponent(i)->setDerivativeIsZeroWhenValueIsZero();
     172             :     }
     173             :   }
     174             :   // Set the periodicities of the output components
     175         497 :   myfunc.setPeriodicityForOutputs( this );
     176             :   // We can't do this with if we are dividing a stack by some a product v.v^T product as we need to store the vector
     177             :   // In order to do this type of calculation.  There should be a neater fix than this but I can't see it.
     178             :   bool foundneigh=false;
     179             :   const ActionWithMatrix* chainstart = NULL;
     180        1507 :   for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     181        1013 :     if( getPntrToArgument(i)->isConstant() && getNumberOfArguments()>1 ) {
     182         275 :       continue ;
     183             :     }
     184         934 :     std::string argname=(getPntrToArgument(i)->getPntrToAction())->getName();
     185         934 :     if( argname=="NEIGHBORS" ) {
     186             :       foundneigh=true;
     187             :       break;
     188             :     }
     189         931 :     ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
     190         931 :     if( !av ) {
     191          31 :       done_in_chain=false;
     192             :     }
     193         931 :     if( getPntrToArgument(i)->getRank()==0 ) {
     194           0 :       function::FunctionOfVector<function::Sum>* as = dynamic_cast<function::FunctionOfVector<function::Sum>*>( getPntrToArgument(i)->getPntrToAction() );
     195           0 :       if(as) {
     196           0 :         done_in_chain=false;
     197             :       }
     198         931 :     } else if( getPntrToArgument(i)->ignoreStoredValue( getLabel() ) ) {
     199             :       // This option deals with the case when you have two adjacency matrices, A_ij and B_ij, multiplied together.  This cannot be done in the chain as the rows
     200             :       // of the two adjacency matrix are run over separately.  The value A_ij is thus not available when B_ij is calculated.
     201         853 :       ActionWithMatrix* am = dynamic_cast<ActionWithMatrix*>( getPntrToArgument(i)->getPntrToAction() );
     202         853 :       plumed_assert( am );
     203         853 :       const ActionWithMatrix* thischain = am->getFirstMatrixInChain();
     204         853 :       if( !thischain->isAdjacencyMatrix() && thischain->getName()!="VSTACK" ) {
     205             :         continue;
     206             :       }
     207         657 :       if( !chainstart ) {
     208             :         chainstart = thischain;
     209         317 :       } else if( thischain!=chainstart ) {
     210           1 :         done_in_chain=false;
     211             :       }
     212             :     }
     213             :   }
     214             :   // If we are working with neighbors we trick PLUMED into storing ALL the components of the other arguments
     215             :   // in this way we can ensure that the function of the neighbours matrix is in a chain starting from the
     216             :   // Neighbours matrix action.
     217             :   if( foundneigh ) {
     218           9 :     for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     219           6 :       ActionWithValue* av=getPntrToArgument(i)->getPntrToAction();
     220           6 :       if( av->getName()!="NEIGHBORS" ) {
     221           8 :         for(int i=0; i<av->getNumberOfComponents(); ++i) {
     222           5 :           (av->copyOutput(i))->buildDataStore();
     223             :         }
     224             :       }
     225             :     }
     226             :   }
     227             :   // Now setup the action in the chain if we can
     228         497 :   nderivatives = buildArgumentStore(myfunc.getArgStart());
     229         994 : }
     230             : 
     231             : template <class T>
     232        1925 : void FunctionOfMatrix<T>::turnOnDerivatives() {
     233        1925 :   if( !myfunc.derivativesImplemented() ) {
     234             :     error("derivatives have not been implemended for " + getName() );
     235             :   }
     236        1925 :   ActionWithValue::turnOnDerivatives();
     237        1925 :   myfunc.setup(this);
     238        1925 : }
     239             : 
     240             : template <class T>
     241       30417 : unsigned FunctionOfMatrix<T>::getNumberOfDerivatives() {
     242       30417 :   return nderivatives;
     243             : }
     244             : 
     245             : template <class T>
     246        2231 : void FunctionOfMatrix<T>::prepare() {
     247             :   unsigned argstart = myfunc.getArgStart();
     248        2231 :   std::vector<unsigned> shape(2);
     249        2231 :   for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     250        2231 :     if( getPntrToArgument(i)->getRank()==2 ) {
     251        2231 :       shape[0] = getPntrToArgument(i)->getShape()[0];
     252        2231 :       shape[1] = getPntrToArgument(i)->getShape()[1];
     253        2231 :       break;
     254             :     }
     255             :   }
     256        6686 :   for(unsigned i=0; i<getNumberOfComponents(); ++i) {
     257        4455 :     Value* myval = getPntrToComponent(i);
     258        4455 :     if( myval->getRank()==2 && (myval->getShape()[0]!=shape[0] || myval->getShape()[1]!=shape[1]) ) {
     259          18 :       myval->setShape(shape);
     260          18 :       if( myval->valueIsStored() ) {
     261          18 :         myval->reshapeMatrixStore( shape[1] );
     262             :       }
     263             :     }
     264             :   }
     265        2231 :   ActionWithVector::prepare();
     266        2231 : }
     267             : 
     268             : template <class T>
     269      281844 : unsigned FunctionOfMatrix<T>::getNumberOfColumns() const {
     270      281844 :   if( getConstPntrToComponent(0)->getRank()==2 ) {
     271             :     unsigned argstart=myfunc.getArgStart();
     272      281844 :     for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     273      281844 :       if( getPntrToArgument(i)->getRank()==2 ) {
     274      281844 :         ActionWithMatrix* am=dynamic_cast<ActionWithMatrix*>( getPntrToArgument(i)->getPntrToAction() );
     275      281844 :         if( am ) {
     276      279606 :           return am->getNumberOfColumns();
     277             :         }
     278        2238 :         return getPntrToArgument(i)->getShape()[1];
     279             :       }
     280             :     }
     281             :   }
     282           0 :   plumed_error();
     283             :   return 0;
     284             : }
     285             : 
     286             : template <class T>
     287        4294 : void FunctionOfMatrix<T>::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
     288       11837 :   for(unsigned i=0; i<getNumberOfArguments(); ++i) {
     289        7543 :     plumed_assert( getPntrToArgument(i)->getRank()==2 );
     290             :   }
     291        4294 :   unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getShape()[1];
     292        4294 :   if( indices.size()!=size_v+1 ) {
     293         424 :     indices.resize( size_v+1 );
     294             :   }
     295      642944 :   for(unsigned i=0; i<size_v; ++i) {
     296      638650 :     indices[i+1] = start_n + i;
     297             :   }
     298             :   myvals.setSplitIndex( size_v + 1 );
     299        4294 : }
     300             : 
     301             : // template <class T>
     302             : // void FunctionOfMatrix<T>::buildTaskListFromArgumentRequests( const unsigned& ntasks, bool& reduce, std::set<AtomNumber>& otasks ) {
     303             : //   // Check if this is the first element in a chain
     304             : //   if( actionInChain() ) return;
     305             : //   // If it is computed outside a chain get the tassks the daughter chain needs
     306             : //   propegateTaskListsForValue( 0, ntasks, reduce, otasks );
     307             : // }
     308             : 
     309             : template <class T>
     310        2527 : void FunctionOfMatrix<T>::setupStreamedComponents( const std::string& headstr, unsigned& nquants, unsigned& nmat, unsigned& maxcol, unsigned& nbookeeping ) {
     311        2527 :   if( firststep ) {
     312         491 :     stored_arguments.resize( getNumberOfArguments() );
     313         491 :     update_arguments.resize( getNumberOfArguments(), true );
     314         491 :     std::string control = getFirstActionInChain()->getLabel();
     315        1488 :     for(unsigned i=0; i<stored_arguments.size(); ++i) {
     316         997 :       stored_arguments[i] = !getPntrToArgument(i)->ignoreStoredValue( control );
     317         997 :       if( !stored_arguments[i] ) {
     318             :         update_arguments[i] = true;
     319             :       } else {
     320         166 :         update_arguments[i] = !argumentDependsOn( headstr, this, getPntrToArgument(i) );
     321             :       }
     322             :     }
     323         491 :     firststep=false;
     324             :   }
     325        2527 :   ActionWithMatrix::setupStreamedComponents( headstr, nquants, nmat, maxcol, nbookeeping );
     326        2527 : }
     327             : 
     328             : template <class T>
     329    27928897 : void FunctionOfMatrix<T>::performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const {
     330             :   unsigned argstart=myfunc.getArgStart();
     331    27928897 :   std::vector<double> args( getNumberOfArguments() - argstart );
     332    27928897 :   unsigned ind2 = index2;
     333    27928897 :   if( getConstPntrToComponent(0)->getRank()==2 && index2>=getConstPntrToComponent(0)->getShape()[0] ) {
     334     3636383 :     ind2 = index2 - getConstPntrToComponent(0)->getShape()[0];
     335    24292514 :   } else if( index2>=getPntrToArgument(0)->getShape()[0] ) {
     336      448225 :     ind2 = index2 - getPntrToArgument(0)->getShape()[0];
     337             :   }
     338    27928897 :   if( actionInChain() ) {
     339    85619946 :     for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     340    58329699 :       if( getPntrToArgument(i)->getRank()==0 ) {
     341      135720 :         args[i-argstart] = getPntrToArgument(i)->get();
     342    58193979 :       } else if( !getPntrToArgument(i)->valueHasBeenSet() ) {
     343    57005386 :         args[i-argstart] = myvals.get( getPntrToArgument(i)->getPositionInStream() );
     344             :       } else {
     345     1188593 :         args[i-argstart] = getPntrToArgument(i)->get( getPntrToArgument(i)->getShape()[1]*index1 + ind2 );
     346             :       }
     347             :     }
     348             :   } else {
     349     1727564 :     for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     350     1088914 :       if( getPntrToArgument(i)->getRank()==2 ) {
     351     1088914 :         args[i-argstart]=getPntrToArgument(i)->get( getPntrToArgument(i)->getShape()[1]*index1 + ind2 );
     352             :       } else {
     353           0 :         args[i-argstart] = getPntrToArgument(i)->get();
     354             :       }
     355             :     }
     356             :   }
     357             :   // Calculate the function and its derivatives
     358    27928897 :   std::vector<double> vals( getNumberOfComponents() );
     359    27928897 :   Matrix<double> derivatives( getNumberOfComponents(), getNumberOfArguments()-argstart );
     360    27928897 :   myfunc.calc( this, args, vals, derivatives );
     361             :   // And set the values
     362    99634847 :   for(unsigned i=0; i<vals.size(); ++i) {
     363    71705950 :     myvals.addValue( getConstPntrToComponent(i)->getPositionInStream(), vals[i] );
     364             :   }
     365             :   // Return if we are not computing derivatives
     366    27928897 :   if( doNotCalculateDerivatives() ) {
     367             :     return;
     368             :   }
     369             : 
     370     5399865 :   if( actionInChain() ) {
     371    33385311 :     for(int i=0; i<getNumberOfComponents(); ++i) {
     372    27990552 :       unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
     373   131996523 :       for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
     374   104005971 :         if( getPntrToArgument(j)->getRank()==2 ) {
     375   103890411 :           unsigned istrn = getPntrToArgument(j)->getPositionInStream();
     376   103890411 :           if( stored_arguments[j] ) {
     377      395048 :             unsigned task_index = getPntrToArgument(i)->getShape()[1]*index1 + ind2;
     378      395048 :             myvals.clearDerivatives(istrn);
     379      395048 :             myvals.addDerivative( istrn, task_index, 1.0 );
     380      395048 :             myvals.updateIndex( istrn, task_index );
     381             :           }
     382   470717695 :           for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
     383   366827284 :             unsigned kind=myvals.getActiveIndex(istrn,k);
     384   366827284 :             myvals.addDerivative( ostrn, arg_deriv_starts[j] + kind, derivatives(i,j)*myvals.getDerivative( istrn, kind ) );
     385             :           }
     386             :         }
     387             :       }
     388             :     }
     389             :     // If we are computing a matrix we need to update the indices here so that derivatives are calcualted correctly in functions of these
     390     5394759 :     if( getConstPntrToComponent(0)->getRank()==2 ) {
     391    32784527 :       for(int i=0; i<getNumberOfComponents(); ++i) {
     392    27690160 :         unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
     393   131395739 :         for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
     394   103705579 :           if( !update_arguments[j] || getPntrToArgument(j)->getRank()==0 ) {
     395      115584 :             continue ;
     396             :           }
     397             :           // Ensure we only store one lot of derivative indices
     398             :           bool found=false;
     399   105009550 :           for(unsigned k=0; k<j; ++k) {
     400    76601003 :             if( arg_deriv_starts[k]==arg_deriv_starts[j] ) {
     401             :               found=true;
     402             :               break;
     403             :             }
     404             :           }
     405   103589995 :           if( found ) {
     406    75181448 :             continue;
     407             :           }
     408             :           unsigned istrn = getPntrToArgument(j)->getPositionInStream();
     409   138447375 :           for(unsigned k=0; k<myvals.getNumberActive(istrn); ++k) {
     410   110038828 :             unsigned kind=myvals.getActiveIndex(istrn,k);
     411   110038828 :             myvals.updateIndex( ostrn, arg_deriv_starts[j] + kind );
     412             :           }
     413             :         }
     414             :       }
     415             :     }
     416             :   } else {
     417             :     unsigned base=0;
     418        5106 :     ind2 = index2;
     419        5106 :     for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
     420        5106 :       if( getPntrToArgument(j)->getRank()!=2 ) {
     421             :         continue ;
     422             :       }
     423        5106 :       if( index2>=getPntrToArgument(j)->getShape()[0] ) {
     424        5106 :         ind2 = index2 - getPntrToArgument(j)->getShape()[0];
     425             :       }
     426             :       break;
     427             :     }
     428       14457 :     for(unsigned j=argstart; j<getNumberOfArguments(); ++j) {
     429        9351 :       if( getPntrToArgument(j)->getRank()==2 ) {
     430       18702 :         for(int i=0; i<getNumberOfComponents(); ++i) {
     431        9351 :           unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
     432        9351 :           unsigned myind = base + getPntrToArgument(j)->getShape()[1]*index1 + ind2;
     433        9351 :           myvals.addDerivative( ostrn, myind, derivatives(i,j) );
     434        9351 :           myvals.updateIndex( ostrn, myind );
     435             :         }
     436             :       } else {
     437           0 :         for(int i=0; i<getNumberOfComponents(); ++i) {
     438           0 :           unsigned ostrn=getConstPntrToComponent(i)->getPositionInStream();
     439           0 :           myvals.addDerivative( ostrn, base, derivatives(i,j) );
     440           0 :           myvals.updateIndex( ostrn, base );
     441             :         }
     442             :       }
     443        9351 :       base += getPntrToArgument(j)->getNumberOfValues();
     444             :     }
     445             :   }
     446             : }
     447             : 
     448             : template <class T>
     449      226474 : void FunctionOfMatrix<T>::runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
     450      226474 :   if( doNotCalculateDerivatives() ) {
     451             :     return;
     452             :   }
     453             : 
     454             :   unsigned argstart=myfunc.getArgStart();
     455       71322 :   if( actionInChain() && getConstPntrToComponent(0)->getRank()==2 ) {
     456             :     // This is triggered if we are outputting a matrix
     457      624578 :     for(int vv=0; vv<getNumberOfComponents(); ++vv) {
     458      558183 :       unsigned nmat = getConstPntrToComponent(vv)->getPositionInMatrixStash();
     459             :       std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
     460             :       unsigned ntot_mat=0;
     461      558183 :       if( mat_indices.size()<nderivatives ) {
     462           0 :         mat_indices.resize( nderivatives );
     463             :       }
     464     2701544 :       for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     465     2143361 :         if( !update_arguments[i] || getPntrToArgument(i)->getRank()==0 ) {
     466        1084 :           continue ;
     467             :         }
     468             :         // Ensure we only store one lot of derivative indices
     469             :         bool found=false;
     470     2168017 :         for(unsigned j=0; j<i; ++j) {
     471     1591516 :           if( arg_deriv_starts[j]==arg_deriv_starts[i] ) {
     472             :             found=true;
     473             :             break;
     474             :           }
     475             :         }
     476     2142277 :         if( found ) {
     477     1565776 :           continue;
     478             :         }
     479             : 
     480      576501 :         if( stored_arguments[i] ) {
     481       15483 :           unsigned tbase = getPntrToArgument(i)->getShape()[1]*ind;
     482      410507 :           for(unsigned k=1; k<indices.size(); ++k) {
     483      395024 :             unsigned ind2 = indices[k] - getConstPntrToComponent(0)->getShape()[0];
     484      395024 :             mat_indices[ntot_mat + k - 1] = arg_deriv_starts[i] + tbase + ind2;
     485             :           }
     486       15483 :           ntot_mat += indices.size()-1;
     487             :         } else {
     488             :           unsigned istrn = getPntrToArgument(i)->getPositionInMatrixStash();
     489             :           std::vector<unsigned>& imat_indices( myvals.getMatrixRowDerivativeIndices( istrn ) );
     490    31848426 :           for(unsigned k=0; k<myvals.getNumberOfMatrixRowDerivatives( istrn ); ++k) {
     491    31287408 :             mat_indices[ntot_mat + k] = arg_deriv_starts[i] + imat_indices[k];
     492             :           }
     493      561018 :           ntot_mat += myvals.getNumberOfMatrixRowDerivatives( istrn );
     494             :         }
     495             :       }
     496             :       myvals.setNumberOfMatrixRowDerivatives( nmat, ntot_mat );
     497             :     }
     498        4927 :   } else if( actionInChain() ) {
     499             :     // This is triggered if we are calculating a single scalar in the function
     500        8822 :     for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     501             :       bool found=false;
     502        4411 :       for(unsigned j=0; j<i; ++j) {
     503           0 :         if( arg_deriv_starts[j]==arg_deriv_starts[i] ) {
     504             :           found=true;
     505             :           break;
     506             :         }
     507             :       }
     508        4411 :       if( found ) {
     509             :         continue;
     510             :       }
     511             :       unsigned istrn = getPntrToArgument(i)->getPositionInMatrixStash();
     512             :       std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( istrn ) );
     513      926766 :       for(unsigned k=0; k<myvals.getNumberOfMatrixRowDerivatives( istrn ); ++k) {
     514     1844710 :         for(int j=0; j<getNumberOfComponents(); ++j) {
     515      922355 :           unsigned ostrn = getConstPntrToComponent(j)->getPositionInStream();
     516      922355 :           myvals.updateIndex( ostrn, arg_deriv_starts[i] + mat_indices[k] );
     517             :         }
     518             :       }
     519             :     }
     520         516 :   } else if( getConstPntrToComponent(0)->getRank()==2 ) {
     521         760 :     for(int vv=0; vv<getNumberOfComponents(); ++vv) {
     522         380 :       unsigned nmat = getConstPntrToComponent(vv)->getPositionInMatrixStash();
     523             :       std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
     524             :       unsigned ntot_mat=0;
     525         380 :       if( mat_indices.size()<nderivatives ) {
     526           0 :         mat_indices.resize( nderivatives );
     527             :       }
     528             :       unsigned matderbase = 0;
     529         986 :       for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     530         606 :         if( getPntrToArgument(i)->getRank()==0 ) {
     531           0 :           continue ;
     532             :         }
     533         606 :         unsigned ss = getPntrToArgument(i)->getShape()[1];
     534         606 :         unsigned tbase = matderbase + ss*myvals.getTaskIndex();
     535        9558 :         for(unsigned k=0; k<ss; ++k) {
     536        8952 :           mat_indices[ntot_mat + k] = tbase + k;
     537             :         }
     538         606 :         ntot_mat += ss;
     539         606 :         matderbase += getPntrToArgument(i)->getNumberOfValues();
     540             :       }
     541             :       myvals.setNumberOfMatrixRowDerivatives( nmat, ntot_mat );
     542             :     }
     543             :   }
     544             : }
     545             : 
     546             : template <class T>
     547         497 : std::vector<unsigned> FunctionOfMatrix<T>::getValueShapeFromArguments() {
     548             :   unsigned argstart=myfunc.getArgStart();
     549         497 :   std::vector<unsigned> shape(2);
     550         497 :   shape[0]=shape[1]=0;
     551        1512 :   for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     552        1015 :     plumed_assert( getPntrToArgument(i)->getRank()==2 || getPntrToArgument(i)->getRank()==0 );
     553        1015 :     if( getPntrToArgument(i)->getRank()==2 ) {
     554         950 :       if( shape[0]>0 && (getPntrToArgument(i)->getShape()[0]!=shape[0] || getPntrToArgument(i)->getShape()[1]!=shape[1]) ) {
     555           0 :         error("all matrices input should have the same shape");
     556         950 :       } else if( shape[0]==0 ) {
     557         511 :         shape[0]=getPntrToArgument(i)->getShape()[0];
     558         511 :         shape[1]=getPntrToArgument(i)->getShape()[1];
     559             :       }
     560         950 :       plumed_assert( !getPntrToArgument(i)->hasDerivatives() );
     561             :     }
     562             :   }
     563          43 :   myfunc.setPrefactor( this, 1.0 );
     564         497 :   return shape;
     565             : }
     566             : 
     567             : }
     568             : }
     569             : #endif

Generated by: LCOV version 1.16