Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2011-2023 The plumed team 3 : (see the PEOPLE file at the root of the distribution for a list of names) 4 : 5 : See http://www.plumed.org for more information. 6 : 7 : This file is part of plumed, version 2. 8 : 9 : plumed is free software: you can redistribute it and/or modify 10 : it under the terms of the GNU Lesser General Public License as published by 11 : the Free Software Foundation, either version 3 of the License, or 12 : (at your option) any later version. 13 : 14 : plumed is distributed in the hope that it will be useful, 15 : but WITHOUT ANY WARRANTY; without even the implied warranty of 16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 : GNU Lesser General Public License for more details. 18 : 19 : You should have received a copy of the GNU Lesser General Public License 20 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 22 : #ifndef __PLUMED_matrixtools_OuterProduct_h 23 : #define __PLUMED_matrixtools_OuterProduct_h 24 : 25 : #include "core/ActionWithMatrix.h" 26 : #include "core/ParallelTaskManager.h" 27 : 28 : namespace PLMD { 29 : namespace matrixtools { 30 : 31 : template <class T> 32 8 : class OuterProductInput { 33 : public: 34 : T funcinput; 35 : RequiredMatrixElements outmat; 36 : }; 37 : 38 : template <class T> 39 : class OuterProductBase : public ActionWithMatrix { 40 : public: 41 : using input_type = OuterProductInput<T>; 42 : using PTM = ParallelTaskManager<OuterProductBase<T>>; 43 : private: 44 : bool isproduct; 45 : PTM taskmanager; 46 : public: 47 : static void registerKeywords( Keywords& keys ); 48 : explicit OuterProductBase(const ActionOptions&); 49 : unsigned getNumberOfDerivatives() override; 50 : void prepare() override ; 51 : int checkTaskIsActive( const unsigned& itask ) const override ; 52 : void calculate() override ; 53 : void applyNonZeroRankForces( std::vector<double>& outforces ) override ; 54 : static void performTask( std::size_t task_index, 55 : const OuterProductInput<T>& actiondata, 56 : ParallelActionsInput& input, 57 : ParallelActionsOutput& output ); 58 : static int getNumberOfValuesPerTask( std::size_t task_index, 59 : const OuterProductInput<T>& actiondata ); 60 : static void getForceIndices( std::size_t task_index, 61 : std::size_t colno, 62 : std::size_t ntotal_force, 63 : const OuterProductInput<T>& actiondata, 64 : const ParallelActionsInput& input, 65 : ForceIndexHolder force_indices ); 66 : }; 67 : 68 : template <class T> 69 177 : void OuterProductBase<T>::registerKeywords( Keywords& keys ) { 70 177 : ActionWithMatrix::registerKeywords(keys); 71 177 : T::registerKeywords( keys ); 72 177 : PTM::registerKeywords( keys ); 73 177 : keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log"); 74 177 : } 75 : 76 : template <class T> 77 87 : OuterProductBase<T>::OuterProductBase(const ActionOptions&ao): 78 : Action(ao), 79 : ActionWithMatrix(ao), 80 87 : isproduct(false), 81 87 : taskmanager(this) { 82 : unsigned nargs=getNumberOfArguments(); 83 : if( getNumberOfMasks()>0 ) { 84 21 : nargs = nargs - getNumberOfMasks(); 85 : } 86 87 : if( nargs%2!=0 ) { 87 0 : error("should be an even number of arguments to this action, they should all be vectors"); 88 : } 89 87 : std::size_t nvals = nargs / 2; 90 192 : for(unsigned i=0; i<nvals; ++i) { 91 105 : if( getPntrToArgument(i)->getRank()!=1 || getPntrToArgument(i)->hasDerivatives() ) { 92 0 : error("first argument to this action should be a vector"); 93 : } 94 105 : if( getPntrToArgument(0)->getShape()[0]!=getPntrToArgument(i)->getShape()[0] ) { 95 0 : error("mismatch between sizes of input vectors"); 96 : } 97 105 : if( getPntrToArgument(nvals+i)->getRank()!=1 || getPntrToArgument(nvals+i)->hasDerivatives() ) { 98 0 : error("first argument to this action should be a vector"); 99 : } 100 : if( getPntrToArgument(nvals)->getShape()[0]!=getPntrToArgument(nvals+i)->getShape()[0] ) { 101 0 : error("mismatch between sizes of input vectors"); 102 : } 103 : } 104 87 : if( getNumberOfMasks()==1 ) { 105 21 : if( getPntrToArgument(nargs)->getRank()!=2 || getPntrToArgument(nargs)->hasDerivatives() ) { 106 0 : error("mask argument should be a matrix"); 107 : } 108 54 : for(unsigned i=0; i<nvals; ++i) { 109 33 : if( getPntrToArgument(nargs)->getShape()[0]!=getPntrToArgument(i)->getShape()[0] ) { 110 0 : error("mask argument has wrong size"); 111 : } 112 33 : if( getPntrToArgument(nargs)->getShape()[1]!=getPntrToArgument(nvals+i)->getShape()[0] ) { 113 0 : error("mask argument has wrong size"); 114 : } 115 : } 116 : } 117 : 118 87 : std::vector<std::size_t> shape(2); 119 87 : shape[0]=getPntrToArgument(0)->getShape()[0]; 120 87 : shape[1]=getPntrToArgument(nvals)->getShape()[0]; 121 : 122 : std::string func; 123 87 : if( keywords.exists("FUNC") ) { 124 162 : parse("FUNC",func); 125 81 : isproduct=(func=="x*y"); 126 : } 127 79 : OuterProductInput<T> actiondata; 128 87 : actiondata.funcinput.setup( shape, func, this ); 129 : 130 87 : if( getNumberOfComponents()==0 ) { 131 81 : addValue( shape ); 132 81 : setNotPeriodic(); 133 : } 134 87 : if( getPntrToArgument(0)->isDerivativeZeroWhenValueIsZero() || getPntrToArgument(nvals)->isDerivativeZeroWhenValueIsZero() ) { 135 114 : for(unsigned i=0; i<getNumberOfComponents(); ++i) { 136 57 : getPntrToComponent(i)->setDerivativeIsZeroWhenValueIsZero(); 137 : } 138 : } 139 87 : taskmanager.setActionInput( actiondata ); 140 166 : } 141 : 142 : template <class T> 143 178 : unsigned OuterProductBase<T>::getNumberOfDerivatives() { 144 178 : unsigned nc = getNumberOfComponents(); 145 178 : return nc*(getPntrToArgument(0)->getNumberOfStoredValues() + getPntrToArgument(nc)->getNumberOfStoredValues()); 146 : } 147 : 148 : template <class T> 149 201 : void OuterProductBase<T>::prepare() { 150 201 : ActionWithVector::prepare(); 151 201 : Value* myval=getPntrToComponent(0); 152 201 : if( myval->getShape()[0]==getPntrToArgument(0)->getShape()[0] && myval->getShape()[1]==getPntrToArgument(getNumberOfComponents())->getShape()[0] ) { 153 184 : return; 154 : } 155 17 : std::vector<std::size_t> shape(2); 156 17 : shape[0] = getPntrToArgument(0)->getShape()[0]; 157 17 : shape[1] = getPntrToArgument(getNumberOfComponents())->getShape()[0]; 158 34 : for(unsigned i=0; i<getNumberOfComponents(); ++i) { 159 17 : getPntrToComponent(i)->setShape( shape ); 160 : } 161 : } 162 : 163 : template <class T> 164 366938 : int OuterProductBase<T>::checkTaskIsActive( const unsigned& itask ) const { 165 366938 : if( getNumberOfMasks()>0 || !isproduct ) { 166 108212 : return ActionWithVector::checkTaskIsActive( itask ); 167 : } 168 258726 : if( fabs( getPntrToArgument(0)->get(itask))>epsilon ) { 169 172473 : return 1; 170 : } 171 : return -1; 172 : } 173 : 174 : template <class T> 175 199 : void OuterProductBase<T>::calculate() { 176 199 : updateBookeepingArrays( taskmanager.getActionInput().outmat ); 177 199 : taskmanager.setupParallelTaskManager( 2*getNumberOfComponents(), getNumberOfComponents()*getPntrToComponent(0)->getShape()[1] ); 178 199 : taskmanager.setWorkspaceSize( 2*getNumberOfComponents() ); 179 199 : taskmanager.runAllTasks(); 180 199 : } 181 : 182 : template <class T> 183 185110 : void OuterProductBase<T>::performTask( std::size_t task_index, 184 : const OuterProductInput<T>& actiondata, 185 : ParallelActionsInput& input, 186 : ParallelActionsOutput& output ) { 187 185110 : auto args = output.buffer.subview(0, 2*input.ncomponents); 188 381776 : for(unsigned i=0; i<input.ncomponents; ++i) { 189 196666 : args[i] = input.inputdata[input.argstarts[i] + task_index]; 190 : } 191 185110 : unsigned fstart = task_index*(1+actiondata.outmat.ncols); 192 185110 : unsigned nelements = actiondata.outmat[fstart]; 193 11628321 : for(unsigned i=0; i<nelements; ++i) { 194 11443211 : std::size_t argpos = actiondata.outmat[fstart+1+i]; 195 24777601 : for(unsigned j=0; j<input.ncomponents; ++j) { 196 13334390 : args[input.ncomponents+j] = input.inputdata[input.argstarts[input.ncomponents+j] + argpos]; 197 : } 198 11443211 : MatrixElementOutput matout( input.ncomponents, 199 : 2*input.ncomponents, 200 11443211 : output.values.data()+i*input.ncomponents, 201 11443211 : output.derivatives.data() + 2*i*input.ncomponents*input.ncomponents ); 202 11443211 : T::calculate( input.noderiv, actiondata.funcinput, {args.data(),args.size()}, matout ); 203 : } 204 185110 : } 205 : 206 : template <class T> 207 51 : void OuterProductBase<T>::applyNonZeroRankForces( std::vector<double>& outforces ) { 208 51 : taskmanager.applyForces( outforces ); 209 51 : } 210 : 211 : template <class T> 212 8831 : int OuterProductBase<T>::getNumberOfValuesPerTask( std::size_t task_index, 213 : const OuterProductInput<T>& actiondata ) { 214 8831 : unsigned fstart = task_index*(1+actiondata.outmat.ncols); 215 8831 : return actiondata.outmat[fstart]; 216 : } 217 : 218 : template <class T> 219 573184 : void OuterProductBase<T>::getForceIndices( std::size_t task_index, 220 : std::size_t colno, 221 : std::size_t ntotal_force, 222 : const OuterProductInput<T>& actiondata, 223 : const ParallelActionsInput& input, 224 : ForceIndexHolder force_indices ) { 225 573184 : unsigned fstart = task_index*(1+actiondata.outmat.ncols); 226 1807256 : for(unsigned j=0; j<input.ncomponents; ++j) { 227 5111696 : for(unsigned k=0; k<input.ncomponents; ++k) { 228 3877624 : force_indices.indices[j][k] = input.argstarts[k] + task_index; 229 3877624 : force_indices.indices[j][input.ncomponents+k] = input.argstarts[input.ncomponents+k] + actiondata.outmat[fstart+1+colno]; 230 : } 231 1234072 : force_indices.threadsafe_derivatives_end[j] = input.ncomponents; 232 1234072 : force_indices.tot_indices[j] = 2*input.ncomponents; 233 : } 234 573184 : } 235 : 236 : } 237 : } 238 : #endif