LCOV - code coverage report
Current view: top level - matrixtools - MatrixTimesMatrix.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 136 148 91.9 %
Date: 2025-12-04 11:19:34 Functions: 14 18 77.8 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2011-2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #ifndef __PLUMED_matrixtools_MatrixTimesMatrix_h
      23             : #define __PLUMED_matrixtools_MatrixTimesMatrix_h
      24             : 
      25             : #include "core/ActionWithMatrix.h"
      26             : #include "core/ParallelTaskManager.h"
      27             : 
      28             : namespace PLMD {
      29             : namespace matrixtools {
      30             : 
      31             : template <class T>
      32          58 : struct MatrixTimesMatrixInput {
      33             :   T funcinput;
      34             :   RequiredMatrixElements outmat;
      35             : #ifdef __PLUMED_USE_OPENACC
      36             :   void toACCDevice() const {
      37             : #pragma acc enter data copyin(this[0:1])
      38             :     funcinput.toACCDevice();
      39             :     outmat.toACCDevice();
      40             :   }
      41             :   void removeFromACCDevice() const {
      42             :     funcinput.removeFromACCDevice();
      43             :     outmat.removeFromACCDevice();
      44             : #pragma acc exit data delete(this[0:1])
      45             :   }
      46             : #endif //__PLUMED_USE_OPENACC
      47             : };
      48             : 
      49             : class InputVectors {
      50             : public:
      51             :   std::size_t nelem;
      52             :   View<double> arg1;
      53             :   View<double> arg2;
      54       71433 :   InputVectors( std::size_t n,  double* b ) : nelem(n), arg1(b,n), arg2(b+n,n) {}
      55             : };
      56             : 
      57             : template <class T>
      58             : class MatrixTimesMatrix : public ActionWithMatrix {
      59             : public:
      60             :   using input_type = MatrixTimesMatrixInput<T>;
      61             :   using PTM = ParallelTaskManager<MatrixTimesMatrix<T>>;
      62             : private:
      63             :   PTM taskmanager;
      64             : public:
      65             :   static void registerKeywords( Keywords& keys );
      66             :   explicit MatrixTimesMatrix(const ActionOptions&);
      67             :   void prepare() override ;
      68             :   unsigned getNumberOfDerivatives() override;
      69             :   void calculate() override ;
      70             :   void applyNonZeroRankForces( std::vector<double>& outforces ) override ;
      71             :   static void performTask( std::size_t task_index, const MatrixTimesMatrixInput<T>& actiondata, ParallelActionsInput& input, ParallelActionsOutput& output );
      72             :   static int getNumberOfValuesPerTask( std::size_t task_index, const MatrixTimesMatrixInput<T>& actiondata );
      73             :   static void getForceIndices( std::size_t task_index, std::size_t colno, std::size_t ntotal_force, const MatrixTimesMatrixInput<T>& actiondata, const ParallelActionsInput& input, ForceIndexHolder force_indices );
      74             : };
      75             : 
      76             : template <class T>
      77          98 : void MatrixTimesMatrix<T>::registerKeywords( Keywords& keys ) {
      78          98 :   ActionWithMatrix::registerKeywords(keys);
      79         196 :   keys.addInputKeyword("optional","MASK","matrix","a matrix that is used to used to determine which elements of the output matrix to compute");
      80         196 :   keys.addInputKeyword("compulsory","ARG","matrix","the label of the two matrices from which the product is calculated");
      81         196 :   if( keys.getDisplayName()=="MATRIX_PRODUCT" ) {
      82          77 :     keys.addFlag("ELEMENTS_ON_DIAGONAL_ARE_ZERO",false,"set all diagonal elements to zero");
      83             :   }
      84          98 :   T::registerKeywords( keys );
      85          98 :   PTM::registerKeywords( keys );
      86          98 : }
      87             : 
      88             : template <class T>
      89          58 : MatrixTimesMatrix<T>::MatrixTimesMatrix(const ActionOptions&ao):
      90             :   Action(ao),
      91             :   ActionWithMatrix(ao),
      92          58 :   taskmanager(this) {
      93             :   int nm=getNumberOfMasks();
      94             :   if( nm<0 ) {
      95             :     nm = 0;
      96             :   }
      97          58 :   if( getNumberOfArguments()-nm!=2 ) {
      98           0 :     error("should be two arguments to this action, a matrix and a vector");
      99             :   }
     100          58 :   if( getPntrToArgument(0)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) {
     101           0 :     error("first argument to this action should be a matrix");
     102             :   }
     103          58 :   if( getPntrToArgument(1)->getRank()!=2 || getPntrToArgument(1)->hasDerivatives() ) {
     104           0 :     error("second argument to this action should be a matrix");
     105             :   }
     106          58 :   if( getPntrToArgument(0)->getShape()[1]!=getPntrToArgument(1)->getShape()[0] ) {
     107           0 :     error("number of columns in first matrix does not equal number of rows in second matrix");
     108             :   }
     109          58 :   std::vector<std::size_t> shape(2);
     110          58 :   shape[0]=getPntrToArgument(0)->getShape()[0];
     111          58 :   shape[1]=getPntrToArgument(1)->getShape()[1];
     112          58 :   addValue( shape );
     113          58 :   setNotPeriodic();
     114          58 :   getPntrToComponent(0)->reshapeMatrixStore( shape[1] );
     115          58 :   if( getName()!="DISSIMILARITIES" && getPntrToArgument(0)->isDerivativeZeroWhenValueIsZero() && getPntrToArgument(1)->isDerivativeZeroWhenValueIsZero() ) {
     116           6 :     getPntrToComponent(0)->setDerivativeIsZeroWhenValueIsZero();
     117             :   }
     118             : 
     119          58 :   if( nm>0 ) {
     120          14 :     unsigned iarg = getNumberOfArguments()-1;
     121          14 :     if( getPntrToArgument(iarg)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) {
     122           0 :       error("argument passed to MASK keyword should be a matrix");
     123             :     }
     124          14 :     if( getPntrToArgument(iarg)->getShape()[0]!=shape[0] || getPntrToArgument(iarg)->getShape()[1]!=shape[1] ) {
     125           0 :       error("argument passed to MASK keyword has the wrong shape");
     126             :     }
     127             :   }
     128             :   MatrixTimesMatrixInput<T> actdata;
     129          12 :   actdata.funcinput.setup( this, getPntrToArgument(0) );
     130          58 :   taskmanager.setActionInput( actdata );
     131          58 : }
     132             : 
     133             : template <class T>
     134          64 : unsigned MatrixTimesMatrix<T>::getNumberOfDerivatives() {
     135          64 :   return getPntrToArgument(0)->getNumberOfStoredValues() + getPntrToArgument(1)->getNumberOfStoredValues();
     136             : }
     137             : 
     138             : template <class T>
     139         757 : void MatrixTimesMatrix<T>::prepare() {
     140         757 :   ActionWithVector::prepare();
     141         757 :   Value* myval = getPntrToComponent(0);
     142         757 :   if( myval->getShape()[0]==getPntrToArgument(0)->getShape()[0] && myval->getShape()[1]==getPntrToArgument(1)->getShape()[1] ) {
     143         740 :     return;
     144             :   }
     145          17 :   std::vector<std::size_t> shape(2);
     146          17 :   shape[0]=getPntrToArgument(0)->getShape()[0];
     147          17 :   shape[1]=getPntrToArgument(1)->getShape()[1];
     148          17 :   myval->setShape(shape);
     149          17 :   myval->reshapeMatrixStore( shape[1] );
     150             : }
     151             : 
     152             : template <class T>
     153         753 : void MatrixTimesMatrix<T>::calculate() {
     154         753 :   if( !getPntrToComponent(0)->isDerivativeZeroWhenValueIsZero() ) {
     155         731 :     if( getPntrToArgument(0)->getNumberOfColumns()<getPntrToArgument(0)->getShape()[1] ) {
     156           0 :       if( !doNotCalculateDerivatives() ) {
     157           0 :         error("cannot calculate derivatives for this action with sparse matrices");
     158           0 :       } else if( getName()=="DISSIMILARITIES" ) {
     159           0 :         error("cannot calculate dissimilarities for sparse matrices");
     160             :       }
     161             :     }
     162         731 :     if( getPntrToArgument(1)->getNumberOfColumns()<getPntrToArgument(1)->getShape()[1] ) {
     163           6 :       if( !doNotCalculateDerivatives() ) {
     164           0 :         error("cannot calculate derivatives for this action with sparse matrices");
     165           6 :       } else if( getName()=="DISSIMILARITIES" ) {
     166           0 :         error("cannot calculate dissimilarities for sparse matrices");
     167             :       }
     168             :     }
     169             :   }
     170         753 :   updateBookeepingArrays( taskmanager.getActionInput().outmat );
     171         753 :   taskmanager.setupParallelTaskManager( 2*getPntrToArgument(0)->getNumberOfColumns(), getPntrToArgument(1)->getNumberOfStoredValues() );
     172         753 :   taskmanager.setWorkspaceSize( 2*getPntrToArgument(0)->getNumberOfColumns() );
     173         753 :   taskmanager.runAllTasks();
     174         753 : }
     175             : 
     176             : template <class T>
     177       71433 : void MatrixTimesMatrix<T>::performTask( std::size_t task_index,
     178             :                                         const MatrixTimesMatrixInput<T>& actiondata,
     179             :                                         ParallelActionsInput& input,
     180             :                                         ParallelActionsOutput& output ) {
     181       71433 :   auto arg0=ArgumentBookeepingHolder::create( 0, input );
     182       71433 :   auto arg1=ArgumentBookeepingHolder::create( 1, input );
     183       71433 :   std::size_t fpos = task_index*(1+arg0.ncols);
     184       71433 :   std::size_t nmult = arg0.bookeeping[fpos];
     185       71433 :   std::size_t vstart = task_index*arg0.ncols;
     186             :   InputVectors vectors( nmult, output.buffer.data() );
     187       71433 :   if( arg1.ncols<arg1.shape[1] ) {
     188         244 :     std::size_t fstart = task_index*(1+actiondata.outmat.ncols);
     189             :     std::size_t nelements = actiondata.outmat[fstart];
     190        2425 :     for(unsigned i=0; i<nelements; ++i) {
     191             :       std::size_t nm = 0;
     192      262316 :       for(unsigned j=0; j<nmult; ++j) {
     193      260135 :         std::size_t kind = arg0.bookeeping[fpos+1+j];
     194      260135 :         std::size_t bstart = kind*(arg1.ncols + 1);
     195      260135 :         std::size_t nr = arg1.bookeeping[bstart];
     196      539226 :         for(unsigned k=0; k<nr; ++k) {
     197      288695 :           if( arg1.bookeeping[bstart+1+k]==actiondata.outmat[fstart+1+i] ) {
     198        9604 :             nm++;
     199        9604 :             break;
     200             :           }
     201             :         }
     202             :       }
     203        2181 :       vectors.nelem = nm;
     204             :       nm = 0;
     205      262316 :       for(unsigned j=0; j<nmult; ++j) {
     206      260135 :         std::size_t kind = arg0.bookeeping[fpos+1+j];
     207      260135 :         std::size_t bstart = kind*(arg1.ncols + 1);
     208      260135 :         std::size_t nr = arg1.bookeeping[bstart];
     209      539226 :         for(unsigned k=0; k<nr; ++k) {
     210      288695 :           if( arg1.bookeeping[bstart+1+k]==actiondata.outmat[fstart+1+i] ) {
     211        9604 :             vectors.arg1[nm] = input.inputdata[ vstart + j ];
     212        9604 :             vectors.arg2[nm] = input.inputdata[ arg1.start + kind*arg1.ncols + k ];
     213        9604 :             nm++;
     214        9604 :             break;
     215             :           }
     216             :         }
     217             :       }
     218        2181 :       MatrixElementOutput elem( 1, 2*nmult, output.values.data() + i, output.derivatives.data() + 2*nmult*i );
     219        2181 :       T::calculate( input.noderiv, actiondata.funcinput, vectors, elem );
     220      252712 :       for(unsigned ii=vectors.nelem; ii<nmult; ++ii) {
     221      250531 :         elem.derivs[0][ii] = 0;
     222             :       }
     223             :     }
     224             :   } else {
     225             :     // Retrieve the row of the first matrix
     226     1796906 :     for(unsigned i=0; i<nmult; ++i) {
     227     1725717 :       vectors.arg1[i] = input.inputdata[ vstart + i ];
     228             :     }
     229             : 
     230             :     // Now do our multiplications
     231       71189 :     std::size_t fstart = task_index*(1+actiondata.outmat.ncols);
     232             :     std::size_t nelements = actiondata.outmat[fstart];
     233     5811368 :     for(unsigned i=0; i<nelements; ++i) {
     234     5740179 :       std::size_t base = arg1.start + actiondata.outmat[fstart+1+i];
     235   139900314 :       for(unsigned j=0; j<nmult; ++j) {
     236   134160135 :         vectors.arg2[j] = input.inputdata[ base + arg1.ncols*arg0.bookeeping[fpos+1+j] ];
     237             :       }
     238     5740179 :       MatrixElementOutput elem( 1, 2*nmult, output.values.data() + i, output.derivatives.data() + 2*nmult*i );
     239     5740179 :       T::calculate( input.noderiv, actiondata.funcinput, vectors, elem );
     240             :     }
     241             :   }
     242       71433 : }
     243             : 
     244             : template <class T>
     245         691 : void MatrixTimesMatrix<T>::applyNonZeroRankForces( std::vector<double>& outforces ) {
     246         691 :   taskmanager.applyForces( outforces );
     247         691 : }
     248             : 
     249             : template <class T>
     250        3404 : int MatrixTimesMatrix<T>::getNumberOfValuesPerTask( std::size_t task_index,
     251             :     const MatrixTimesMatrixInput<T>& actiondata ) {
     252        3404 :   std::size_t fstart = task_index*(1+actiondata.outmat.ncols);
     253        3404 :   return actiondata.outmat[fstart];
     254             : }
     255             : 
     256             : template <class T>
     257      124671 : void MatrixTimesMatrix<T>::getForceIndices( std::size_t task_index,
     258             :     std::size_t colno,
     259             :     std::size_t ntotal_force,
     260             :     const MatrixTimesMatrixInput<T>& actiondata,
     261             :     const ParallelActionsInput& input,
     262             :     ForceIndexHolder force_indices ) {
     263      124671 :   auto arg0=ArgumentBookeepingHolder::create( 0, input );
     264      124671 :   auto arg1=ArgumentBookeepingHolder::create( 1, input );
     265      124671 :   std::size_t fpos = task_index*(1+arg0.ncols);
     266      124671 :   std::size_t nmult = arg0.bookeeping[fpos];
     267      124671 :   std::size_t fstart = task_index*(1+actiondata.outmat.ncols);
     268      124671 :   if( arg1.ncols<arg1.shape[1] ) {
     269             :     std::size_t nmult_r = 0;
     270        5831 :     for(unsigned j=0; j<nmult; ++j) {
     271        4998 :       std::size_t kind = arg0.bookeeping[fpos+1+j];
     272        4998 :       std::size_t bstart = kind*(arg1.ncols + 1);
     273        4998 :       std::size_t nr = arg1.bookeeping[bstart];
     274       19992 :       for(unsigned k=0; k<nr; ++k) {
     275       19278 :         if( arg1.bookeeping[bstart+1+k]==actiondata.outmat[fstart+1+colno] ) {
     276        4284 :           nmult_r++;
     277        4284 :           break;
     278             :         }
     279             :       }
     280             :     }
     281             :     std::size_t n = 0;
     282        5831 :     for(unsigned j=0; j<nmult; ++j) {
     283        4998 :       std::size_t kind = arg0.bookeeping[fpos+1+j];
     284        4998 :       std::size_t bstart = kind*(arg1.ncols + 1);
     285        4998 :       std::size_t nr = arg1.bookeeping[bstart];
     286       19992 :       for(unsigned k=0; k<nr; ++k) {
     287       19278 :         if( arg1.bookeeping[bstart+1+k]==actiondata.outmat[fstart+1+colno] ) {
     288        4284 :           force_indices.indices[0][n] = task_index*arg0.ncols + j;
     289        4284 :           force_indices.indices[0][nmult+n] = arg1.start + arg0.bookeeping[fpos+1+j]*arg1.ncols + k;
     290        4284 :           n++;
     291        4284 :           break;
     292             :         }
     293             :       }
     294             :     }
     295         833 :     force_indices.threadsafe_derivatives_end[0] = nmult_r;
     296         833 :     force_indices.tot_indices[0] = nmult + nmult_r;
     297             :   } else {
     298      939230 :     for(unsigned j=0; j<nmult; ++j) {
     299      815392 :       force_indices.indices[0][j] = task_index*arg0.ncols + j;
     300      815392 :       force_indices.indices[0][nmult+j] = arg1.start + arg0.bookeeping[fpos+1+j]*arg1.ncols + actiondata.outmat[fstart+1+colno];
     301             :     }
     302      123838 :     force_indices.threadsafe_derivatives_end[0] = nmult;
     303      123838 :     force_indices.tot_indices[0] = 2*nmult;
     304             :   }
     305      124671 : }
     306             : 
     307             : }
     308             : }
     309             : #endif

Generated by: LCOV version 1.16