Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2011-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #ifndef __PLUMED_matrixtools_MatrixTimesVectorBase_h
23 : #define __PLUMED_matrixtools_MatrixTimesVectorBase_h
24 :
25 : #include "core/ActionWithVector.h"
26 : #include "core/ParallelTaskManager.h"
27 :
28 : namespace PLMD {
29 : namespace matrixtools {
30 :
31 1460 : class MatrixTimesVectorData {
32 : public:
33 : std::size_t fshift;
34 : Matrix<std::size_t> pairs;
35 : #ifdef __PLUMED_USE_OPENACC
36 : void toACCDevice() const {
37 : #pragma acc enter data copyin(this[0:1],fshift)
38 : pairs.toACCDevice();
39 : }
40 : void removeFromACCDevice() const {
41 : pairs.removeFromACCDevice();
42 : #pragma acc exit data delete(fshift,this[0:1])
43 : }
44 : #endif //__PLUMED_USE_OPENACC
45 : };
46 :
47 : class MatrixForceIndexInput {
48 : public:
49 : std::size_t rowlen;
50 : View<const std::size_t> indices;
51 1820421 : MatrixForceIndexInput( std::size_t task_index,
52 : std::size_t ipair,
53 : const MatrixTimesVectorData& actiondata,
54 1820421 : const ParallelActionsInput& input ):
55 1820421 : rowlen(input.bookeeping[input.bookstarts[actiondata.pairs[ipair][0]]
56 1820421 : + (1+input.ncols[actiondata.pairs[ipair][0]])*task_index]),
57 1820421 : indices(input.bookeeping + input.bookstarts[actiondata.pairs[ipair][0]]
58 1820421 : + (1+input.ncols[actiondata.pairs[ipair][0]])*task_index + 1,
59 1820421 : rowlen) {}
60 : };
61 :
62 : class MatrixTimesVectorInput {
63 : public:
64 : bool noderiv;
65 : std::size_t rowlen;
66 : View<const std::size_t> indices;
67 : View<const double> matrow;
68 : View<const double> vector;
69 10144979 : MatrixTimesVectorInput( std::size_t task_index,
70 : std::size_t ipair,
71 : const MatrixTimesVectorData& actiondata,
72 : const ParallelActionsInput& input,
73 10144979 : double* argdata ):
74 10144979 : noderiv(input.noderiv),
75 10144979 : rowlen(input.bookeeping[input.bookstarts[actiondata.pairs[ipair][0]]
76 10144979 : + (1+input.ncols[actiondata.pairs[ipair][0]])*task_index]),
77 10144979 : indices(input.bookeeping + input.bookstarts[actiondata.pairs[ipair][0]]
78 10144979 : + (1+input.ncols[actiondata.pairs[ipair][0]])*task_index + 1,rowlen),
79 10144979 : matrow(argdata + input.argstarts[actiondata.pairs[ipair][0]]
80 10144979 : + task_index*input.ncols[actiondata.pairs[ipair][0]],rowlen),
81 10144979 : vector(argdata + input.argstarts[actiondata.pairs[ipair][1]], input.shapedata[1]) {
82 10144979 : }
83 : };
84 :
85 : class MatrixTimesVectorOutput {
86 : public:
87 : std::size_t rowlen;
88 : View<double,1> values;
89 : View<double> matrow_deriv;
90 : View<double> vector_deriv;
91 10144979 : MatrixTimesVectorOutput( std::size_t task_index,
92 : std::size_t ipair,
93 : std::size_t nder,
94 : const MatrixTimesVectorData& actiondata,
95 : const ParallelActionsInput& input,
96 10144979 : ParallelActionsOutput& output ):
97 10144979 : rowlen(input.bookeeping[input.bookstarts[actiondata.pairs[ipair][0]]
98 10144979 : + (1+input.ncols[actiondata.pairs[ipair][0]])*task_index]),
99 10144979 : values(output.values.data()+ipair),
100 10144979 : matrow_deriv(output.derivatives.data()+ipair*nder,rowlen),
101 10144979 : vector_deriv(output.derivatives.data()+ipair*nder+rowlen,rowlen) {
102 10144979 : }
103 : };
104 :
105 : template <class T>
106 : class MatrixTimesVectorBase : public ActionWithVector {
107 : public:
108 : using input_type = MatrixTimesVectorData;
109 : using PTM = ParallelTaskManager<MatrixTimesVectorBase<T>>;
110 : private:
111 : /// The parallel task manager
112 : PTM taskmanager;
113 : public:
114 : static void registerKeywords( Keywords& keys );
115 : static void registerLocalKeywords( Keywords& keys );
116 : explicit MatrixTimesVectorBase(const ActionOptions&);
117 : std::string getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const override ;
118 : unsigned getNumberOfDerivatives() override ;
119 : void prepare() override ;
120 : void calculate() override ;
121 : void applyNonZeroRankForces( std::vector<double>& outforces ) override ;
122 : int checkTaskIsActive( const unsigned& itask ) const override ;
123 : /// Override this so we write the graph properly
124 10 : std::string writeInGraph() const override {
125 10 : return "MATRIX_VECTOR_PRODUCT";
126 : }
127 : static void performTask( std::size_t task_index,
128 : const MatrixTimesVectorData& actiondata,
129 : ParallelActionsInput& input,
130 : ParallelActionsOutput& output );
131 : static int getNumberOfValuesPerTask( std::size_t task_index,
132 : const MatrixTimesVectorData& actiondata );
133 : static void getForceIndices( std::size_t task_index,
134 : std::size_t colno,
135 : std::size_t ntotal_force,
136 : const MatrixTimesVectorData& actiondata,
137 : const ParallelActionsInput& input,
138 : ForceIndexHolder force_indices );
139 : };
140 :
141 : template <class T>
142 738 : void MatrixTimesVectorBase<T>::registerKeywords( Keywords& keys ) {
143 738 : ActionWithVector::registerKeywords(keys);
144 738 : keys.setDisplayName("MATRIX_VECTOR_PRODUCT");
145 738 : registerLocalKeywords( keys );
146 738 : ActionWithValue::useCustomisableComponents(keys);
147 738 : keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log");
148 738 : }
149 :
150 : template <class T>
151 1384 : void MatrixTimesVectorBase<T>::registerLocalKeywords( Keywords& keys ) {
152 1384 : PTM::registerKeywords( keys );
153 2768 : keys.addInputKeyword("compulsory","ARG","matrix/vector/scalar","the label for the matrix and the vector/scalar that are being multiplied. Alternatively, you can provide labels for multiple matrices and a single vector or labels for a single matrix and multiple vectors. In these cases multiple matrix vector products will be computed.");
154 1384 : keys.add("hidden","MASKED_INPUT_ALLOWED","turns on that you are allowed to use masked inputs ");
155 2768 : keys.setValueDescription("vector","the vector that is obtained by taking the product between the matrix and the vector that were input");
156 1384 : ActionWithValue::useCustomisableComponents(keys);
157 1384 : }
158 :
159 : template <class T>
160 6 : std::string MatrixTimesVectorBase<T>::getOutputComponentDescription( const std::string& cname,
161 : const Keywords& keys ) const {
162 6 : if( getPntrToArgument(1)->getRank()==1 ) {
163 0 : for(unsigned i=1; i<getNumberOfArguments(); ++i) {
164 0 : if( getPntrToArgument(i)->getName().find(cname)!=std::string::npos ) {
165 0 : return "the product of the matrix " + getPntrToArgument(0)->getName() + " and the vector " + getPntrToArgument(i)->getName();
166 : }
167 : }
168 : }
169 21 : for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
170 21 : if( getPntrToArgument(i)->getName().find(cname)!=std::string::npos ) {
171 12 : return "the product of the matrix " + getPntrToArgument(i)->getName() + " and the vector " + getPntrToArgument(getNumberOfArguments()-1)->getName();
172 : }
173 : }
174 0 : plumed_merror( "could not understand request for component " + cname );
175 : return "";
176 : }
177 :
178 : template <class T>
179 365 : MatrixTimesVectorBase<T>::MatrixTimesVectorBase(const ActionOptions&ao):
180 : Action(ao),
181 : ActionWithVector(ao),
182 365 : taskmanager(this) {
183 365 : if( getNumberOfArguments()<2 ) {
184 0 : error("Not enough arguments specified");
185 : }
186 : bool vectormask=false, derivbool = true;
187 1917 : for(unsigned i=0; i<getNumberOfArguments(); ++i) {
188 1552 : if( getPntrToArgument(i)->hasDerivatives() ) {
189 0 : error("arguments should be vectors or matrices");
190 : }
191 1552 : if( getPntrToArgument(i)->getRank()<=1 ) {
192 550 : ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(i)->getPntrToAction() );
193 550 : if( av && av->getNumberOfMasks()>=0 ) {
194 : vectormask=true;
195 : }
196 : }
197 1552 : if( !getPntrToArgument(i)->isDerivativeZeroWhenValueIsZero() ) {
198 : derivbool = false;
199 : }
200 : }
201 365 : if( !vectormask ) {
202 365 : ignoreMaskArguments();
203 : }
204 :
205 365 : std::vector<std::size_t> shape(1);
206 365 : shape[0]=getPntrToArgument(0)->getShape()[0];
207 365 : if( getNumberOfArguments()==2 ) {
208 313 : addValue( shape );
209 313 : setNotPeriodic();
210 313 : if( derivbool ) {
211 233 : getPntrToComponent(0)->setDerivativeIsZeroWhenValueIsZero();
212 : }
213 : } else {
214 : unsigned namestart=1, nameend=getNumberOfArguments();
215 52 : if( getPntrToArgument(1)->getRank()==2 ) {
216 : namestart = 0;
217 43 : nameend = getNumberOfArguments()-1;
218 : }
219 :
220 926 : for(unsigned i=namestart; i<nameend; ++i) {
221 874 : std::string name = getPntrToArgument(i)->getName();
222 874 : if( name.find_first_of(".")!=std::string::npos ) {
223 680 : std::size_t dot=name.find_first_of(".");
224 1360 : name = name.substr(dot+1);
225 : }
226 874 : addComponent( name, shape );
227 874 : componentIsNotPeriodic( name );
228 874 : if( derivbool ) {
229 1740 : copyOutput( getLabel() + "." + name )->setDerivativeIsZeroWhenValueIsZero();
230 : }
231 : }
232 : }
233 : // This sets up an array in the parallel task manager to hold all the indices
234 : // Sets up the index list in the task manager
235 365 : std::size_t nder = getPntrToArgument(getNumberOfArguments()-1)->getNumberOfStoredValues();
236 : MatrixTimesVectorData input;
237 365 : input.pairs.resize( getNumberOfArguments()-1, 2 );
238 365 : if( getPntrToArgument(1)->getRank()==2 ) {
239 723 : for(unsigned i=1; i<getNumberOfArguments(); ++i) {
240 680 : input.pairs[i-1][0] = i-1;
241 680 : input.pairs[i-1][1] = getNumberOfArguments()-1;
242 : }
243 43 : input.fshift=0;
244 : } else {
245 829 : for(unsigned i=1; i<getNumberOfArguments(); ++i) {
246 507 : input.pairs[i-1][0] = 0;
247 507 : input.pairs[i-1][1] = i;
248 : }
249 322 : input.fshift=nder;
250 : }
251 365 : taskmanager.setActionInput( input );
252 365 : }
253 :
254 : template <class T>
255 28652 : unsigned MatrixTimesVectorBase<T>::getNumberOfDerivatives() {
256 : unsigned nderivatives=0;
257 695284 : for(unsigned i=0; i<getNumberOfArguments(); ++i) {
258 666632 : nderivatives += getPntrToArgument(i)->getNumberOfStoredValues();
259 : }
260 28652 : return nderivatives;
261 : }
262 :
263 : template <class T>
264 13715 : void MatrixTimesVectorBase<T>::prepare() {
265 13715 : ActionWithVector::prepare();
266 13715 : Value* myval = getPntrToComponent(0);
267 13715 : if( myval->getShape()[0]==getPntrToArgument(0)->getShape()[0] ) {
268 13705 : return;
269 : }
270 10 : std::vector<std::size_t> shape(1);
271 10 : shape[0] = getPntrToArgument(0)->getShape()[0];
272 10 : myval->setShape(shape);
273 : }
274 :
275 : template <class T>
276 13708 : void MatrixTimesVectorBase<T>::calculate() {
277 13708 : std::size_t nvectors, nder = getPntrToArgument(getNumberOfArguments()-1)->getNumberOfStoredValues();
278 13708 : if( getPntrToArgument(1)->getRank()==2 ) {
279 : nvectors = 1;
280 : } else {
281 13585 : nvectors = getNumberOfArguments()-1;
282 : }
283 13708 : if( getName()=="MATRIX_VECTOR_PRODUCT_ROWSUMS" ) {
284 11317 : taskmanager.setupParallelTaskManager( nder, 0 );
285 : } else {
286 2391 : taskmanager.setupParallelTaskManager( 2*nder, nvectors*nder );
287 : }
288 13708 : taskmanager.runAllTasks();
289 13708 : }
290 :
291 : template <class T>
292 1824672 : int MatrixTimesVectorBase<T>::checkTaskIsActive( const unsigned& itask ) const {
293 2516128 : for(unsigned i=0; i<getNumberOfArguments(); ++i) {
294 : Value* myarg = getPntrToArgument(i);
295 2516128 : if( myarg->getRank()==1 && !myarg->hasDerivatives() ) {
296 : return 0;
297 1824673 : } else if( myarg->getRank()==2 && !myarg->hasDerivatives() ) {
298 1824673 : if (myarg->checkValueIsActiveForMMul(itask)) {
299 : return 1;
300 : }
301 : } else {
302 0 : plumed_merror("should not be in action " + getName() );
303 : }
304 : }
305 : return 0;
306 : }
307 :
308 : template <class T>
309 1513469 : void MatrixTimesVectorBase<T>::performTask( std::size_t task_index,
310 : const MatrixTimesVectorData& actiondata,
311 : ParallelActionsInput& input,
312 : ParallelActionsOutput& output ) {
313 11658448 : for(unsigned i=0; i<actiondata.pairs.nrows(); ++i) {
314 10144979 : MatrixTimesVectorOutput doutput( task_index,
315 : i,
316 : input.nderivatives_per_scalar,
317 : actiondata,
318 : input,
319 : output );
320 10144979 : T::performTask( MatrixTimesVectorInput( task_index,
321 : i,
322 : actiondata,
323 : input,
324 : input.inputdata ),
325 : doutput );
326 : }
327 1513469 : }
328 :
329 : template <class T>
330 11890 : void MatrixTimesVectorBase<T>::applyNonZeroRankForces( std::vector<double>& outforces ) {
331 11890 : taskmanager.applyForces( outforces );
332 11890 : }
333 :
334 : template <class T>
335 443916 : int MatrixTimesVectorBase<T>::getNumberOfValuesPerTask( std::size_t task_index,
336 : const MatrixTimesVectorData& actiondata ) {
337 443916 : return 1;
338 : }
339 :
340 : template <class T>
341 443916 : void MatrixTimesVectorBase<T>::getForceIndices( std::size_t task_index,
342 : std::size_t colno,
343 : std::size_t ntotal_force,
344 : const MatrixTimesVectorData& actiondata,
345 : const ParallelActionsInput& input,
346 : ForceIndexHolder force_indices ) {
347 2264337 : for(unsigned i=0; i<actiondata.pairs.nrows(); ++i) {
348 1820421 : std::size_t base = input.argstarts[actiondata.pairs[i][0]]
349 1820421 : + task_index*input.ncols[actiondata.pairs[i][0]];
350 1820421 : std::size_t n = input.bookeeping[input.bookstarts[actiondata.pairs[i][0]]
351 1820421 : + (1+input.ncols[actiondata.pairs[i][0]])*task_index];
352 76718715 : for(unsigned j=0; j<n; ++j) {
353 74898294 : force_indices.indices[i][j] = base + j;
354 : }
355 1820421 : force_indices.threadsafe_derivatives_end[i] = n;
356 1820421 : force_indices.tot_indices[i] = T::getAdditionalIndices( n,
357 124096 : input.argstarts[actiondata.pairs[i][1]],
358 1944517 : MatrixForceIndexInput( task_index, i, actiondata, input ),
359 : force_indices.indices[i] );
360 : }
361 443916 : }
362 :
363 : }
364 : }
365 : #endif
|