Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2024 The METATOMIC-PLUMED team
3 : (see the PEOPLE-METATOMIC file at the root of this folder for a list of names)
4 :
5 : See https://docs.metatensor.org/metatomic/ for more information about the
6 : metatomic package that this module allows you to call from PLUMED.
7 :
8 : This file is part of METATOMIC-PLUMED module.
9 :
10 : The METATOMIC-PLUMED module is free software: you can redistribute it and/or modify
11 : it under the terms of the GNU Lesser General Public License as published by
12 : the Free Software Foundation, either version 3 of the License, or
13 : (at your option) any later version.
14 :
15 : The METATOMIC-PLUMED module is distributed in the hope that it will be useful,
16 : but WITHOUT ANY WARRANTY; without even the implied warranty of
17 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 : GNU Lesser General Public License for more details.
19 :
20 : You should have received a copy of the GNU Lesser General Public License
21 : along with the METATOMIC-PLUMED module. If not, see <http://www.gnu.org/licenses/>.
22 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
23 :
24 : #include "core/ActionAtomistic.h"
25 : #include "core/ActionWithValue.h"
26 : #include "core/ActionRegister.h"
27 : #include "core/PlumedMain.h"
28 :
29 : //+PLUMEDOC METATOMICMOD_COLVAR METATOMIC
30 : /*
31 : Use arbitrary machine learning models as collective variables.
32 :
33 : \note This action requires the metatomic-torch library. Check the
34 : instructions in the \ref METATOMICMOD page to enable this module.
35 :
36 : This action enables the use of fully custom machine learning models — based on
37 : the [metatomic] models interface — as collective variables in PLUMED. Such
38 : machine learning model are typically written and customized using Python code,
39 : and then exported to run within PLUMED as [TorchScript], which is a subset of
40 : Python that can be executed by the C++ torch library.
41 :
42 : Metatomic offers a way to define such models and pass data from PLUMED (or any
43 : other simulation engine) to the model and back. For more information on how to
44 : define such model, have a look at the [corresponding tutorials][mta_tutorials],
45 : or at the code in `regtest/metatomic/`. Each of the Python scripts in this
46 : directory defines a custom machine learning CV that can be used with PLUMED.
47 :
48 : \par Examples
49 :
50 : The following input shows how you can call metatomic and evaluate the model that
51 : is described in the file `custom_cv.pt` from PLUMED.
52 :
53 : \plumedfile metatomic_cv: METATOMIC ... MODEL=custom_cv.pt
54 :
55 : SPECIES1=1-26
56 : SPECIES2=27-62
57 : SPECIES3=63-76
58 : SPECIES_TO_TYPES=6,1,8
59 : ...
60 : \endplumedfile
61 :
62 : The numbered `SPECIES` labels are used to indicate the list of atoms that belong
63 : to each atomic species in the system. The `SPECIES_TO_TYPE` keyword then
64 : provides information on the atom type for each species. The first number here is
65 : the atomic type of the atoms that have been specified using the `SPECIES1` flag,
66 : the second number is the atomic number of the atoms that have been specified
67 : using the `SPECIES2` flag and so on.
68 :
69 : `METATOMIC` action also accepts the following options:
70 :
71 : - `EXTENSIONS_DIRECTORY` should be the path to a directory containing
72 : TorchScript extensions (as shared libraries) that are required to load and
73 : execute the model. This matches the `collect_extensions` argument to
74 : `AtomisticModel.export` in Python.
75 : - `CHECK_CONSISTENCY` can be used to enable internal consistency checks;
76 : - `SELECTED_ATOMS` can be used to signal the metatomic models that it should
77 : only run its calculation for the selected subset of atoms. The model still
78 : need to know about all the atoms in the system (through the `SPECIES`
79 : keyword); but this can be used to reduce the calculation cost. Note that the
80 : indices of the selected atoms should start at 1 in the PLUMED input file, but
81 : they will be translated to start at 0 when given to the model (i.e. in
82 : Python/TorchScript, the `forward` method will receive a `selected_atoms` which
83 : starts at 0)
84 :
85 : Here is another example with all the possible keywords:
86 :
87 : \plumedfile soap: METATOMIC ... MODEL=soap.pt EXTENSION_DIRECTORY=extensions
88 : CHECK_CONSISTENCY
89 :
90 : SPECIES1=1-10
91 : SPECIES2=11-20
92 : SPECIES_TO_TYPES=8,13
93 :
94 : # only run the calculation for the Aluminium (type 13) atoms, but
95 : # include the Oxygen (type 8) as potential neighbors.
96 : SELECTED_ATOMS=11-20
97 : ...
98 : \endplumedfile
99 :
100 : \par Collective variables and metatomic models
101 :
102 : PLUMED can use the [`"features"` output][features_output] of metatomic models as
103 : a collective variables.
104 :
105 : */ /*
106 :
107 : [TorchScript]: https://pytorch.org/docs/stable/jit.html
108 : [metatomic]: https://docs.metatensor.org/metatomic/
109 : [mta_tutorials]: https://docs.metatensor.org/metatomic/latest/examples/
110 : [features_output]: https://docs.metatensor.org/metatomic/latest/outputs/features.html
111 : */
112 : //+ENDPLUMEDOC
113 :
114 : /*INDENT-OFF*/
115 : #if !defined(__PLUMED_HAS_LIBMETATOMIC) || !defined(__PLUMED_HAS_LIBTORCH)
116 :
117 : namespace PLMD { namespace metatomic {
118 : class MetatomicPlumedAction: public ActionAtomistic, public ActionWithValue {
119 : public:
120 : static void registerKeywords(Keywords& keys);
121 : explicit MetatomicPlumedAction(const ActionOptions& options):
122 : Action(options),
123 : ActionAtomistic(options),
124 : ActionWithValue(options)
125 : {
126 : throw std::runtime_error(
127 : "Can not use metatomic action without the corresponding libraries. \n"
128 : "Make sure to configure with `--enable-libmetatomic --enable-libtorch` "
129 : "and that the corresponding libraries are found"
130 : );
131 : }
132 :
133 : void calculate() override {}
134 : void apply() override {}
135 : unsigned getNumberOfDerivatives() override {return 0;}
136 : };
137 :
138 : }} // namespace PLMD::metatomic
139 :
140 : #else
141 :
142 : #include <type_traits>
143 :
144 : #pragma GCC diagnostic push
145 : #pragma GCC diagnostic ignored "-Wpedantic"
146 : #pragma GCC diagnostic ignored "-Wunused-parameter"
147 : #pragma GCC diagnostic ignored "-Wfloat-equal"
148 : #pragma GCC diagnostic ignored "-Wfloat-conversion"
149 : #pragma GCC diagnostic ignored "-Wimplicit-float-conversion"
150 : #pragma GCC diagnostic ignored "-Wimplicit-int-conversion"
151 : #pragma GCC diagnostic ignored "-Wshorten-64-to-32"
152 : #pragma GCC diagnostic ignored "-Wsign-conversion"
153 : #pragma GCC diagnostic ignored "-Wold-style-cast"
154 :
155 : #include <torch/script.h>
156 : #include <torch/version.h>
157 : #include <torch/cuda.h>
158 : #if TORCH_VERSION_MAJOR >= 2
159 : #include <torch/mps.h>
160 : #endif
161 :
162 : #pragma GCC diagnostic pop
163 :
164 : #include <metatensor/torch.hpp>
165 : #include <metatomic/torch.hpp>
166 :
167 : #include "vesin.h"
168 :
169 :
170 : namespace PLMD {
171 : namespace metatomic {
172 :
173 : // We will cast Vector/Tensor to pointers to arrays and doubles, so let's make
174 : // sure this is legal to do
175 : static_assert(std::is_standard_layout<PLMD::Vector>::value);
176 : static_assert(sizeof(PLMD::Vector) == sizeof(std::array<double, 3>));
177 : static_assert(alignof(PLMD::Vector) == alignof(std::array<double, 3>));
178 :
179 : static_assert(std::is_standard_layout<PLMD::Tensor>::value);
180 : static_assert(sizeof(PLMD::Tensor) == sizeof(std::array<std::array<double, 3>, 3>));
181 : static_assert(alignof(PLMD::Tensor) == alignof(std::array<std::array<double, 3>, 3>));
182 :
183 : /// Small helper class to compute one neighbor list requested by the metatomc
184 : /// model using vesin
185 : class NeighborListCalculator {
186 : public:
187 : NeighborListCalculator(
188 : metatomic_torch::NeighborListOptions options,
189 : const std::string& engine_length_unit
190 : );
191 : ~NeighborListCalculator();
192 :
193 : NeighborListCalculator(const NeighborListCalculator& other) = delete;
194 : NeighborListCalculator& operator=(const NeighborListCalculator& other) = delete;
195 :
196 : NeighborListCalculator(NeighborListCalculator&& other) noexcept;
197 : NeighborListCalculator& operator=(NeighborListCalculator&& other) noexcept;
198 :
199 : // compute the neighbor list following metatomic format, using data from PLUMED
200 : metatensor_torch::TensorBlock compute(
201 : const std::vector<PLMD::Vector>& positions,
202 : const PLMD::Tensor& cell,
203 : std::array<bool, 3> periodic,
204 : torch::ScalarType dtype,
205 : torch::Device device
206 : );
207 :
208 : metatomic_torch::NeighborListOptions options;
209 : private:
210 : double engine_cutoff_;
211 : vesin::VesinNeighborList neighbors_;
212 : };
213 :
214 8 : NeighborListCalculator::NeighborListCalculator(
215 : metatomic_torch::NeighborListOptions options_,
216 : const std::string& engine_length_unit
217 8 : ):
218 8 : options(options_),
219 8 : engine_cutoff_(options_->engine_cutoff(engine_length_unit))
220 : {
221 : memset(&this->neighbors_, 0, sizeof(vesin::VesinNeighborList));
222 8 : }
223 :
224 0 : NeighborListCalculator::NeighborListCalculator(NeighborListCalculator&& other) noexcept {
225 0 : this->options = other.options;
226 0 : this->engine_cutoff_ = other.engine_cutoff_;
227 0 : this->neighbors_ = other.neighbors_;
228 :
229 0 : memset(&other.neighbors_, 0, sizeof(vesin::VesinNeighborList));
230 0 : }
231 :
232 0 : NeighborListCalculator& NeighborListCalculator::operator=(NeighborListCalculator&& other) noexcept {
233 0 : if (this != &other) {
234 0 : vesin::vesin_free(&this->neighbors_);
235 :
236 0 : this->options = other.options;
237 0 : this->engine_cutoff_ = other.engine_cutoff_;
238 0 : this->neighbors_ = other.neighbors_;
239 :
240 0 : memset(&other.neighbors_, 0, sizeof(vesin::VesinNeighborList));
241 : }
242 0 : return *this;
243 : }
244 :
245 8 : NeighborListCalculator::~NeighborListCalculator() {
246 8 : vesin::vesin_free(&this->neighbors_);
247 8 : }
248 :
249 17 : metatensor_torch::TensorBlock NeighborListCalculator::compute(
250 : const std::vector<PLMD::Vector>& positions,
251 : const PLMD::Tensor& cell,
252 : std::array<bool, 3> periodic,
253 : torch::ScalarType dtype,
254 : torch::Device device
255 : ) {
256 17 : auto labels_options = torch::TensorOptions().dtype(torch::kInt32).device(device);
257 : auto neighbor_component = torch::make_intrusive<metatensor_torch::LabelsHolder>(
258 : "xyz",
259 102 : torch::tensor({0, 1, 2}, labels_options).reshape({3, 1})
260 : );
261 : auto neighbor_properties = torch::make_intrusive<metatensor_torch::LabelsHolder>(
262 17 : "distance", torch::zeros({1, 1}, labels_options)
263 : );
264 :
265 : // use https://github.com/Luthaf/vesin to compute the requested neighbor
266 : // lists since we can not get these from PLUMED
267 : vesin::VesinOptions vesin_options;
268 17 : vesin_options.cutoff = this->engine_cutoff_;
269 17 : vesin_options.full = this->options->full_list();
270 17 : vesin_options.return_shifts = true;
271 17 : vesin_options.return_distances = false;
272 17 : vesin_options.return_vectors = true;
273 17 : vesin_options.algorithm = vesin::VesinAutoAlgorithm;
274 :
275 17 : const char* error_message = nullptr;
276 17 : int status = vesin_neighbors(
277 : reinterpret_cast<const double (*)[3]>(positions.data()),
278 : positions.size(),
279 17 : reinterpret_cast<const double (*)[3]>(&cell(0, 0)),
280 : periodic.data(),
281 : {vesin::VesinCPU, 0},
282 : vesin_options,
283 : &this->neighbors_,
284 : &error_message
285 : );
286 :
287 17 : if (status != EXIT_SUCCESS) {
288 0 : plumed_merror(
289 : "failed to compute neighbor list (cutoff=" + std::to_string(this->engine_cutoff_) +
290 : ", full=" + (this->options->full_list() ? "true" : "false") + "): " + error_message
291 : );
292 : }
293 :
294 : // transform from vesin to metatomic format
295 17 : auto n_pairs = static_cast<int64_t>(this->neighbors_.length);
296 :
297 : auto pair_vectors = torch::from_blob(
298 17 : this->neighbors_.vectors,
299 : {n_pairs, 3, 1},
300 34 : torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU)
301 17 : );
302 :
303 17 : auto pair_samples_values = torch::empty({n_pairs, 5}, labels_options.device(torch::kCPU));
304 17 : auto pair_samples_values_ptr = pair_samples_values.accessor<int32_t, 2>();
305 15407 : for (unsigned i=0; i<n_pairs; i++) {
306 15390 : pair_samples_values_ptr[i][0] = static_cast<int32_t>(this->neighbors_.pairs[i][0]);
307 15390 : pair_samples_values_ptr[i][1] = static_cast<int32_t>(this->neighbors_.pairs[i][1]);
308 15390 : pair_samples_values_ptr[i][2] = this->neighbors_.shifts[i][0];
309 15390 : pair_samples_values_ptr[i][3] = this->neighbors_.shifts[i][1];
310 15390 : pair_samples_values_ptr[i][4] = this->neighbors_.shifts[i][2];
311 : }
312 :
313 : auto neighbor_samples = torch::make_intrusive<metatensor_torch::LabelsHolder>(
314 136 : std::vector<std::string>{"first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"},
315 17 : pair_samples_values.to(device),
316 : // vesin should create unique pairs
317 17 : metatensor::assume_unique{}
318 : );
319 :
320 : auto neighbors = torch::make_intrusive<metatensor_torch::TensorBlockHolder>(
321 51 : pair_vectors.to(dtype).to(device),
322 : neighbor_samples,
323 68 : std::vector<metatensor_torch::Labels>{neighbor_component},
324 : neighbor_properties
325 : );
326 :
327 17 : return neighbors;
328 : }
329 :
330 : class MetatomicPlumedAction: public ActionAtomistic, public ActionWithValue {
331 : public:
332 : static void registerKeywords(Keywords& keys);
333 : explicit MetatomicPlumedAction(const ActionOptions&);
334 :
335 : void calculate() override;
336 : void apply() override;
337 : unsigned getNumberOfDerivatives() override;
338 :
339 : private:
340 : // fill this->system_ according to the current PLUMED data
341 : void createSystem();
342 :
343 :
344 : // execute the model for the given system
345 : metatensor_torch::TensorBlock executeModel(metatomic_torch::System system);
346 :
347 : metatensor_torch::Module model_;
348 :
349 : metatomic_torch::ModelCapabilities capabilities_;
350 : // name of the output we request
351 : std::string features_key;
352 :
353 : // neighbor lists requests made by the model and the corresponding data
354 : std::vector<NeighborListCalculator> neighbor_lists_;
355 :
356 : // dtype/device to use to execute the model
357 : torch::ScalarType dtype_;
358 : torch::Device device_;
359 :
360 : torch::Tensor atomic_types_;
361 : // store the strain to be able to compute the virial with autograd
362 : torch::Tensor strain_;
363 :
364 : metatomic_torch::System system_;
365 : metatomic_torch::ModelEvaluationOptions evaluations_options_;
366 : bool check_consistency_;
367 :
368 : metatensor_torch::TensorMap output_;
369 : // shape of the output of this model
370 : unsigned n_samples_;
371 : unsigned n_properties_;
372 : };
373 :
374 8 : MetatomicPlumedAction::MetatomicPlumedAction(const ActionOptions& options):
375 : Action(options),
376 : ActionAtomistic(options),
377 : ActionWithValue(options),
378 8 : model_(torch::jit::Module()),
379 16 : device_(torch::kCPU)
380 : {
381 16 : if (metatomic_torch::version().find("0.1.") != 0) {
382 0 : this->error(
383 0 : "this code requires version 0.1.x of metatomic-torch, got version " +
384 0 : metatomic_torch::version()
385 : );
386 : }
387 :
388 : // first, load the model
389 : std::string extensions_directory_str;
390 16 : this->parse("EXTENSIONS_DIRECTORY", extensions_directory_str);
391 :
392 : torch::optional<std::string> extensions_directory = torch::nullopt;
393 8 : if (!extensions_directory_str.empty()) {
394 3 : extensions_directory = std::move(extensions_directory_str);
395 : }
396 :
397 : std::string model_path;
398 16 : this->parse("MODEL", model_path);
399 :
400 : try {
401 16 : this->model_ = metatomic_torch::load_atomistic_model(model_path, extensions_directory);
402 0 : } catch (const std::exception& e) {
403 0 : this->error("failed to load model at '" + model_path + "': " + e.what());
404 0 : }
405 :
406 : // extract information from the model
407 16 : auto metadata = this->model_.run_method("metadata").toCustomClass<metatomic_torch::ModelMetadataHolder>();
408 16 : this->capabilities_ = this->model_.run_method("capabilities").toCustomClass<metatomic_torch::ModelCapabilitiesHolder>();
409 8 : auto nl_requests_ivalue = this->model_.run_method("requested_neighbor_lists");
410 16 : for (auto nl_request_ivalue: nl_requests_ivalue.toList()) {
411 8 : auto nl_request = nl_request_ivalue.get().toCustomClass<metatomic_torch::NeighborListOptionsHolder>();
412 8 : this->neighbor_lists_.emplace_back(nl_request, this->getUnits().getLengthString());
413 : }
414 :
415 16 : auto extra_inputs = this->model_.run_method("requested_inputs").toGenericDict();
416 : auto standard_inputs = std::vector<std::string>{};
417 : auto custom_inputs = std::vector<std::string>{};
418 8 : for (const auto& item: extra_inputs) {
419 0 : auto key = item.key().toStringRef();
420 0 : if (key.find("::") != std::string::npos) {
421 0 : custom_inputs.push_back(key);
422 : } else {
423 0 : standard_inputs.push_back(key);
424 : }
425 : }
426 :
427 8 : if (!standard_inputs.empty()) {
428 0 : this->error(
429 : "The model requested extra inputs that are not yet supported in PLUMED. "
430 0 : "Please open an issue to request support for the following inputs: " +
431 0 : torch::str(standard_inputs)
432 : );
433 : }
434 :
435 8 : if (!custom_inputs.empty()) {
436 0 : this->error(
437 0 : "The model requested custom inputs (" + torch::str(custom_inputs) + ") "
438 : "that can not be provided by PLUMED. Please change your model to use "
439 : "standard inputs only."
440 : );
441 : }
442 :
443 16 : log.printf("\n%s\n", metadata->print().c_str());
444 : // add the model references to PLUMED citation handling mechanism
445 8 : for (const auto& it: metadata->references) {
446 2 : for (const auto& ref: it.value()) {
447 2 : this->cite(ref);
448 1 : }
449 : }
450 :
451 : // parse the atomic types from the input file
452 : std::vector<int32_t> atomic_types;
453 : std::vector<int32_t> species_to_types;
454 16 : this->parseVector("SPECIES_TO_TYPES", species_to_types);
455 : bool has_custom_types = !species_to_types.empty();
456 :
457 : std::vector<AtomNumber> all_atoms;
458 16 : this->parseAtomList("SPECIES", all_atoms);
459 :
460 : size_t n_species = 0;
461 8 : if (all_atoms.empty()) {
462 : // first parse each of the 'SPECIES' entry
463 : std::vector<std::vector<AtomNumber>> atoms_per_species;
464 : int i = 0;
465 : while (true) {
466 27 : i += 1;
467 : auto atoms = std::vector<AtomNumber>();
468 54 : this->parseAtomList("SPECIES", i, atoms);
469 :
470 27 : if (atoms.empty()) {
471 : break;
472 : }
473 :
474 : int32_t type = i;
475 19 : if (has_custom_types) {
476 19 : if (species_to_types.size() < static_cast<size_t>(i)) {
477 0 : this->error(
478 : "SPECIES_TO_TYPES is too small, it should have one entry "
479 0 : "for each species (we have at least " + std::to_string(i) +
480 0 : " species and " + std::to_string(species_to_types.size()) +
481 : "entries in SPECIES_TO_TYPES)"
482 : );
483 : }
484 :
485 19 : type = species_to_types[static_cast<size_t>(i - 1)];
486 : }
487 :
488 19 : log.printf(" atoms with type %d are: ", type);
489 1296 : for(unsigned j=0; j<atoms.size(); j++) {
490 1277 : log.printf("%d ", atoms[j]);
491 : }
492 19 : log.printf("\n");
493 :
494 19 : n_species += 1;
495 19 : atoms_per_species.emplace_back(std::move(atoms));
496 : }
497 :
498 : size_t n_atoms = 0;
499 27 : for (const auto& atoms: atoms_per_species) {
500 19 : n_atoms += atoms.size();
501 : }
502 :
503 : // then fill the atomic_types as required
504 8 : atomic_types.resize(n_atoms, 0);
505 : i = 0;
506 27 : for (const auto& atoms: atoms_per_species) {
507 19 : i += 1;
508 :
509 : int32_t type = i;
510 19 : if (has_custom_types) {
511 19 : type = species_to_types[static_cast<size_t>(i - 1)];
512 : }
513 :
514 1296 : for (const auto& atom: atoms) {
515 1277 : atomic_types[atom.index()] = type;
516 : }
517 : }
518 8 : } else {
519 : n_species = 1;
520 :
521 0 : int32_t type = 1;
522 0 : if (has_custom_types) {
523 0 : type = species_to_types[0];
524 : }
525 0 : atomic_types.resize(all_atoms.size(), type);
526 : }
527 :
528 8 : if (has_custom_types && species_to_types.size() != n_species) {
529 0 : this->warning(
530 0 : "SPECIES_TO_TYPES contains more entries (" +
531 0 : std::to_string(species_to_types.size()) +
532 0 : ") than there where species (" + std::to_string(n_species) + ")"
533 : );
534 : }
535 :
536 : // request atoms in order
537 : all_atoms.clear();
538 1285 : for (size_t i=0; i<atomic_types.size(); i++) {
539 1277 : all_atoms.push_back(AtomNumber::index(i));
540 : }
541 8 : this->requestAtoms(all_atoms);
542 :
543 16 : this->atomic_types_ = torch::tensor(atomic_types);
544 :
545 8 : this->check_consistency_ = false;
546 8 : this->parseFlag("CHECK_CONSISTENCY", this->check_consistency_);
547 8 : if (this->check_consistency_) {
548 0 : log.printf(" checking for internal consistency of the model\n");
549 : }
550 :
551 : // create evaluation options for the model. These won't change during the
552 : // simulation, so we initialize them once here.
553 8 : evaluations_options_ = torch::make_intrusive<metatomic_torch::ModelEvaluationOptionsHolder>();
554 16 : evaluations_options_->set_length_unit(getUnits().getLengthString());
555 :
556 : auto outputs = this->capabilities_->outputs();
557 :
558 : std::string requested_variant;
559 : torch::optional<std::string> requested_variant_opt = torch::nullopt;
560 16 : this->parse("VARIANT", requested_variant);
561 8 : if (!requested_variant.empty()) {
562 1 : requested_variant_opt = requested_variant;
563 : }
564 :
565 25 : this->features_key = metatomic_torch::pick_output("features", outputs, requested_variant_opt);
566 :
567 : auto output = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
568 : // this output has no quantity or unit to set
569 :
570 24 : output->set_sample_kind(this->capabilities_->outputs().at(this->features_key)->sample_kind());
571 : // we are using torch autograd system to compute gradients,
572 : // so we don't need any explicit gradients.
573 8 : output->explicit_gradients = {};
574 8 : evaluations_options_->outputs.insert(this->features_key, output);
575 :
576 : std::string requested_device;
577 : torch::optional<std::string> requested_device_opt = torch::nullopt;
578 16 : this->parse("DEVICE", requested_device);
579 8 : if (!requested_device.empty()) {
580 5 : requested_device_opt = requested_device;
581 : }
582 :
583 8 : this->device_ = torch::Device(
584 8 : metatomic_torch::pick_device(this->capabilities_->supported_devices, requested_device_opt),
585 : /*index=*/ 0
586 : );
587 :
588 8 : this->model_.to(this->device_);
589 8 : this->atomic_types_ = this->atomic_types_.to(this->device_);
590 :
591 8 : log.printf(
592 : " running model on %s device with %s data\n",
593 16 : this->device_.str().c_str(),
594 : this->capabilities_->dtype().c_str()
595 : );
596 :
597 8 : if (this->capabilities_->dtype() == "float64") {
598 3 : this->dtype_ = torch::kFloat64;
599 5 : } else if (this->capabilities_->dtype() == "float32") {
600 5 : this->dtype_ = torch::kFloat32;
601 : } else {
602 0 : this->error(
603 0 : "the model requested an unsupported dtype '" + this->capabilities_->dtype() + "'"
604 : );
605 : }
606 :
607 8 : auto tensor_options = torch::TensorOptions().dtype(this->dtype_).device(this->device_);
608 16 : this->strain_ = torch::eye(3, tensor_options.requires_grad(true));
609 :
610 : // determine how many properties there will be in the output by running the
611 : // model once on a dummy system
612 : auto dummy_system = torch::make_intrusive<metatomic_torch::SystemHolder>(
613 16 : /*types = */ torch::zeros({0}, tensor_options.dtype(torch::kInt32)),
614 16 : /*positions = */ torch::zeros({0, 3}, tensor_options),
615 16 : /*cell = */ torch::zeros({3, 3}, tensor_options),
616 8 : /*pbc = */ torch::zeros({3}, tensor_options.dtype(torch::kBool))
617 : );
618 :
619 8 : log.printf(" the following neighbor lists have been requested:\n");
620 8 : auto length_unit = this->getUnits().getLengthString();
621 8 : auto model_length_unit = this->capabilities_->length_unit();
622 16 : for (auto& nl: this->neighbor_lists_) {
623 11 : log.printf(" - %s list, %g %s cutoff (requested %g %s)\n",
624 : nl.options->full_list() ? "full" : "half",
625 : nl.options->engine_cutoff(length_unit),
626 : length_unit.c_str(),
627 : nl.options->cutoff(),
628 : model_length_unit.c_str()
629 : );
630 :
631 : auto neighbors = nl.compute(
632 : {PLMD::Vector(0, 0, 0)},
633 16 : PLMD::Tensor(0, 0, 0, 0, 0, 0, 0, 0, 0),
634 : {false, false, false},
635 : this->dtype_,
636 : this->device_
637 8 : );
638 32 : metatomic_torch::register_autograd_neighbors(dummy_system, neighbors, this->check_consistency_);
639 16 : dummy_system->add_neighbor_list(nl.options, neighbors);
640 : }
641 :
642 8 : this->n_properties_ = static_cast<unsigned>(
643 32 : this->executeModel(dummy_system)->properties()->count()
644 : );
645 :
646 : // parse and handle atom sub-selection. This is done AFTER determining the
647 : // output size, since the selection might not be valid for the dummy system
648 : std::vector<AtomNumber> selected_atoms;
649 16 : this->parseAtomList("SELECTED_ATOMS", selected_atoms);
650 8 : if (!selected_atoms.empty()) {
651 : auto selection_value = torch::zeros(
652 : {static_cast<int64_t>(selected_atoms.size()), 2},
653 2 : torch::TensorOptions().dtype(torch::kInt32).device(this->device_)
654 2 : );
655 :
656 9 : for (unsigned i=0; i<selected_atoms.size(); i++) {
657 : auto n_atoms = this->atomic_types_.size(0);
658 7 : if (selected_atoms[i].index() > n_atoms) {
659 0 : this->error(
660 : "Values in metatomic's SELECTED_ATOMS should be between 1 "
661 0 : "and the number of atoms (" + std::to_string(n_atoms) + "), "
662 0 : "got " + std::to_string(selected_atoms[i].serial()));
663 : }
664 21 : selection_value[i][1] = static_cast<int32_t>(selected_atoms[i].index());
665 : }
666 :
667 4 : evaluations_options_->set_selected_atoms(
668 2 : torch::make_intrusive<metatensor_torch::LabelsHolder>(
669 8 : std::vector<std::string>{"system", "atom"}, selection_value
670 : )
671 : );
672 : }
673 :
674 : // Now that we now both n_samples and n_properties, we can setup the
675 : // PLUMED-side storage for the computed CV
676 16 : if (output->sample_kind() == "atom") {
677 4 : if (selected_atoms.empty()) {
678 2 : this->n_samples_ = static_cast<unsigned>(this->atomic_types_.size(0));
679 : } else {
680 2 : this->n_samples_ = static_cast<unsigned>(selected_atoms.size());
681 : }
682 : } else {
683 : assert(output->sample_kind() == "system");
684 4 : this->n_samples_ = 1;
685 : }
686 :
687 8 : if (n_samples_ == 1 && n_properties_ == 1) {
688 3 : log.printf(" the output of this model is a scalar\n");
689 :
690 6 : this->addValue();
691 5 : } else if (n_samples_ == 1) {
692 1 : log.printf(" the output of this model is 1x%d vector\n", n_properties_);
693 :
694 1 : this->addValue({this->n_properties_});
695 1 : this->getPntrToComponent(0)->buildDataStore();
696 4 : } else if (n_properties_ == 1) {
697 1 : log.printf(" the output of this model is %dx1 vector\n", n_samples_);
698 :
699 1 : this->addValue({this->n_samples_});
700 1 : this->getPntrToComponent(0)->buildDataStore();
701 : } else {
702 3 : log.printf(" the output of this model is a %dx%d matrix\n", n_samples_, n_properties_);
703 :
704 3 : this->addValue({this->n_samples_, this->n_properties_});
705 3 : this->getPntrToComponent(0)->buildDataStore();
706 3 : this->getPntrToComponent(0)->reshapeMatrixStore(n_properties_);
707 : }
708 :
709 8 : this->setNotPeriodic();
710 16 : }
711 :
712 14 : unsigned MetatomicPlumedAction::getNumberOfDerivatives() {
713 : // gradients w.r.t. positions (3 x N values) + gradients w.r.t. strain (9 values)
714 14 : return 3 * this->getNumberOfAtoms() + 9;
715 : }
716 :
717 :
718 9 : void MetatomicPlumedAction::createSystem() {
719 18 : if (this->getTotAtoms() != static_cast<unsigned>(this->atomic_types_.size(0))) {
720 0 : std::ostringstream oss;
721 0 : oss << "METATOMIC action needs to know about all atoms in the system. ";
722 0 : oss << "There are " << this->getTotAtoms() << " atoms overall, ";
723 0 : oss << "but we only have atomic types for " << this->atomic_types_.size(0) << " of them.";
724 0 : plumed_merror(oss.str());
725 0 : }
726 :
727 : // this->getTotAtoms()
728 :
729 9 : const auto& cell = this->getPbc().getBox();
730 :
731 9 : auto cpu_f64_tensor = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU);
732 9 : auto torch_cell = torch::zeros({3, 3}, cpu_f64_tensor);
733 :
734 27 : torch_cell[0][0] = cell(0, 0);
735 27 : torch_cell[0][1] = cell(0, 1);
736 27 : torch_cell[0][2] = cell(0, 2);
737 :
738 27 : torch_cell[1][0] = cell(1, 0);
739 27 : torch_cell[1][1] = cell(1, 1);
740 27 : torch_cell[1][2] = cell(1, 2);
741 :
742 27 : torch_cell[2][0] = cell(2, 0);
743 27 : torch_cell[2][1] = cell(2, 1);
744 27 : torch_cell[2][2] = cell(2, 2);
745 :
746 : using torch::indexing::Slice;
747 :
748 63 : auto norm_a = torch_cell.index({0, Slice()}).norm().abs().item<double>();
749 63 : auto norm_b = torch_cell.index({1, Slice()}).norm().abs().item<double>();
750 63 : auto norm_c = torch_cell.index({2, Slice()}).norm().abs().item<double>();
751 :
752 9 : auto periodic = std::array<bool, 3>{true, true, true};
753 :
754 : // make sure the cell and pbc argument agree with each other
755 9 : if (norm_a < 1e-9) {
756 0 : if (norm_a > 1e-30) {
757 0 : this->warning(
758 0 : "the cell vector A has a very small norm (" + std::to_string(norm_a) + "), "
759 : "this direction will be treated as non periodic. If this is intentional, ensure "
760 : "the vector is exactly zero to silence this warning"
761 : );
762 : }
763 0 : torch_cell.index({0, Slice()}).fill_(0);
764 0 : periodic[0] = false;
765 : }
766 :
767 9 : if (norm_b < 1e-9) {
768 0 : if (norm_b > 1e-30) {
769 0 : this->warning(
770 0 : "the cell vector B has a very small norm (" + std::to_string(norm_b) + "), "
771 : "this direction will be treated as non periodic. If this is intentional, ensure "
772 : "the vector is exactly zero to silence this warning"
773 : );
774 : }
775 0 : torch_cell.index({1, Slice()}).fill_(0);
776 0 : periodic[1] = false;
777 : }
778 :
779 9 : if (norm_c < 1e-9) {
780 0 : if (norm_c > 1e-30) {
781 0 : this->warning(
782 0 : "the cell vector C has a very small norm (" + std::to_string(norm_c) + "), "
783 : "this direction will be treated as non periodic. If this is intentional, ensure "
784 : "the vector is exactly zero to silence this warning"
785 : );
786 : }
787 0 : torch_cell.index({2, Slice()}).fill_(0);
788 0 : periodic[2] = false;
789 : }
790 :
791 9 : auto torch_pbc = torch::zeros({3}, torch::TensorOptions().dtype(torch::kBool).device(torch::kCPU));
792 36 : for (unsigned i=0; i<3; i++) {
793 54 : torch_pbc[i] = periodic[i];
794 : }
795 :
796 : const auto& positions = this->getPositions();
797 :
798 : auto torch_positions = torch::from_blob(
799 : const_cast<PLMD::Vector*>(positions.data()),
800 : {static_cast<int64_t>(positions.size()), 3},
801 : cpu_f64_tensor
802 9 : );
803 :
804 18 : torch_positions = torch_positions.to(this->dtype_).to(this->device_);
805 27 : torch_cell = torch_cell.to(this->dtype_).to(this->device_);
806 9 : torch_pbc = torch_pbc.to(this->device_);
807 :
808 : // setup torch's automatic gradient tracking
809 9 : if (!this->doNotCalculateDerivatives()) {
810 : torch_positions.requires_grad_(true);
811 :
812 : // pretend to scale positions/cell by the strain so that it enters the
813 : // computational graph.
814 8 : torch_positions = torch_positions.matmul(this->strain_);
815 8 : torch_positions.retain_grad();
816 :
817 8 : torch_cell = torch_cell.matmul(this->strain_);
818 : }
819 :
820 9 : this->system_ = torch::make_intrusive<metatomic_torch::SystemHolder>(
821 9 : this->atomic_types_,
822 : torch_positions,
823 : torch_cell,
824 : torch_pbc
825 : );
826 :
827 : // compute the neighbors list requested by the model, and register them with
828 : // the system
829 18 : for (auto& nl: this->neighbor_lists_) {
830 9 : auto neighbors = nl.compute(positions, cell, periodic, this->dtype_, this->device_);
831 36 : metatomic_torch::register_autograd_neighbors(this->system_, neighbors, this->check_consistency_);
832 18 : this->system_->add_neighbor_list(nl.options, neighbors);
833 : }
834 9 : }
835 :
836 :
837 17 : metatensor_torch::TensorBlock MetatomicPlumedAction::executeModel(metatomic_torch::System system) {
838 : try {
839 51 : auto ivalue_output = this->model_.forward({
840 68 : std::vector<metatomic_torch::System>{system},
841 : evaluations_options_,
842 17 : this->check_consistency_,
843 85 : });
844 :
845 17 : auto dict_output = ivalue_output.toGenericDict();
846 34 : auto cv = dict_output.at(this->features_key);
847 34 : this->output_ = cv.toCustomClass<metatensor_torch::TensorMapHolder>();
848 0 : } catch (const std::exception& e) {
849 0 : plumed_merror("failed to evaluate the model: " + std::string(e.what()));
850 0 : }
851 :
852 34 : plumed_massert(this->output_->keys()->count() == 1, "output should have a single block");
853 34 : auto block = metatensor_torch::TensorMapHolder::block_by_id(this->output_, 0);
854 17 : plumed_massert(block->components().empty(), "components are not yet supported in the output");
855 :
856 17 : return block;
857 : }
858 :
859 :
860 9 : void MetatomicPlumedAction::calculate() {
861 9 : this->createSystem();
862 :
863 18 : auto block = this->executeModel(this->system_);
864 45 : auto torch_values = block->values().to(torch::kCPU).to(torch::kFloat64);
865 :
866 9 : if (static_cast<unsigned>(torch_values.size(0)) != this->n_samples_) {
867 0 : plumed_merror(
868 : "expected the model to return a TensorBlock with " +
869 : std::to_string(this->n_samples_) + " samples, got " +
870 : std::to_string(torch_values.size(0)) + " instead"
871 : );
872 9 : } else if (static_cast<unsigned>(torch_values.size(1)) != this->n_properties_) {
873 0 : plumed_merror(
874 : "expected the model to return a TensorBlock with " +
875 : std::to_string(this->n_properties_) + " properties, got " +
876 : std::to_string(torch_values.size(1)) + " instead"
877 : );
878 : }
879 :
880 9 : Value* value = this->getPntrToComponent(0);
881 : // reshape the plumed `Value` to hold the data returned by the model
882 9 : if (n_samples_ == 1) {
883 5 : if (n_properties_ == 1) {
884 4 : value->set(torch_values.item<double>());
885 : } else {
886 : // we have multiple CV describing a single thing (atom or full system)
887 3 : for (unsigned i=0; i<n_properties_; i++) {
888 6 : value->set(i, torch_values[0][i].item<double>());
889 : }
890 : }
891 : } else {
892 : auto samples = block->samples();
893 16 : plumed_assert((samples->names() == std::vector<std::string>{"system", "atom"}));
894 :
895 8 : auto samples_values = samples->values().to(torch::kCPU);
896 : auto selected_atoms = this->evaluations_options_->get_selected_atoms();
897 :
898 : // handle the possibility that samples are returned in
899 : // a non-sorted order.
900 92 : auto get_output_location = [&](unsigned i) {
901 92 : if (selected_atoms.has_value()) {
902 : // If the users picked some selected atoms, then we store the
903 : // output in the same order as the selection was given
904 28 : auto sample = samples_values.index({static_cast<int64_t>(i), torch::indexing::Slice()});
905 7 : auto position = selected_atoms.value()->position(sample);
906 7 : plumed_assert(position.has_value());
907 7 : return static_cast<unsigned>(position.value());
908 : } else {
909 340 : return static_cast<unsigned>(samples_values[i][1].item<int32_t>());
910 : }
911 4 : };
912 :
913 4 : if (n_properties_ == 1) {
914 : // we have a single CV describing multiple things (i.e. atoms)
915 5 : for (unsigned i=0; i<n_samples_; i++) {
916 4 : auto output_i = get_output_location(i);
917 16 : value->set(output_i, torch_values[i][0].item<double>());
918 : }
919 : } else {
920 : // the CV is a matrix
921 91 : for (unsigned i=0; i<n_samples_; i++) {
922 88 : auto output_i = get_output_location(i);
923 343 : for (unsigned j=0; j<n_properties_; j++) {
924 765 : value->set(output_i * n_properties_ + j, torch_values[i][j].item<double>());
925 : }
926 : }
927 : }
928 : }
929 9 : }
930 :
931 :
932 9 : void MetatomicPlumedAction::apply() {
933 9 : const auto* value = this->getPntrToComponent(0);
934 9 : if (!value->forcesWereAdded()) {
935 1 : return;
936 : }
937 :
938 16 : auto block = metatensor_torch::TensorMapHolder::block_by_id(this->output_, 0);
939 40 : auto torch_values = block->values().to(torch::kCPU).to(torch::kFloat64);
940 :
941 8 : if (!torch_values.requires_grad()) {
942 0 : this->warning(
943 : "the output of the model does not requires gradients, this might "
944 : "indicate a problem"
945 : );
946 : return;
947 : }
948 :
949 8 : auto output_grad = torch::zeros_like(torch_values);
950 8 : if (n_samples_ == 1) {
951 5 : if (n_properties_ == 1) {
952 8 : output_grad[0][0] = value->getForce();
953 : } else {
954 3 : for (unsigned i=0; i<n_properties_; i++) {
955 6 : output_grad[0][i] = value->getForce(i);
956 : }
957 : }
958 : } else {
959 : auto samples = block->samples();
960 12 : plumed_assert((samples->names() == std::vector<std::string>{"system", "atom"}));
961 :
962 6 : auto samples_values = samples->values().to(torch::kCPU);
963 : auto selected_atoms = this->evaluations_options_->get_selected_atoms();
964 :
965 : // see above for an explanation of why we use this function
966 89 : auto get_output_location = [&](unsigned i) {
967 89 : if (selected_atoms.has_value()) {
968 16 : auto sample = samples_values.index({static_cast<int64_t>(i), torch::indexing::Slice()});
969 4 : auto position = selected_atoms.value()->position(sample);
970 4 : plumed_assert(position.has_value());
971 4 : return static_cast<unsigned>(position.value());
972 : } else {
973 340 : return static_cast<unsigned>(samples_values[i][1].item<int32_t>());
974 : }
975 3 : };
976 :
977 3 : if (n_properties_ == 1) {
978 5 : for (unsigned i=0; i<n_samples_; i++) {
979 4 : auto output_i = get_output_location(i);
980 12 : output_grad[i][0] = value->getForce(output_i);
981 : }
982 : } else {
983 87 : for (unsigned i=0; i<n_samples_; i++) {
984 85 : auto output_i = get_output_location(i);
985 331 : for (unsigned j=0; j<n_properties_; j++) {
986 738 : output_grad[i][j] = value->getForce(output_i * n_properties_ + j);
987 : }
988 : }
989 : }
990 : }
991 :
992 8 : this->system_->positions().mutable_grad() = torch::Tensor();
993 : this->strain_.mutable_grad() = torch::Tensor();
994 :
995 8 : torch_values.backward(output_grad);
996 8 : auto positions_grad = this->system_->positions().grad();
997 8 : auto strain_grad = this->strain_.grad();
998 :
999 32 : positions_grad = positions_grad.to(torch::kCPU).to(torch::kFloat64);
1000 32 : strain_grad = strain_grad.to(torch::kCPU).to(torch::kFloat64);
1001 :
1002 8 : plumed_assert(positions_grad.sizes().size() == 2);
1003 8 : plumed_assert(positions_grad.is_contiguous());
1004 :
1005 8 : plumed_assert(strain_grad.sizes().size() == 2);
1006 8 : plumed_assert(strain_grad.is_contiguous());
1007 :
1008 : auto derivatives = std::vector<double>(
1009 : positions_grad.data_ptr<double>(),
1010 8 : positions_grad.data_ptr<double>() + 3 * this->system_->size()
1011 8 : );
1012 :
1013 : // add virials to the derivatives
1014 24 : derivatives.push_back(-strain_grad[0][0].item<double>());
1015 24 : derivatives.push_back(-strain_grad[0][1].item<double>());
1016 24 : derivatives.push_back(-strain_grad[0][2].item<double>());
1017 :
1018 24 : derivatives.push_back(-strain_grad[1][0].item<double>());
1019 24 : derivatives.push_back(-strain_grad[1][1].item<double>());
1020 24 : derivatives.push_back(-strain_grad[1][2].item<double>());
1021 :
1022 24 : derivatives.push_back(-strain_grad[2][0].item<double>());
1023 24 : derivatives.push_back(-strain_grad[2][1].item<double>());
1024 16 : derivatives.push_back(-strain_grad[2][2].item<double>());
1025 :
1026 8 : unsigned index = 0;
1027 8 : this->setForcesOnAtoms(derivatives, index);
1028 : }
1029 :
1030 : } // namespace metatomic
1031 : } // namespace PLMD
1032 :
1033 :
1034 : #endif
1035 :
1036 :
1037 : namespace PLMD {
1038 : namespace metatomic {
1039 :
1040 : // use the same implementation for both the actual action and the dummy one
1041 : // (when libtorch and libmetatomic could not be found).
1042 12 : void MetatomicPlumedAction::registerKeywords(Keywords& keys) {
1043 12 : Action::registerKeywords(keys);
1044 12 : ActionAtomistic::registerKeywords(keys);
1045 12 : ActionWithValue::registerKeywords(keys);
1046 :
1047 24 : keys.add("compulsory", "MODEL", "path to the exported metatomic model");
1048 24 : keys.add("optional", "EXTENSIONS_DIRECTORY", "path to the directory containing TorchScript extensions to load");
1049 24 : keys.add("optional", "DEVICE", "Torch device to use for the calculations");
1050 :
1051 24 : keys.addFlag("CHECK_CONSISTENCY", false, "should we enable internal consistency checks when executing the model");
1052 :
1053 24 : keys.add("numbered", "SPECIES", "the indices of atoms in each PLUMED species");
1054 24 : keys.reset_style("SPECIES", "atoms");
1055 :
1056 24 : keys.add("optional", "SELECTED_ATOMS", "subset of atoms that should be used for the calculation");
1057 24 : keys.reset_style("SELECTED_ATOMS", "atoms");
1058 :
1059 24 : keys.add("optional", "SPECIES_TO_TYPES", "mapping from PLUMED SPECIES to metatomic's atom types");
1060 :
1061 24 : keys.add("optional", "VARIANT", "which variant of the 'features' output to pick");
1062 :
1063 24 : keys.addOutputComponent("outputs", "default", "collective variable created by the metatomic model");
1064 :
1065 12 : keys.setValueDescription("collective variable created by the metatomic model");
1066 12 : }
1067 :
1068 : PLUMED_REGISTER_ACTION(MetatomicPlumedAction, "METATOMIC")
1069 :
1070 : } // namespace metatomic
1071 : } // namespace PLMD
|