Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2014-2017 The plumed team 3 : (see the PEOPLE file at the root of the distribution for a list of names) 4 : 5 : See http://www.plumed.org for more information. 6 : 7 : This file is part of plumed, version 2. 8 : 9 : plumed is free software: you can redistribute it and/or modify 10 : it under the terms of the GNU Lesser General Public License as published by 11 : the Free Software Foundation, either version 3 of the License, or 12 : (at your option) any later version. 13 : 14 : plumed is distributed in the hope that it will be useful, 15 : but WITHOUT ANY WARRANTY; without even the implied warranty of 16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 : GNU Lesser General Public License for more details. 18 : 19 : You should have received a copy of the GNU Lesser General Public License 20 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 22 : #include "core/ActionWithValue.h" 23 : #include "core/ActionWithArguments.h" 24 : #include "core/ActionRegister.h" 25 : #include "core/PlumedMain.h" 26 : #include "core/ActionSet.h" 27 : 28 : //+PLUMEDOC PRINTANALYSIS SELECT_WITH_MASK 29 : /* 30 : Use a mask to select elements of an array 31 : 32 : \par Examples 33 : 34 : */ 35 : //+ENDPLUMEDOC 36 : 37 : namespace PLMD { 38 : namespace valtools { 39 : 40 : class SelectWithMask : 41 : public ActionWithValue, 42 : public ActionWithArguments { 43 : private: 44 : unsigned getOutputVectorLength( const Value* mask ) const ; 45 : public: 46 : static void registerKeywords( Keywords& keys ); 47 : /// Constructor 48 : explicit SelectWithMask(const ActionOptions&); 49 : /// Get the number of derivatives 50 98 : unsigned getNumberOfDerivatives() override { 51 98 : return 0; 52 : } 53 : /// 54 : void getMatrixColumnTitles( std::vector<std::string>& argnames ) const override ; 55 : /// 56 : void prepare() override ; 57 : /// Do the calculation 58 : void calculate() override; 59 : /// 60 : void apply() override; 61 : }; 62 : 63 : PLUMED_REGISTER_ACTION(SelectWithMask,"SELECT_WITH_MASK") 64 : 65 178 : void SelectWithMask::registerKeywords( Keywords& keys ) { 66 178 : Action::registerKeywords( keys ); 67 178 : ActionWithValue::registerKeywords( keys ); 68 178 : ActionWithArguments::registerKeywords( keys ); 69 178 : keys.use("ARG"); 70 356 : keys.add("optional","ROW_MASK","an array with ones in the rows of the matrix that you want to discard"); 71 356 : keys.add("optional","COLUMN_MASK","an array with ones in the columns of the matrix that you want to discard"); 72 356 : keys.add("compulsory","MASK","an array with ones in the components that you want to discard"); 73 178 : keys.setValueDescription("a vector/matrix of values that is obtained using a mask to select elements of interest"); 74 178 : } 75 : 76 93 : SelectWithMask::SelectWithMask(const ActionOptions& ao): 77 : Action(ao), 78 : ActionWithValue(ao), 79 93 : ActionWithArguments(ao) { 80 93 : if( getNumberOfArguments()!=1 ) { 81 0 : error("should only be one argument for this action"); 82 : } 83 93 : getPntrToArgument(0)->buildDataStore(); 84 : std::vector<unsigned> shape; 85 93 : if( getPntrToArgument(0)->getRank()==1 ) { 86 : std::vector<Value*> mask; 87 136 : parseArgumentList("MASK",mask); 88 68 : if( mask.size()!=1 ) { 89 0 : error("should only be one input for mask"); 90 : } 91 68 : if( mask[0]->getNumberOfValues()!=getPntrToArgument(0)->getNumberOfValues() ) { 92 0 : error("mismatch between size of mask and input vector"); 93 : } 94 68 : log.printf(" creating vector from elements of %s who have a corresponding element in %s that is zero\n", getPntrToArgument(0)->getName().c_str(), mask[0]->getName().c_str() ); 95 68 : std::vector<Value*> args( getArguments() ); 96 68 : args.push_back( mask[0] ); 97 68 : requestArguments( args ); 98 68 : shape.resize(1,0); 99 68 : if( (mask[0]->getPntrToAction())->getName()=="CONSTANT" ) { 100 57 : shape[0]=getOutputVectorLength(mask[0]); 101 : } 102 25 : } else if( getPntrToArgument(0)->getRank()==2 ) { 103 : std::vector<Value*> rmask, cmask; 104 25 : parseArgumentList("ROW_MASK",rmask); 105 50 : parseArgumentList("COLUMN_MASK",cmask); 106 25 : if( rmask.size()==0 && cmask.size()==0 ) { 107 0 : error("no mask elements have been specified"); 108 25 : } else if( cmask.size()==0 ) { 109 11 : std::string con="0"; 110 144 : for(unsigned i=1; i<getPntrToArgument(0)->getShape()[1]; ++i) { 111 : con += ",0"; 112 : } 113 11 : plumed.readInputWords( Tools::getWords(getLabel() + "_colmask: CONSTANT VALUES=" + con), false ); 114 11 : std::vector<std::string> labs(1, getLabel() + "_colmask"); 115 11 : ActionWithArguments::interpretArgumentList( labs, plumed.getActionSet(), this, cmask ); 116 25 : } else if( rmask.size()==0 ) { 117 1 : std::string con="0"; 118 13 : for(unsigned i=1; i<getPntrToArgument(0)->getShape()[0]; ++i) { 119 : con += ",0"; 120 : } 121 1 : plumed.readInputWords( Tools::getWords(getLabel() + "_rowmask: CONSTANT VALUES=" + con), false ); 122 1 : std::vector<std::string> labs(1, getLabel() + "_rowmask"); 123 1 : ActionWithArguments::interpretArgumentList( labs, plumed.getActionSet(), this, rmask ); 124 1 : } 125 25 : shape.resize(2); 126 25 : rmask[0]->buildDataStore(); 127 25 : shape[0] = getOutputVectorLength( rmask[0] ); 128 25 : cmask[0]->buildDataStore(); 129 25 : shape[1] = getOutputVectorLength( cmask[0] ); 130 25 : std::vector<Value*> args( getArguments() ); 131 25 : args.push_back( rmask[0] ); 132 25 : args.push_back( cmask[0] ); 133 25 : requestArguments( args ); 134 : } else { 135 0 : error("input should be vector or matrix"); 136 : } 137 : 138 93 : addValue( shape ); 139 93 : getPntrToComponent(0)->buildDataStore(); 140 93 : if( getPntrToArgument(0)->isPeriodic() ) { 141 : std::string min, max; 142 7 : getPntrToArgument(0)->getDomain( min, max ); 143 7 : setPeriodic( min, max ); 144 : } else { 145 86 : setNotPeriodic(); 146 : } 147 93 : if( getPntrToComponent(0)->getRank()==2 ) { 148 25 : getPntrToComponent(0)->reshapeMatrixStore( shape[1] ); 149 : } 150 93 : } 151 : 152 10693 : unsigned SelectWithMask::getOutputVectorLength( const Value* mask ) const { 153 : unsigned l=0; 154 154139 : for(unsigned i=0; i<mask->getNumberOfValues(); ++i) { 155 143446 : if( fabs(mask->get(i))>0 ) { 156 10011 : continue; 157 : } 158 133435 : l++; 159 : } 160 10693 : return l; 161 : } 162 : 163 18 : void SelectWithMask::getMatrixColumnTitles( std::vector<std::string>& argnames ) const { 164 : std::vector<std::string> alltitles; 165 18 : (getPntrToArgument(0)->getPntrToAction())->getMatrixColumnTitles( alltitles ); 166 103 : for(unsigned i=0; i<alltitles.size(); ++i) { 167 85 : if( fabs(getPntrToArgument(2)->get(i))>0 ) { 168 34 : continue; 169 : } 170 51 : argnames.push_back( alltitles[i] ); 171 : } 172 18 : } 173 : 174 10551 : void SelectWithMask::prepare() { 175 : Value* arg = getPntrToArgument(0); 176 10551 : Value* out = getPntrToComponent(0); 177 10551 : if( arg->getRank()==1 ) { 178 : Value* mask = getPntrToArgument(1); 179 10516 : std::vector<unsigned> shape(1); 180 10516 : shape[0]=getOutputVectorLength( mask ); 181 10516 : if( out->getNumberOfValues()!=shape[0] ) { 182 19 : if( shape[0]==1 ) { 183 0 : shape.resize(0); 184 : } 185 19 : out->setShape(shape); 186 : } 187 35 : } else if( arg->getRank()==2 ) { 188 35 : std::vector<unsigned> outshape(2); 189 : Value* rmask = getPntrToArgument(1); 190 35 : outshape[0] = getOutputVectorLength( rmask ); 191 : Value* cmask = getPntrToArgument(2); 192 35 : outshape[1] = getOutputVectorLength( cmask ); 193 35 : if( out->getShape()[0]!=outshape[0] || out->getShape()[1]!=outshape[1] ) { 194 20 : out->setShape(outshape); 195 20 : out->reshapeMatrixStore( outshape[1] ); 196 : } 197 : } 198 10551 : } 199 : 200 10543 : void SelectWithMask::calculate() { 201 : Value* arg = getPntrToArgument(0); 202 10543 : Value* out = getPntrToComponent(0); 203 10543 : if( arg->getRank()==1 ) { 204 : Value* mask = getPntrToArgument(1); 205 : unsigned n=0; 206 149144 : for(unsigned i=0; i<mask->getNumberOfValues(); ++i) { 207 138632 : if( fabs(mask->get(i))>0 ) { 208 7434 : continue; 209 : } 210 131198 : out->set(n, arg->get(i) ); 211 131198 : n++; 212 : } 213 31 : } else if ( arg->getRank()==2 ) { 214 31 : std::vector<unsigned> outshape( out->getShape() ); 215 : unsigned n = 0; 216 31 : std::vector<unsigned> inshape( arg->getShape() ); 217 : Value* rmask = getPntrToArgument(1); 218 : Value* cmask = getPntrToArgument(2); 219 1774 : for(unsigned i=0; i<inshape[0]; ++i) { 220 1743 : if( fabs(rmask->get(i))>0 ) { 221 592 : continue; 222 : } 223 : unsigned m = 0; 224 378651 : for(unsigned j=0; j<inshape[1]; ++j) { 225 377500 : if( fabs(cmask->get(j))>0 ) { 226 188095 : continue; 227 : } 228 189405 : out->set( n*outshape[1] + m, arg->get(i*inshape[1] + j) ); 229 189405 : m++; 230 : } 231 1151 : n++; 232 : } 233 : } 234 10543 : } 235 : 236 10505 : void SelectWithMask::apply() { 237 10505 : if( doNotCalculateDerivatives() || !getPntrToComponent(0)->forcesWereAdded() ) { 238 62 : return ; 239 : } 240 : 241 : Value* arg = getPntrToArgument(0); 242 10443 : Value* out = getPntrToComponent(0); 243 10443 : if( arg->getRank()==1 ) { 244 : unsigned n=0; 245 : Value* mask = getPntrToArgument(1); 246 145276 : for(unsigned i=0; i<mask->getNumberOfValues(); ++i) { 247 134833 : if( fabs(mask->get(i))>0 ) { 248 4153 : continue; 249 : } 250 130680 : arg->addForce(i, out->getForce(n) ); 251 130680 : n++; 252 : } 253 0 : } else if( arg->getRank()==2 ) { 254 : unsigned n = 0; 255 0 : std::vector<unsigned> inshape( arg->getShape() ); 256 0 : std::vector<unsigned> outshape( out->getShape() ); 257 : Value* rmask = getPntrToArgument(1); 258 : Value* cmask = getPntrToArgument(2); 259 0 : for(unsigned i=0; i<inshape[0]; ++i) { 260 0 : if( fabs(rmask->get(i))>0 ) { 261 0 : continue; 262 : } 263 : unsigned m = 0; 264 0 : for(unsigned j=0; j<inshape[1]; ++j) { 265 0 : if( fabs(cmask->get(j))>0 ) { 266 0 : continue; 267 : } 268 0 : arg->addForce( i*inshape[1] + j, out->getForce(n*outshape[1] + m) ); 269 0 : m++; 270 : } 271 0 : n++; 272 : } 273 : } 274 : } 275 : 276 : 277 : 278 : } 279 : }