Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2016-2018 The plumed team 3 : (see the PEOPLE file at the root of the distribution for a list of names) 4 : 5 : See http://www.plumed.org for more information. 6 : 7 : This file is part of plumed, version 2. 8 : 9 : plumed is free software: you can redistribute it and/or modify 10 : it under the terms of the GNU Lesser General Public License as published by 11 : the Free Software Foundation, either version 3 of the License, or 12 : (at your option) any later version. 13 : 14 : plumed is distributed in the hope that it will be useful, 15 : but WITHOUT ANY WARRANTY; without even the implied warranty of 16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 17 : GNU Lesser General Public License for more details. 18 : 19 : You should have received a copy of the GNU Lesser General Public License 20 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 22 : #include "core/ActionRegister.h" 23 : #include "core/PlumedMain.h" 24 : #include "core/ActionSet.h" 25 : #include "core/ActionShortcut.h" 26 : #include "core/ActionWithValue.h" 27 : 28 : //+PLUMEDOC FUNCTION MAHALANOBIS_DISTANCE 29 : /* 30 : Calculate the mahalanobis distance between two points in CV space 31 : 32 : \par Examples 33 : 34 : */ 35 : //+ENDPLUMEDOC 36 : 37 : namespace PLMD { 38 : namespace refdist { 39 : 40 : class MahalanobisDistance : public ActionShortcut { 41 : public: 42 : static void registerKeywords( Keywords& keys ); 43 : explicit MahalanobisDistance(const ActionOptions&ao); 44 : }; 45 : 46 : PLUMED_REGISTER_ACTION(MahalanobisDistance,"MAHALANOBIS_DISTANCE") 47 : 48 21 : void MahalanobisDistance::registerKeywords( Keywords& keys ) { 49 21 : ActionShortcut::registerKeywords(keys); 50 42 : keys.add("compulsory","ARG1","The point that we are calculating the distance from"); 51 42 : keys.add("compulsory","ARG2","The point that we are calculating the distance to"); 52 42 : keys.add("compulsory","METRIC","The inverse covariance matrix that should be used when calculating the distance"); 53 42 : keys.addFlag("SQUARED",false,"The squared distance should be calculated"); 54 42 : keys.addFlag("VON_MISSES",false,"Compute the mahalanobis distance in a way that is more sympathetic to the periodic boundary conditions"); 55 21 : keys.setValueDescription("the Mahalanobis distances between the input vectors"); 56 21 : keys.needsAction("DISPLACEMENT"); 57 21 : keys.needsAction("CUSTOM"); 58 21 : keys.needsAction("OUTER_PRODUCT"); 59 21 : keys.needsAction("TRANSPOSE"); 60 21 : keys.needsAction("MATRIX_PRODUCT_DIAGONAL"); 61 21 : keys.needsAction("CONSTANT"); 62 21 : keys.needsAction("MATRIX_VECTOR_PRODUCT"); 63 21 : keys.needsAction("MATRIX_PRODUCT"); 64 21 : keys.needsAction("COMBINE"); 65 21 : } 66 : 67 12 : MahalanobisDistance::MahalanobisDistance( const ActionOptions& ao): 68 : Action(ao), 69 12 : ActionShortcut(ao) { 70 : std::string arg1, arg2, metstr; 71 12 : parse("ARG1",arg1); 72 12 : parse("ARG2",arg2); 73 12 : parse("METRIC",metstr); 74 : // Check on input metric 75 12 : ActionWithValue* mav=plumed.getActionSet().selectWithLabel<ActionWithValue*>( metstr ); 76 12 : if( !mav ) { 77 0 : error("could not find action named " + metstr + " to use for metric"); 78 : } 79 12 : if( mav->copyOutput(0)->getRank()!=2 ) { 80 0 : error("metric has incorrect rank"); 81 : } 82 : 83 24 : readInputLine( getShortcutLabel() + "_diff: DISPLACEMENT ARG1=" + arg1 + " ARG2=" + arg2 ); 84 24 : readInputLine( getShortcutLabel() + "_diffT: TRANSPOSE ARG=" + getShortcutLabel() + "_diff"); 85 : bool von_miss, squared; 86 12 : parseFlag("VON_MISSES",von_miss); 87 12 : parseFlag("SQUARED",squared); 88 12 : if( von_miss ) { 89 7 : unsigned nrows = mav->copyOutput(0)->getShape()[0]; 90 7 : if( mav->copyOutput(0)->getShape()[1]!=nrows ) { 91 0 : error("metric is not symmetric"); 92 : } 93 : // Create a matrix that can be used to compute the off diagonal elements 94 : std::string valstr, nrstr; 95 7 : Tools::convert( mav->copyOutput(0)->get(0), valstr ); 96 7 : Tools::convert( nrows, nrstr ); 97 7 : std::string diagmet = getShortcutLabel() + "_diagmet: CONSTANT VALUES=" + valstr; 98 14 : std::string offdiagmet = getShortcutLabel() + "_offdiagmet: CONSTANT NROWS=" + nrstr + " NCOLS=" + nrstr + " VALUES=0"; 99 21 : for(unsigned i=0; i<nrows; ++i) { 100 42 : for(unsigned j=0; j<nrows; ++j) { 101 28 : Tools::convert( mav->copyOutput(0)->get(i*nrows+j), valstr ); 102 28 : if( i==j && i>0 ) { 103 : offdiagmet += ",0"; 104 14 : diagmet += "," + valstr; 105 21 : } else if( i!=j ) { 106 28 : offdiagmet += "," + valstr; 107 : } 108 : } 109 : } 110 7 : readInputLine( diagmet ); 111 7 : readInputLine( offdiagmet ); 112 : // Compute distances scaled by periods 113 7 : ActionWithValue* av=plumed.getActionSet().selectWithLabel<ActionWithValue*>( getShortcutLabel() + "_diff" ); 114 7 : plumed_assert( av ); 115 7 : if( !av->copyOutput(0)->isPeriodic() ) { 116 0 : error("VON_MISSES only works with periodic variables"); 117 : } 118 : std::string min, max; 119 7 : av->copyOutput(0)->getDomain(min,max); 120 14 : readInputLine( getShortcutLabel() + "_scaled: CUSTOM ARG=" + getShortcutLabel() + "_diffT FUNC=2*pi*x/(" + max +"-" + min + ") PERIODIC=NO"); 121 : // We start calculating off-diagonal elements by computing the sines of the scaled differences (this is a column vector) 122 14 : readInputLine( getShortcutLabel() + "_sinediffT: CUSTOM ARG=" + getShortcutLabel() + "_scaled FUNC=sin(x) PERIODIC=NO"); 123 : // Transpose sines to get a row vector 124 14 : readInputLine( getShortcutLabel() + "_sinediff: TRANSPOSE ARG=" + getShortcutLabel() + "_sinediffT"); 125 : // Compute the off diagonal elements 126 14 : readInputLine( getShortcutLabel() + "_matvec: MATRIX_PRODUCT ARG=" + getShortcutLabel() + "_offdiagmet," + getShortcutLabel() +"_sinediffT"); 127 14 : readInputLine( getShortcutLabel() + "_offdiag: MATRIX_PRODUCT_DIAGONAL ARG=" + getShortcutLabel() + "_sinediff," + getShortcutLabel() +"_matvec"); 128 : // Sort out the metric for the diagonal elements 129 7 : std::string metstr2 = getShortcutLabel() + "_diagmet"; 130 : // If this is a matrix we need create a matrix to multiply by 131 7 : if( av->copyOutput(0)->getShape()[0]>1 ) { 132 : // Create some ones 133 7 : std::string ones=" VALUES=1"; 134 21 : for(unsigned i=1; i<av->copyOutput(0)->getShape()[0]; ++i ) { 135 : ones += ",1"; 136 : } 137 14 : readInputLine( getShortcutLabel() + "_ones: CONSTANT " + ones ); 138 : // Now do some multiplication to create a matrix that can be multiplied by our "inverse variance" vector 139 14 : readInputLine( getShortcutLabel() + "_" + metstr + ": OUTER_PRODUCT ARG=" + metstr2 + "," + getShortcutLabel() + "_ones"); 140 14 : metstr2 = getShortcutLabel() + "_" + metstr; 141 : } 142 : // Compute the diagonal elements 143 14 : readInputLine( getShortcutLabel() + "_prod: CUSTOM ARG=" + getShortcutLabel() + "_scaled," + metstr2 + " FUNC=2*(1-cos(x))*y PERIODIC=NO"); 144 : std::string ncstr; 145 7 : Tools::convert( nrows, ncstr ); 146 7 : Tools::convert( av->copyOutput(0)->getShape()[0], nrstr ); 147 7 : std::string ones=" VALUES=1"; 148 42 : for(unsigned i=1; i<av->copyOutput(0)->getNumberOfValues(); ++i) { 149 : ones += ",1"; 150 : } 151 14 : readInputLine( getShortcutLabel() + "_matones: CONSTANT NROWS=" + nrstr + " NCOLS=" + ncstr + ones ); 152 14 : readInputLine( getShortcutLabel() + "_diag: MATRIX_PRODUCT_DIAGONAL ARG=" + getShortcutLabel() + "_matones," + getShortcutLabel() + "_prod"); 153 : // Sum everything 154 7 : if( !squared ) { 155 0 : readInputLine( getShortcutLabel() + "_2: COMBINE ARG=" + getShortcutLabel() + "_offdiag," + getShortcutLabel() + "_diag PERIODIC=NO"); 156 : } else { 157 14 : readInputLine( getShortcutLabel() + ": COMBINE ARG=" + getShortcutLabel() + "_offdiag," + getShortcutLabel() + "_diag PERIODIC=NO"); 158 : } 159 : } else { 160 5 : ActionWithValue* av=plumed.getActionSet().selectWithLabel<ActionWithValue*>( getShortcutLabel() + "_diffT" ); 161 5 : plumed_assert( av && av->getNumberOfComponents()==1 ); 162 5 : if( (av->copyOutput(0))->getRank()==1 ) { 163 8 : readInputLine( getShortcutLabel() + "_matvec: MATRIX_VECTOR_PRODUCT ARG=" + metstr + "," + getShortcutLabel() +"_diffT"); 164 : } else { 165 2 : readInputLine( getShortcutLabel() + "_matvec: MATRIX_PRODUCT ARG=" + metstr + "," + getShortcutLabel() +"_diffT"); 166 : } 167 5 : std::string olab = getShortcutLabel(); 168 5 : if( !squared ) { 169 : olab += "_2"; 170 : } 171 10 : readInputLine( olab + ": MATRIX_PRODUCT_DIAGONAL ARG=" + getShortcutLabel() + "_diff," + getShortcutLabel() +"_matvec"); 172 : } 173 12 : if( !squared ) { 174 10 : readInputLine( getShortcutLabel() + ": CUSTOM ARG=" + getShortcutLabel() + "_2 FUNC=sqrt(x) PERIODIC=NO"); 175 : } 176 12 : } 177 : 178 : } 179 : }