Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2016-2021 The VES code team
3 : (see the PEOPLE-VES file at the root of this folder for a list of names)
4 :
5 : See http://www.ves-code.org for more information.
6 :
7 : This file is part of VES code module.
8 :
9 : The VES code module is free software: you can redistribute it and/or modify
10 : it under the terms of the GNU Lesser General Public License as published by
11 : the Free Software Foundation, either version 3 of the License, or
12 : (at your option) any later version.
13 :
14 : The VES code module is distributed in the hope that it will be useful,
15 : but WITHOUT ANY WARRANTY; without even the implied warranty of
16 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 : GNU Lesser General Public License for more details.
18 :
19 : You should have received a copy of the GNU Lesser General Public License
20 : along with the VES code module. If not, see <http://www.gnu.org/licenses/>.
21 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
22 :
23 :
24 : #include "BasisFunctions.h"
25 : #include "GridLinearInterpolation.h"
26 : #include "tools/Grid.h"
27 : #include "VesTools.h"
28 : #include "WaveletGrid.h"
29 : #include "core/ActionRegister.h"
30 : #include "tools/Exception.h"
31 : #include "core/PlumedMain.h"
32 :
33 :
34 : namespace PLMD {
35 : namespace ves {
36 :
37 :
38 : //+PLUMEDOC VES_BASISF BF_WAVELETS
39 : /*
40 : Daubechies Wavelets basis functions.
41 :
42 : Note: at the moment only bases with a single level of scaling functions are usable, as multiscale optimization is not yet implemented.
43 :
44 : This basis set uses the Daubechies Wavelets that are discussed in the first article cited below to construct a complete and orthogonal basis. See the second paper cited below for full details.
45 :
46 : The basis set is based on using a pair of functions, the scaling function (or father wavelet) $\phi$ and the wavelet function (or mother wavelet) $\psi$.
47 : They are defined via the two-scale relations for scale $j$ and shift $k$:
48 :
49 : $$
50 : \begin{aligned}
51 : \phi_k^j \left(x\right) = 2^{-j/2} \phi \left( 2^{-j} x - k\right)\\
52 : \psi_k^j \left(x\right) = 2^{-j/2} \psi \left( 2^{-j} x - k\right)
53 : \end{aligned}
54 : $$
55 :
56 : The exact properties are set by choosing filter coefficients, e.g. choosing $h_k$ for the father wavelet:
57 :
58 : $$
59 : \phi\left(x\right) = \sqrt{2} \sum_k h_k\, \phi \left( 2 x - k\right)
60 : $$
61 :
62 : The filter coefficients by Daubechies result in an orthonormal basis of all integer shifted functions:
63 :
64 : $$
65 : \int \phi(x+i) \phi(x+j) \mathop{}\!\mathrm{d}x = \delta_{ij} \quad \text{for} \quad i,j \in \mathbb{Z}
66 : $$
67 :
68 : Because no analytic formula for these wavelets exist, they are instead constructed iteratively on a grid.
69 : The method of construction is close to the "Vector cascade algorithm" described in [this book](https://epubs.siam.org/doi/book/10.1137/1.9780961408879?mi=0&af=R&pubType=book&sortBy=EpubDate&target=browse).
70 : The needed filter coefficients of the scaling function are hardcoded, and were previously generated via a python script.
71 : Currently the "maximum phase" type (Db) and the "least asymmetric" (Sym) type are implemented.
72 : We recommend to use Symlets.
73 :
74 : As an example two adjacent basis functions of both Sym8 (ORDER=8, TYPE=SYMLET) and Db8 (ORDER=8, TYPE=DAUBECHIES) is shown in the figure.
75 : The full basis consists of shifted wavelets in the full specified interval.
76 :
77 : 
78 :
79 : ## Specify the wavelet type
80 :
81 : The TYPE keyword sets the type of Wavelet, at the moment "DAUBECHIES" and "SYMLETS" are available.
82 : The specified ORDER of the basis corresponds to the number of vanishing moments of the wavelet, i.e. if TYPE was specified as "DAUBECHIES" an order of 8 results in Db8 wavelets.
83 :
84 :
85 : ## Specify the number of functions
86 :
87 : The resulting basis set consists of integer shifts of the wavelet with some scaling $j$,
88 :
89 : $$
90 : V(x) = \sum_i \alpha_i * \phi_i (x) = \sum_i \alpha_i * \phi(\frac{x+i}{j})
91 : $$
92 :
93 : with the variational parameters $\alpha$.
94 : Additionally a constant basis function is included.
95 :
96 : There are two different ways to specify the number of used basis functions implemented.
97 : You can either specify the scale or alternatively a fixed number of basis function.
98 :
99 : Coming from the multiresolution aspect of wavelets, you can set the scale of the father wavelets, i.e. the largest scale used for approximation.
100 : This can be done with the FUNCTION_LENGTH keyword.
101 : It should be given in the same units as the used CV and specifies the length (of the domain interval) of the individual father wavelet functions.
102 :
103 : Alternatively a fixed number of basis functions for the bias expansion can be specified with the NUM_BF keyword, which will set the scale automatically to match the desired number of functions.
104 : Note that this also includes the constant function.
105 :
106 : If you do not specify anything, it is assumed that the range of the bias should match the scale of the wavelet functions.
107 : More precise, the basis functions are scaled to match the specified size of the CV space (MINIMUM and MAXIMUM keywords).
108 : This has so far been a good initial choice.
109 :
110 : If the wavelets are scaled to match the CV range exactly there would be $4*\text{ORDER} -3$ basis functions whose domain is at least partially in this region.
111 : This number is adjusted if FUNCTION_LENGTH or NUM_BF is specified.
112 : Additionally, some of the shifted basis functions will not have significant contributions because of their function values being close to zero over the full range of the bias.
113 : These 'tail wavelets' can be omitted by using the TAILS_THRESHOLD keyword.
114 : This omits all shifted functions that have only function values smaller than a fraction of their maximum value inside the bias range.
115 : Using a value of e.g. 0.01 will already reduce the number of basis functions significantly.
116 : The default setting will not omit any tail wavelets (i.e. TAILS_THRESHOLD=0).
117 :
118 : The number of basis functions is then not easily determinable a priori but will be given in the logfile.
119 : Additionally the starting point (leftmost defined point) of the individual basis functions is printed.
120 :
121 :
122 : With the PERIODIC keyword the basis set can also be used to bias periodic CVs.
123 : Then the shift between the functions will be chosen such that the function at the left border and right border coincide.
124 : If the FUNCTION_LENGTH keyword is used together with PERIODIC, a smaller length might be chosen to satisfy this requirement.
125 :
126 :
127 : ## Grid
128 :
129 : The values of the wavelet function are generated on a grid.
130 : Using the cascade algorithm results in doubling the grid values for each iteration.
131 : This means that the grid size will always be a power of two multiplied by the number of coefficients ($ 2*\text{ORDER} -1$) for the specified wavelet.
132 : Using the MIN_GRID_SIZE keyword a lower bound for the number of grid points can be specified.
133 : By default at least 1,000 grid points are used.
134 : Function values in between grid points are calculated by linear interpolation.
135 :
136 : ## Optimization notes
137 :
138 : To avoid 'blind' optimization of the basis functions outside the currently sampled area, it is often beneficial to use the OPTIMIZATION_THRESHOLD keyword of the [VES_LINEAR_EXPANSION](VES_LINEAR_EXPANSION.md) (set it to a small value, e.g. 1e-6)
139 :
140 : ## Examples
141 :
142 :
143 : First a very simple example that relies on the default values.
144 : We want to bias some CV in the range of 0 to 4.
145 : The wavelets will therefore be scaled to match that range.
146 : Using Db8 wavelets this results in 30 basis functions (including the constant one), with their starting points given by $ -14*\frac{4}{15}, -13*\frac{4}{15}, \cdots , 0 , \cdots, 13*\frac{4}{15}, 14*\frac{4}{15}$.
147 :
148 : ```plumed
149 : BF_WAVELETS ...
150 : ORDER=8
151 : TYPE=DAUBECHIES
152 : MINIMUM=0.0
153 : MAXIMUM=4.0
154 : LABEL=bf
155 : ... BF_WAVELETS
156 : ```
157 :
158 :
159 : By omitting wavelets with only insignificant parts, we can reduce the number of basis functions. Using a threshold of 0.01 will in this example remove the 8 leftmost shifts, which we can check in the logfile.
160 :
161 : ```plumed
162 : BF_WAVELETS ...
163 : ORDER=8
164 : TYPE=DAUBECHIES
165 : MINIMUM=0.0
166 : MAXIMUM=4.0
167 : TAILS_THRESHOLD=0.01
168 : LABEL=bf
169 : ... BF_WAVELETS
170 : ```
171 :
172 : The length of the individual basis functions can also be adjusted to fit the specific problem.
173 : If for example the wavelets are instead scaled to length 3, there will be 35 basis functions, with leftmost points at $ -14*\frac{3}{15}, -13*\frac{3}{15}, \cdots, 0, \cdots, 18*\frac{3}{15}, 19*\frac{3}{15} $.
174 :
175 : ```plumed
176 : BF_WAVELETS ...
177 : ORDER=8
178 : TYPE=DAUBECHIES
179 : MINIMUM=0.0
180 : MAXIMUM=4.0
181 : FUNCTION_LENGTH=3
182 : LABEL=bf
183 : ... BF_WAVELETS
184 : ```
185 :
186 : Alternatively you can also specify the number of basis functions. Here we specify the usage of 40 Sym10 wavelet functions. We also used a custom minimum size for the grid and want it to be printed to a file with a specific numerical format.
187 :
188 : ```plumed
189 : BF_WAVELETS ...
190 : ORDER=10
191 : TYPE=SYMLETS
192 : MINIMUM=0.0
193 : MAXIMUM=4.0
194 : NUM_BF=40
195 : MIN_GRID_SIZE=500
196 : DUMP_WAVELET_GRID
197 : WAVELET_FILE_FMT=%11.4f
198 : LABEL=bf
199 : ... BF_WAVELETS
200 : ```
201 :
202 : */
203 : //+ENDPLUMEDOC
204 :
205 :
206 : class BF_Wavelets : public BasisFunctions {
207 : private:
208 : void setupLabels() override;
209 : /// ptr to Grid that holds the Wavelet values and its derivative
210 : std::unique_ptr<Grid> waveletGrid_;
211 : /// calculate threshold for omitted tail wavelets
212 : std::vector<double> getCutoffPoints(const double& threshold);
213 : /// scale factor of the individual BFs to match specified length
214 : double scale_;
215 : /// shift of the individual BFs
216 : std::vector<double> shifts_;
217 : public:
218 : static void registerKeywords( Keywords&);
219 : explicit BF_Wavelets(const ActionOptions&);
220 : void getAllValues(const double, double&, bool&, std::vector<double>&, std::vector<double>&) const override;
221 : };
222 :
223 :
224 : PLUMED_REGISTER_ACTION(BF_Wavelets,"BF_WAVELETS")
225 :
226 :
227 49 : void BF_Wavelets::registerKeywords(Keywords& keys) {
228 49 : BasisFunctions::registerKeywords(keys);
229 49 : keys.add("compulsory","TYPE","Specify the wavelet type. Currently available are DAUBECHIES Wavelets with minimum phase and the more symmetric SYMLETS");
230 49 : keys.add("optional","FUNCTION_LENGTH","The domain size of the individual basis functions. (length) This is used to alter the scaling of the basis functions. By default it is set to the total size of the interval. This also influences the number of actually used basis functions, as all shifted functions that are partially supported in the CV space are used.");
231 49 : keys.add("optional","NUM_BF","The number of basis functions that should be used. Includes the constant one and N-1 shifted wavelets within the specified range. Cannot be used together with FUNCTION_LENGTH.");
232 49 : keys.add("optional","TAILS_THRESHOLD","The threshold for cutting off tail wavelets as a fraction of the maximum value. All shifted wavelet functions that only have values smaller than the threshold in the bias range will be excluded from the basis set. Defaults to 0 (include all).");
233 49 : keys.addFlag("MOTHER_WAVELET", false, "If this flag is set mother wavelets will be used instead of the scaling function (father wavelet). Makes only sense for multiresolution, which is at the moment not usable.");
234 49 : keys.add("optional","MIN_GRID_SIZE","The minimal number of grid bins of the Wavelet function. The true number depends also on the used wavelet type and will probably be larger. Defaults to 1000.");
235 49 : keys.addFlag("DUMP_WAVELET_GRID", false, "If this flag is set the grid with the wavelet values will be written to a file. This file is called wavelet_grid.data.");
236 49 : keys.add("optional","WAVELET_FILE_FMT","The number format of the wavelet grid values and derivatives written to file. By default it is %15.8f.\n");
237 49 : keys.addFlag("PERIODIC", false, "Use periodic version of basis set.");
238 49 : keys.remove("NUMERICAL_INTEGRALS");
239 49 : keys.addDOI("10.1137/1.9781611970104");
240 49 : keys.addDOI("10.1021/acs.jctc.2c00197");
241 49 : }
242 :
243 :
244 47 : BF_Wavelets::BF_Wavelets(const ActionOptions& ao):
245 : PLUMED_VES_BASISFUNCTIONS_INIT(ao),
246 47 : waveletGrid_(nullptr),
247 47 : scale_(0.0) {
248 47 : log.printf(" Wavelet basis functions, see and cite ");
249 94 : log << plumed.cite("Pampel and Valsson, J. Chem. Theory Comput. 18, 4127-4141 (2022) - DOI:10.1021/acs.jctc.2c00197");
250 :
251 : // parse properties for waveletGrid and set it up
252 : bool use_mother_wavelet;
253 94 : parseFlag("MOTHER_WAVELET", use_mother_wavelet);
254 :
255 : std::string wavelet_type_str;
256 47 : parse("TYPE", wavelet_type_str);
257 94 : addKeywordToList("TYPE", wavelet_type_str);
258 :
259 47 : unsigned min_grid_size = 1000;
260 47 : parse("MIN_GRID_SIZE", min_grid_size);
261 47 : if(min_grid_size != 1000) {
262 72 : addKeywordToList("MIN_GRID_SIZE",min_grid_size);
263 : }
264 :
265 94 : waveletGrid_ = WaveletGrid::setupGrid(getOrder(), min_grid_size, use_mother_wavelet, WaveletGrid::stringToType(wavelet_type_str));
266 47 : bool dump_wavelet_grid=false;
267 47 : parseFlag("DUMP_WAVELET_GRID", dump_wavelet_grid);
268 47 : if (dump_wavelet_grid) {
269 36 : OFile wavelet_gridfile;
270 36 : std::string fmt = "%13.6f";
271 72 : parse("WAVELET_FILE_FMT",fmt);
272 : waveletGrid_->setOutputFmt(fmt); // property of grid not OFile determines fmt
273 36 : wavelet_gridfile.link(*this);
274 36 : wavelet_gridfile.enforceBackup();
275 72 : wavelet_gridfile.open(getLabel()+".wavelet_grid.data");
276 36 : waveletGrid_->writeToFile(wavelet_gridfile);
277 36 : }
278 :
279 47 : bool periodic = false;
280 47 : parseFlag("PERIODIC",periodic);
281 47 : if (periodic) {
282 8 : addKeywordToList("PERIODIC",periodic);
283 : }
284 :
285 : // now set up properties of basis set
286 47 : unsigned intrinsic_length = 2*getOrder() - 1; // length of unscaled wavelet
287 47 : double bias_length = intervalMax() - intervalMin(); // intervalRange() is not yet set
288 :
289 : // parse threshold for tail wavelets and get respective cutoff points
290 47 : double threshold = 0.0;
291 47 : std::vector<double> cutoffpoints (2);
292 47 : parse("TAILS_THRESHOLD",threshold);
293 47 : plumed_massert(threshold < 1, "TAILS_THRESHOLD should be significantly smaller than 1.");
294 47 : if(threshold == 0.0) {
295 45 : cutoffpoints = {0.0, static_cast<double>(intrinsic_length)};
296 : } else {
297 2 : plumed_massert(!periodic, "TAILS_THRESHOLD can't be used with the periodic wavelet variant");
298 2 : addKeywordToList("TAILS_THRESHOLD",threshold);
299 4 : cutoffpoints = getCutoffPoints(threshold);
300 : };
301 :
302 47 : double function_length = bias_length;
303 47 : parse("FUNCTION_LENGTH",function_length);
304 47 : if(function_length != bias_length) {
305 4 : if (periodic) { // shifted functions need to fit into interval exactly -> reduce size if not
306 2 : unsigned num_shifts = ceil(bias_length * intrinsic_length / function_length);
307 2 : function_length = bias_length * intrinsic_length / num_shifts;
308 : }
309 8 : addKeywordToList("FUNCTION_LENGTH",function_length);
310 : }
311 :
312 : // determine number of BFs and needed scaling
313 47 : unsigned num_BFs = 0;
314 47 : parse("NUM_BF",num_BFs);
315 47 : if(num_BFs == 0) { // get from function length
316 43 : scale_ = intrinsic_length / function_length;
317 43 : if (periodic) {
318 : // this is the same value as num_shifts above + constant
319 2 : num_BFs = static_cast<unsigned>(bias_length * scale_) + 1;
320 : } else {
321 41 : num_BFs = 1; // constant one
322 : // left shifts (w/o left cutoff) + right shifts - right cutoff - 1
323 41 : num_BFs += static_cast<unsigned>(ceil(cutoffpoints[1] + (bias_length)*scale_ - cutoffpoints[0]) - 1);
324 : }
325 : } else {
326 : plumed_massert(num_BFs > 0, "The number of basis functions has to be positive (NUM_BF > 0)");
327 : // check does not work if function length was given as intrinsic length, but can't check for keyword use directly
328 4 : plumed_massert(function_length==bias_length,"The keywords \"NUM_BF\" and \"FUNCTION_LENGTH\" cannot be used at the same time");
329 4 : addKeywordToList("NUM_BF",num_BFs);
330 :
331 4 : if (periodic) { // inverted num_BFs calculation from where FUNCTION_LENGTH is specified
332 2 : scale_ = (num_BFs - 1) / bias_length ;
333 : } else {
334 2 : double cutoff_length = cutoffpoints[1] - cutoffpoints [0];
335 2 : double intrinsic_bias_length = num_BFs - cutoff_length + 1; // length of bias in intrinsic scale of wavelets
336 2 : scale_ = intrinsic_bias_length / bias_length;
337 : }
338 : }
339 :
340 47 : setNumberOfBasisFunctions(num_BFs);
341 :
342 : // now set up the starting points of the basis functions
343 47 : shifts_.push_back(0.0); // constant BF – never used, just for clearer notation
344 1908 : for(unsigned int i = 1; i < getNumberOfBasisFunctions(); ++i) {
345 1861 : shifts_.push_back(-intervalMin()*scale_ + cutoffpoints[1] - i);
346 : }
347 :
348 : // set some properties
349 47 : setIntrinsicInterval(0.0,intrinsic_length);
350 47 : periodic ? setPeriodic() : setNonPeriodic();
351 : setIntervalBounded();
352 : setType(wavelet_type_str);
353 47 : setDescription("Wavelets as localized basis functions");
354 47 : setupBF();
355 47 : checkRead();
356 :
357 47 : log.printf(" Each basisfunction spans %f in CV space\n", intrinsic_length/scale_);
358 47 : }
359 :
360 :
361 62249 : void BF_Wavelets::getAllValues(const double arg, double& argT, bool& inside_range, std::vector<double>& values, std::vector<double>& derivs) const {
362 62249 : argT=checkIfArgumentInsideInterval(arg,inside_range);
363 : //
364 62249 : values[0]=1.0;
365 62249 : derivs[0]=0.0;
366 2315762 : for(unsigned int i = 1; i < getNumberOfBasisFunctions(); ++i) {
367 : // scale and shift argument to match current wavelet
368 2253513 : double x = shifts_[i] + argT*scale_;
369 2253513 : if (arePeriodic()) { // periodic interval [0,intervalRange*scale]
370 171766 : x = x - floor(x/(intervalRange()*scale_))*intervalRange()*scale_;
371 : }
372 :
373 2253513 : if (x < 0 || x >= intrinsicIntervalMax()) { // Wavelets are 0 outside the defined range
374 989659 : values[i] = 0.0;
375 989659 : derivs[i] = 0.0;
376 : } else {
377 1263854 : std::vector<double> temp_deriv (1);
378 1263854 : values[i] = GridLinearInterpolation::getGridValueAndDerivativesWithLinearInterpolation(waveletGrid_.get(), {x}, temp_deriv);
379 1263854 : derivs[i] = temp_deriv[0] * scale_; // scale derivative
380 : }
381 : }
382 62249 : if(!inside_range) {
383 5818 : for(auto& deriv : derivs) {
384 5636 : deriv=0.0;
385 : }
386 : }
387 62249 : }
388 :
389 :
390 : // returns left and right cutoff point of Wavelet
391 : // threshold is a percent value of maximum
392 2 : std::vector<double> BF_Wavelets::getCutoffPoints(const double& threshold) {
393 2 : double threshold_value = threshold * waveletGrid_->getMaxValue();
394 : std::vector<double> cutoffpoints;
395 :
396 475 : for (size_t i = 0; i < waveletGrid_->getSize(); ++i) {
397 475 : if (fabs(waveletGrid_->getValue(i)) >= threshold_value) {
398 2 : cutoffpoints.push_back(waveletGrid_->getPoint(i)[0]);
399 2 : break;
400 : }
401 : }
402 :
403 1073 : for (int i = waveletGrid_->getSize() - 1; i >= 0; --i) {
404 1073 : if (fabs(waveletGrid_->getValue(i)) >= threshold_value) {
405 2 : cutoffpoints.push_back(waveletGrid_->getPoint(i)[0]);
406 2 : break;
407 : }
408 : }
409 :
410 2 : return cutoffpoints;
411 : }
412 :
413 :
414 : // labels according to minimum position in CV space
415 47 : void BF_Wavelets::setupLabels() {
416 47 : setLabel(0,"const");
417 1908 : for(unsigned int i=1; i < getNumberOfBasisFunctions(); i++) {
418 1861 : double pos = -shifts_[i]/scale_;
419 1861 : if (arePeriodic()) {
420 88 : pos = pos - floor((pos-intervalMin())/intervalRange())*intervalRange();
421 : }
422 : std::string is;
423 1861 : Tools::convert(pos, is);
424 3722 : setLabel(i,"i="+is);
425 : }
426 47 : }
427 :
428 :
429 : }
430 : }
|