Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2011-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "core/ActionWithMatrix.h"
23 : #include "core/ActionRegister.h"
24 :
25 : //+PLUMEDOC MCOLVAR QUATERNION_PRODUCT_MATRIX
26 : /*
27 : Calculate the outer product matrix from two vectors of quaternions
28 :
29 : \par Examples
30 :
31 : */
32 : //+ENDPLUMEDOC
33 :
34 : namespace PLMD {
35 : namespace crystdistrib {
36 :
37 : class QuaternionProductMatrix : public ActionWithMatrix {
38 : private:
39 : unsigned nderivatives;
40 : public:
41 : static void registerKeywords( Keywords& keys );
42 : explicit QuaternionProductMatrix(const ActionOptions&);
43 : unsigned getNumberOfDerivatives();
44 48 : unsigned getNumberOfColumns() const override {
45 48 : return getConstPntrToComponent(0)->getShape()[1];
46 : }
47 : void setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const ;
48 : void performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const override;
49 : void runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const override ;
50 : };
51 :
52 : PLUMED_REGISTER_ACTION(QuaternionProductMatrix,"QUATERNION_PRODUCT_MATRIX")
53 :
54 9 : void QuaternionProductMatrix::registerKeywords( Keywords& keys ) {
55 9 : ActionWithMatrix::registerKeywords(keys);
56 9 : keys.use("ARG");
57 18 : keys.addOutputComponent("w","default","the real component of quaternion");
58 18 : keys.addOutputComponent("i","default","the i component of the quaternion");
59 18 : keys.addOutputComponent("j","default","the j component of the quaternion");
60 18 : keys.addOutputComponent("k","default","the k component of the quaternion");
61 9 : }
62 :
63 6 : QuaternionProductMatrix::QuaternionProductMatrix(const ActionOptions&ao):
64 : Action(ao),
65 6 : ActionWithMatrix(ao) {
66 6 : if( getNumberOfArguments()!=8 ) {
67 0 : error("should be eight arguments to this action. Four quaternions for each set of atoms. You can repeat actions");
68 : }
69 6 : unsigned nquat = getPntrToArgument(0)->getNumberOfValues();
70 54 : for(unsigned i=0; i<8; ++i) {
71 : Value* myarg=getPntrToArgument(i);
72 48 : if( i==4 ) {
73 6 : nquat = getPntrToArgument(i)->getNumberOfValues();
74 : }
75 48 : if( myarg->getRank()!=1 ) {
76 0 : error("all arguments to this action should be vectors");
77 : }
78 48 : if( (myarg->getPntrToAction())->getName()!="QUATERNION_VECTOR" ) {
79 0 : error("all arguments to this action should be quaternions");
80 : }
81 48 : std::string mylab=getPntrToArgument(i)->getName();
82 48 : std::size_t dot=mylab.find_first_of(".");
83 72 : if( (i==0 || i==4) && mylab.substr(dot+1)!="w" ) {
84 0 : error("quaternion arguments are in wrong order");
85 : }
86 72 : if( (i==1 || i==5) && mylab.substr(dot+1)!="i" ) {
87 0 : error("quaternion arguments are in wrong order");
88 : }
89 72 : if( (i==2 || i==6) && mylab.substr(dot+1)!="j" ) {
90 0 : error("quaternion arguments are in wrong order");
91 : }
92 72 : if( (i==3 || i==7) && mylab.substr(dot+1)!="k" ) {
93 0 : error("quaternion arguments are in wrong order");
94 : }
95 : }
96 6 : std::vector<unsigned> shape(2);
97 6 : shape[0]=getPntrToArgument(0)->getShape()[0];
98 6 : shape[1]=getPntrToArgument(4)->getShape()[0];
99 6 : addComponent( "w", shape );
100 6 : componentIsNotPeriodic("w");
101 6 : addComponent( "i", shape );
102 6 : componentIsNotPeriodic("i");
103 6 : addComponent( "j", shape );
104 6 : componentIsNotPeriodic("j");
105 6 : addComponent( "k", shape );
106 6 : componentIsNotPeriodic("k");
107 6 : nderivatives = buildArgumentStore(0);
108 6 : }
109 :
110 16 : unsigned QuaternionProductMatrix::getNumberOfDerivatives() {
111 16 : return nderivatives;
112 : }
113 :
114 18 : void QuaternionProductMatrix::setupForTask( const unsigned& task_index, std::vector<unsigned>& indices, MultiValue& myvals ) const {
115 18 : unsigned start_n = getPntrToArgument(0)->getShape()[0], size_v = getPntrToArgument(4)->getShape()[0];
116 18 : if( indices.size()!=size_v+1 ) {
117 6 : indices.resize( size_v+1 );
118 : }
119 63 : for(unsigned i=0; i<size_v; ++i) {
120 45 : indices[i+1] = start_n + i;
121 : }
122 : myvals.setSplitIndex( size_v + 1 );
123 18 : }
124 :
125 399598 : void QuaternionProductMatrix::performTask( const std::string& controller, const unsigned& index1, const unsigned& index2, MultiValue& myvals ) const {
126 399598 : unsigned ostrn, ind2=index2;
127 399598 : if( index2>=getPntrToArgument(0)->getShape()[0] ) {
128 45 : ind2 = index2 - getPntrToArgument(0)->getShape()[0];
129 : }
130 :
131 399598 : std::vector<double> quat1(4), quat2(4);
132 :
133 : // Retrieve the first quaternion
134 1997990 : for(unsigned i=0; i<4; ++i) {
135 1598392 : quat1[i] = getArgumentElement( i, index1, myvals );
136 : }
137 : // Retrieve the second quaternion
138 1997990 : for(unsigned i=0; i<4; ++i) {
139 1598392 : quat2[i] = getArgumentElement( 4+i, ind2, myvals );
140 : }
141 :
142 : //make q1 the conjugate
143 399598 : quat1[1] *= -1;
144 399598 : quat1[2] *= -1;
145 399598 : quat1[3] *= -1;
146 :
147 :
148 : double pref=1;
149 : double pref2=1;
150 : double conj=1;
151 : //real part of q1*q2
152 1997990 : for(unsigned i=0; i<4; ++i) {
153 1598392 : if( i>0 ) {
154 : pref=-1;
155 : pref2=-1;
156 : }
157 1598392 : myvals.addValue( getConstPntrToComponent(0)->getPositionInStream(), pref*quat1[i]*quat2[i] );
158 1598392 : if( doNotCalculateDerivatives() ) {
159 839152 : continue ;
160 : }
161 759240 : if (i>0) {
162 : conj=-1;
163 : }
164 759240 : addDerivativeOnVectorArgument( false, 0, i, index1, conj*pref*quat2[i], myvals );
165 759240 : addDerivativeOnVectorArgument( false, 0, 4+i, ind2, pref2*quat1[i], myvals );
166 : }
167 : //i component
168 : pref=1;
169 : conj=1;
170 : pref2=1;
171 1997990 : for (unsigned i=0; i<4; i++) {
172 1598392 : if(i==3) {
173 : pref=-1;
174 : } else {
175 : pref=1;
176 : }
177 1598392 : if(i==2) {
178 : pref2=-1;
179 : } else {
180 : pref2=1;
181 : }
182 1598392 : myvals.addValue( getConstPntrToComponent(1)->getPositionInStream(), pref*quat1[i]*quat2[(5-i)%4]);
183 1598392 : if( doNotCalculateDerivatives() ) {
184 839152 : continue ;
185 : }
186 759240 : if (i>0) {
187 : conj=-1;
188 : }
189 759240 : addDerivativeOnVectorArgument( false, 1, i, index1, conj*pref*quat2[(5-i)%4], myvals );
190 759240 : addDerivativeOnVectorArgument( false, 1, 4+i, ind2, pref2*quat1[(5-i)%4], myvals );
191 : }
192 :
193 : //j component
194 : pref=1;
195 : conj=1;
196 : pref2=1;
197 1997990 : for (unsigned i=0; i<4; i++) {
198 1598392 : if(i==1) {
199 : pref=-1;
200 : } else {
201 : pref=1;
202 : }
203 1598392 : if (i==3) {
204 : pref2=-1;
205 : } else {
206 : pref2=1;
207 : }
208 1598392 : myvals.addValue( getConstPntrToComponent(2)->getPositionInStream(), pref*quat1[i]*quat2[(i+2)%4]);
209 1598392 : if( doNotCalculateDerivatives() ) {
210 839152 : continue ;
211 : }
212 759240 : if (i>0) {
213 : conj=-1;
214 : }
215 759240 : addDerivativeOnVectorArgument( false, 2, i, index1, conj*pref*quat2[(i+2)%4], myvals );
216 759240 : addDerivativeOnVectorArgument( false, 2, 4+i, ind2, pref2*quat1[(i+2)%4], myvals );
217 : }
218 :
219 : //k component
220 : pref=1;
221 : conj=1;
222 : pref2=1;
223 1997990 : for (unsigned i=0; i<4; i++) {
224 1598392 : if(i==2) {
225 : pref=-1;
226 : } else {
227 : pref=1;
228 : }
229 1598392 : if(i==1) {
230 : pref2=-1;
231 : } else {
232 : pref2=1;
233 : }
234 1598392 : myvals.addValue( getConstPntrToComponent(3)->getPositionInStream(), pref*quat1[i]*quat2[(3-i)]);
235 1598392 : if( doNotCalculateDerivatives() ) {
236 839152 : continue ;
237 : }
238 759240 : if (i>0) {
239 : conj=-1;
240 : }
241 759240 : addDerivativeOnVectorArgument( false, 3, i, index1, conj*pref*quat2[3-i], myvals );
242 759240 : addDerivativeOnVectorArgument( false, 3, 4+i, ind2, pref2*quat1[3-i], myvals );
243 :
244 : }
245 :
246 :
247 399598 : }
248 :
249 2028 : void QuaternionProductMatrix::runEndOfRowJobs( const unsigned& ival, const std::vector<unsigned> & indices, MultiValue& myvals ) const {
250 2028 : if( doNotCalculateDerivatives() || !matrixChainContinues() ) {
251 : return ;
252 : }
253 :
254 3915 : for(unsigned j=0; j<getNumberOfComponents(); ++j) {
255 3132 : unsigned nmat = getConstPntrToComponent(j)->getPositionInMatrixStash(), nmat_ind = myvals.getNumberOfMatrixRowDerivatives( nmat );
256 : std::vector<unsigned>& matrix_indices( myvals.getMatrixRowDerivativeIndices( nmat ) );
257 : unsigned ntwo_atoms = myvals.getSplitIndex();
258 : // Quaternion for first molecule
259 : unsigned base = 0;
260 15660 : for(unsigned k=0; k<4; ++k) {
261 12528 : matrix_indices[nmat_ind] = base + ival;
262 12528 : base += getPntrToArgument(k)->getShape()[0];
263 12528 : nmat_ind++;
264 : }
265 : // Loop over row of matrix
266 762372 : for(unsigned i=1; i<ntwo_atoms; ++i) {
267 759240 : unsigned ind2 = indices[i];
268 759240 : if( ind2>=getPntrToArgument(0)->getShape()[0] ) {
269 0 : ind2 = indices[i] - getPntrToArgument(0)->getShape()[0];
270 : }
271 759240 : base = 4*getPntrToArgument(0)->getShape()[0];
272 : // Quaternion of second molecule
273 3796200 : for(unsigned k=0; k<4; ++k) {
274 3036960 : matrix_indices[nmat_ind] = base + ind2;
275 3036960 : base += getPntrToArgument(4+k)->getShape()[0];
276 3036960 : nmat_ind++;
277 : }
278 : }
279 : myvals.setNumberOfMatrixRowDerivatives( nmat, nmat_ind );
280 : }
281 :
282 : }
283 :
284 : }
285 : }
|