LCOV - code coverage report
Current view: top level - dimred - SketchMap.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 117 123 95.1 %
Date: 2025-12-04 11:19:34 Functions: 2 3 66.7 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             :    Copyright (c) 2015-2020 The plumed team
       3             :    (see the PEOPLE file at the root of the distribution for a list of names)
       4             : 
       5             :    See http://www.plumed.org for more information.
       6             : 
       7             :    This file is part of plumed, version 2.
       8             : 
       9             :    plumed is free software: you can redistribute it and/or modify
      10             :    it under the terms of the GNU Lesser General Public License as published by
      11             :    the Free Software Foundation, either version 3 of the License, or
      12             :    (at your option) any later version.
      13             : 
      14             :    plumed is distributed in the hope that it will be useful,
      15             :    but WITHOUT ANY WARRANTY; without even the implied warranty of
      16             :    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      17             :    GNU Lesser General Public License for more details.
      18             : 
      19             :    You should have received a copy of the GNU Lesser General Public License
      20             :    along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      21             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      22             : #include "core/ActionShortcut.h"
      23             : #include "core/PlumedMain.h"
      24             : #include "core/ActionSet.h"
      25             : #include "core/ActionRegister.h"
      26             : #include "core/ActionWithValue.h"
      27             : 
      28             : //+PLUMEDOC DIMRED SKETCHMAP
      29             : /*
      30             : Construct a sketch map projection of the input data
      31             : 
      32             : This shortcut allows you to construct a sketch-map projection from a trajectory. The sketch-map algorithm
      33             : is introduced and examples of how it is used are given in the papers cited below.
      34             : 
      35             : The following input illustrates how to run a sketch-map calculation with PLUMED:
      36             : 
      37             : ```plumed
      38             : d1: DISTANCE ATOMS=1,2
      39             : d2: DISTANCE ATOMS=3,4
      40             : d3: DISTANCE ATOMS=5,6
      41             : 
      42             : ff: COLLECT_FRAMES STRIDE=1 ARG=d1,d2,d3
      43             : lwe: CUSTOM ARG=ff_logweights FUNC=exp(x) PERIODIC=NO
      44             : smap: SKETCHMAP ...
      45             :    ARG=ff NLOW_DIM=2
      46             :    HIGH_DIM_FUNCTION={SMAP R_0=2 A=3 B=9}
      47             :    LOW_DIM_FUNCTION={SMAP R_0=2 A=2 B=2}
      48             :    WEIGHTS=lwe NCYCLES=3 FGRID_SIZE=50,50
      49             : ...
      50             : 
      51             : DUMPVECTOR ARG=smap.*,lwe FILE=smap
      52             : ```
      53             : 
      54             : The sketch-map projection is constructed from all the data in the input trajectory here and is output at the end
      55             : of the simulation.  Dissimilarities between the trajectory frames are calculated based on how much the three
      56             : distances that are used in the input to the [COLLECT_FRAMES](COLLECT_FRAMES.md) object here have changed.  However,
      57             : if you want to use RMSD distances to compute these dissimilarities instead you can use an input like the one shown below:
      58             : 
      59             : ```plumed
      60             : ff: COLLECT_FRAMES STRIDE=1 ATOMS=1-256
      61             : lwe: CUSTOM ARG=ff_logweights FUNC=exp(x) PERIODIC=NO
      62             : smap: SKETCHMAP ...
      63             :    ARG=ff NLOW_DIM=2
      64             :    HIGH_DIM_FUNCTION={SMAP R_0=2 A=3 B=9}
      65             :    LOW_DIM_FUNCTION={SMAP R_0=2 A=2 B=2}
      66             :    WEIGHTS=lwe NCYCLES=3 FGRID_SIZE=50,50
      67             : ...
      68             : 
      69             : DUMPVECTOR ARG=smap.*,lwe FILE=smap
      70             : ```
      71             : 
      72             : Information on how we optimise the sketch-map stress function can be found by expanding the SKETCHMAP shortcut in the above
      73             : input and reading the documentation for [ARRANGE_POINTS](ARRANGE_POINTS.md).
      74             : 
      75             : ## Using landmarks
      76             : 
      77             : Optimising the sketch-map stress function is computationally expensive and so the usual practise with this method is to
      78             : pick a subset of [landmark](module_landmarks.md) points.  Projections for these points are found by optimising the sketch-map
      79             : stress function using [ARRANGE_POINTS](ARRANGE_POINTS.md).  Projections for non-landmark points are then found by using
      80             : [PROJECT_POINTS](PROJECT_POINTS.md).  The following example input illustrates how you can perform such a calculation with PLUMED
      81             : 
      82             : ```plumed
      83             : d1: DISTANCE ATOMS=1,2
      84             : d2: DISTANCE ATOMS=3,4
      85             : d3: DISTANCE ATOMS=5,6
      86             : 
      87             : ff: COLLECT_FRAMES ARG=d1,d2,d3
      88             : ff_dataT: TRANSPOSE ARG=ff_data
      89             : ll: LANDMARK_SELECT_STRIDE ARG=ff NLANDMARKS=250
      90             : 
      91             : # Calculate the weights
      92             : voro: VORONOI ARG=ll_rectdissims
      93             : weights: CUSTOM ARG=ff.logweights FUNC=exp(x) PERIODIC=NO
      94             : weightsT: TRANSPOSE ARG=weights
      95             : lweT: MATRIX_PRODUCT ARG=weightsT,voro
      96             : lwe: TRANSPOSE ARG=lweT
      97             : 
      98             : smap: SKETCHMAP ...
      99             :   ARG=ll NLOW_DIM=2 PROJECT_ALL
     100             :   HIGH_DIM_FUNCTION={SMAP R_0=4 A=3 B=2}
     101             :   LOW_DIM_FUNCTION={SMAP R_0=4 A=1 B=2}
     102             : ...
     103             : 
     104             : DUMPVECTOR ARG=smap,lwe FILE=smap
     105             : DUMPVECTOR ARG=smap_osample,weights FILE=projections
     106             : ```
     107             : 
     108             : ## Using SMACOF
     109             : 
     110             : By default we PLUMED uses the method described in the paper cited below to optimise the sketch-map stress function.
     111             : In other words, we use a combination of conjugate gradients and a pointwise global optimisation algorithm.  Within
     112             : the code there is also an experimental implementation of optimisation using a variant on the [smacof](https://en.wikipedia.org/wiki/Stress_majorization)
     113             : algorithm.  If you would like to experiment with this option you use the USE_SMACOF flag as illustrated below:
     114             : 
     115             : ```plumed
     116             : d1: DISTANCE ATOMS=1,2
     117             : d2: DISTANCE ATOMS=3,4
     118             : d3: DISTANCE ATOMS=5,6
     119             : 
     120             : ff: COLLECT_FRAMES STRIDE=1 ARG=d1,d2,d3
     121             : lwe: CUSTOM ARG=ff_logweights FUNC=exp(x) PERIODIC=NO
     122             : smap: SKETCHMAP ...
     123             :    ARG=ff NLOW_DIM=2 USE_SMACOF
     124             :    HIGH_DIM_FUNCTION={SMAP R_0=2 A=3 B=9}
     125             :    LOW_DIM_FUNCTION={SMAP R_0=2 A=2 B=2}
     126             :    WEIGHTS=lwe
     127             : ...
     128             : 
     129             : DUMPVECTOR ARG=smap.*,lwe FILE=smap
     130             : ```
     131             : 
     132             : 
     133             : */
     134             : //+ENDPLUMEDOC
     135             : 
     136             : namespace PLMD {
     137             : namespace dimred {
     138             : 
     139             : class SketchMap : public ActionShortcut {
     140             : public:
     141             :   static void registerKeywords( Keywords& keys );
     142             :   explicit SketchMap( const ActionOptions& ao );
     143             : };
     144             : 
     145             : PLUMED_REGISTER_ACTION(SketchMap,"SKETCHMAP")
     146             : 
     147           9 : void SketchMap::registerKeywords( Keywords& keys ) {
     148           9 :   ActionShortcut::registerKeywords( keys );
     149           9 :   keys.add("compulsory","NLOW_DIM","number of low-dimensional coordinates required");
     150           9 :   keys.add("optional","WEIGHTS","a vector containing the weights of the points");
     151           9 :   keys.add("compulsory","ARG","the matrix of high dimensional coordinates that you want to project in the low dimensional space");
     152           9 :   keys.add("compulsory","HIGH_DIM_FUNCTION","the parameters of the switching function in the high dimensional space");
     153           9 :   keys.add("compulsory","LOW_DIM_FUNCTION","the parameters of the switching function in the low dimensional space");
     154           9 :   keys.add("compulsory","CGTOL","1E-6","The tolerance for the conjugate gradient minimization that finds the projection of the landmarks");
     155           9 :   keys.add("compulsory","MAXITER","1000","maximum number of optimization cycles for optimisation algorithms");
     156           9 :   keys.add("compulsory","NCYCLES","0","The number of cycles of pointwise global optimisation that are required");
     157           9 :   keys.add("compulsory","BUFFER","1.1","grid extent for search is (max projection - minimum projection) multiplied by this value");
     158           9 :   keys.add("compulsory","CGRID_SIZE","10","number of points to use in each grid direction");
     159           9 :   keys.add("compulsory","FGRID_SIZE","0","interpolate the grid onto this number of points -- only works in 2D");
     160           9 :   keys.addFlag("PROJECT_ALL",false,"if the input are landmark coordinates then project the out of sample configurations");
     161           9 :   keys.add("compulsory","OS_CGTOL","1E-6","The tolerance for the conjugate gradient minimization that finds the out of sample projections");
     162           9 :   keys.addFlag("USE_SMACOF",false,"find the projection in the low dimensional space using the SMACOF algorithm");
     163           9 :   keys.add("compulsory","SMACTOL","1E-4","the tolerance for the smacof algorithm");
     164           9 :   keys.add("compulsory","SMACREG","0.001","this is used to ensure that we don't divide by zero when updating weights for SMACOF algorithm");
     165          18 :   keys.setValueDescription("matrix","the sketch-map projection of the input points");
     166          18 :   keys.addOutputComponent("osample","PROJECT_ALL","matrix","the out-of-sample projections");
     167           9 :   keys.addDOI("10.1073/pnas.1108486108");
     168           9 :   keys.addDOI("10.1073/pnas.1201152109");
     169           9 :   keys.addDOI("10.1021/ct3010563");
     170           9 :   keys.addDOI("10.1021/ct500950z");
     171           9 :   keys.addDOI("10.1021/acs.jctc.5b00714");
     172           9 :   keys.needsAction("CLASSICAL_MDS");
     173           9 :   keys.needsAction("MORE_THAN");
     174           9 :   keys.needsAction("SUM");
     175           9 :   keys.needsAction("CUSTOM");
     176           9 :   keys.needsAction("OUTER_PRODUCT");
     177           9 :   keys.needsAction("ARRANGE_POINTS");
     178           9 :   keys.needsAction("PROJECT_POINTS");
     179           9 :   keys.needsAction("VSTACK");
     180           9 : }
     181             : 
     182           4 : SketchMap::SketchMap( const ActionOptions& ao):
     183             :   Action(ao),
     184           4 :   ActionShortcut(ao) {
     185             :   // Get the high dimensioal data
     186             :   std::string argn;
     187           4 :   parse("ARG",argn);
     188           4 :   std::string dissimilarities = getShortcutLabel() + "_mds_mat";
     189           4 :   ActionShortcut* as = plumed.getActionSet().getShortcutActionWithLabel( argn );
     190           4 :   if( !as ) {
     191           0 :     error("found no action with name " + argn );
     192             :   }
     193           4 :   if( as->getName()!="COLLECT_FRAMES" ) {
     194           1 :     if( as->getName().find("LANDMARK_SELECT")==std::string::npos ) {
     195           0 :       error("found no COLLECT_FRAMES or LANDMARK_SELECT action with label " + argn );
     196             :     } else {
     197           1 :       ActionWithValue* dissims = plumed.getActionSet().selectWithLabel<ActionWithValue*>( argn + "_sqrdissims");
     198           1 :       if( dissims ) {
     199           2 :         dissimilarities = argn + "_sqrdissims";
     200             :       }
     201             :     }
     202             :   }
     203             :   unsigned ndim;
     204           8 :   parse("NLOW_DIM",ndim);
     205             :   std::string str_ndim;
     206           4 :   Tools::convert( ndim, str_ndim );
     207             :   // Construct a projection using classical MDS
     208           8 :   readInputLine( getShortcutLabel() + "_mds: CLASSICAL_MDS ARG=" + argn + " NLOW_DIM=" + str_ndim );
     209             :   // Transform the dissimilarities using the switching function
     210             :   std::string hdfunc;
     211           4 :   parse("HIGH_DIM_FUNCTION",hdfunc);
     212           8 :   readInputLine( getShortcutLabel() + "_hdmat: MORE_THAN ARG=" + dissimilarities + " SQUARED SWITCH={" + hdfunc + "}");
     213             :   // Now for the weights - read the vector of weights first
     214             :   std::string wvec;
     215           8 :   parse("WEIGHTS",wvec);
     216           4 :   if( wvec.length()==0 ) {
     217           2 :     wvec = argn + "_weights";
     218             :   }
     219             :   // Now calculate the sum of thse weights
     220           8 :   readInputLine( wvec + "_sum: SUM ARG=" + wvec + " PERIODIC=NO");
     221             :   // And normalise the vector of weights using this sum
     222           8 :   readInputLine( wvec + "_normed: CUSTOM ARG=" + wvec + "," + wvec + "_sum FUNC=x/y PERIODIC=NO");
     223             :   // And now create the matrix of weights
     224           8 :   readInputLine( wvec + "_mat: OUTER_PRODUCT ARG=" + wvec + "_normed," + wvec + "_normed");
     225             :   // Run the arrange points object
     226             :   std::string ldfunc, cgtol, maxiter;
     227           4 :   parse("LOW_DIM_FUNCTION",ldfunc);
     228           4 :   parse("CGTOL",cgtol);
     229           4 :   parse("MAXITER",maxiter);
     230             :   unsigned ncycles;
     231           8 :   parse("NCYCLES",ncycles);
     232           4 :   std::string num, argstr, lname=getShortcutLabel() + "_ap";
     233           4 :   if( ncycles>0 ) {
     234           2 :     lname = getShortcutLabel() + "_cg";
     235             :   }
     236           8 :   argstr = "ARG=" + getShortcutLabel() + "_mds-1";
     237           8 :   for(unsigned i=1; i<ndim; ++i) {
     238           4 :     Tools::convert( i+1, num );
     239           8 :     argstr += "," + getShortcutLabel() + "_mds-" + num;
     240             :   }
     241             :   bool usesmacof;
     242           4 :   parseFlag("USE_SMACOF",usesmacof);
     243           4 :   if( usesmacof ) {
     244             :     std::string smactol, smacreg;
     245           1 :     parse("SMACTOL",smactol);
     246           1 :     parse("SMACREG",smacreg);
     247           3 :     readInputLine( lname + ": ARRANGE_POINTS " + argstr  + " MINTYPE=smacof TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat" +
     248           3 :                    " MAXITER=" + maxiter + " SMACTOL=" + smactol + " SMACREG=" + smacreg + " TARGET2=" + getShortcutLabel() + "_mds_mat WEIGHTS2=" + wvec + "_mat");
     249             :   } else {
     250           6 :     readInputLine( lname + ": ARRANGE_POINTS " + argstr  + " MINTYPE=conjgrad TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat CGTOL=" + cgtol);
     251           3 :     if( ncycles>0 ) {
     252             :       std::string buf;
     253           2 :       parse("BUFFER",buf);
     254             :       std::vector<std::string> fgrid;
     255           2 :       parseVector("FGRID_SIZE",fgrid);
     256             :       std::string ncyc;
     257           1 :       Tools::convert(ncycles,ncyc);
     258           2 :       std::string pwise_args=" NCYCLES=" + ncyc + " BUFFER=" + buf;
     259           1 :       if( fgrid.size()>0 ) {
     260           1 :         if( fgrid.size()!=ndim ) {
     261           0 :           error("number of elements of fgrid is not correct");
     262             :         }
     263           1 :         pwise_args += " FGRID_SIZE=" + fgrid[0];
     264           2 :         for(unsigned i=1; i<fgrid.size(); ++i) {
     265           2 :           pwise_args += "," + fgrid[i];
     266             :         }
     267             :       }
     268           1 :       std::vector<std::string> cgrid(ndim);
     269           2 :       parseVector("CGRID_SIZE",cgrid);
     270           1 :       pwise_args += " CGRID_SIZE=" + cgrid[0];
     271           2 :       for(unsigned i=1; i<cgrid.size(); ++i) {
     272           2 :         pwise_args += "," + cgrid[i];
     273             :       }
     274           2 :       argstr="ARG=" + getShortcutLabel() + "_cg.coord-1";
     275           2 :       for(unsigned i=1; i<ndim; ++i) {
     276           1 :         Tools::convert( i+1, num );
     277           2 :         argstr += "," + getShortcutLabel() + "_cg.coord-" + num;
     278             :       }
     279           2 :       readInputLine( getShortcutLabel() + "_ap: ARRANGE_POINTS " + argstr  + pwise_args + " MINTYPE=pointwise TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_mat CGTOL=" + cgtol);
     280           2 :     }
     281             :   }
     282           8 :   argstr="ARG=" + getShortcutLabel() + "_ap.coord-1";
     283           8 :   for(unsigned i=1; i<ndim; ++i) {
     284           4 :     Tools::convert( i+1, num );
     285           8 :     argstr += "," + getShortcutLabel() + "_ap.coord-" + num;
     286             :   }
     287           8 :   readInputLine( getShortcutLabel() + ": VSTACK " + argstr );
     288             :   bool projall;
     289           4 :   parseFlag("PROJECT_ALL",projall);
     290           4 :   if( !projall ) {
     291             :     return ;
     292             :   }
     293           1 :   parse("OS_CGTOL",cgtol);
     294           1 :   argstr = getShortcutLabel() + "_ap.coord-1";
     295           2 :   for(unsigned i=1; i<ndim; ++i) {
     296           1 :     Tools::convert( i+1, num );
     297           2 :     argstr += "," + getShortcutLabel() + "_ap.coord-" + num;
     298             :   }
     299           1 :   if( as->getName().find("LANDMARK_SELECT")==std::string::npos ) {
     300           0 :     readInputLine( getShortcutLabel() + "_osample_pp: PROJECT_POINTS " + argstr + " TARGET1=" + getShortcutLabel() + "_hdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_normed CGTOL=" + cgtol );
     301             :   } else {
     302           1 :     ActionWithValue* dissims = plumed.getActionSet().selectWithLabel<ActionWithValue*>( argn + "_rectdissims");
     303           1 :     if( !dissims ) {
     304           0 :       error("cannot PROJECT_ALL as " + as->getName() + " with label " + argn + " was involved without the DISSIMILARITIES keyword");
     305             :     }
     306           2 :     readInputLine( getShortcutLabel() + "_lhdmat: MORE_THAN ARG=" + argn + "_rectdissims SQUARED SWITCH={" + hdfunc + "}");
     307           2 :     readInputLine( getShortcutLabel() + "_osample_pp: PROJECT_POINTS ARG=" + argstr + " TARGET1=" + getShortcutLabel() + "_lhdmat FUNC1={" + ldfunc + "} WEIGHTS1=" + wvec + "_normed CGTOL=" + cgtol );
     308             :   }
     309           2 :   argstr="ARG=" + getShortcutLabel() + "_osample_pp.coord-1";
     310           2 :   for(unsigned i=1; i<ndim; ++i) {
     311           1 :     Tools::convert( i+1, num );
     312           2 :     argstr += "," + getShortcutLabel() + "_osample_pp.coord-" + num;
     313             :   }
     314           2 :   readInputLine( getShortcutLabel() + "_osample: VSTACK " + argstr );
     315           0 : }
     316             : 
     317             : }
     318             : }

Generated by: LCOV version 1.16