Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2011-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "core/ActionWithMatrix.h"
23 : #include "core/ActionRegister.h"
24 :
25 : //+PLUMEDOC MCOLVAR MATRIX_VECTOR_PRODUCT
26 : /*
27 : Calculate the product of the matrix and the vector
28 :
29 : \par Examples
30 :
31 : */
32 : //+ENDPLUMEDOC
33 :
34 : namespace PLMD {
35 : namespace matrixtools {
36 :
37 : class MatrixTimesVector : public ActionWithMatrix {
38 : private:
39 : bool sumrows;
40 : unsigned nderivatives;
41 : std::vector<bool> stored_arg;
42 : public:
43 : static void registerKeywords( Keywords& keys );
44 : explicit MatrixTimesVector(const ActionOptions&);
45 : std::string getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const override ;
46 0 : unsigned getNumberOfColumns() const override {
47 0 : plumed_error();
48 : }
49 : unsigned getNumberOfDerivatives();
50 : void prepare() override ;
51 2151 : bool isInSubChain( unsigned& nder ) override {
52 2151 : nder = arg_deriv_starts[0];
53 2151 : return true;
54 : }
55 : void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const ;
56 : void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const override;
57 : void runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const override ;
58 : void updateAdditionalIndices( const unsigned& ostrn, MultiValue& myvals ) const override ;
59 : };
60 :
61 : PLUMED_REGISTER_ACTION(MatrixTimesVector,"MATRIX_VECTOR_PRODUCT")
62 :
63 629 : void MatrixTimesVector::registerKeywords( Keywords& keys ) {
64 629 : ActionWithMatrix::registerKeywords(keys);
65 629 : keys.use("ARG");
66 629 : keys.setValueDescription("the vector that is obtained by taking the product between the matrix and the vector that were input");
67 629 : ActionWithValue::useCustomisableComponents(keys);
68 629 : }
69 :
70 6 : std::string MatrixTimesVector::getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const {
71 6 : if( getPntrToArgument(1)->getRank()==1 ) {
72 0 : for(unsigned i=1; i<getNumberOfArguments(); ++i) {
73 0 : if( getPntrToArgument(i)->getName().find(cname)!=std::string::npos ) {
74 0 : return "the product of the matrix " + getPntrToArgument(0)->getName() + " and the vector " + getPntrToArgument(i)->getName();
75 : }
76 : }
77 : }
78 21 : for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
79 21 : if( getPntrToArgument(i)->getName().find(cname)!=std::string::npos ) {
80 12 : return "the product of the matrix " + getPntrToArgument(i)->getName() + " and the vector " + getPntrToArgument(getNumberOfArguments()-1)->getName();
81 : }
82 : }
83 0 : plumed_merror( "could not understand request for component " + cname );
84 : return "";
85 : }
86 :
87 352 : MatrixTimesVector::MatrixTimesVector(const ActionOptions&ao):
88 : Action(ao),
89 : ActionWithMatrix(ao),
90 352 : sumrows(false) {
91 352 : if( getNumberOfArguments()<2 ) {
92 0 : error("Not enough arguments specified");
93 : }
94 : unsigned nvectors=0, nmatrices=0;
95 1875 : for(unsigned i=0; i<getNumberOfArguments(); ++i) {
96 1523 : if( getPntrToArgument(i)->hasDerivatives() ) {
97 0 : error("arguments should be vectors or matrices");
98 : }
99 1523 : if( getPntrToArgument(i)->getRank()<=1 ) {
100 537 : nvectors++;
101 : }
102 1523 : if( getPntrToArgument(i)->getRank()==2 ) {
103 986 : nmatrices++;
104 : }
105 : }
106 :
107 352 : std::vector<unsigned> shape(1);
108 352 : shape[0]=getPntrToArgument(0)->getShape()[0];
109 352 : if( nvectors==1 ) {
110 : unsigned n = getNumberOfArguments()-1;
111 1320 : for(unsigned i=0; i<n; ++i) {
112 977 : if( getPntrToArgument(i)->getRank()!=2 || getPntrToArgument(i)->hasDerivatives() ) {
113 0 : error("all arguments other than last argument should be matrices");
114 : }
115 977 : if( getPntrToArgument(n)->getRank()==0 ) {
116 1 : if( getPntrToArgument(i)->getShape()[1]!=1 ) {
117 0 : error("number of columns in input matrix does not equal number of elements in vector");
118 : }
119 976 : } else if( getPntrToArgument(i)->getShape()[1]!=getPntrToArgument(n)->getShape()[0] ) {
120 0 : error("number of columns in input matrix does not equal number of elements in vector");
121 : }
122 : }
123 343 : if( getPntrToArgument(n)->getRank()>0 ) {
124 342 : if( getPntrToArgument(n)->getRank()!=1 || getPntrToArgument(n)->hasDerivatives() ) {
125 0 : error("last argument to this action should be a vector");
126 : }
127 : }
128 343 : getPntrToArgument(n)->buildDataStore();
129 :
130 343 : ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(0)->getPntrToAction() );
131 343 : if( av ) {
132 314 : done_in_chain=canBeAfterInChain( av );
133 : }
134 :
135 343 : if( getNumberOfArguments()==2 ) {
136 301 : addValue( shape );
137 301 : setNotPeriodic();
138 : } else {
139 718 : for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
140 676 : std::string name = getPntrToArgument(i)->getName();
141 676 : if( name.find_first_of(".")!=std::string::npos ) {
142 676 : std::size_t dot=name.find_first_of(".");
143 1352 : name = name.substr(dot+1);
144 : }
145 676 : addComponent( name, shape );
146 676 : componentIsNotPeriodic( name );
147 : }
148 : }
149 343 : if( (getPntrToArgument(n)->getPntrToAction())->getName()=="CONSTANT" ) {
150 306 : sumrows=true;
151 306 : if( getPntrToArgument(n)->getRank()==0 ) {
152 1 : if( fabs( getPntrToArgument(n)->get() - 1.0 )>epsilon ) {
153 0 : sumrows = false;
154 : }
155 : } else {
156 180438 : for(unsigned i=0; i<getPntrToArgument(n)->getShape()[0]; ++i) {
157 180141 : if( fabs( getPntrToArgument(n)->get(i) - 1.0 )>epsilon ) {
158 8 : sumrows=false;
159 8 : break;
160 : }
161 : }
162 : }
163 : }
164 9 : } else if( nmatrices==1 ) {
165 9 : if( getPntrToArgument(0)->getRank()!=2 || getPntrToArgument(0)->hasDerivatives() ) {
166 0 : error("first argument to this action should be a matrix");
167 : }
168 203 : for(unsigned i=1; i<getNumberOfArguments(); ++i) {
169 194 : if( getPntrToArgument(i)->getRank()>1 || getPntrToArgument(i)->hasDerivatives() ) {
170 0 : error("all arguments other than first argument should be vectors");
171 : }
172 194 : if( getPntrToArgument(i)->getRank()==0 ) {
173 0 : if( getPntrToArgument(0)->getShape()[1]!=1 ) {
174 0 : error("number of columns in input matrix does not equal number of elements in vector");
175 : }
176 194 : } else if( getPntrToArgument(0)->getShape()[1]!=getPntrToArgument(i)->getShape()[0] ) {
177 0 : error("number of columns in input matrix does not equal number of elements in vector");
178 : }
179 194 : getPntrToArgument(i)->buildDataStore();
180 : }
181 :
182 9 : ActionWithVector* av=dynamic_cast<ActionWithVector*>( getPntrToArgument(0)->getPntrToAction() );
183 9 : if( av ) {
184 9 : done_in_chain=canBeAfterInChain( av );
185 : }
186 :
187 203 : for(unsigned i=1; i<getNumberOfArguments(); ++i) {
188 194 : std::string name = getPntrToArgument(i)->getName();
189 194 : if( name.find_first_of(".")!=std::string::npos ) {
190 0 : std::size_t dot=name.find_first_of(".");
191 0 : name = name.substr(dot+1);
192 : }
193 194 : addComponent( name, shape );
194 194 : componentIsNotPeriodic( name );
195 : }
196 : } else {
197 0 : error("You should either have one vector or one matrix in input");
198 : }
199 :
200 352 : nderivatives = buildArgumentStore(0);
201 352 : std::string headstr=getFirstActionInChain()->getLabel();
202 352 : stored_arg.resize( getNumberOfArguments() );
203 1875 : for(unsigned i=0; i<getNumberOfArguments(); ++i) {
204 1523 : stored_arg[i] = getPntrToArgument(i)->ignoreStoredValue( headstr );
205 : }
206 352 : }
207 :
208 31643 : unsigned MatrixTimesVector::getNumberOfDerivatives() {
209 31643 : return nderivatives;
210 : }
211 :
212 13575 : void MatrixTimesVector::prepare() {
213 13575 : ActionWithVector::prepare();
214 13575 : Value* myval = getPntrToComponent(0);
215 13575 : if( myval->getShape()[0]==getPntrToArgument(0)->getShape()[0] ) {
216 13565 : return;
217 : }
218 10 : std::vector<unsigned> shape(1);
219 10 : shape[0] = getPntrToArgument(0)->getShape()[0];
220 10 : myval->setShape(shape);
221 : }
222 :
223 6574 : void MatrixTimesVector::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
224 6574 : unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(0)->getRowLength(task_index);
225 6574 : if( indices.size()!=size_v+1 ) {
226 3508 : indices.resize( size_v + 1 );
227 : }
228 842762 : for(unsigned i=0; i<size_v; ++i) {
229 836188 : indices[i+1] = start_n + i;
230 : }
231 : myvals.setSplitIndex( size_v + 1 );
232 : }
233 :
234 23970940 : void MatrixTimesVector::performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const {
235 23970940 : unsigned ind2 = index2;
236 23970940 : if( index2>=getPntrToArgument(0)->getShape()[0] ) {
237 1600742 : ind2 = index2 - getPntrToArgument(0)->getShape()[0];
238 : }
239 23970940 : if( sumrows ) {
240 22303792 : unsigned n=getNumberOfArguments()-1;
241 : double matval = 0;
242 87441027 : for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
243 65137235 : unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
244 : Value* myarg = getPntrToArgument(i);
245 65137235 : if( !myarg->valueHasBeenSet() ) {
246 65122517 : myvals.addValue( ostrn, myvals.get( myarg->getPositionInStream() ) );
247 : } else {
248 14718 : myvals.addValue( ostrn, myarg->get( index1*myarg->getNumberOfColumns() + ind2, false ) );
249 : }
250 : // Now lets work out the derivatives
251 65137235 : if( doNotCalculateDerivatives() ) {
252 32313889 : continue;
253 : }
254 32823346 : addDerivativeOnMatrixArgument( stored_arg[i], i, i, index1, ind2, 1.0, myvals );
255 : }
256 1667148 : } else if( getPntrToArgument(1)->getRank()==1 ) {
257 : double matval = 0;
258 : Value* myarg = getPntrToArgument(0);
259 1667148 : unsigned vcol = ind2;
260 1667148 : if( !myarg->valueHasBeenSet() ) {
261 840110 : matval = myvals.get( myarg->getPositionInStream() );
262 : } else {
263 827038 : matval = myarg->get( index1*myarg->getNumberOfColumns() + ind2, false );
264 827038 : vcol = getPntrToArgument(0)->getRowIndex( index1, ind2 );
265 : }
266 18356786 : for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
267 16689638 : unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
268 16689638 : double vecval=getArgumentElement( i+1, vcol, myvals );
269 : // And add this part of the product
270 16689638 : myvals.addValue( ostrn, matval*vecval );
271 : // Now lets work out the derivatives
272 16689638 : if( doNotCalculateDerivatives() ) {
273 1000870 : continue;
274 : }
275 15688768 : addDerivativeOnMatrixArgument( stored_arg[0], i, 0, index1, ind2, vecval, myvals );
276 15688768 : addDerivativeOnVectorArgument( stored_arg[i+1], i, i+1, vcol, matval, myvals );
277 : }
278 : } else {
279 0 : unsigned n=getNumberOfArguments()-1;
280 0 : double matval = 0;
281 0 : unsigned vcol = ind2;
282 0 : for(unsigned i=0; i<getNumberOfArguments()-1; ++i) {
283 0 : unsigned ostrn = getConstPntrToComponent(i)->getPositionInStream();
284 : Value* myarg = getPntrToArgument(i);
285 0 : if( !myarg->valueHasBeenSet() ) {
286 0 : matval = myvals.get( myarg->getPositionInStream() );
287 : } else {
288 0 : matval = myarg->get( index1*myarg->getNumberOfColumns() + ind2, false );
289 0 : vcol = getPntrToArgument(i)->getRowIndex( index1, ind2 );
290 : }
291 0 : double vecval=getArgumentElement( n, vcol, myvals );
292 : // And add this part of the product
293 0 : myvals.addValue( ostrn, matval*vecval );
294 : // Now lets work out the derivatives
295 0 : if( doNotCalculateDerivatives() ) {
296 0 : continue;
297 : }
298 0 : addDerivativeOnMatrixArgument( stored_arg[i], i, i, index1, ind2, vecval, myvals );
299 0 : addDerivativeOnVectorArgument( stored_arg[n], i, n, vcol, matval, myvals );
300 : }
301 : }
302 23970940 : }
303 :
304 472445 : void MatrixTimesVector::runEndOfRowJobs( const unsigned& ind, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
305 472445 : if( doNotCalculateDerivatives() || !actionInChain() ) {
306 : return ;
307 : }
308 :
309 358714 : if( getPntrToArgument(1)->getRank()==1 ) {
310 : unsigned istrn = getPntrToArgument(0)->getPositionInMatrixStash();
311 : std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( istrn ) );
312 1010565 : for(unsigned j=0; j<getNumberOfComponents(); ++j) {
313 671975 : unsigned ostrn = getConstPntrToComponent(j)->getPositionInStream();
314 40971258 : for(unsigned i=0; i<myvals.getNumberOfMatrixRowDerivatives(istrn); ++i) {
315 40299283 : myvals.updateIndex( ostrn, mat_indices[i] );
316 : }
317 : }
318 : } else {
319 530036 : for(unsigned j=0; j<getNumberOfComponents(); ++j) {
320 : unsigned istrn = getPntrToArgument(j)->getPositionInMatrixStash();
321 509912 : unsigned ostrn = getConstPntrToComponent(j)->getPositionInStream();
322 : std::vector<unsigned>& mat_indices( myvals.getMatrixRowDerivativeIndices( istrn ) );
323 17456348 : for(unsigned i=0; i<myvals.getNumberOfMatrixRowDerivatives(istrn); ++i) {
324 16946436 : myvals.updateIndex( ostrn, mat_indices[i] );
325 : }
326 : }
327 : }
328 : }
329 :
330 372677 : void MatrixTimesVector::updateAdditionalIndices( const unsigned& ostrn, MultiValue& myvals ) const {
331 372677 : unsigned n = getNumberOfArguments()-1;
332 372677 : if( getPntrToArgument(1)->getRank()==1 ) {
333 : n = 1;
334 : }
335 372677 : unsigned nvals = getPntrToArgument(n)->getNumberOfValues();
336 1387754027 : for(unsigned i=0; i<nvals; ++i) {
337 1387381350 : myvals.updateIndex( ostrn, arg_deriv_starts[n] + i );
338 : }
339 372677 : }
340 :
341 : }
342 : }
|