Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2014-2017 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "core/ActionWithValue.h"
23 : #include "core/ActionWithArguments.h"
24 : #include "core/ActionRegister.h"
25 : #include "core/PlumedMain.h"
26 : #include "core/ActionSet.h"
27 :
28 : //+PLUMEDOC PRINTANALYSIS SELECT_WITH_MASK
29 : /*
30 : Use a mask to select elements of an array
31 :
32 : Output a scalar, vector or matrix that contains a subset of the elements in the input vector or matrix.
33 : The following example shows how we can output a scalar, `v`, that contains the distance between and 3 and 4
34 : by using the mask vector `m` to select this element from the three element vector `d`:
35 :
36 : ```plumed
37 : d: DISTANCE ATOMS1=1,2 ATOMS2=3,4 ATOMS3=5,6
38 : m: CONSTANT VALUES=1,0,1
39 : v: SELECT_WITH_MASK ARG=d MASK=m
40 : ```
41 :
42 : The value, `m`, that is passed to the keyword MASK here is a vector with the same length as `d`.
43 : Elements of `d` that whose corresponding elements in `m` are zero are copied to the output value `v`.
44 : When elements of `m` are non-zero the corresponding elements in `d` are not transferred to the output
45 : value - they are masked.
46 :
47 : If you use this action with matrices you must use the keywords `ROW_MASK` and `COLUMN_MASK`. As shown in the example
48 : inputs below, these keywords take vectors as input. In this first example, the output matrix is $3 \times 5$ as rows
49 : of the matrix whose corresponding elements in `m` are non-zero are not transferred:
50 :
51 : ```plumed
52 : d: DISTANCE_MATRIX GROUP=1-5
53 : m: CONSTANT VALUES=0,1,1,0,0
54 : v: SELECT_WITH_MASK ARG=d ROW_MASK=m
55 : ```
56 :
57 : For this second example the output matrix is $5 \times 3$ as columns of the matrix whose corresponding elements in `m` are non-zero
58 : are not transferred:
59 :
60 : ```plumed
61 : d: DISTANCE_MATRIX GROUP=1-5
62 : m: CONSTANT VALUES=0,1,1,0,0
63 : v: SELECT_WITH_MASK ARG=d COLUMN_MASK=m
64 : ```
65 :
66 : For this final example the output matrix is $3 \times 3$ as we do not transfer the rows and the columns in `d` whose corresponding
67 : elements in `m` are non-zero.
68 :
69 : ```plumed
70 : d: DISTANCE_MATRIX GROUP=1-5
71 : m: CONSTANT VALUES=0,1,1,0,0
72 : v: SELECT_WITH_MASK ARG=d ROW_MASK=m COLUMN_MASK=m
73 : ```
74 :
75 : */
76 : //+ENDPLUMEDOC
77 :
78 : namespace PLMD {
79 : namespace valtools {
80 :
81 : class SelectWithMask :
82 : public ActionWithValue,
83 : public ActionWithArguments {
84 : private:
85 : unsigned getOutputVectorLength( const Value* mask ) const ;
86 : public:
87 : static void registerKeywords( Keywords& keys );
88 : /// Constructor
89 : explicit SelectWithMask(const ActionOptions&);
90 : /// Get the number of derivatives
91 98 : unsigned getNumberOfDerivatives() override {
92 98 : return 0;
93 : }
94 : ///
95 : void getMatrixColumnTitles( std::vector<std::string>& argnames ) const override ;
96 : ///
97 : void prepare() override ;
98 : /// Do the calculation
99 : void calculate() override;
100 : ///
101 : void apply() override;
102 : };
103 :
104 : PLUMED_REGISTER_ACTION(SelectWithMask,"SELECT_WITH_MASK")
105 :
106 178 : void SelectWithMask::registerKeywords( Keywords& keys ) {
107 178 : Action::registerKeywords( keys );
108 178 : ActionWithValue::registerKeywords( keys );
109 178 : ActionWithArguments::registerKeywords( keys );
110 356 : keys.addInputKeyword("compulsory","ARG","scalar/vector/matrix","the label for the value upon which you are going to apply the mask");
111 356 : keys.addInputKeyword("optional","ROW_MASK","vector","an array with ones in the rows of the matrix that you want to discard");
112 356 : keys.addInputKeyword("optional","COLUMN_MASK","vector","an array with ones in the columns of the matrix that you want to discard");
113 356 : keys.addInputKeyword("compulsory","MASK","vector/matrix","an array with ones in the components that you want to discard");
114 356 : keys.setValueDescription("vector/matrix","a vector/matrix of values that is obtained using a mask to select elements of interest");
115 178 : keys.remove("NUMERICAL_DERIVATIVES");
116 178 : }
117 :
118 93 : SelectWithMask::SelectWithMask(const ActionOptions& ao):
119 : Action(ao),
120 : ActionWithValue(ao),
121 93 : ActionWithArguments(ao) {
122 93 : if( getNumberOfArguments()!=1 ) {
123 0 : error("should only be one argument for this action");
124 : }
125 : std::vector<std::size_t> shape;
126 93 : if( getPntrToArgument(0)->getRank()==1 ) {
127 : std::vector<Value*> mask;
128 136 : parseArgumentList("MASK",mask);
129 68 : if( mask.size()!=1 ) {
130 0 : error("should only be one input for mask");
131 : }
132 68 : if( mask[0]->getNumberOfValues()!=getPntrToArgument(0)->getNumberOfValues() ) {
133 0 : error("mismatch between size of mask and input vector");
134 : }
135 68 : log.printf(" creating vector from elements of %s who have a corresponding element in %s that is zero\n", getPntrToArgument(0)->getName().c_str(), mask[0]->getName().c_str() );
136 68 : std::vector<Value*> args( getArguments() );
137 68 : args.push_back( mask[0] );
138 68 : requestArguments( args );
139 68 : shape.resize(1,0);
140 68 : shape[0]=getOutputVectorLength(mask[0]);
141 25 : } else if( getPntrToArgument(0)->getRank()==2 ) {
142 : std::vector<Value*> rmask, cmask;
143 25 : parseArgumentList("ROW_MASK",rmask);
144 50 : parseArgumentList("COLUMN_MASK",cmask);
145 25 : if( rmask.size()==0 && cmask.size()==0 ) {
146 0 : error("no mask elements have been specified");
147 25 : } else if( cmask.size()==0 ) {
148 11 : std::string con="0";
149 144 : for(unsigned i=1; i<getPntrToArgument(0)->getShape()[1]; ++i) {
150 : con += ",0";
151 : }
152 11 : plumed.readInputWords( Tools::getWords(getLabel() + "_colmask: CONSTANT VALUES=" + con), false );
153 11 : std::vector<std::string> labs(1, getLabel() + "_colmask");
154 11 : ActionWithArguments::interpretArgumentList( labs, plumed.getActionSet(), this, cmask );
155 25 : } else if( rmask.size()==0 ) {
156 1 : std::string con="0";
157 13 : for(unsigned i=1; i<getPntrToArgument(0)->getShape()[0]; ++i) {
158 : con += ",0";
159 : }
160 1 : plumed.readInputWords( Tools::getWords(getLabel() + "_rowmask: CONSTANT VALUES=" + con), false );
161 1 : std::vector<std::string> labs(1, getLabel() + "_rowmask");
162 1 : ActionWithArguments::interpretArgumentList( labs, plumed.getActionSet(), this, rmask );
163 1 : }
164 25 : shape.resize(2);
165 25 : shape[0] = getOutputVectorLength( rmask[0] );
166 25 : shape[1] = getOutputVectorLength( cmask[0] );
167 25 : std::vector<Value*> args( getArguments() );
168 25 : args.push_back( rmask[0] );
169 25 : args.push_back( cmask[0] );
170 25 : requestArguments( args );
171 : } else {
172 0 : error("input should be vector or matrix");
173 : }
174 :
175 93 : addValue( shape );
176 93 : if( getPntrToArgument(0)->isPeriodic() ) {
177 : std::string min, max;
178 7 : getPntrToArgument(0)->getDomain( min, max );
179 7 : setPeriodic( min, max );
180 : } else {
181 86 : setNotPeriodic();
182 : }
183 93 : if( getPntrToComponent(0)->getRank()==2 ) {
184 25 : getPntrToComponent(0)->reshapeMatrixStore( shape[1] );
185 : }
186 93 : }
187 :
188 10704 : unsigned SelectWithMask::getOutputVectorLength( const Value* mask ) const {
189 : unsigned l=0;
190 154174 : for(unsigned i=0; i<mask->getNumberOfValues(); ++i) {
191 143470 : if( fabs(mask->get(i))>0 ) {
192 10015 : continue;
193 : }
194 133455 : l++;
195 : }
196 10704 : return l;
197 : }
198 :
199 18 : void SelectWithMask::getMatrixColumnTitles( std::vector<std::string>& argnames ) const {
200 : std::vector<std::string> alltitles;
201 18 : (getPntrToArgument(0)->getPntrToAction())->getMatrixColumnTitles( alltitles );
202 103 : for(unsigned i=0; i<alltitles.size(); ++i) {
203 85 : if( fabs(getPntrToArgument(2)->get(i))>0 ) {
204 34 : continue;
205 : }
206 51 : argnames.push_back( alltitles[i] );
207 : }
208 18 : }
209 :
210 10551 : void SelectWithMask::prepare() {
211 : Value* arg = getPntrToArgument(0);
212 10551 : Value* out = getPntrToComponent(0);
213 10551 : if( arg->getRank()==1 ) {
214 : Value* mask = getPntrToArgument(1);
215 10516 : std::vector<std::size_t> shape(1);
216 10516 : shape[0]=getOutputVectorLength( mask );
217 10516 : if( out->getNumberOfValues()!=shape[0] ) {
218 19 : if( shape[0]==1 ) {
219 0 : shape.resize(0);
220 : }
221 19 : out->setShape(shape);
222 : }
223 35 : } else if( arg->getRank()==2 ) {
224 35 : std::vector<std::size_t> outshape(2);
225 : Value* rmask = getPntrToArgument(1);
226 35 : outshape[0] = getOutputVectorLength( rmask );
227 : Value* cmask = getPntrToArgument(2);
228 35 : outshape[1] = getOutputVectorLength( cmask );
229 35 : if( out->getShape()[0]!=outshape[0] || out->getShape()[1]!=outshape[1] ) {
230 19 : out->setShape(outshape);
231 19 : out->reshapeMatrixStore( outshape[1] );
232 : }
233 : }
234 10551 : }
235 :
236 10543 : void SelectWithMask::calculate() {
237 : Value* arg = getPntrToArgument(0);
238 10543 : Value* out = getPntrToComponent(0);
239 10543 : if( arg->getRank()==1 ) {
240 : Value* mask = getPntrToArgument(1);
241 : unsigned n=0;
242 149144 : for(unsigned i=0; i<mask->getNumberOfValues(); ++i) {
243 138632 : if( fabs(mask->get(i))>0 ) {
244 7434 : continue;
245 : }
246 131198 : out->set(n, arg->get(i) );
247 131198 : n++;
248 : }
249 31 : } else if ( arg->getRank()==2 ) {
250 31 : std::vector<std::size_t> outshape( out->getShape() );
251 : unsigned n = 0;
252 31 : std::vector<std::size_t> inshape( arg->getShape() );
253 : Value* rmask = getPntrToArgument(1);
254 : Value* cmask = getPntrToArgument(2);
255 1774 : for(unsigned i=0; i<inshape[0]; ++i) {
256 1743 : if( fabs(rmask->get(i))>0 ) {
257 592 : continue;
258 : }
259 : unsigned m = 0;
260 378651 : for(unsigned j=0; j<inshape[1]; ++j) {
261 377500 : if( fabs(cmask->get(j))>0 ) {
262 188095 : continue;
263 : }
264 189405 : out->set( n*outshape[1] + m, arg->get(i*inshape[1] + j) );
265 189405 : m++;
266 : }
267 1151 : n++;
268 : }
269 : }
270 10543 : }
271 :
272 10505 : void SelectWithMask::apply() {
273 10505 : if( doNotCalculateDerivatives() || !getPntrToComponent(0)->forcesWereAdded() ) {
274 62 : return ;
275 : }
276 :
277 : Value* arg = getPntrToArgument(0);
278 10443 : Value* out = getPntrToComponent(0);
279 10443 : if( arg->getRank()==1 ) {
280 : unsigned n=0;
281 : Value* mask = getPntrToArgument(1);
282 145276 : for(unsigned i=0; i<mask->getNumberOfValues(); ++i) {
283 134833 : if( fabs(mask->get(i))>0 ) {
284 4153 : continue;
285 : }
286 130680 : arg->addForce(i, out->getForce(n) );
287 130680 : n++;
288 : }
289 0 : } else if( arg->getRank()==2 ) {
290 : unsigned n = 0;
291 0 : std::vector<std::size_t> inshape( arg->getShape() );
292 0 : std::vector<std::size_t> outshape( out->getShape() );
293 : Value* rmask = getPntrToArgument(1);
294 : Value* cmask = getPntrToArgument(2);
295 0 : for(unsigned i=0; i<inshape[0]; ++i) {
296 0 : if( fabs(rmask->get(i))>0 ) {
297 0 : continue;
298 : }
299 : unsigned m = 0;
300 0 : for(unsigned j=0; j<inshape[1]; ++j) {
301 0 : if( fabs(cmask->get(j))>0 ) {
302 0 : continue;
303 : }
304 0 : arg->addForce( i*inshape[1] + j, out->getForce(n*outshape[1] + m) );
305 0 : m++;
306 : }
307 0 : n++;
308 : }
309 : }
310 : }
311 :
312 :
313 :
314 : }
315 : }
|