Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2014-2017 The plumed team 3 : (see the PEOPLE file at the root of the distribution for a list of names) 4 : 5 : See http://www.plumed.org for more information. 6 : 7 : This file is part of plumed, version 2. 8 : 9 : plumed is free software: you can redistribute it and/or modify 10 : it under the terms of the GNU Lesser General Public License as published by 11 : the Free Software Foundation, either version 3 of the License, or 12 : (at your option) any later version. 13 : 14 : plumed is distributed in the hope that it will be useful, 15 : but WITHOUT ANY WARRANTY; without even the implied warranty of 16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 : GNU Lesser General Public License for more details. 18 : 19 : You should have received a copy of the GNU Lesser General Public License 20 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 22 : #include "core/ActionWithValue.h" 23 : #include "core/ActionWithArguments.h" 24 : #include "core/ActionRegister.h" 25 : 26 : //+PLUMEDOC MCOLVAR CONCATENATE 27 : /* 28 : Join vectors or matrices together 29 : 30 : \par Examples 31 : 32 : */ 33 : //+ENDPLUMEDOC 34 : 35 : namespace PLMD { 36 : namespace valtools { 37 : 38 : class Concatenate : 39 : public ActionWithValue, 40 : public ActionWithArguments { 41 : private: 42 : bool vectors; 43 : std::vector<unsigned> row_starts; 44 : std::vector<unsigned> col_starts; 45 : public: 46 : static void registerKeywords( Keywords& keys ); 47 : /// Constructor 48 : explicit Concatenate(const ActionOptions&); 49 : /// Get the number of derivatives 50 257 : unsigned getNumberOfDerivatives() override { 51 257 : return 0; 52 : } 53 : /// Do the calculation 54 : void calculate() override; 55 : /// 56 : void apply(); 57 : }; 58 : 59 : PLUMED_REGISTER_ACTION(Concatenate,"CONCATENATE") 60 : 61 353 : void Concatenate::registerKeywords( Keywords& keys ) { 62 353 : Action::registerKeywords( keys ); 63 353 : ActionWithValue::registerKeywords( keys ); 64 353 : ActionWithArguments::registerKeywords( keys ); 65 353 : keys.use("ARG"); 66 706 : keys.add("numbered","MATRIX","specify the matrices that you wish to join together into a single matrix"); 67 706 : keys.reset_style("MATRIX","compulsory"); 68 353 : keys.setValueDescription("the concatenated vector/matrix that was constructed from the input values"); 69 353 : } 70 : 71 176 : Concatenate::Concatenate(const ActionOptions& ao): 72 : Action(ao), 73 : ActionWithValue(ao), 74 176 : ActionWithArguments(ao) { 75 176 : if( getNumberOfArguments()>0 ) { 76 172 : vectors=true; 77 172 : std::vector<unsigned> shape(1); 78 172 : shape[0]=0; 79 547 : for(unsigned i=0; i<getNumberOfArguments(); ++i) { 80 375 : if( getPntrToArgument(i)->getRank()>1 ) { 81 0 : error("cannot concatenate matrix with vectors"); 82 : } 83 375 : getPntrToArgument(i)->buildDataStore(); 84 375 : shape[0] += getPntrToArgument(i)->getNumberOfValues(); 85 : } 86 172 : log.printf(" creating vector with %d elements \n", shape[0] ); 87 172 : addValue( shape ); 88 172 : bool period=getPntrToArgument(0)->isPeriodic(); 89 : std::string min, max; 90 172 : if( period ) { 91 0 : getPntrToArgument(0)->getDomain( min, max ); 92 : } 93 375 : for(unsigned i=1; i<getNumberOfArguments(); ++i) { 94 203 : if( period!=getPntrToArgument(i)->isPeriodic() ) { 95 0 : error("periods of input arguments should match"); 96 : } 97 203 : if( period ) { 98 : std::string min0, max0; 99 0 : getPntrToArgument(i)->getDomain( min0, max0 ); 100 0 : if( min0!=min || max0!=max ) { 101 0 : error("domains of input arguments should match"); 102 : } 103 : } 104 : } 105 172 : if( period ) { 106 0 : setPeriodic( min, max ); 107 : } else { 108 172 : setNotPeriodic(); 109 : } 110 172 : getPntrToComponent(0)->buildDataStore(); 111 172 : if( getPntrToComponent(0)->getRank()==2 ) { 112 0 : getPntrToComponent(0)->reshapeMatrixStore( shape[1] ); 113 : } 114 : } else { 115 : unsigned nrows=0, ncols=0; 116 : std::vector<Value*> arglist; 117 4 : vectors=false; 118 7 : for(unsigned i=1;; i++) { 119 : unsigned nt_cols=0; 120 : unsigned size_b4 = arglist.size(); 121 14 : for(unsigned j=1;; j++) { 122 25 : if( j==10 ) { 123 0 : error("cannot combine more than 9 matrices"); 124 : } 125 : std::vector<Value*> argn; 126 50 : parseArgumentList("MATRIX", i*10+j, argn); 127 25 : if( argn.size()==0 ) { 128 : break; 129 : } 130 14 : if( argn.size()>1 ) { 131 0 : error("should only be one argument to each matrix keyword"); 132 : } 133 14 : if( argn[0]->getRank()!=0 && argn[0]->getRank()!=2 ) { 134 0 : error("input arguments for this action should be matrices"); 135 : } 136 14 : argn[0]->buildDataStore(); 137 14 : arglist.push_back( argn[0] ); 138 14 : nt_cols++; 139 14 : if( argn[0]->getRank()==0 ) { 140 0 : log.printf(" %d %d component of composed matrix is scalar labelled %s\n", i, j, argn[0]->getName().c_str() ); 141 : } else { 142 14 : log.printf(" %d %d component of composed matrix is %d by %d matrix labelled %s\n", i, j, argn[0]->getShape()[0], argn[0]->getShape()[1], argn[0]->getName().c_str() ); 143 : } 144 14 : } 145 11 : if( arglist.size()==size_b4 ) { 146 : break; 147 : } 148 7 : if( i==1 ) { 149 : ncols=nt_cols; 150 3 : } else if( nt_cols!=ncols ) { 151 0 : error("should be joining same number of matrices in each row"); 152 : } 153 7 : nrows++; 154 7 : } 155 : 156 4 : std::vector<unsigned> shape(2); 157 4 : shape[0]=0; 158 : unsigned k=0; 159 4 : row_starts.resize( arglist.size() ); 160 4 : col_starts.resize( arglist.size() ); 161 11 : for(unsigned i=0; i<nrows; ++i) { 162 : unsigned cstart = 0, nr = 1; 163 7 : if( arglist[k]->getRank()==2 ) { 164 7 : nr=arglist[k]->getShape()[0]; 165 : } 166 21 : for(unsigned j=0; j<ncols; ++j) { 167 14 : if( arglist[k]->getRank()==0 ) { 168 0 : if( nr!=1 ) { 169 0 : error("mismatched matrix sizes"); 170 : } 171 14 : } else if( nrows>1 && arglist[k]->getShape()[0]!=nr ) { 172 0 : error("mismatched matrix sizes"); 173 : } 174 14 : row_starts[k] = shape[0]; 175 14 : col_starts[k] = cstart; 176 14 : if( arglist[k]->getRank()==0 ) { 177 0 : cstart += 1; 178 : } else { 179 14 : cstart += arglist[k]->getShape()[1]; 180 : } 181 14 : k++; 182 : } 183 7 : if( i==0 ) { 184 4 : shape[1]=cstart; 185 3 : } else if( cstart!=shape[1] ) { 186 0 : error("mismatched matrix sizes"); 187 : } 188 7 : if( arglist[k-1]->getRank()==0 ) { 189 0 : shape[0] += 1; 190 : } else { 191 7 : shape[0] += arglist[k-1]->getShape()[0]; 192 : } 193 : } 194 : // Now request the arguments to make sure we store things we need 195 4 : requestArguments(arglist); 196 4 : addValue( shape ); 197 4 : setNotPeriodic(); 198 4 : getPntrToComponent(0)->buildDataStore(); 199 4 : if( getPntrToComponent(0)->getRank()==2 ) { 200 4 : getPntrToComponent(0)->reshapeMatrixStore( shape[1] ); 201 : } 202 : } 203 176 : } 204 : 205 12191 : void Concatenate::calculate() { 206 12191 : Value* myval = getPntrToComponent(0); 207 12191 : if( vectors ) { 208 : unsigned k=0; 209 61297 : for(unsigned i=0; i<getNumberOfArguments(); ++i) { 210 : Value* myarg=getPntrToArgument(i); 211 49158 : unsigned nvals=myarg->getNumberOfValues(); 212 404266 : for(unsigned j=0; j<nvals; ++j) { 213 355108 : myval->set( k, myarg->get(j) ); 214 355108 : k++; 215 : } 216 : } 217 : } else { 218 : // Retrieve the matrix from input 219 52 : unsigned ncols = myval->getShape()[1]; 220 258 : for(unsigned k=0; k<getNumberOfArguments(); ++k) { 221 : Value* argn = getPntrToArgument(k); 222 206 : if( argn->getRank()==0 ) { 223 0 : myval->set( ncols*row_starts[k]+col_starts[k], argn->get() ); 224 : } else { 225 : std::vector<double> vals; 226 : std::vector<std::pair<unsigned,unsigned> > pairs; 227 : bool symmetric=getPntrToArgument(k)->isSymmetric(); 228 206 : unsigned nedge=0; 229 206 : getPntrToArgument(k)->retrieveEdgeList( nedge, pairs, vals ); 230 8946 : for(unsigned l=0; l<nedge; ++l ) { 231 8740 : unsigned i=pairs[l].first, j=pairs[l].second; 232 8740 : myval->set( ncols*(row_starts[k]+i)+col_starts[k]+j, vals[l] ); 233 8740 : if( symmetric ) { 234 2142 : myval->set( ncols*(row_starts[k]+j)+col_starts[k]+i, vals[l] ); 235 : } 236 : } 237 : } 238 : } 239 : } 240 12191 : } 241 : 242 12070 : void Concatenate::apply() { 243 12070 : if( doNotCalculateDerivatives() || !getPntrToComponent(0)->forcesWereAdded() ) { 244 7037 : return; 245 : } 246 : 247 5033 : Value* val=getPntrToComponent(0); 248 5033 : if( vectors ) { 249 : unsigned k=0; 250 19923 : for(unsigned i=0; i<getNumberOfArguments(); ++i) { 251 : Value* myarg=getPntrToArgument(i); 252 14942 : unsigned nvals=myarg->getNumberOfValues(); 253 205938 : for(unsigned j=0; j<nvals; ++j) { 254 190996 : myarg->addForce( j, val->getForce(k) ); 255 190996 : k++; 256 : } 257 : } 258 : } else { 259 52 : unsigned ncols=val->getShape()[1]; 260 258 : for(unsigned k=0; k<getNumberOfArguments(); ++k) { 261 : Value* argn=getPntrToArgument(k); 262 206 : if( argn->getRank()==0 ) { 263 0 : argn->addForce( 0, val->getForce(ncols*row_starts[k]+col_starts[k]) ); 264 : } else { 265 : unsigned val_ncols=val->getNumberOfColumns(); 266 : unsigned arg_ncols=argn->getNumberOfColumns(); 267 1686 : for(unsigned i=0; i<argn->getShape()[0]; ++i) { 268 : unsigned ncol = argn->getRowLength(i); 269 13140 : for(unsigned j=0; j<ncol; ++j) { 270 11660 : argn->addForce( i*arg_ncols+j, val->getForce( val_ncols*(row_starts[k]+i)+col_starts[k]+argn->getRowIndex(i,j) ), false ); 271 : } 272 : } 273 : } 274 : } 275 : } 276 : } 277 : 278 : } 279 : }