LCOV - code coverage report
Current view: top level - matrixtools - MatrixTimesVectorBase.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 139 146 95.2 %
Date: 2025-12-04 11:19:34 Functions: 28 30 93.3 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2011-2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #ifndef __PLUMED_matrixtools_MatrixTimesVectorBase_h
      23             : #define __PLUMED_matrixtools_MatrixTimesVectorBase_h
      24             : 
      25             : #include "core/ActionWithVector.h"
      26             : #include "core/ParallelTaskManager.h"
      27             : 
      28             : namespace PLMD {
      29             : namespace matrixtools {
      30             : 
      31        1460 : class MatrixTimesVectorData {
      32             : public:
      33             :   std::size_t fshift;
      34             :   Matrix<std::size_t> pairs;
      35             : #ifdef __PLUMED_USE_OPENACC
      36             :   void toACCDevice() const {
      37             : #pragma acc enter data copyin(this[0:1],fshift)
      38             :     pairs.toACCDevice();
      39             :   }
      40             :   void removeFromACCDevice() const {
      41             :     pairs.removeFromACCDevice();
      42             : #pragma acc exit data delete(fshift,this[0:1])
      43             :   }
      44             : #endif //__PLUMED_USE_OPENACC
      45             : };
      46             : 
      47             : class MatrixForceIndexInput {
      48             : public:
      49             :   std::size_t rowlen;
      50             :   View<const std::size_t> indices;
      51     1820421 :   MatrixForceIndexInput( std::size_t task_index,
      52             :                          std::size_t ipair,
      53             :                          const MatrixTimesVectorData& actiondata,
      54     1820421 :                          const ParallelActionsInput& input ):
      55     1820421 :     rowlen(input.bookeeping[input.bookstarts[actiondata.pairs[ipair][0]]
      56     1820421 :                             + (1+input.ncols[actiondata.pairs[ipair][0]])*task_index]),
      57     1820421 :     indices(input.bookeeping + input.bookstarts[actiondata.pairs[ipair][0]]
      58     1820421 :             + (1+input.ncols[actiondata.pairs[ipair][0]])*task_index + 1,
      59     1820421 :             rowlen) {}
      60             : };
      61             : 
      62             : class MatrixTimesVectorInput {
      63             : public:
      64             :   bool noderiv;
      65             :   std::size_t rowlen;
      66             :   View<const std::size_t> indices;
      67             :   View<const double> matrow;
      68             :   View<const double> vector;
      69    10144979 :   MatrixTimesVectorInput( std::size_t task_index,
      70             :                           std::size_t ipair,
      71             :                           const MatrixTimesVectorData& actiondata,
      72             :                           const ParallelActionsInput& input,
      73    10144979 :                           double* argdata ):
      74    10144979 :     noderiv(input.noderiv),
      75    10144979 :     rowlen(input.bookeeping[input.bookstarts[actiondata.pairs[ipair][0]]
      76    10144979 :                             + (1+input.ncols[actiondata.pairs[ipair][0]])*task_index]),
      77    10144979 :     indices(input.bookeeping + input.bookstarts[actiondata.pairs[ipair][0]]
      78    10144979 :             + (1+input.ncols[actiondata.pairs[ipair][0]])*task_index + 1,rowlen),
      79    10144979 :     matrow(argdata + input.argstarts[actiondata.pairs[ipair][0]]
      80    10144979 :            + task_index*input.ncols[actiondata.pairs[ipair][0]],rowlen),
      81    10144979 :     vector(argdata + input.argstarts[actiondata.pairs[ipair][1]], input.shapedata[1]) {
      82    10144979 :   }
      83             : };
      84             : 
      85             : class MatrixTimesVectorOutput {
      86             : public:
      87             :   std::size_t rowlen;
      88             :   View<double,1> values;
      89             :   View<double> matrow_deriv;
      90             :   View<double> vector_deriv;
      91    10144979 :   MatrixTimesVectorOutput( std::size_t task_index,
      92             :                            std::size_t ipair,
      93             :                            std::size_t nder,
      94             :                            const MatrixTimesVectorData& actiondata,
      95             :                            const ParallelActionsInput& input,
      96    10144979 :                            ParallelActionsOutput& output ):
      97    10144979 :     rowlen(input.bookeeping[input.bookstarts[actiondata.pairs[ipair][0]]
      98    10144979 :                             + (1+input.ncols[actiondata.pairs[ipair][0]])*task_index]),
      99    10144979 :     values(output.values.data()+ipair),
     100    10144979 :     matrow_deriv(output.derivatives.data()+ipair*nder,rowlen),
     101    10144979 :     vector_deriv(output.derivatives.data()+ipair*nder+rowlen,rowlen) {
     102    10144979 :   }
     103             : };
     104             : 
     105             : template <class T>
     106             : class MatrixTimesVectorBase : public ActionWithVector {
     107             : public:
     108             :   using input_type = MatrixTimesVectorData;
     109             :   using PTM = ParallelTaskManager<MatrixTimesVectorBase<T>>;
     110             : private:
     111             : /// The parallel task manager
     112             :   PTM taskmanager;
     113             : public:
     114             :   static void registerKeywords( Keywords& keys );
     115             :   static void registerLocalKeywords( Keywords& keys );
     116             :   explicit MatrixTimesVectorBase(const ActionOptions&);
     117             :   std::string getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const override ;
     118             :   unsigned getNumberOfDerivatives() override ;
     119             :   void prepare() override ;
     120             :   void calculate() override ;
     121             :   void applyNonZeroRankForces( std::vector<double>& outforces ) override ;
     122             :   int checkTaskIsActive( const unsigned& itask ) const override ;
     123             :   /// Override this so we write the graph properly
     124          10 :   std::string writeInGraph() const override {
     125          10 :     return "MATRIX_VECTOR_PRODUCT";
     126             :   }
     127             :   static void performTask( std::size_t task_index,
     128             :                            const MatrixTimesVectorData& actiondata,
     129             :                            ParallelActionsInput& input,
     130             :                            ParallelActionsOutput& output );
     131             :   static int getNumberOfValuesPerTask( std::size_t task_index,
     132             :                                        const MatrixTimesVectorData& actiondata );
     133             :   static void getForceIndices( std::size_t task_index,
     134             :                                std::size_t colno,
     135             :                                std::size_t ntotal_force,
     136             :                                const MatrixTimesVectorData& actiondata,
     137             :                                const ParallelActionsInput& input,
     138             :                                ForceIndexHolder force_indices );
     139             : };
     140             : 
     141             : template <class T>
     142         738 : void MatrixTimesVectorBase<T>::registerKeywords( Keywords& keys ) {
     143         738 :   ActionWithVector::registerKeywords(keys);
     144         738 :   keys.setDisplayName("MATRIX_VECTOR_PRODUCT");
     145         738 :   registerLocalKeywords( keys );
     146         738 :   ActionWithValue::useCustomisableComponents(keys);
     147         738 :   keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log");
     148         738 : }
     149             : 
     150             : template <class T>
     151        1384 : void MatrixTimesVectorBase<T>::registerLocalKeywords( Keywords& keys ) {
     152        1384 :   PTM::registerKeywords( keys );
     153        2768 :   keys.addInputKeyword("compulsory","ARG","matrix/vector/scalar","the label for the matrix and the vector/scalar that are being multiplied.  Alternatively, you can provide labels for multiple matrices and a single vector or labels for a single matrix and multiple vectors. In these cases multiple matrix vector products will be computed.");
     154        1384 :   keys.add("hidden","MASKED_INPUT_ALLOWED","turns on that you are allowed to use masked inputs ");
     155        2768 :   keys.setValueDescription("vector","the vector that is obtained by taking the product between the matrix and the vector that were input");
     156        1384 :   ActionWithValue::useCustomisableComponents(keys);
     157        1384 : }
     158             : 
     159             : template <class T>
     160           6 : std::string MatrixTimesVectorBase<T>::getOutputComponentDescription( const std::string& cname,
     161             :     const Keywords& keys ) const {
     162           6 :   if( getPntrToArgument(1)->getRank()==1 ) {
     163           0 :     for(unsigned i=1; i<getNumberOfArguments(); ++i) {
     164           0 :       if( getPntrToArgument(i)->getName().find(cname)!=std::string::npos ) {
     165           0 :         return "the product of the matrix " + getPntrToArgument(0)->getName() + " and the vector " + getPntrToArgument(i)->getName();
     166             :       }
     167             :     }
     168             :   }
     169          21 :   for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
     170          21 :     if( getPntrToArgument(i)->getName().find(cname)!=std::string::npos ) {
     171          12 :       return "the product of the matrix " + getPntrToArgument(i)->getName() + " and the vector " + getPntrToArgument(getNumberOfArguments()-1)->getName();
     172             :     }
     173             :   }
     174           0 :   plumed_merror( "could not understand request for component " + cname );
     175             :   return "";
     176             : }
     177             : 
     178             : template <class T>
     179         365 : MatrixTimesVectorBase<T>::MatrixTimesVectorBase(const ActionOptions&ao):
     180             :   Action(ao),
     181             :   ActionWithVector(ao),
     182         365 :   taskmanager(this) {
     183         365 :   if( getNumberOfArguments()<2 ) {
     184           0 :     error("Not enough arguments specified");
     185             :   }
     186             :   bool vectormask=false, derivbool = true;
     187        1917 :   for(unsigned i=0; i<getNumberOfArguments(); ++i) {
     188        1552 :     if( getPntrToArgument(i)->hasDerivatives() ) {
     189           0 :       error("arguments should be vectors or matrices");
     190             :     }
     191        1552 :     if( getPntrToArgument(i)->getRank()<=1 ) {
     192         550 :       ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
     193         550 :       if( av && av->getNumberOfMasks()>=0 ) {
     194             :         vectormask=true;
     195             :       }
     196             :     }
     197        1552 :     if( !getPntrToArgument(i)->isDerivativeZeroWhenValueIsZero() ) {
     198             :       derivbool = false;
     199             :     }
     200             :   }
     201         365 :   if( !vectormask ) {
     202         365 :     ignoreMaskArguments();
     203             :   }
     204             : 
     205         365 :   std::vector<std::size_t> shape(1);
     206         365 :   shape[0]=getPntrToArgument(0)->getShape()[0];
     207         365 :   if( getNumberOfArguments()==2 ) {
     208         313 :     addValue( shape );
     209         313 :     setNotPeriodic();
     210         313 :     if( derivbool ) {
     211         233 :       getPntrToComponent(0)->setDerivativeIsZeroWhenValueIsZero();
     212             :     }
     213             :   } else {
     214             :     unsigned namestart=1, nameend=getNumberOfArguments();
     215          52 :     if( getPntrToArgument(1)->getRank()==2 ) {
     216             :       namestart = 0;
     217          43 :       nameend = getNumberOfArguments()-1;
     218             :     }
     219             : 
     220         926 :     for(unsigned i=namestart; i<nameend; ++i) {
     221         874 :       std::string name = getPntrToArgument(i)->getName();
     222         874 :       if( name.find_first_of(".")!=std::string::npos ) {
     223         680 :         std::size_t dot=name.find_first_of(".");
     224        1360 :         name = name.substr(dot+1);
     225             :       }
     226         874 :       addComponent( name, shape );
     227         874 :       componentIsNotPeriodic( name );
     228         874 :       if( derivbool ) {
     229        1740 :         copyOutput( getLabel() + "." + name )->setDerivativeIsZeroWhenValueIsZero();
     230             :       }
     231             :     }
     232             :   }
     233             :   // This sets up an array in the parallel task manager to hold all the indices
     234             :   // Sets up the index list in the task manager
     235         365 :   std::size_t nder = getPntrToArgument(getNumberOfArguments()-1)->getNumberOfStoredValues();
     236             :   MatrixTimesVectorData input;
     237         365 :   input.pairs.resize( getNumberOfArguments()-1, 2 );
     238         365 :   if( getPntrToArgument(1)->getRank()==2 ) {
     239         723 :     for(unsigned i=1; i<getNumberOfArguments(); ++i) {
     240         680 :       input.pairs[i-1][0] = i-1;
     241         680 :       input.pairs[i-1][1] = getNumberOfArguments()-1;
     242             :     }
     243          43 :     input.fshift=0;
     244             :   } else {
     245         829 :     for(unsigned i=1; i<getNumberOfArguments(); ++i) {
     246         507 :       input.pairs[i-1][0] = 0;
     247         507 :       input.pairs[i-1][1] = i;
     248             :     }
     249         322 :     input.fshift=nder;
     250             :   }
     251         365 :   taskmanager.setActionInput( input );
     252         365 : }
     253             : 
     254             : template <class T>
     255       28652 : unsigned MatrixTimesVectorBase<T>::getNumberOfDerivatives() {
     256             :   unsigned nderivatives=0;
     257      695284 :   for(unsigned i=0; i<getNumberOfArguments(); ++i) {
     258      666632 :     nderivatives += getPntrToArgument(i)->getNumberOfStoredValues();
     259             :   }
     260       28652 :   return nderivatives;
     261             : }
     262             : 
     263             : template <class T>
     264       13715 : void MatrixTimesVectorBase<T>::prepare() {
     265       13715 :   ActionWithVector::prepare();
     266       13715 :   Value* myval = getPntrToComponent(0);
     267       13715 :   if( myval->getShape()[0]==getPntrToArgument(0)->getShape()[0] ) {
     268       13705 :     return;
     269             :   }
     270          10 :   std::vector<std::size_t> shape(1);
     271          10 :   shape[0] = getPntrToArgument(0)->getShape()[0];
     272          10 :   myval->setShape(shape);
     273             : }
     274             : 
     275             : template <class T>
     276       13708 : void MatrixTimesVectorBase<T>::calculate() {
     277       13708 :   std::size_t nvectors, nder = getPntrToArgument(getNumberOfArguments()-1)->getNumberOfStoredValues();
     278       13708 :   if( getPntrToArgument(1)->getRank()==2 ) {
     279             :     nvectors = 1;
     280             :   } else {
     281       13585 :     nvectors = getNumberOfArguments()-1;
     282             :   }
     283       13708 :   if( getName()=="MATRIX_VECTOR_PRODUCT_ROWSUMS" ) {
     284       11317 :     taskmanager.setupParallelTaskManager( nder, 0 );
     285             :   } else {
     286        2391 :     taskmanager.setupParallelTaskManager( 2*nder, nvectors*nder );
     287             :   }
     288       13708 :   taskmanager.runAllTasks();
     289       13708 : }
     290             : 
     291             : template <class T>
     292     1824672 : int MatrixTimesVectorBase<T>::checkTaskIsActive( const unsigned& itask ) const {
     293     2516128 :   for(unsigned i=0; i<getNumberOfArguments(); ++i) {
     294             :     Value* myarg = getPntrToArgument(i);
     295     2516128 :     if( myarg->getRank()==1 && !myarg->hasDerivatives() ) {
     296             :       return 0;
     297     1824673 :     } else if( myarg->getRank()==2 && !myarg->hasDerivatives() ) {
     298     1824673 :       if (myarg->checkValueIsActiveForMMul(itask)) {
     299             :         return 1;
     300             :       }
     301             :     } else {
     302           0 :       plumed_merror("should not be in action " + getName() );
     303             :     }
     304             :   }
     305             :   return 0;
     306             : }
     307             : 
     308             : template <class T>
     309     1513469 : void MatrixTimesVectorBase<T>::performTask( std::size_t task_index,
     310             :     const MatrixTimesVectorData& actiondata,
     311             :     ParallelActionsInput& input,
     312             :     ParallelActionsOutput& output ) {
     313    11658448 :   for(unsigned i=0; i<actiondata.pairs.nrows(); ++i) {
     314    10144979 :     MatrixTimesVectorOutput doutput( task_index,
     315             :                                      i,
     316             :                                      input.nderivatives_per_scalar,
     317             :                                      actiondata,
     318             :                                      input,
     319             :                                      output );
     320    10144979 :     T::performTask( MatrixTimesVectorInput( task_index,
     321             :                                             i,
     322             :                                             actiondata,
     323             :                                             input,
     324             :                                             input.inputdata ),
     325             :                     doutput );
     326             :   }
     327     1513469 : }
     328             : 
     329             : template <class T>
     330       11890 : void MatrixTimesVectorBase<T>::applyNonZeroRankForces( std::vector<double>& outforces ) {
     331       11890 :   taskmanager.applyForces( outforces );
     332       11890 : }
     333             : 
     334             : template <class T>
     335      443916 : int MatrixTimesVectorBase<T>::getNumberOfValuesPerTask( std::size_t task_index,
     336             :     const MatrixTimesVectorData& actiondata ) {
     337      443916 :   return 1;
     338             : }
     339             : 
     340             : template <class T>
     341      443916 : void MatrixTimesVectorBase<T>::getForceIndices( std::size_t task_index,
     342             :     std::size_t colno,
     343             :     std::size_t ntotal_force,
     344             :     const MatrixTimesVectorData& actiondata,
     345             :     const ParallelActionsInput& input,
     346             :     ForceIndexHolder force_indices ) {
     347     2264337 :   for(unsigned i=0; i<actiondata.pairs.nrows(); ++i) {
     348     1820421 :     std::size_t base = input.argstarts[actiondata.pairs[i][0]]
     349     1820421 :                        + task_index*input.ncols[actiondata.pairs[i][0]];
     350     1820421 :     std::size_t n = input.bookeeping[input.bookstarts[actiondata.pairs[i][0]]
     351     1820421 :                                      + (1+input.ncols[actiondata.pairs[i][0]])*task_index];
     352    76718715 :     for(unsigned j=0; j<n; ++j) {
     353    74898294 :       force_indices.indices[i][j] = base + j;
     354             :     }
     355     1820421 :     force_indices.threadsafe_derivatives_end[i] = n;
     356     1820421 :     force_indices.tot_indices[i] = T::getAdditionalIndices( n,
     357      124096 :                                    input.argstarts[actiondata.pairs[i][1]],
     358     1944517 :                                    MatrixForceIndexInput( task_index, i, actiondata, input ),
     359             :                                    force_indices.indices[i] );
     360             :   }
     361      443916 : }
     362             : 
     363             : }
     364             : }
     365             : #endif

Generated by: LCOV version 1.16