Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2011-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #ifndef __PLUMED_matrixtools_MatrixTimesMatrix_h
23 : #define __PLUMED_matrixtools_MatrixTimesMatrix_h
24 :
25 : #include "core/ActionWithMatrix.h"
26 : #include "core/ParallelTaskManager.h"
27 :
28 : namespace PLMD {
29 : namespace matrixtools {
30 :
31 : template <class T>
32 58 : struct MatrixTimesMatrixInput {
33 : T funcinput;
34 : RequiredMatrixElements outmat;
35 : #ifdef __PLUMED_USE_OPENACC
36 : void toACCDevice() const {
37 : #pragma acc enter data copyin(this[0:1])
38 : funcinput.toACCDevice();
39 : outmat.toACCDevice();
40 : }
41 : void removeFromACCDevice() const {
42 : funcinput.removeFromACCDevice();
43 : outmat.removeFromACCDevice();
44 : #pragma acc exit data delete(this[0:1])
45 : }
46 : #endif //__PLUMED_USE_OPENACC
47 : };
48 :
49 : class InputVectors {
50 : public:
51 : std::size_t nelem;
52 : View<double> arg1;
53 : View<double> arg2;
54 71433 : InputVectors( std::size_t n, double* b ) : nelem(n), arg1(b,n), arg2(b+n,n) {}
55 : };
56 :
57 : template <class T>
58 : class MatrixTimesMatrix : public ActionWithMatrix {
59 : public:
60 : using input_type = MatrixTimesMatrixInput<T>;
61 : using PTM = ParallelTaskManager<MatrixTimesMatrix<T>>;
62 : private:
63 : PTM taskmanager;
64 : public:
65 : static void registerKeywords( Keywords& keys );
66 : explicit MatrixTimesMatrix(const ActionOptions&);
67 : void prepare() override ;
68 : unsigned getNumberOfDerivatives() override;
69 : void calculate() override ;
70 : void applyNonZeroRankForces( std::vector<double>& outforces ) override ;
71 : static void performTask( std::size_t task_index, const MatrixTimesMatrixInput<T>& actiondata, ParallelActionsInput& input, ParallelActionsOutput& output );
72 : static int getNumberOfValuesPerTask( std::size_t task_index, const MatrixTimesMatrixInput<T>& actiondata );
73 : static void getForceIndices( std::size_t task_index, std::size_t colno, std::size_t ntotal_force, const MatrixTimesMatrixInput<T>& actiondata, const ParallelActionsInput& input, ForceIndexHolder force_indices );
74 : };
75 :
76 : template <class T>
77 98 : void MatrixTimesMatrix<T>::registerKeywords( Keywords& keys ) {
78 98 : ActionWithMatrix::registerKeywords(keys);
79 196 : keys.addInputKeyword("optional","MASK","matrix","a matrix that is used to used to determine which elements of the output matrix to compute");
80 196 : keys.addInputKeyword("compulsory","ARG","matrix","the label of the two matrices from which the product is calculated");
81 196 : if( keys.getDisplayName()=="MATRIX_PRODUCT" ) {
82 77 : keys.addFlag("ELEMENTS_ON_DIAGONAL_ARE_ZERO",false,"set all diagonal elements to zero");
83 : }
84 98 : T::registerKeywords( keys );
85 98 : PTM::registerKeywords( keys );
86 98 : }
87 :
88 : template <class T>
89 58 : MatrixTimesMatrix<T>::MatrixTimesMatrix(const ActionOptions&ao):
90 : Action(ao),
91 : ActionWithMatrix(ao),
92 58 : taskmanager(this) {
93 : int nm=getNumberOfMasks();
94 : if( nm<0 ) {
95 : nm = 0;
96 : }
97 58 : if( getNumberOfArguments()-nm!=2 ) {
98 0 : error("should be two arguments to this action, a matrix and a vector");
99 : }
100 58 : if( getPntrToArgument(0)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) {
101 0 : error("first argument to this action should be a matrix");
102 : }
103 58 : if( getPntrToArgument(1)->getRank()!=2 || getPntrToArgument(1)->hasDerivatives() ) {
104 0 : error("second argument to this action should be a matrix");
105 : }
106 58 : if( getPntrToArgument(0)->getShape()[1]!=getPntrToArgument(1)->getShape()[0] ) {
107 0 : error("number of columns in first matrix does not equal number of rows in second matrix");
108 : }
109 58 : std::vector<std::size_t> shape(2);
110 58 : shape[0]=getPntrToArgument(0)->getShape()[0];
111 58 : shape[1]=getPntrToArgument(1)->getShape()[1];
112 58 : addValue( shape );
113 58 : setNotPeriodic();
114 58 : getPntrToComponent(0)->reshapeMatrixStore( shape[1] );
115 58 : if( getName()!="DISSIMILARITIES" && getPntrToArgument(0)->isDerivativeZeroWhenValueIsZero() && getPntrToArgument(1)->isDerivativeZeroWhenValueIsZero() ) {
116 6 : getPntrToComponent(0)->setDerivativeIsZeroWhenValueIsZero();
117 : }
118 :
119 58 : if( nm>0 ) {
120 14 : unsigned iarg = getNumberOfArguments()-1;
121 14 : if( getPntrToArgument(iarg)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) {
122 0 : error("argument passed to MASK keyword should be a matrix");
123 : }
124 14 : if( getPntrToArgument(iarg)->getShape()[0]!=shape[0] || getPntrToArgument(iarg)->getShape()[1]!=shape[1] ) {
125 0 : error("argument passed to MASK keyword has the wrong shape");
126 : }
127 : }
128 : MatrixTimesMatrixInput<T> actdata;
129 12 : actdata.funcinput.setup( this, getPntrToArgument(0) );
130 58 : taskmanager.setActionInput( actdata );
131 58 : }
132 :
133 : template <class T>
134 64 : unsigned MatrixTimesMatrix<T>::getNumberOfDerivatives() {
135 64 : return getPntrToArgument(0)->getNumberOfStoredValues() + getPntrToArgument(1)->getNumberOfStoredValues();
136 : }
137 :
138 : template <class T>
139 757 : void MatrixTimesMatrix<T>::prepare() {
140 757 : ActionWithVector::prepare();
141 757 : Value* myval = getPntrToComponent(0);
142 757 : if( myval->getShape()[0]==getPntrToArgument(0)->getShape()[0] && myval->getShape()[1]==getPntrToArgument(1)->getShape()[1] ) {
143 740 : return;
144 : }
145 17 : std::vector<std::size_t> shape(2);
146 17 : shape[0]=getPntrToArgument(0)->getShape()[0];
147 17 : shape[1]=getPntrToArgument(1)->getShape()[1];
148 17 : myval->setShape(shape);
149 17 : myval->reshapeMatrixStore( shape[1] );
150 : }
151 :
152 : template <class T>
153 753 : void MatrixTimesMatrix<T>::calculate() {
154 753 : if( !getPntrToComponent(0)->isDerivativeZeroWhenValueIsZero() ) {
155 731 : if( getPntrToArgument(0)->getNumberOfColumns()<getPntrToArgument(0)->getShape()[1] ) {
156 0 : if( !doNotCalculateDerivatives() ) {
157 0 : error("cannot calculate derivatives for this action with sparse matrices");
158 0 : } else if( getName()=="DISSIMILARITIES" ) {
159 0 : error("cannot calculate dissimilarities for sparse matrices");
160 : }
161 : }
162 731 : if( getPntrToArgument(1)->getNumberOfColumns()<getPntrToArgument(1)->getShape()[1] ) {
163 6 : if( !doNotCalculateDerivatives() ) {
164 0 : error("cannot calculate derivatives for this action with sparse matrices");
165 6 : } else if( getName()=="DISSIMILARITIES" ) {
166 0 : error("cannot calculate dissimilarities for sparse matrices");
167 : }
168 : }
169 : }
170 753 : updateBookeepingArrays( taskmanager.getActionInput().outmat );
171 753 : taskmanager.setupParallelTaskManager( 2*getPntrToArgument(0)->getNumberOfColumns(), getPntrToArgument(1)->getNumberOfStoredValues() );
172 753 : taskmanager.setWorkspaceSize( 2*getPntrToArgument(0)->getNumberOfColumns() );
173 753 : taskmanager.runAllTasks();
174 753 : }
175 :
176 : template <class T>
177 71433 : void MatrixTimesMatrix<T>::performTask( std::size_t task_index,
178 : const MatrixTimesMatrixInput<T>& actiondata,
179 : ParallelActionsInput& input,
180 : ParallelActionsOutput& output ) {
181 71433 : auto arg0=ArgumentBookeepingHolder::create( 0, input );
182 71433 : auto arg1=ArgumentBookeepingHolder::create( 1, input );
183 71433 : std::size_t fpos = task_index*(1+arg0.ncols);
184 71433 : std::size_t nmult = arg0.bookeeping[fpos];
185 71433 : std::size_t vstart = task_index*arg0.ncols;
186 : InputVectors vectors( nmult, output.buffer.data() );
187 71433 : if( arg1.ncols<arg1.shape[1] ) {
188 244 : std::size_t fstart = task_index*(1+actiondata.outmat.ncols);
189 : std::size_t nelements = actiondata.outmat[fstart];
190 2425 : for(unsigned i=0; i<nelements; ++i) {
191 : std::size_t nm = 0;
192 262316 : for(unsigned j=0; j<nmult; ++j) {
193 260135 : std::size_t kind = arg0.bookeeping[fpos+1+j];
194 260135 : std::size_t bstart = kind*(arg1.ncols + 1);
195 260135 : std::size_t nr = arg1.bookeeping[bstart];
196 539226 : for(unsigned k=0; k<nr; ++k) {
197 288695 : if( arg1.bookeeping[bstart+1+k]==actiondata.outmat[fstart+1+i] ) {
198 9604 : nm++;
199 9604 : break;
200 : }
201 : }
202 : }
203 2181 : vectors.nelem = nm;
204 : nm = 0;
205 262316 : for(unsigned j=0; j<nmult; ++j) {
206 260135 : std::size_t kind = arg0.bookeeping[fpos+1+j];
207 260135 : std::size_t bstart = kind*(arg1.ncols + 1);
208 260135 : std::size_t nr = arg1.bookeeping[bstart];
209 539226 : for(unsigned k=0; k<nr; ++k) {
210 288695 : if( arg1.bookeeping[bstart+1+k]==actiondata.outmat[fstart+1+i] ) {
211 9604 : vectors.arg1[nm] = input.inputdata[ vstart + j ];
212 9604 : vectors.arg2[nm] = input.inputdata[ arg1.start + kind*arg1.ncols + k ];
213 9604 : nm++;
214 9604 : break;
215 : }
216 : }
217 : }
218 2181 : MatrixElementOutput elem( 1, 2*nmult, output.values.data() + i, output.derivatives.data() + 2*nmult*i );
219 2181 : T::calculate( input.noderiv, actiondata.funcinput, vectors, elem );
220 252712 : for(unsigned ii=vectors.nelem; ii<nmult; ++ii) {
221 250531 : elem.derivs[0][ii] = 0;
222 : }
223 : }
224 : } else {
225 : // Retrieve the row of the first matrix
226 1796906 : for(unsigned i=0; i<nmult; ++i) {
227 1725717 : vectors.arg1[i] = input.inputdata[ vstart + i ];
228 : }
229 :
230 : // Now do our multiplications
231 71189 : std::size_t fstart = task_index*(1+actiondata.outmat.ncols);
232 : std::size_t nelements = actiondata.outmat[fstart];
233 5811368 : for(unsigned i=0; i<nelements; ++i) {
234 5740179 : std::size_t base = arg1.start + actiondata.outmat[fstart+1+i];
235 139900314 : for(unsigned j=0; j<nmult; ++j) {
236 134160135 : vectors.arg2[j] = input.inputdata[ base + arg1.ncols*arg0.bookeeping[fpos+1+j] ];
237 : }
238 5740179 : MatrixElementOutput elem( 1, 2*nmult, output.values.data() + i, output.derivatives.data() + 2*nmult*i );
239 5740179 : T::calculate( input.noderiv, actiondata.funcinput, vectors, elem );
240 : }
241 : }
242 71433 : }
243 :
244 : template <class T>
245 691 : void MatrixTimesMatrix<T>::applyNonZeroRankForces( std::vector<double>& outforces ) {
246 691 : taskmanager.applyForces( outforces );
247 691 : }
248 :
249 : template <class T>
250 3404 : int MatrixTimesMatrix<T>::getNumberOfValuesPerTask( std::size_t task_index,
251 : const MatrixTimesMatrixInput<T>& actiondata ) {
252 3404 : std::size_t fstart = task_index*(1+actiondata.outmat.ncols);
253 3404 : return actiondata.outmat[fstart];
254 : }
255 :
256 : template <class T>
257 124671 : void MatrixTimesMatrix<T>::getForceIndices( std::size_t task_index,
258 : std::size_t colno,
259 : std::size_t ntotal_force,
260 : const MatrixTimesMatrixInput<T>& actiondata,
261 : const ParallelActionsInput& input,
262 : ForceIndexHolder force_indices ) {
263 124671 : auto arg0=ArgumentBookeepingHolder::create( 0, input );
264 124671 : auto arg1=ArgumentBookeepingHolder::create( 1, input );
265 124671 : std::size_t fpos = task_index*(1+arg0.ncols);
266 124671 : std::size_t nmult = arg0.bookeeping[fpos];
267 124671 : std::size_t fstart = task_index*(1+actiondata.outmat.ncols);
268 124671 : if( arg1.ncols<arg1.shape[1] ) {
269 : std::size_t nmult_r = 0;
270 5831 : for(unsigned j=0; j<nmult; ++j) {
271 4998 : std::size_t kind = arg0.bookeeping[fpos+1+j];
272 4998 : std::size_t bstart = kind*(arg1.ncols + 1);
273 4998 : std::size_t nr = arg1.bookeeping[bstart];
274 19992 : for(unsigned k=0; k<nr; ++k) {
275 19278 : if( arg1.bookeeping[bstart+1+k]==actiondata.outmat[fstart+1+colno] ) {
276 4284 : nmult_r++;
277 4284 : break;
278 : }
279 : }
280 : }
281 : std::size_t n = 0;
282 5831 : for(unsigned j=0; j<nmult; ++j) {
283 4998 : std::size_t kind = arg0.bookeeping[fpos+1+j];
284 4998 : std::size_t bstart = kind*(arg1.ncols + 1);
285 4998 : std::size_t nr = arg1.bookeeping[bstart];
286 19992 : for(unsigned k=0; k<nr; ++k) {
287 19278 : if( arg1.bookeeping[bstart+1+k]==actiondata.outmat[fstart+1+colno] ) {
288 4284 : force_indices.indices[0][n] = task_index*arg0.ncols + j;
289 4284 : force_indices.indices[0][nmult+n] = arg1.start + arg0.bookeeping[fpos+1+j]*arg1.ncols + k;
290 4284 : n++;
291 4284 : break;
292 : }
293 : }
294 : }
295 833 : force_indices.threadsafe_derivatives_end[0] = nmult_r;
296 833 : force_indices.tot_indices[0] = nmult + nmult_r;
297 : } else {
298 939230 : for(unsigned j=0; j<nmult; ++j) {
299 815392 : force_indices.indices[0][j] = task_index*arg0.ncols + j;
300 815392 : force_indices.indices[0][nmult+j] = arg1.start + arg0.bookeeping[fpos+1+j]*arg1.ncols + actiondata.outmat[fstart+1+colno];
301 : }
302 123838 : force_indices.threadsafe_derivatives_end[0] = nmult;
303 123838 : force_indices.tot_indices[0] = 2*nmult;
304 : }
305 124671 : }
306 :
307 : }
308 : }
309 : #endif
|