Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2011-2017 The plumed team
3 : (see the PEOPLE file at the root of the distribution for a list of names)
4 :
5 : See http://www.plumed.org for more information.
6 :
7 : This file is part of plumed, version 2.
8 :
9 : plumed is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : plumed is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with plumed. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 : #include "core/ActionShortcut.h"
23 : #include "core/PlumedMain.h"
24 : #include "core/ActionSet.h"
25 : #include "core/ActionRegister.h"
26 : #include "core/ActionWithValue.h"
27 : #include "tools/IFile.h"
28 :
29 : #include <cmath>
30 :
31 : namespace PLMD {
32 : namespace refdist {
33 :
34 : //+PLUMEDOC FUNCTION KERNEL
35 : /*
36 : Use a switching function to determine how many of the input variables are less than a certain cutoff.
37 :
38 : \par Examples
39 :
40 : */
41 : //+ENDPLUMEDOC
42 :
43 :
44 : class Kernel : public ActionShortcut {
45 : public:
46 : static std::string fixArgumentDot( const std::string& argin );
47 : explicit Kernel(const ActionOptions&);
48 : static void registerKeywords(Keywords& keys);
49 : };
50 :
51 :
52 : PLUMED_REGISTER_ACTION(Kernel,"KERNEL")
53 :
54 20 : void Kernel::registerKeywords(Keywords& keys) {
55 20 : ActionShortcut::registerKeywords( keys );
56 40 : keys.add("numbered","ARG","the arguments that should be used as input to this method");
57 40 : keys.add("compulsory","TYPE","gaussian","the type of kernel to use");
58 40 : keys.add("compulsory","CENTER","the position of the center of the kernel");
59 40 : keys.add("optional","SIGMA","square root of variance of the cluster");
60 40 : keys.add("compulsory","COVAR","the covariance of the kernel");
61 40 : keys.add("compulsory","WEIGHT","1.0","the weight to multiply this kernel function by");
62 40 : keys.add("optional","REFERENCE","the file from which to read the kernel parameters");
63 40 : keys.add("compulsory","NUMBER","1","if there are multiple sets of kernel parameters in the input file which set of kernel parameters would you like to read in here");
64 40 : keys.addFlag("NORMALIZED",false,"would you like the kernel function to be normalized");
65 20 : keys.setValueDescription("the value of the kernel evaluated at the argument values");
66 20 : keys.needsAction("CONSTANT");
67 20 : keys.needsAction("CUSTOM");
68 20 : keys.needsAction("NORMALIZED_EUCLIDEAN_DISTANCE");
69 20 : keys.needsAction("PRODUCT");
70 20 : keys.needsAction("INVERT_MATRIX");
71 20 : keys.needsAction("MAHALANOBIS_DISTANCE");
72 20 : keys.needsAction("DIAGONALIZE");
73 20 : keys.needsAction("CONCATENATE");
74 20 : keys.needsAction("DETERMINANT");
75 20 : keys.needsAction("BESSEL");
76 20 : }
77 :
78 32 : std::string Kernel::fixArgumentDot( const std::string& argin ) {
79 32 : std::string argout = argin;
80 32 : std::size_t dot=argin.find(".");
81 32 : if( dot!=std::string::npos ) {
82 0 : argout = argin.substr(0,dot) + "_" + argin.substr(dot+1);
83 : }
84 32 : return argout;
85 : }
86 :
87 9 : Kernel::Kernel(const ActionOptions&ao):
88 : Action(ao),
89 9 : ActionShortcut(ao) {
90 : // Read in the arguments
91 : std::vector<std::string> argnames;
92 18 : parseVector("ARG",argnames);
93 9 : if( argnames.size()==0 ) {
94 0 : error("no arguments were specified");
95 : }
96 : // Now sort out the parameters
97 : double weight;
98 : std::string fname;
99 18 : parse("REFERENCE",fname);
100 : bool usemahalanobis=false;
101 9 : if( fname.length()>0 ) {
102 9 : IFile ifile;
103 9 : ifile.open(fname);
104 9 : ifile.allowIgnoredFields();
105 : unsigned number;
106 9 : parse("NUMBER",number);
107 : bool readline=false;
108 : // Create actions to hold the position of the center
109 31 : for(unsigned line=0; line<number; ++line) {
110 90 : for(unsigned i=0; i<argnames.size(); ++i) {
111 : std::string val;
112 59 : ifile.scanField(argnames[i], val);
113 59 : if( line==number-1 ) {
114 32 : readInputLine( getShortcutLabel() + "_" + fixArgumentDot(argnames[i]) + "_ref: CONSTANT VALUES=" + val );
115 : }
116 : }
117 62 : if( ifile.FieldExist("sigma_" + argnames[0]) ) {
118 : std::string varstr;
119 0 : for(unsigned i=0; i<argnames.size(); ++i) {
120 : std::string val;
121 0 : ifile.scanField("sigma_" + argnames[i], val);
122 0 : if( i==0 ) {
123 : varstr = val;
124 : } else {
125 0 : varstr += "," + val;
126 : }
127 : }
128 0 : if( line==number-1 ) {
129 0 : readInputLine( getShortcutLabel() + "_var: CONSTANT VALUES=" + varstr );
130 : }
131 : } else {
132 : std::string varstr, nvals;
133 31 : Tools::convert( argnames.size(), nvals );
134 31 : usemahalanobis=(argnames.size()>1);
135 90 : for(unsigned i=0; i<argnames.size(); ++i) {
136 174 : for(unsigned j=0; j<argnames.size(); j++) {
137 : std::string val;
138 230 : ifile.scanField("sigma_" +argnames[i] + "_" + argnames[j], val );
139 115 : if(i==0 && j==0 ) {
140 : varstr = val;
141 : } else {
142 168 : varstr += "," + val;
143 : }
144 : }
145 : }
146 31 : if( line==number-1 ) {
147 9 : if( !usemahalanobis ) {
148 4 : readInputLine( getShortcutLabel() + "_var: CONSTANT VALUES=" + varstr );
149 : } else {
150 14 : readInputLine( getShortcutLabel() + "_cov: CONSTANT NCOLS=" + nvals + " NROWS=" + nvals + " VALUES=" + varstr );
151 : }
152 : }
153 : }
154 31 : if( line==number-1 ) {
155 : readline=true;
156 : break;
157 : }
158 22 : ifile.scanField();
159 : }
160 9 : if( !readline ) {
161 0 : error("could not read reference configuration");
162 : }
163 9 : ifile.scanField();
164 9 : ifile.close();
165 9 : } else {
166 : // Create actions to hold the position of the center
167 0 : std::vector<std::string> center(argnames.size());
168 0 : parseVector("CENTER",center);
169 0 : for(unsigned i=0; i<argnames.size(); ++i) {
170 0 : readInputLine( getShortcutLabel() + "_" + fixArgumentDot(argnames[i]) + "_ref: CONSTANT VALUES=" + center[i] );
171 : }
172 : std::vector<std::string> sig;
173 0 : parseVector("SIGMA",sig);
174 0 : if( sig.size()==0 ) {
175 : // Create actions to hold the covariance
176 : std::string cov;
177 0 : parse("COVAR",cov);
178 0 : usemahalanobis=(argnames.size()>1);
179 0 : if( !usemahalanobis ) {
180 0 : readInputLine( getShortcutLabel() + "_var: CONSTANT VALUES=" + cov );
181 : } else {
182 : std::string nvals;
183 0 : Tools::convert( argnames.size(), nvals );
184 0 : readInputLine( getShortcutLabel() + "_cov: CONSTANT NCOLS=" + nvals + " NROWS=" + nvals + " VALUES=" + cov );
185 : }
186 0 : } else if( sig.size()==argnames.size() ) {
187 : // And actions to hold the standard deviation
188 0 : std::string valstr = sig[0];
189 0 : for(unsigned i=1; i<sig.size(); ++i) {
190 0 : valstr += "," + sig[i];
191 : }
192 0 : readInputLine( getShortcutLabel() + "_sigma: CONSTANT VALUES=" + valstr );
193 0 : readInputLine( getShortcutLabel() + "_var: CUSTOM ARG=" + getShortcutLabel() + "_sigma FUNC=x*x PERIODIC=NO");
194 : } else {
195 0 : error("sigma has wrong length");
196 : }
197 0 : }
198 :
199 : // Create the reference point and arguments
200 : std::string refpoint, argstr;
201 25 : for(unsigned i=0; i<argnames.size(); ++i) {
202 16 : if( i==0 ) {
203 : argstr = argnames[0];
204 18 : refpoint = getShortcutLabel() + "_" + fixArgumentDot(argnames[i]) + "_ref";
205 : } else {
206 14 : argstr += "," + argnames[1];
207 14 : refpoint += "," + getShortcutLabel() + "_" + fixArgumentDot(argnames[i]) + "_ref";
208 : }
209 : }
210 :
211 : // Get the information on the kernel type
212 : std::string func_str, ktype;
213 18 : parse("TYPE",ktype);
214 16 : if( ktype=="gaussian" || ktype=="von-misses" ) {
215 : func_str = "exp(-x/2)";
216 0 : } else if( ktype=="triangular" ) {
217 : func_str = "step(1.-sqrt(x))*(1.-sqrt(x))";
218 : } else {
219 : func_str = ktype;
220 : }
221 9 : std::string vm_str="";
222 9 : if( ktype=="von-misses" ) {
223 : vm_str=" VON_MISSES";
224 : }
225 :
226 9 : unsigned nvals = argnames.size();
227 : bool norm;
228 9 : parseFlag("NORMALIZED",norm);
229 9 : if( !usemahalanobis ) {
230 : // Invert the variance
231 4 : readInputLine( getShortcutLabel() + "_icov: CUSTOM ARG=" + getShortcutLabel() + "_var FUNC=1/x PERIODIC=NO");
232 : // Compute the distance between the center of the basin and the current configuration
233 4 : readInputLine( getShortcutLabel() + "_dist_2: NORMALIZED_EUCLIDEAN_DISTANCE SQUARED" + vm_str +" ARG1=" + argstr + " ARG2=" + refpoint + " METRIC=" + getShortcutLabel() + "_icov");
234 : // And compute a determinent for the input covariance matrix if it is required
235 2 : if( norm ) {
236 2 : if( ktype=="von-misses" ) {
237 0 : readInputLine( getShortcutLabel() + "_vec: CUSTOM ARG=" + getShortcutLabel() + "_icov FUNC=x PERIODIC=NO" );
238 : } else {
239 4 : readInputLine( getShortcutLabel() + "_det: PRODUCT ARG=" + getShortcutLabel() + "_var");
240 : }
241 : }
242 : } else {
243 : // Invert the input covariance matrix
244 14 : readInputLine( getShortcutLabel() + "_icov: INVERT_MATRIX ARG=" + getShortcutLabel() + "_cov" );
245 : // Compute the distance between the center of the basin and the current configuration
246 14 : readInputLine( getShortcutLabel() + "_dist_2: MAHALANOBIS_DISTANCE SQUARED ARG1=" + argstr + " ARG2=" + refpoint + " METRIC=" + getShortcutLabel() + "_icov " + vm_str );
247 : // And compute a determinent for the input covariance matrix if it is required
248 7 : if( norm ) {
249 7 : if( ktype=="von-misses" ) {
250 14 : readInputLine( getShortcutLabel() + "_det: DIAGONALIZE ARG=" + getShortcutLabel() + "_cov VECTORS=all" );
251 7 : std::string num, argnames= getShortcutLabel() + "_det.vals-1";
252 14 : for(unsigned i=1; i<nvals; ++i) {
253 7 : Tools::convert( i+1, num );
254 14 : argnames += "," + getShortcutLabel() + "_det.vals-" + num;
255 : }
256 14 : readInputLine( getShortcutLabel() + "_comp: CONCATENATE ARG=" + argnames );
257 14 : readInputLine( getShortcutLabel() + "_vec: CUSTOM ARG=" + getShortcutLabel() + "_comp FUNC=1/x PERIODIC=NO");
258 : } else {
259 0 : readInputLine( getShortcutLabel() + "_det: DETERMINANT ARG=" + getShortcutLabel() + "_cov");
260 : }
261 : }
262 : }
263 :
264 : // Compute the Gaussian
265 : std::string wstr;
266 9 : parse("WEIGHT",wstr);
267 9 : if( norm ) {
268 9 : if( ktype=="gaussian" ) {
269 : std::string pstr;
270 2 : Tools::convert( sqrt(pow(2*pi,nvals)), pstr );
271 4 : readInputLine( getShortcutLabel() + "_vol: CUSTOM ARG=" + getShortcutLabel() + "_det FUNC=(sqrt(x)*" + pstr + ") PERIODIC=NO");
272 7 : } else if( ktype=="von-misses" ) {
273 : std::string wstr, min, max;
274 7 : ActionWithValue* av=plumed.getActionSet().selectWithLabel<ActionWithValue*>( getShortcutLabel() + "_dist_2_diff" );
275 7 : plumed_assert( av );
276 7 : if( !av->copyOutput(0)->isPeriodic() ) {
277 0 : error("VON_MISSES only works with periodic variables");
278 : }
279 7 : av->copyOutput(0)->getDomain(min,max);
280 14 : readInputLine( getShortcutLabel() + "_bes: BESSEL ORDER=0 ARG=" + getShortcutLabel() + "_vec");
281 14 : readInputLine( getShortcutLabel() + "_cc: CUSTOM ARG=" + getShortcutLabel() + "_bes FUNC=("+max+"-"+min+")*x PERIODIC=NO");
282 14 : readInputLine( getShortcutLabel() + "_vol: PRODUCT ARG=" + getShortcutLabel() + "_cc");
283 : } else {
284 0 : error("only gaussian and von-misses kernels are normalizable");
285 : }
286 : // And the (suitably normalized) kernel
287 18 : readInputLine( getShortcutLabel() + ": CUSTOM ARG=" + getShortcutLabel() + "_dist_2," + getShortcutLabel() + "_vol FUNC=" + wstr + "*exp(-x/2)/y PERIODIC=NO");
288 : } else {
289 0 : readInputLine( getShortcutLabel() + ": CUSTOM ARG1=" + getShortcutLabel() + "_dist_2 FUNC=" + wstr + "*" + func_str + " PERIODIC=NO");
290 : }
291 9 : checkRead();
292 :
293 9 : }
294 :
295 : }
296 : }
297 :
298 :
|