Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2011-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "core/ActionWithMatrix.h"
23 : #include "core/ActionRegister.h"
24 : #include "tools/LeptonCall.h"
25 :
26 : //+PLUMEDOC COLVAR OUTER_PRODUCT
27 : /*
28 : Calculate the outer product matrix of two vectors
29 :
30 : \par Examples
31 :
32 : */
33 : //+ENDPLUMEDOC
34 :
35 : namespace PLMD {
36 : namespace matrixtools {
37 :
38 : class OuterProduct : public ActionWithMatrix {
39 : private:
40 : bool domin, domax, diagzero;
41 : LeptonCall function;
42 : unsigned nderivatives;
43 : bool stored_vector1, stored_vector2;
44 : public:
45 : static void registerKeywords( Keywords& keys );
46 : explicit OuterProduct(const ActionOptions&);
47 : unsigned getNumberOfDerivatives();
48 : void prepare() override ;
49 2298 : unsigned getNumberOfColumns() const override {
50 2298 : return getConstPntrToComponent(0)->getShape()[1];
51 : }
52 : void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const ;
53 : void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const override;
54 : void runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const override ;
55 : };
56 :
57 : PLUMED_REGISTER_ACTION(OuterProduct,"OUTER_PRODUCT")
58 :
59 137 : void OuterProduct::registerKeywords( Keywords& keys ) {
60 137 : ActionWithMatrix::registerKeywords(keys);
61 137 : keys.use("ARG");
62 274 : keys.add("compulsory","FUNC","x*y","the function of the input vectors that should be put in the elements of the outer product");
63 274 : keys.addFlag("ELEMENTS_ON_DIAGONAL_ARE_ZERO",false,"set all diagonal elements to zero");
64 137 : keys.setValueDescription("a matrix containing the outer product of the two input vectors that was obtained using the function that was input");
65 137 : }
66 :
67 77 : OuterProduct::OuterProduct(const ActionOptions&ao):
68 : Action(ao),
69 : ActionWithMatrix(ao),
70 77 : domin(false),
71 77 : domax(false) {
72 77 : if( getNumberOfArguments()!=2 ) {
73 0 : error("should be two arguments to this action, a matrix and a vector");
74 : }
75 77 : if( getPntrToArgument(0)->getRank()!=1 || getPntrToArgument(0)->hasDerivatives() ) {
76 0 : error("first argument to this action should be a vector");
77 : }
78 77 : if( getPntrToArgument(1)->getRank()!=1 || getPntrToArgument(1)->hasDerivatives() ) {
79 0 : error("first argument to this action should be a vector");
80 : }
81 :
82 : std::string func;
83 154 : parse("FUNC",func);
84 77 : if( func=="min") {
85 0 : domin=true;
86 0 : log.printf(" taking minimum of two input vectors \n");
87 77 : } else if( func=="max" ) {
88 2 : domax=true;
89 2 : log.printf(" taking maximum of two input vectors \n");
90 : } else {
91 75 : log.printf(" with function : %s \n", func.c_str() );
92 75 : std::vector<std::string> var(2);
93 : var[0]="x";
94 : var[1]="y";
95 75 : function.set( func, var, this );
96 75 : }
97 77 : parseFlag("ELEMENTS_ON_DIAGONAL_ARE_ZERO",diagzero);
98 77 : if( diagzero ) {
99 2 : log.printf(" setting diagonal elements equal to zero\n");
100 : }
101 :
102 77 : std::vector<unsigned> shape(2);
103 77 : shape[0]=getPntrToArgument(0)->getShape()[0];
104 77 : shape[1]=getPntrToArgument(1)->getShape()[0];
105 77 : addValue( shape );
106 77 : setNotPeriodic();
107 77 : nderivatives = buildArgumentStore(0);
108 77 : std::string headstr=getFirstActionInChain()->getLabel();
109 77 : stored_vector1 = getPntrToArgument(0)->ignoreStoredValue( headstr );
110 77 : stored_vector2 = getPntrToArgument(1)->ignoreStoredValue( headstr );
111 77 : if( getPntrToArgument(0)->isDerivativeZeroWhenValueIsZero() || getPntrToArgument(1)->isDerivativeZeroWhenValueIsZero() ) {
112 15 : getPntrToComponent(0)->setDerivativeIsZeroWhenValueIsZero();
113 : }
114 77 : }
115 :
116 96 : unsigned OuterProduct::getNumberOfDerivatives() {
117 96 : return nderivatives;
118 : }
119 :
120 158 : void OuterProduct::prepare() {
121 158 : ActionWithVector::prepare();
122 158 : Value* myval=getPntrToComponent(0);
123 158 : if( myval->getShape()[0]==getPntrToArgument(0)->getShape()[0] && myval->getShape()[1]==getPntrToArgument(1)->getShape()[0] ) {
124 141 : return;
125 : }
126 17 : std::vector<unsigned> shape(2);
127 17 : shape[0] = getPntrToArgument(0)->getShape()[0];
128 17 : shape[1] = getPntrToArgument(1)->getShape()[0];
129 17 : myval->setShape( shape );
130 : }
131 :
132 27151 : void OuterProduct::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
133 27151 : unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(1)->getShape()[0];
134 27151 : if( diagzero ) {
135 990 : if( indices.size()!=size_v ) {
136 20 : indices.resize( size_v );
137 : }
138 : unsigned k=1;
139 99000 : for(unsigned i=0; i<size_v; ++i) {
140 98010 : if( task_index==i ) {
141 990 : continue ;
142 : }
143 97020 : indices[k] = size_v + i;
144 97020 : k++;
145 : }
146 : myvals.setSplitIndex( size_v );
147 : } else {
148 26161 : if( indices.size()!=size_v+1 ) {
149 249 : indices.resize( size_v+1 );
150 : }
151 1690193 : for(unsigned i=0; i<size_v; ++i) {
152 1664032 : indices[i+1] = start_n + i;
153 : }
154 : myvals.setSplitIndex( size_v + 1 );
155 : }
156 27151 : }
157 :
158 6874326 : void OuterProduct::performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const {
159 6874326 : unsigned ostrn = getConstPntrToComponent(0)->getPositionInStream(), ind2=index2;
160 6874326 : if( index2>=getPntrToArgument(0)->getShape()[0] ) {
161 1976448 : ind2 = index2 - getPntrToArgument(0)->getShape()[0];
162 : }
163 6874326 : if( diagzero && index1==ind2 ) {
164 6511300 : return;
165 : }
166 :
167 : double fval;
168 6874326 : unsigned jarg = 0, kelem = index1;
169 6874326 : bool jstore=stored_vector1;
170 6874326 : std::vector<double> args(2);
171 6874326 : args[0] = getArgumentElement( 0, index1, myvals );
172 6874326 : args[1] = getArgumentElement( 1, ind2, myvals );
173 6874326 : if( domin ) {
174 0 : fval=args[0];
175 0 : if( args[1]<args[0] ) {
176 : fval=args[1];
177 0 : jarg=1;
178 0 : kelem=ind2;
179 0 : jstore=stored_vector2;
180 : }
181 6874326 : } else if( domax ) {
182 315192 : fval=args[0];
183 315192 : if( args[1]>args[0] ) {
184 : fval=args[1];
185 2055 : jarg=1;
186 2055 : kelem=ind2;
187 2055 : jstore=stored_vector2;
188 : }
189 : } else {
190 6559134 : fval=function.evaluate( args );
191 : }
192 :
193 6874326 : myvals.addValue( ostrn, fval );
194 6874326 : if( doNotCalculateDerivatives() ) {
195 : return ;
196 : }
197 :
198 366326 : if( domin || domax ) {
199 0 : addDerivativeOnVectorArgument( jstore, 0, jarg, kelem, 1.0, myvals );
200 : } else {
201 366326 : addDerivativeOnVectorArgument( stored_vector1, 0, 0, index1, function.evaluateDeriv( 0, args ), myvals );
202 366326 : addDerivativeOnVectorArgument( stored_vector2, 0, 1, ind2, function.evaluateDeriv( 1, args ), myvals );
203 : }
204 366326 : if( doNotCalculateDerivatives() || !matrixChainContinues() ) {
205 : return ;
206 : }
207 363026 : unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
208 363026 : myvals.getMatrixRowDerivativeIndices( nmat )[nmat_ind] = arg_deriv_starts[1] + ind2;
209 363026 : myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+1 );
210 : }
211 :
212 39963 : void OuterProduct::runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
213 39963 : if( doNotCalculateDerivatives() || !matrixChainContinues() ) {
214 : return ;
215 : }
216 11402 : unsigned nmat = getConstPntrToComponent(0)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
217 11402 : myvals.getMatrixRowDerivativeIndices( nmat )[nmat_ind] = ival;
218 11402 : myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind+1 );
219 : }
220 :
221 : }
222 : }
|