Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2012-2023 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "ActionWithVessel.h"
23 : #include "tools/Communicator.h"
24 : #include "Vessel.h"
25 : #include "ShortcutVessel.h"
26 : #include "StoreDataVessel.h"
27 : #include "VesselRegister.h"
28 : #include "BridgeVessel.h"
29 : #include "FunctionVessel.h"
30 : #include "StoreDataVessel.h"
31 : #include "tools/OpenMP.h"
32 : #include "tools/Stopwatch.h"
33 :
34 : namespace PLMD {
35 : namespace vesselbase {
36 :
37 1040 : void ActionWithVessel::registerKeywords(Keywords& keys) {
38 2080 : keys.add("hidden","TOL","this keyword can be used to speed up your calculation. When accumulating sums in which the individual "
39 : "terms are numbers in between zero and one it is assumed that terms less than a certain tolerance "
40 : "make only a small contribution to the sum. They can thus be safely ignored as can the the derivatives "
41 : "wrt these small quantities.");
42 2080 : keys.add("hidden","MAXDERIVATIVES","The maximum number of derivatives that can be used when storing data. This controls when "
43 : "we have to start using lowmem");
44 2080 : keys.addFlag("SERIAL",false,"do the calculation in serial. Do not use MPI");
45 2080 : keys.addFlag("LOWMEM",false,"lower the memory requirements");
46 2080 : keys.addFlag("TIMINGS",false,"output information on the timings of the various parts of the calculation");
47 2080 : keys.reserveFlag("HIGHMEM",false,"use a more memory intensive version of this collective variable");
48 1040 : keys.add( vesselRegister().getKeywords() );
49 1040 : }
50 :
51 588 : ActionWithVessel::ActionWithVessel(const ActionOptions&ao):
52 : Action(ao),
53 588 : serial(false),
54 588 : lowmem(false),
55 588 : noderiv(true),
56 588 : actionIsBridged(false),
57 588 : nactive_tasks(0),
58 588 : dertime_can_be_off(false),
59 588 : dertime(true),
60 588 : contributorsAreUnlocked(false),
61 588 : weightHasDerivatives(false),
62 588 : mydata(NULL) {
63 588 : maxderivatives=309;
64 588 : parse("MAXDERIVATIVES",maxderivatives);
65 1176 : if( keywords.exists("SERIAL") ) {
66 1110 : parseFlag("SERIAL",serial);
67 : } else {
68 33 : serial=true;
69 : }
70 588 : if(serial) {
71 33 : log.printf(" doing calculation in serial\n");
72 : }
73 1176 : if( keywords.exists("LOWMEM") ) {
74 1112 : plumed_assert( !keywords.exists("HIGHMEM") );
75 556 : parseFlag("LOWMEM",lowmem);
76 556 : if(lowmem) {
77 29 : log.printf(" lowering memory requirements\n");
78 29 : dertime_can_be_off=true;
79 : }
80 : }
81 1176 : if( keywords.exists("HIGHMEM") ) {
82 46 : plumed_assert( !keywords.exists("LOWMEM") );
83 : bool highmem;
84 23 : parseFlag("HIGHMEM",highmem);
85 23 : lowmem=!highmem;
86 23 : if(!lowmem) {
87 2 : log.printf(" increasing the memory requirements\n");
88 : }
89 : }
90 588 : tolerance=nl_tolerance=epsilon;
91 1176 : if( keywords.exists("TOL") ) {
92 954 : parse("TOL",tolerance);
93 : }
94 588 : if( tolerance>epsilon) {
95 6 : log.printf(" Ignoring contributions less than %f \n",tolerance);
96 : }
97 588 : parseFlag("TIMINGS",timers);
98 588 : stopwatch.start();
99 588 : stopwatch.pause();
100 588 : }
101 :
102 588 : ActionWithVessel::~ActionWithVessel() {
103 588 : stopwatch.start();
104 588 : stopwatch.stop();
105 588 : if(timers) {
106 0 : log.printf("timings for action %s with label %s \n", getName().c_str(), getLabel().c_str() );
107 0 : log<<stopwatch;
108 : }
109 588 : }
110 :
111 377 : void ActionWithVessel::addVessel( const std::string& name, const std::string& input, const int numlab ) {
112 377 : VesselOptions da(name,"",numlab,input,this);
113 754 : auto vv=vesselRegister().create(name,da);
114 377 : FunctionVessel* fv=dynamic_cast<FunctionVessel*>(vv.get());
115 377 : if( fv ) {
116 347 : std::string mylabel=Vessel::transformName( name );
117 347 : plumed_massert( keywords.outputComponentExists(mylabel,false), "a description of the value calculated by vessel " + name + " has not been added to the manual");
118 : }
119 377 : addVessel(std::move(vv));
120 377 : }
121 :
122 573 : void ActionWithVessel::addVessel( std::unique_ptr<Vessel> vv_ptr ) {
123 :
124 : // In the original code, the dynamically casted pointer was deleted here.
125 : // Now that vv_ptr is a unique_ptr, the object will be deleted automatically when
126 : // exiting this routine.
127 573 : if(dynamic_cast<ShortcutVessel*>(vv_ptr.get())) {
128 : return;
129 : }
130 :
131 562 : vv_ptr->checkRead();
132 :
133 562 : StoreDataVessel* mm=dynamic_cast<StoreDataVessel*>( vv_ptr.get() );
134 562 : if( mydata && mm ) {
135 0 : error("cannot have more than one StoreDataVessel in one action");
136 562 : } else if( mm ) {
137 131 : mydata=mm;
138 : } else {
139 431 : dertime_can_be_off=false;
140 : }
141 :
142 : // Ownership is transferred to functions
143 562 : functions.emplace_back(std::move(vv_ptr));
144 : }
145 :
146 46 : BridgeVessel* ActionWithVessel::addBridgingVessel( ActionWithVessel* tome ) {
147 92 : VesselOptions da("","",0,"",this);
148 46 : auto bv=Tools::make_unique<BridgeVessel>(da);
149 46 : bv->setOutputAction( tome );
150 46 : tome->actionIsBridged=true;
151 46 : dertime_can_be_off=false;
152 : // store this pointer in order to return it later.
153 : // notice that I cannot access this with functions.tail().get()
154 : // since functions contains pointers to a different class (Vessel)
155 : auto toBeReturned=bv.get();
156 46 : functions.emplace_back( std::move(bv) );
157 46 : resizeFunctions();
158 46 : return toBeReturned;
159 46 : }
160 :
161 192 : StoreDataVessel* ActionWithVessel::buildDataStashes( ActionWithVessel* actionThatUses ) {
162 192 : if(mydata) {
163 87 : if( actionThatUses ) {
164 74 : mydata->addActionThatUses( actionThatUses );
165 : }
166 87 : return mydata;
167 : }
168 :
169 210 : VesselOptions da("","",0,"",this);
170 105 : auto mm=Tools::make_unique<StoreDataVessel>(da);
171 105 : if( actionThatUses ) {
172 53 : mm->addActionThatUses( actionThatUses );
173 : }
174 105 : addVessel(std::move(mm));
175 :
176 : // Make sure resizing of vessels is done
177 105 : resizeFunctions();
178 :
179 105 : return mydata;
180 105 : }
181 :
182 12749659 : void ActionWithVessel::addTaskToList( const unsigned& taskCode ) {
183 12749659 : fullTaskList.push_back( taskCode );
184 12749659 : taskFlags.push_back(0);
185 12749659 : plumed_assert( fullTaskList.size()==taskFlags.size() );
186 12749659 : }
187 :
188 435 : void ActionWithVessel::readVesselKeywords() {
189 : // Set maxderivatives if it is too big
190 435 : if( maxderivatives>getNumberOfDerivatives() ) {
191 214 : maxderivatives=getNumberOfDerivatives();
192 : }
193 :
194 : // Loop over all keywords find the vessels and create appropriate functions
195 10586 : for(unsigned i=0; i<keywords.size(); ++i) {
196 : std::string thiskey,input;
197 10151 : thiskey=keywords.getKeyword(i);
198 : // Check if this is a key for a vessel
199 10151 : if( vesselRegister().check(thiskey) ) {
200 5706 : plumed_assert( keywords.style(thiskey,"vessel") );
201 2853 : bool dothis=false;
202 2853 : parseFlag(thiskey,dothis);
203 2853 : if(dothis) {
204 119 : addVessel( thiskey, input );
205 : }
206 :
207 2853 : parse(thiskey,input);
208 2853 : if(input.size()!=0) {
209 143 : addVessel( thiskey, input );
210 : } else {
211 2710 : for(unsigned i=1;; ++i) {
212 2735 : if( !parseNumbered(thiskey,i,input) ) {
213 : break;
214 : }
215 : std::string ss;
216 25 : Tools::convert(i,ss);
217 25 : addVessel( thiskey, input, i );
218 : input.clear();
219 25 : }
220 : }
221 : }
222 : }
223 :
224 : // Make sure all vessels have had been resized at start
225 435 : if( functions.size()>0 ) {
226 307 : resizeFunctions();
227 : }
228 435 : }
229 :
230 1441 : void ActionWithVessel::resizeFunctions() {
231 3680 : for(unsigned i=0; i<functions.size(); ++i) {
232 2239 : functions[i]->resize();
233 : }
234 1441 : }
235 :
236 860 : void ActionWithVessel::needsDerivatives() {
237 : // Turn on the derivatives and resize
238 860 : noderiv=false;
239 860 : resizeFunctions();
240 : // Setting contributors unlocked here ensures that link cells are ignored
241 : contributorsAreUnlocked=true;
242 860 : contributorsAreUnlocked=false;
243 : // And turn on the derivatives in all actions on which we are dependent
244 1158 : for(unsigned i=0; i<getDependencies().size(); ++i) {
245 298 : ActionWithVessel* vv=dynamic_cast<ActionWithVessel*>( getDependencies()[i] );
246 298 : if(vv) {
247 284 : vv->needsDerivatives();
248 : }
249 : }
250 860 : }
251 :
252 3343 : void ActionWithVessel::lockContributors() {
253 3343 : nactive_tasks = 0;
254 18625172 : for(unsigned i=0; i<fullTaskList.size(); ++i) {
255 18621829 : if( taskFlags[i]>0 ) {
256 5614828 : nactive_tasks++;
257 : }
258 : }
259 :
260 : unsigned n=0;
261 3343 : partialTaskList.resize( nactive_tasks );
262 3343 : indexOfTaskInFullList.resize( nactive_tasks );
263 18625172 : for(unsigned i=0; i<fullTaskList.size(); ++i) {
264 : // Deactivate sets inactive tasks to number not equal to zero
265 18621829 : if( taskFlags[i]>0 ) {
266 5614828 : partialTaskList[n] = fullTaskList[i];
267 5614828 : indexOfTaskInFullList[n]=i;
268 5614828 : n++;
269 : }
270 : }
271 : plumed_dbg_assert( n==nactive_tasks );
272 8250 : for(unsigned i=0; i<functions.size(); ++i) {
273 4907 : BridgeVessel* bb = dynamic_cast<BridgeVessel*>( functions[i].get() );
274 4907 : if( bb ) {
275 795 : bb->copyTaskFlags();
276 : }
277 : }
278 : // Resize mydata to accommodate all active tasks
279 3343 : if( mydata ) {
280 1090 : mydata->resize();
281 : }
282 3343 : contributorsAreUnlocked=false;
283 3343 : }
284 :
285 3343 : void ActionWithVessel::deactivateAllTasks() {
286 3343 : contributorsAreUnlocked=true;
287 3343 : nactive_tasks = 0;
288 3343 : taskFlags.assign(taskFlags.size(),0);
289 3343 : }
290 :
291 213721 : bool ActionWithVessel::taskIsCurrentlyActive( const unsigned& index ) const {
292 : plumed_dbg_assert( index<taskFlags.size() );
293 213721 : return (taskFlags[index]>0);
294 : }
295 :
296 23099 : void ActionWithVessel::doJobsRequiredBeforeTaskList() {
297 : // Do any preparatory stuff for functions
298 58476 : for(unsigned j=0; j<functions.size(); ++j) {
299 35377 : functions[j]->prepare();
300 : }
301 23099 : }
302 :
303 23612 : unsigned ActionWithVessel::getSizeOfBuffer( unsigned& bufsize ) {
304 59665 : for(unsigned i=0; i<functions.size(); ++i) {
305 36053 : functions[i]->setBufferStart( bufsize );
306 : }
307 23612 : if( buffer.size()!=bufsize ) {
308 521 : buffer.resize( bufsize );
309 : }
310 23612 : if( mydata ) {
311 : unsigned dsize=mydata->getSizeOfDerivativeList();
312 2554 : if( der_list.size()!=dsize ) {
313 113 : der_list.resize( dsize );
314 : }
315 : }
316 23612 : return bufsize;
317 : }
318 :
319 20074 : void ActionWithVessel::runAllTasks() {
320 20074 : plumed_massert( !contributorsAreUnlocked && functions.size()>0, "you must have a call to readVesselKeywords somewhere" );
321 20074 : unsigned stride=comm.Get_size();
322 20074 : unsigned rank=comm.Get_rank();
323 20074 : if(serial) {
324 : stride=1;
325 : rank=0;
326 : }
327 :
328 : // Make sure jobs are done
329 20074 : if(timers) {
330 0 : stopwatch.start("1 Prepare Tasks");
331 : }
332 20074 : doJobsRequiredBeforeTaskList();
333 20074 : if(timers) {
334 0 : stopwatch.stop("1 Prepare Tasks");
335 : }
336 :
337 : // Get number of threads for OpenMP
338 20074 : unsigned nt=OpenMP::getNumThreads();
339 20074 : if( nt*stride*2>nactive_tasks || !threadSafe()) {
340 : nt=1;
341 : }
342 :
343 : // Get size for buffer
344 20074 : unsigned bsize=0, bufsize=getSizeOfBuffer( bsize );
345 : // Clear buffer
346 20074 : buffer.assign( buffer.size(), 0.0 );
347 : // Switch off calculation of derivatives in main loop
348 20074 : if( dertime_can_be_off ) {
349 37 : dertime=false;
350 : }
351 :
352 20074 : if(timers) {
353 0 : stopwatch.start("2 Loop over tasks");
354 : }
355 20074 : #pragma omp parallel num_threads(nt)
356 : {
357 : std::vector<double> omp_buffer;
358 : if( nt>1 ) {
359 : omp_buffer.resize( bufsize, 0.0 );
360 : }
361 : MultiValue myvals( getNumberOfQuantities(), getNumberOfDerivatives() );
362 : MultiValue bvals( getNumberOfQuantities(), getNumberOfDerivatives() );
363 : myvals.clearAll();
364 : bvals.clearAll();
365 :
366 : #pragma omp for nowait schedule(dynamic)
367 : for(unsigned i=rank; i<nactive_tasks; i+=stride) {
368 : // Calculate the stuff in the loop for this action
369 : performTask( indexOfTaskInFullList[i], partialTaskList[i], myvals );
370 :
371 : // Check for conditions that allow us to just to skip the calculation
372 : // the condition is that the weight of the contribution is low
373 : // N.B. Here weights are assumed to be between zero and one
374 : if( myvals.get(0)<tolerance ) {
375 : // Clear the derivatives
376 : myvals.clearAll();
377 : continue;
378 : }
379 :
380 : // Now calculate all the functions
381 : // If the contribution of this quantity is very small at neighbour list time ignore it
382 : // until next neighbour list time
383 : if( nt>1 ) {
384 : calculateAllVessels( indexOfTaskInFullList[i], myvals, bvals, omp_buffer, der_list );
385 : } else {
386 : calculateAllVessels( indexOfTaskInFullList[i], myvals, bvals, buffer, der_list );
387 : }
388 :
389 : // Clear the value
390 : myvals.clearAll();
391 : }
392 : #pragma omp critical
393 : if(nt>1)
394 : for(unsigned i=0; i<bufsize; ++i) {
395 : buffer[i]+=omp_buffer[i];
396 : }
397 : }
398 20074 : if(timers) {
399 0 : stopwatch.stop("2 Loop over tasks");
400 : }
401 : // Turn back on derivative calculation
402 20074 : dertime=true;
403 :
404 20074 : if(timers) {
405 0 : stopwatch.start("3 MPI gather");
406 : }
407 : // MPI Gather everything
408 20074 : if( !serial && buffer.size()>0 ) {
409 20074 : comm.Sum( buffer );
410 : }
411 : // MPI Gather index stores
412 20074 : if( mydata && !lowmem && !noderiv ) {
413 690 : comm.Sum( der_list );
414 690 : mydata->setActiveValsAndDerivatives( der_list );
415 : }
416 : // Update the elements that are makign contributions to the sum here
417 : // this causes problems if we do it in prepare
418 20074 : if(timers) {
419 0 : stopwatch.stop("3 MPI gather");
420 : }
421 :
422 20074 : if(timers) {
423 0 : stopwatch.start("4 Finishing computations");
424 : }
425 20074 : finishComputations( buffer );
426 20074 : if(timers) {
427 0 : stopwatch.stop("4 Finishing computations");
428 : }
429 20074 : }
430 :
431 0 : void ActionWithVessel::transformBridgedDerivatives( const unsigned& current, MultiValue& invals, MultiValue& outvals ) const {
432 0 : plumed_error();
433 : }
434 :
435 515746 : void ActionWithVessel::calculateAllVessels( const unsigned& taskCode, MultiValue& myvals, MultiValue& bvals, std::vector<double>& buffer, std::vector<unsigned>& der_list ) {
436 1176200 : for(unsigned j=0; j<functions.size(); ++j) {
437 : // Calculate returns a bool that tells us if this particular
438 : // quantity is contributing more than the tolerance
439 660454 : functions[j]->calculate( taskCode, functions[j]->transformDerivatives(taskCode, myvals, bvals), buffer, der_list );
440 660454 : if( !actionIsBridged ) {
441 584733 : bvals.clearAll();
442 : }
443 : }
444 515746 : return;
445 : }
446 :
447 23099 : void ActionWithVessel::finishComputations( const std::vector<double>& buffer ) {
448 : // Set the final value of the function
449 58476 : for(unsigned j=0; j<functions.size(); ++j) {
450 35377 : functions[j]->finish( buffer );
451 : }
452 23099 : }
453 :
454 7330 : bool ActionWithVessel::getForcesFromVessels( std::vector<double>& forcesToApply ) {
455 : #ifndef NDEBUG
456 : if( forcesToApply.size()>0 ) {
457 : plumed_dbg_assert( forcesToApply.size()==getNumberOfDerivatives() );
458 : }
459 : #endif
460 7330 : if(tmpforces.size()!=forcesToApply.size() ) {
461 224 : tmpforces.resize( forcesToApply.size() );
462 : }
463 :
464 7330 : forcesToApply.assign( forcesToApply.size(),0.0 );
465 : bool wasforced=false;
466 21326 : for(unsigned i=0; i<getNumberOfVessels(); ++i) {
467 13996 : if( (functions[i]->applyForce( tmpforces )) ) {
468 : wasforced=true;
469 758615 : for(unsigned j=0; j<forcesToApply.size(); ++j) {
470 758228 : forcesToApply[j]+=tmpforces[j];
471 : }
472 : }
473 : }
474 7330 : return wasforced;
475 : }
476 :
477 0 : void ActionWithVessel::retrieveDomain( std::string& min, std::string& max ) {
478 0 : plumed_merror("If your function is periodic you need to add a retrieveDomain function so that ActionWithVessel can retrieve the domain");
479 : }
480 :
481 0 : Vessel* ActionWithVessel::getVesselWithName( const std::string& mynam ) {
482 : int target=-1;
483 0 : for(unsigned i=0; i<functions.size(); ++i) {
484 0 : if( functions[i]->getName().find(mynam)!=std::string::npos ) {
485 0 : if( target<0 ) {
486 0 : target=i;
487 : } else {
488 0 : error("found more than one " + mynam + " object in action");
489 : }
490 : }
491 : }
492 0 : plumed_assert(target>=0);
493 0 : return functions[target].get();
494 : }
495 :
496 : }
497 : }
|