LCOV - code coverage report
Current view: top level - valtools - Concatenate.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 113 134 84.3 %
Date: 2025-11-25 13:55:50 Functions: 5 6 83.3 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2014-2017 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "core/ActionWithValue.h"
      23             : #include "core/ActionWithArguments.h"
      24             : #include "core/ActionRegister.h"
      25             : 
      26             : //+PLUMEDOC MCOLVAR CONCATENATE
      27             : /*
      28             : Join vectors or matrices together
      29             : 
      30             : \par Examples
      31             : 
      32             : */
      33             : //+ENDPLUMEDOC
      34             : 
      35             : namespace PLMD {
      36             : namespace valtools {
      37             : 
      38             : class Concatenate :
      39             :   public ActionWithValue,
      40             :   public ActionWithArguments {
      41             : private:
      42             :   bool vectors;
      43             :   std::vector<unsigned> row_starts;
      44             :   std::vector<unsigned> col_starts;
      45             : public:
      46             :   static void registerKeywords( Keywords& keys );
      47             : /// Constructor
      48             :   explicit Concatenate(const ActionOptions&);
      49             : /// Get the number of derivatives
      50         257 :   unsigned getNumberOfDerivatives() override {
      51         257 :     return 0;
      52             :   }
      53             : /// Do the calculation
      54             :   void calculate() override;
      55             : ///
      56             :   void apply();
      57             : };
      58             : 
      59             : PLUMED_REGISTER_ACTION(Concatenate,"CONCATENATE")
      60             : 
      61         353 : void Concatenate::registerKeywords( Keywords& keys ) {
      62         353 :   Action::registerKeywords( keys );
      63         353 :   ActionWithValue::registerKeywords( keys );
      64         353 :   ActionWithArguments::registerKeywords( keys );
      65         353 :   keys.use("ARG");
      66         706 :   keys.add("numbered","MATRIX","specify the matrices that you wish to join together into a single matrix");
      67         706 :   keys.reset_style("MATRIX","compulsory");
      68         353 :   keys.setValueDescription("the concatenated vector/matrix that was constructed from the input values");
      69         353 : }
      70             : 
      71         176 : Concatenate::Concatenate(const ActionOptions& ao):
      72             :   Action(ao),
      73             :   ActionWithValue(ao),
      74         176 :   ActionWithArguments(ao) {
      75         176 :   if( getNumberOfArguments()>0 ) {
      76         172 :     vectors=true;
      77         172 :     std::vector<unsigned> shape(1);
      78         172 :     shape[0]=0;
      79         547 :     for(unsigned i=0; i<getNumberOfArguments(); ++i) {
      80         375 :       if( getPntrToArgument(i)->getRank()>1 ) {
      81           0 :         error("cannot concatenate matrix with vectors");
      82             :       }
      83         375 :       getPntrToArgument(i)->buildDataStore();
      84         375 :       shape[0] += getPntrToArgument(i)->getNumberOfValues();
      85             :     }
      86         172 :     log.printf("  creating vector with %d elements \n", shape[0] );
      87         172 :     addValue( shape );
      88         172 :     bool period=getPntrToArgument(0)->isPeriodic();
      89             :     std::string min, max;
      90         172 :     if( period ) {
      91           0 :       getPntrToArgument(0)->getDomain( min, max );
      92             :     }
      93         375 :     for(unsigned i=1; i<getNumberOfArguments(); ++i) {
      94         203 :       if( period!=getPntrToArgument(i)->isPeriodic() ) {
      95           0 :         error("periods of input arguments should match");
      96             :       }
      97         203 :       if( period ) {
      98             :         std::string min0, max0;
      99           0 :         getPntrToArgument(i)->getDomain( min0, max0 );
     100           0 :         if( min0!=min || max0!=max ) {
     101           0 :           error("domains of input arguments should match");
     102             :         }
     103             :       }
     104             :     }
     105         172 :     if( period ) {
     106           0 :       setPeriodic( min, max );
     107             :     } else {
     108         172 :       setNotPeriodic();
     109             :     }
     110         172 :     getPntrToComponent(0)->buildDataStore();
     111         172 :     if( getPntrToComponent(0)->getRank()==2 ) {
     112           0 :       getPntrToComponent(0)->reshapeMatrixStore( shape[1] );
     113             :     }
     114             :   } else {
     115             :     unsigned nrows=0, ncols=0;
     116             :     std::vector<Value*> arglist;
     117           4 :     vectors=false;
     118           7 :     for(unsigned i=1;; i++) {
     119             :       unsigned nt_cols=0;
     120             :       unsigned size_b4 = arglist.size();
     121          14 :       for(unsigned j=1;; j++) {
     122          25 :         if( j==10 ) {
     123           0 :           error("cannot combine more than 9 matrices");
     124             :         }
     125             :         std::vector<Value*> argn;
     126          50 :         parseArgumentList("MATRIX", i*10+j, argn);
     127          25 :         if( argn.size()==0 ) {
     128             :           break;
     129             :         }
     130          14 :         if( argn.size()>1 ) {
     131           0 :           error("should only be one argument to each matrix keyword");
     132             :         }
     133          14 :         if( argn[0]->getRank()!=0 && argn[0]->getRank()!=2 ) {
     134           0 :           error("input arguments for this action should be matrices");
     135             :         }
     136          14 :         argn[0]->buildDataStore();
     137          14 :         arglist.push_back( argn[0] );
     138          14 :         nt_cols++;
     139          14 :         if( argn[0]->getRank()==0 ) {
     140           0 :           log.printf("  %d %d component of composed matrix is scalar labelled %s\n", i, j, argn[0]->getName().c_str() );
     141             :         } else {
     142          14 :           log.printf("  %d %d component of composed matrix is %d by %d matrix labelled %s\n", i, j, argn[0]->getShape()[0], argn[0]->getShape()[1], argn[0]->getName().c_str() );
     143             :         }
     144          14 :       }
     145          11 :       if( arglist.size()==size_b4 ) {
     146             :         break;
     147             :       }
     148           7 :       if( i==1 ) {
     149             :         ncols=nt_cols;
     150           3 :       } else if( nt_cols!=ncols ) {
     151           0 :         error("should be joining same number of matrices in each row");
     152             :       }
     153           7 :       nrows++;
     154           7 :     }
     155             : 
     156           4 :     std::vector<unsigned> shape(2);
     157           4 :     shape[0]=0;
     158             :     unsigned k=0;
     159           4 :     row_starts.resize( arglist.size() );
     160           4 :     col_starts.resize( arglist.size() );
     161          11 :     for(unsigned i=0; i<nrows; ++i) {
     162             :       unsigned cstart = 0, nr = 1;
     163           7 :       if( arglist[k]->getRank()==2 ) {
     164           7 :         nr=arglist[k]->getShape()[0];
     165             :       }
     166          21 :       for(unsigned j=0; j<ncols; ++j) {
     167          14 :         if( arglist[k]->getRank()==0 ) {
     168           0 :           if( nr!=1 ) {
     169           0 :             error("mismatched matrix sizes");
     170             :           }
     171          14 :         } else if( nrows>1 && arglist[k]->getShape()[0]!=nr ) {
     172           0 :           error("mismatched matrix sizes");
     173             :         }
     174          14 :         row_starts[k] = shape[0];
     175          14 :         col_starts[k] = cstart;
     176          14 :         if( arglist[k]->getRank()==0 ) {
     177           0 :           cstart += 1;
     178             :         } else {
     179          14 :           cstart += arglist[k]->getShape()[1];
     180             :         }
     181          14 :         k++;
     182             :       }
     183           7 :       if( i==0 ) {
     184           4 :         shape[1]=cstart;
     185           3 :       } else if( cstart!=shape[1] ) {
     186           0 :         error("mismatched matrix sizes");
     187             :       }
     188           7 :       if( arglist[k-1]->getRank()==0 ) {
     189           0 :         shape[0] += 1;
     190             :       } else {
     191           7 :         shape[0] += arglist[k-1]->getShape()[0];
     192             :       }
     193             :     }
     194             :     // Now request the arguments to make sure we store things we need
     195           4 :     requestArguments(arglist);
     196           4 :     addValue( shape );
     197           4 :     setNotPeriodic();
     198           4 :     getPntrToComponent(0)->buildDataStore();
     199           4 :     if( getPntrToComponent(0)->getRank()==2 ) {
     200           4 :       getPntrToComponent(0)->reshapeMatrixStore( shape[1] );
     201             :     }
     202             :   }
     203         176 : }
     204             : 
     205       12191 : void Concatenate::calculate() {
     206       12191 :   Value* myval = getPntrToComponent(0);
     207       12191 :   if( vectors ) {
     208             :     unsigned k=0;
     209       61297 :     for(unsigned i=0; i<getNumberOfArguments(); ++i) {
     210             :       Value* myarg=getPntrToArgument(i);
     211       49158 :       unsigned nvals=myarg->getNumberOfValues();
     212      404266 :       for(unsigned j=0; j<nvals; ++j) {
     213      355108 :         myval->set( k, myarg->get(j) );
     214      355108 :         k++;
     215             :       }
     216             :     }
     217             :   } else {
     218             :     // Retrieve the matrix from input
     219          52 :     unsigned ncols = myval->getShape()[1];
     220         258 :     for(unsigned k=0; k<getNumberOfArguments(); ++k) {
     221             :       Value* argn = getPntrToArgument(k);
     222         206 :       if( argn->getRank()==0 ) {
     223           0 :         myval->set( ncols*row_starts[k]+col_starts[k], argn->get() );
     224             :       } else {
     225             :         std::vector<double> vals;
     226             :         std::vector<std::pair<unsigned,unsigned> > pairs;
     227             :         bool symmetric=getPntrToArgument(k)->isSymmetric();
     228         206 :         unsigned nedge=0;
     229         206 :         getPntrToArgument(k)->retrieveEdgeList( nedge, pairs, vals );
     230        8946 :         for(unsigned l=0; l<nedge; ++l ) {
     231        8740 :           unsigned i=pairs[l].first, j=pairs[l].second;
     232        8740 :           myval->set( ncols*(row_starts[k]+i)+col_starts[k]+j, vals[l] );
     233        8740 :           if( symmetric ) {
     234        2142 :             myval->set( ncols*(row_starts[k]+j)+col_starts[k]+i, vals[l] );
     235             :           }
     236             :         }
     237             :       }
     238             :     }
     239             :   }
     240       12191 : }
     241             : 
     242       12070 : void Concatenate::apply() {
     243       12070 :   if( doNotCalculateDerivatives() || !getPntrToComponent(0)->forcesWereAdded() ) {
     244        7037 :     return;
     245             :   }
     246             : 
     247        5033 :   Value* val=getPntrToComponent(0);
     248        5033 :   if( vectors ) {
     249             :     unsigned k=0;
     250       19923 :     for(unsigned i=0; i<getNumberOfArguments(); ++i) {
     251             :       Value* myarg=getPntrToArgument(i);
     252       14942 :       unsigned nvals=myarg->getNumberOfValues();
     253      205938 :       for(unsigned j=0; j<nvals; ++j) {
     254      190996 :         myarg->addForce( j, val->getForce(k) );
     255      190996 :         k++;
     256             :       }
     257             :     }
     258             :   } else {
     259          52 :     unsigned ncols=val->getShape()[1];
     260         258 :     for(unsigned k=0; k<getNumberOfArguments(); ++k) {
     261             :       Value* argn=getPntrToArgument(k);
     262         206 :       if( argn->getRank()==0 ) {
     263           0 :         argn->addForce( 0, val->getForce(ncols*row_starts[k]+col_starts[k]) );
     264             :       } else {
     265             :         unsigned val_ncols=val->getNumberOfColumns();
     266             :         unsigned arg_ncols=argn->getNumberOfColumns();
     267        1686 :         for(unsigned i=0; i<argn->getShape()[0]; ++i) {
     268             :           unsigned ncol = argn->getRowLength(i);
     269       13140 :           for(unsigned j=0; j<ncol; ++j) {
     270       11660 :             argn->addForce( i*arg_ncols+j, val->getForce( val_ncols*(row_starts[k]+i)+col_starts[k]+argn->getRowIndex(i,j) ), false );
     271             :           }
     272             :         }
     273             :       }
     274             :     }
     275             :   }
     276             : }
     277             : 
     278             : }
     279             : }

Generated by: LCOV version 1.16