LCOV - code coverage report
Current view: top level - matrixtools - OuterProduct.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 92 100 92.0 %
Date: 2025-12-04 11:19:34 Functions: 27 40 67.5 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2011-2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #ifndef __PLUMED_matrixtools_OuterProduct_h
      23             : #define __PLUMED_matrixtools_OuterProduct_h
      24             : 
      25             : #include "core/ActionWithMatrix.h"
      26             : #include "core/ParallelTaskManager.h"
      27             : 
      28             : namespace PLMD {
      29             : namespace matrixtools {
      30             : 
      31             : template <class T>
      32           8 : class OuterProductInput {
      33             : public:
      34             :   T funcinput;
      35             :   RequiredMatrixElements outmat;
      36             : };
      37             : 
      38             : template <class T>
      39             : class OuterProductBase : public ActionWithMatrix {
      40             : public:
      41             :   using input_type = OuterProductInput<T>;
      42             :   using PTM = ParallelTaskManager<OuterProductBase<T>>;
      43             : private:
      44             :   bool isproduct;
      45             :   PTM taskmanager;
      46             : public:
      47             :   static void registerKeywords( Keywords& keys );
      48             :   explicit OuterProductBase(const ActionOptions&);
      49             :   unsigned getNumberOfDerivatives() override;
      50             :   void prepare() override ;
      51             :   int checkTaskIsActive( const unsigned& itask ) const override ;
      52             :   void calculate() override ;
      53             :   void applyNonZeroRankForces( std::vector<double>& outforces ) override ;
      54             :   static void performTask( std::size_t task_index,
      55             :                            const OuterProductInput<T>& actiondata,
      56             :                            ParallelActionsInput& input,
      57             :                            ParallelActionsOutput& output );
      58             :   static int getNumberOfValuesPerTask( std::size_t task_index,
      59             :                                        const OuterProductInput<T>& actiondata );
      60             :   static void getForceIndices( std::size_t task_index,
      61             :                                std::size_t colno,
      62             :                                std::size_t ntotal_force,
      63             :                                const OuterProductInput<T>& actiondata,
      64             :                                const ParallelActionsInput& input,
      65             :                                ForceIndexHolder force_indices );
      66             : };
      67             : 
      68             : template <class T>
      69         177 : void OuterProductBase<T>::registerKeywords( Keywords& keys ) {
      70         177 :   ActionWithMatrix::registerKeywords(keys);
      71         177 :   T::registerKeywords( keys );
      72         177 :   PTM::registerKeywords( keys );
      73         177 :   keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log");
      74         177 : }
      75             : 
      76             : template <class T>
      77          87 : OuterProductBase<T>::OuterProductBase(const ActionOptions&ao):
      78             :   Action(ao),
      79             :   ActionWithMatrix(ao),
      80          87 :   isproduct(false),
      81          87 :   taskmanager(this) {
      82             :   unsigned nargs=getNumberOfArguments();
      83             :   if( getNumberOfMasks()>0 ) {
      84          21 :     nargs = nargs - getNumberOfMasks();
      85             :   }
      86          87 :   if( nargs%2!=0 ) {
      87           0 :     error("should be an even number of arguments to this action, they should all be vectors");
      88             :   }
      89          87 :   std::size_t nvals = nargs / 2;
      90         192 :   for(unsigned i=0; i<nvals; ++i) {
      91         105 :     if( getPntrToArgument(i)->getRank()!=1 || getPntrToArgument(i)->hasDerivatives() ) {
      92           0 :       error("first argument to this action should be a vector");
      93             :     }
      94         105 :     if( getPntrToArgument(0)->getShape()[0]!=getPntrToArgument(i)->getShape()[0] ) {
      95           0 :       error("mismatch between sizes of input vectors");
      96             :     }
      97         105 :     if( getPntrToArgument(nvals+i)->getRank()!=1 || getPntrToArgument(nvals+i)->hasDerivatives() ) {
      98           0 :       error("first argument to this action should be a vector");
      99             :     }
     100             :     if( getPntrToArgument(nvals)->getShape()[0]!=getPntrToArgument(nvals+i)->getShape()[0] ) {
     101           0 :       error("mismatch between sizes of input vectors");
     102             :     }
     103             :   }
     104          87 :   if( getNumberOfMasks()==1 ) {
     105          21 :     if( getPntrToArgument(nargs)->getRank()!=2 || getPntrToArgument(nargs)->hasDerivatives() ) {
     106           0 :       error("mask argument should be a matrix");
     107             :     }
     108          54 :     for(unsigned i=0; i<nvals; ++i) {
     109          33 :       if( getPntrToArgument(nargs)->getShape()[0]!=getPntrToArgument(i)->getShape()[0] ) {
     110           0 :         error("mask argument has wrong size");
     111             :       }
     112          33 :       if( getPntrToArgument(nargs)->getShape()[1]!=getPntrToArgument(nvals+i)->getShape()[0] ) {
     113           0 :         error("mask argument has wrong size");
     114             :       }
     115             :     }
     116             :   }
     117             : 
     118          87 :   std::vector<std::size_t> shape(2);
     119          87 :   shape[0]=getPntrToArgument(0)->getShape()[0];
     120          87 :   shape[1]=getPntrToArgument(nvals)->getShape()[0];
     121             : 
     122             :   std::string func;
     123          87 :   if( keywords.exists("FUNC") ) {
     124         162 :     parse("FUNC",func);
     125          81 :     isproduct=(func=="x*y");
     126             :   }
     127          79 :   OuterProductInput<T> actiondata;
     128          87 :   actiondata.funcinput.setup( shape, func, this );
     129             : 
     130          87 :   if( getNumberOfComponents()==0 ) {
     131          81 :     addValue( shape );
     132          81 :     setNotPeriodic();
     133             :   }
     134          87 :   if( getPntrToArgument(0)->isDerivativeZeroWhenValueIsZero() || getPntrToArgument(nvals)->isDerivativeZeroWhenValueIsZero() ) {
     135         114 :     for(unsigned i=0; i<getNumberOfComponents(); ++i) {
     136          57 :       getPntrToComponent(i)->setDerivativeIsZeroWhenValueIsZero();
     137             :     }
     138             :   }
     139          87 :   taskmanager.setActionInput( actiondata );
     140         166 : }
     141             : 
     142             : template <class T>
     143         178 : unsigned OuterProductBase<T>::getNumberOfDerivatives() {
     144         178 :   unsigned nc = getNumberOfComponents();
     145         178 :   return nc*(getPntrToArgument(0)->getNumberOfStoredValues() + getPntrToArgument(nc)->getNumberOfStoredValues());
     146             : }
     147             : 
     148             : template <class T>
     149         201 : void OuterProductBase<T>::prepare() {
     150         201 :   ActionWithVector::prepare();
     151         201 :   Value* myval=getPntrToComponent(0);
     152         201 :   if( myval->getShape()[0]==getPntrToArgument(0)->getShape()[0] && myval->getShape()[1]==getPntrToArgument(getNumberOfComponents())->getShape()[0] ) {
     153         184 :     return;
     154             :   }
     155          17 :   std::vector<std::size_t> shape(2);
     156          17 :   shape[0] = getPntrToArgument(0)->getShape()[0];
     157          17 :   shape[1] = getPntrToArgument(getNumberOfComponents())->getShape()[0];
     158          34 :   for(unsigned i=0; i<getNumberOfComponents(); ++i) {
     159          17 :     getPntrToComponent(i)->setShape( shape );
     160             :   }
     161             : }
     162             : 
     163             : template <class T>
     164      366938 : int OuterProductBase<T>::checkTaskIsActive( const unsigned& itask ) const {
     165      366938 :   if( getNumberOfMasks()>0 || !isproduct ) {
     166      108212 :     return ActionWithVector::checkTaskIsActive( itask );
     167             :   }
     168      258726 :   if( fabs( getPntrToArgument(0)->get(itask))>epsilon ) {
     169      172473 :     return 1;
     170             :   }
     171             :   return -1;
     172             : }
     173             : 
     174             : template <class T>
     175         199 : void OuterProductBase<T>::calculate() {
     176         199 :   updateBookeepingArrays( taskmanager.getActionInput().outmat );
     177         199 :   taskmanager.setupParallelTaskManager( 2*getNumberOfComponents(), getNumberOfComponents()*getPntrToComponent(0)->getShape()[1] );
     178         199 :   taskmanager.setWorkspaceSize( 2*getNumberOfComponents() );
     179         199 :   taskmanager.runAllTasks();
     180         199 : }
     181             : 
     182             : template <class T>
     183      185110 : void OuterProductBase<T>::performTask( std::size_t task_index,
     184             :                                        const OuterProductInput<T>& actiondata,
     185             :                                        ParallelActionsInput& input,
     186             :                                        ParallelActionsOutput& output ) {
     187      185110 :   auto args = output.buffer.subview(0, 2*input.ncomponents);
     188      381776 :   for(unsigned i=0; i<input.ncomponents; ++i) {
     189      196666 :     args[i] = input.inputdata[input.argstarts[i] + task_index];
     190             :   }
     191      185110 :   unsigned fstart = task_index*(1+actiondata.outmat.ncols);
     192      185110 :   unsigned nelements = actiondata.outmat[fstart];
     193    11628321 :   for(unsigned i=0; i<nelements; ++i) {
     194    11443211 :     std::size_t argpos = actiondata.outmat[fstart+1+i];
     195    24777601 :     for(unsigned j=0; j<input.ncomponents; ++j) {
     196    13334390 :       args[input.ncomponents+j] = input.inputdata[input.argstarts[input.ncomponents+j] + argpos];
     197             :     }
     198    11443211 :     MatrixElementOutput matout( input.ncomponents,
     199             :                                 2*input.ncomponents,
     200    11443211 :                                 output.values.data()+i*input.ncomponents,
     201    11443211 :                                 output.derivatives.data() + 2*i*input.ncomponents*input.ncomponents );
     202    11443211 :     T::calculate( input.noderiv, actiondata.funcinput, {args.data(),args.size()}, matout );
     203             :   }
     204      185110 : }
     205             : 
     206             : template <class T>
     207          51 : void OuterProductBase<T>::applyNonZeroRankForces( std::vector<double>& outforces ) {
     208          51 :   taskmanager.applyForces( outforces );
     209          51 : }
     210             : 
     211             : template <class T>
     212        8831 : int OuterProductBase<T>::getNumberOfValuesPerTask( std::size_t task_index,
     213             :     const OuterProductInput<T>& actiondata ) {
     214        8831 :   unsigned fstart = task_index*(1+actiondata.outmat.ncols);
     215        8831 :   return actiondata.outmat[fstart];
     216             : }
     217             : 
     218             : template <class T>
     219      573184 : void OuterProductBase<T>::getForceIndices( std::size_t task_index,
     220             :     std::size_t colno,
     221             :     std::size_t ntotal_force,
     222             :     const OuterProductInput<T>& actiondata,
     223             :     const ParallelActionsInput& input,
     224             :     ForceIndexHolder force_indices ) {
     225      573184 :   unsigned fstart = task_index*(1+actiondata.outmat.ncols);
     226     1807256 :   for(unsigned j=0; j<input.ncomponents; ++j) {
     227     5111696 :     for(unsigned k=0; k<input.ncomponents; ++k) {
     228     3877624 :       force_indices.indices[j][k] = input.argstarts[k] + task_index;
     229     3877624 :       force_indices.indices[j][input.ncomponents+k] = input.argstarts[input.ncomponents+k] + actiondata.outmat[fstart+1+colno];
     230             :     }
     231     1234072 :     force_indices.threadsafe_derivatives_end[j] = input.ncomponents;
     232     1234072 :     force_indices.tot_indices[j] = 2*input.ncomponents;
     233             :   }
     234      573184 : }
     235             : 
     236             : }
     237             : }
     238             : #endif

Generated by: LCOV version 1.16