LCOV - code coverage report
Current view: top level - function - FunctionOfScalar.h (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 61 65 93.8 %
Date: 2026-03-30 11:13:23 Functions: 52 80 65.0 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2011-2023 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #ifndef __PLUMED_function_FunctionOfScalar_h
      23             : #define __PLUMED_function_FunctionOfScalar_h
      24             : 
      25             : #include "Function.h"
      26             : #include "tools/Matrix.h"
      27             : 
      28             : namespace PLMD {
      29             : namespace function {
      30             : 
      31             : /**
      32             : \ingroup INHERIT
      33             : This is the abstract base class to use for implementing new CV function, within it there is
      34             : \ref AddingAFunction "information" as to how to go about implementing a new function.
      35             : */
      36             : 
      37             : template <class T>
      38             : class FunctionOfScalar : public Function {
      39             : private:
      40             : /// The function that is being computed
      41             :   T myfunc;
      42             : /// Are we on the first step
      43             :   bool firststep;
      44             : public:
      45             :   explicit FunctionOfScalar(const ActionOptions&);
      46        2878 :   virtual ~FunctionOfScalar() {}
      47             : /// Get the label to write in the graph
      48           3 :   std::string writeInGraph() const override {
      49           3 :     return myfunc.getGraphInfo( getName() );
      50             :   }
      51             :   std::string getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const override ;
      52             :   void calculate() override;
      53             :   static void registerKeywords(Keywords&);
      54             :   void turnOnDerivatives() override;
      55             : };
      56             : 
      57             : template <class T>
      58        2936 : void FunctionOfScalar<T>::registerKeywords(Keywords& keys) {
      59        2936 :   Function::registerKeywords(keys);
      60        2936 :   keys.use("ARG");
      61        2936 :   std::string name = keys.getDisplayName();
      62        2936 :   std::size_t und=name.find("_SCALAR");
      63        2936 :   keys.setDisplayName( name.substr(0,und) );
      64        5872 :   keys.add("hidden","NO_ACTION_LOG","suppresses printing from action on the log");
      65        2382 :   T tfunc;
      66        2936 :   tfunc.registerKeywords( keys );
      67        5872 :   if( keys.getDisplayName()=="SUM" ) {
      68           8 :     keys.setValueDescription("the sum of all the input arguments");
      69        5864 :   } else if( keys.getDisplayName()=="MEAN" ) {
      70           8 :     keys.setValueDescription("the mean of all the input arguments");
      71             :   }
      72        5803 : }
      73             : 
      74             : template <class T>
      75        1444 : FunctionOfScalar<T>::FunctionOfScalar(const ActionOptions&ao):
      76             :   Action(ao),
      77             :   Function(ao),
      78        1444 :   firststep(true) {
      79        1444 :   myfunc.read( this );
      80             :   // Get the names of the components
      81        1439 :   std::vector<std::string> components( keywords.getOutputComponents() );
      82             :   // Create the values to hold the output
      83        1396 :   std::vector<std::string> str_ind( myfunc.getComponentsPerLabel() );
      84        2878 :   for(unsigned i=0; i<components.size(); ++i) {
      85          13 :     if( str_ind.size()>0 ) {
      86          13 :       std::string compstr = components[i];
      87          13 :       if( compstr==".#!value" ) {
      88             :         compstr = "";
      89             :       }
      90          40 :       for(unsigned j=0; j<str_ind.size(); ++j) {
      91          54 :         addComponentWithDerivatives( compstr + str_ind[j] );
      92             :       }
      93        1426 :     } else if( components[i]==".#!value" ) {
      94        1424 :       addValueWithDerivatives();
      95           2 :     } else if( components[i].find_first_of("_")!=std::string::npos ) {
      96           2 :       if( getNumberOfArguments()==1 ) {
      97           1 :         addValueWithDerivatives();
      98             :       } else {
      99           3 :         for(unsigned j=0; j<getNumberOfArguments(); ++j) {
     100           4 :           addComponentWithDerivatives( getPntrToArgument(j)->getName() + components[i] );
     101             :         }
     102             :       }
     103             :     } else {
     104           0 :       addComponentWithDerivatives( components[i] );
     105             :     }
     106             :   }
     107             :   // Set the periodicities of the output components
     108        1439 :   myfunc.setPeriodicityForOutputs( this );
     109           0 :   myfunc.setPrefactor( this, 1.0 );
     110        1452 : }
     111             : 
     112             : template <class T>
     113           3 : std::string FunctionOfScalar<T>::getOutputComponentDescription( const std::string& cname, const Keywords& keys ) const {
     114           3 :   if( getName().find("SORT")==std::string::npos ) {
     115           0 :     return ActionWithValue::getOutputComponentDescription( cname, keys );
     116             :   }
     117           6 :   return "the " + cname + "th largest of the input scalars";
     118             : }
     119             : 
     120             : template <class T>
     121        2192 : void FunctionOfScalar<T>::turnOnDerivatives() {
     122             :   if( !myfunc.derivativesImplemented() ) {
     123           0 :     error("derivatives have not been implemended for " + getName() );
     124             :   }
     125        2192 :   ActionWithValue::turnOnDerivatives();
     126        2192 : }
     127             : 
     128             : template <class T>
     129       92873 : void FunctionOfScalar<T>::calculate() {
     130       92873 :   if( firststep ) {
     131        1390 :     myfunc.setup( this );
     132        1390 :     firststep=false;
     133             :   }
     134        1675 :   unsigned argstart = myfunc.getArgStart();
     135       92873 :   std::vector<double> args( getNumberOfArguments() - argstart );
     136      210504 :   for(unsigned i=argstart; i<getNumberOfArguments(); ++i) {
     137      117631 :     args[i-argstart]=getPntrToArgument(i)->get();
     138             :   }
     139       92873 :   std::vector<double> vals( getNumberOfComponents() );
     140       92873 :   Matrix<double> derivatives( getNumberOfComponents(), args.size() );
     141       92873 :   myfunc.calc( this, args, vals, derivatives );
     142      185783 :   for(unsigned i=0; i<vals.size(); ++i) {
     143       92910 :     copyOutput(i)->set(vals[i]);
     144             :   }
     145       92873 :   if( doNotCalculateDerivatives() ) {
     146             :     return;
     147             :   }
     148             : 
     149      160614 :   for(unsigned i=0; i<vals.size(); ++i) {
     150       80322 :     Value* val = getPntrToComponent(i);
     151      177007 :     for(unsigned j=0; j<args.size(); ++j) {
     152       96685 :       setDerivative( val, j, derivatives(i,j) );
     153             :     }
     154             :   }
     155             : }
     156             : 
     157             : }
     158             : }
     159             : #endif

Generated by: LCOV version 1.16