Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : Copyright (c) 2024 The METATOMIC-PLUMED team
3 : (see the PEOPLE-METATOMIC file at the root of this folder for a list of names)
4 :
5 : See https://docs.metatensor.org/metatomic/ for more information about the
6 : metatomic package that this module allows you to call from PLUMED.
7 :
8 : This file is part of METATOMIC-PLUMED module.
9 :
10 : The METATOMIC-PLUMED module is free software: you can redistribute it and/or modify
11 : it under the terms of the GNU Lesser General Public License as published by
12 : the Free Software Foundation, either version 3 of the License, or
13 : (at your option) any later version.
14 :
15 : The METATOMIC-PLUMED module is distributed in the hope that it will be useful,
16 : but WITHOUT ANY WARRANTY; without even the implied warranty of
17 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
18 : GNU Lesser General Public License for more details.
19 :
20 : You should have received a copy of the GNU Lesser General Public License
21 : along with the METATOMIC-PLUMED module. If not, see <http://www.gnu.org/licenses/>.
22 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
23 :
24 : #include "core/ActionAtomistic.h"
25 : #include "core/ActionWithValue.h"
26 : #include "core/ActionRegister.h"
27 : #include "core/PlumedMain.h"
28 :
29 : //+PLUMEDOC METATOMICMOD_COLVAR METATOMIC
30 : /*
31 : Use arbitrary machine learning models as collective variables.
32 :
33 : \note This action requires the metatomic-torch library. Check the
34 : instructions in the \ref METATOMICMOD page to enable this module.
35 :
36 : This action enables the use of fully custom machine learning models — based on
37 : the [metatomic] models interface — as collective variables in PLUMED. Such
38 : machine learning model are typically written and customized using Python code,
39 : and then exported to run within PLUMED as [TorchScript], which is a subset of
40 : Python that can be executed by the C++ torch library.
41 :
42 : Metatomic offers a way to define such models and pass data from PLUMED (or any
43 : other simulation engine) to the model and back. For more information on how to
44 : define such model, have a look at the [corresponding tutorials][mta_tutorials],
45 : or at the code in `regtest/metatomic/`. Each of the Python scripts in this
46 : directory defines a custom machine learning CV that can be used with PLUMED.
47 :
48 : \par Examples
49 :
50 : The following input shows how you can call metatomic and evaluate the model that
51 : is described in the file `custom_cv.pt` from PLUMED.
52 :
53 : \plumedfile metatomic_cv: METATOMIC ... MODEL=custom_cv.pt
54 :
55 : SPECIES1=1-26
56 : SPECIES2=27-62
57 : SPECIES3=63-76
58 : SPECIES_TO_TYPES=6,1,8
59 : ...
60 : \endplumedfile
61 :
62 : The numbered `SPECIES` labels are used to indicate the list of atoms that belong
63 : to each atomic species in the system. The `SPECIES_TO_TYPE` keyword then
64 : provides information on the atom type for each species. The first number here is
65 : the atomic type of the atoms that have been specified using the `SPECIES1` flag,
66 : the second number is the atomic number of the atoms that have been specified
67 : using the `SPECIES2` flag and so on.
68 :
69 : `METATOMIC` action also accepts the following options:
70 :
71 : - `EXTENSIONS_DIRECTORY` should be the path to a directory containing
72 : TorchScript extensions (as shared libraries) that are required to load and
73 : execute the model. This matches the `collect_extensions` argument to
74 : `AtomisticModel.export` in Python.
75 : - `CHECK_CONSISTENCY` can be used to enable internal consistency checks;
76 : - `SELECTED_ATOMS` can be used to signal the metatomic models that it should
77 : only run its calculation for the selected subset of atoms. The model still
78 : need to know about all the atoms in the system (through the `SPECIES`
79 : keyword); but this can be used to reduce the calculation cost. Note that the
80 : indices of the selected atoms should start at 1 in the PLUMED input file, but
81 : they will be translated to start at 0 when given to the model (i.e. in
82 : Python/TorchScript, the `forward` method will receive a `selected_atoms` which
83 : starts at 0)
84 :
85 : Here is another example with all the possible keywords:
86 :
87 : \plumedfile soap: METATOMIC ... MODEL=soap.pt EXTENSION_DIRECTORY=extensions
88 : CHECK_CONSISTENCY
89 :
90 : SPECIES1=1-10
91 : SPECIES2=11-20
92 : SPECIES_TO_TYPES=8,13
93 :
94 : # only run the calculation for the Aluminium (type 13) atoms, but
95 : # include the Oxygen (type 8) as potential neighbors.
96 : SELECTED_ATOMS=11-20
97 : ...
98 : \endplumedfile
99 :
100 : \par Collective variables and metatomic models
101 :
102 : PLUMED can use the [`"features"` output][features_output] of metatomic models as
103 : a collective variables.
104 :
105 : */ /*
106 :
107 : [TorchScript]: https://pytorch.org/docs/stable/jit.html
108 : [metatomic]: https://docs.metatensor.org/metatomic/
109 : [mta_tutorials]: https://docs.metatensor.org/metatomic/latest/examples/
110 : [features_output]: https://docs.metatensor.org/metatomic/latest/outputs/features.html
111 : */
112 : //+ENDPLUMEDOC
113 :
114 : /*INDENT-OFF*/
115 : #if !defined(__PLUMED_HAS_LIBMETATOMIC) || !defined(__PLUMED_HAS_LIBTORCH)
116 :
117 : namespace PLMD { namespace metatomic {
118 : class MetatomicPlumedAction: public ActionAtomistic, public ActionWithValue {
119 : public:
120 : static void registerKeywords(Keywords& keys);
121 : explicit MetatomicPlumedAction(const ActionOptions& options):
122 : Action(options),
123 : ActionAtomistic(options),
124 : ActionWithValue(options)
125 : {
126 : throw std::runtime_error(
127 : "Can not use metatomic action without the corresponding libraries. \n"
128 : "Make sure to configure with `--enable-libmetatomic --enable-libtorch` "
129 : "and that the corresponding libraries are found"
130 : );
131 : }
132 :
133 : void calculate() override {}
134 : void apply() override {}
135 : unsigned getNumberOfDerivatives() override {return 0;}
136 : };
137 :
138 : }} // namespace PLMD::metatomic
139 :
140 : #else
141 :
142 : #include <type_traits>
143 :
144 : #pragma GCC diagnostic push
145 : #pragma GCC diagnostic ignored "-Wpedantic"
146 : #pragma GCC diagnostic ignored "-Wunused-parameter"
147 : #pragma GCC diagnostic ignored "-Wfloat-equal"
148 : #pragma GCC diagnostic ignored "-Wfloat-conversion"
149 : #pragma GCC diagnostic ignored "-Wimplicit-float-conversion"
150 : #pragma GCC diagnostic ignored "-Wimplicit-int-conversion"
151 : #pragma GCC diagnostic ignored "-Wshorten-64-to-32"
152 : #pragma GCC diagnostic ignored "-Wsign-conversion"
153 : #pragma GCC diagnostic ignored "-Wold-style-cast"
154 :
155 : #include <torch/script.h>
156 : #include <torch/version.h>
157 : #include <torch/cuda.h>
158 : #if TORCH_VERSION_MAJOR >= 2
159 : #include <torch/mps.h>
160 : #endif
161 :
162 : #pragma GCC diagnostic pop
163 :
164 : #include <metatensor/torch.hpp>
165 : #include <metatomic/torch.hpp>
166 :
167 : #include "vesin.h"
168 :
169 :
170 : namespace PLMD {
171 : namespace metatomic {
172 :
173 : // We will cast Vector/Tensor to pointers to arrays and doubles, so let's make
174 : // sure this is legal to do
175 : static_assert(std::is_standard_layout<PLMD::Vector>::value);
176 : static_assert(sizeof(PLMD::Vector) == sizeof(std::array<double, 3>));
177 : static_assert(alignof(PLMD::Vector) == alignof(std::array<double, 3>));
178 :
179 : static_assert(std::is_standard_layout<PLMD::Tensor>::value);
180 : static_assert(sizeof(PLMD::Tensor) == sizeof(std::array<std::array<double, 3>, 3>));
181 : static_assert(alignof(PLMD::Tensor) == alignof(std::array<std::array<double, 3>, 3>));
182 :
183 : class MetatomicPlumedAction: public ActionAtomistic, public ActionWithValue {
184 : public:
185 : static void registerKeywords(Keywords& keys);
186 : explicit MetatomicPlumedAction(const ActionOptions&);
187 :
188 : void calculate() override;
189 : void apply() override;
190 : unsigned getNumberOfDerivatives() override;
191 :
192 : private:
193 : // fill this->system_ according to the current PLUMED data
194 : void createSystem();
195 : // compute a neighbor list following metatomic format, using data from PLUMED
196 : metatensor_torch::TensorBlock computeNeighbors(
197 : metatomic_torch::NeighborListOptions request,
198 : const std::vector<PLMD::Vector>& positions,
199 : const PLMD::Tensor& cell,
200 : bool periodic
201 : );
202 :
203 : // execute the model for the given system
204 : metatensor_torch::TensorBlock executeModel(metatomic_torch::System system);
205 :
206 : torch::jit::Module model_;
207 :
208 : metatomic_torch::ModelCapabilities capabilities_;
209 :
210 : // neighbor lists requests made by the model
211 : std::vector<metatomic_torch::NeighborListOptions> nl_requests_;
212 :
213 : // dtype/device to use to execute the model
214 : torch::ScalarType dtype_;
215 : torch::Device device_;
216 :
217 : torch::Tensor atomic_types_;
218 : // store the strain to be able to compute the virial with autograd
219 : torch::Tensor strain_;
220 :
221 : metatomic_torch::System system_;
222 : metatomic_torch::ModelEvaluationOptions evaluations_options_;
223 : bool check_consistency_;
224 :
225 : metatensor_torch::TensorMap output_;
226 : // shape of the output of this model
227 : unsigned n_samples_;
228 : unsigned n_properties_;
229 : };
230 :
231 :
232 7 : MetatomicPlumedAction::MetatomicPlumedAction(const ActionOptions& options):
233 : Action(options),
234 : ActionAtomistic(options),
235 : ActionWithValue(options),
236 7 : device_(torch::kCPU)
237 : {
238 14 : if (metatomic_torch::version().find("0.1.") != 0) {
239 0 : this->error(
240 0 : "this code requires version 0.1.x of metatomic-torch, got version " +
241 0 : metatomic_torch::version()
242 : );
243 : }
244 :
245 : // first, load the model
246 : std::string extensions_directory_str;
247 14 : this->parse("EXTENSIONS_DIRECTORY", extensions_directory_str);
248 :
249 : torch::optional<std::string> extensions_directory = torch::nullopt;
250 7 : if (!extensions_directory_str.empty()) {
251 3 : extensions_directory = std::move(extensions_directory_str);
252 : }
253 :
254 : std::string model_path;
255 14 : this->parse("MODEL", model_path);
256 :
257 : try {
258 7 : this->model_ = metatomic_torch::load_atomistic_model(model_path, extensions_directory);
259 0 : } catch (const std::exception& e) {
260 0 : this->error("failed to load model at '" + model_path + "': " + e.what());
261 0 : }
262 :
263 : // extract information from the model
264 14 : auto metadata = this->model_.run_method("metadata").toCustomClass<metatomic_torch::ModelMetadataHolder>();
265 14 : this->capabilities_ = this->model_.run_method("capabilities").toCustomClass<metatomic_torch::ModelCapabilitiesHolder>();
266 7 : auto requests_ivalue = this->model_.run_method("requested_neighbor_lists");
267 14 : for (auto request_ivalue: requests_ivalue.toList()) {
268 7 : auto request = request_ivalue.get().toCustomClass<metatomic_torch::NeighborListOptionsHolder>();
269 7 : this->nl_requests_.push_back(request);
270 : }
271 :
272 14 : log.printf("\n%s\n", metadata->print().c_str());
273 : // add the model references to PLUMED citation handling mechanism
274 8 : for (const auto& it: metadata->references) {
275 2 : for (const auto& ref: it.value()) {
276 2 : this->cite(ref);
277 1 : }
278 : }
279 :
280 : // parse the atomic types from the input file
281 : std::vector<int32_t> atomic_types;
282 : std::vector<int32_t> species_to_types;
283 14 : this->parseVector("SPECIES_TO_TYPES", species_to_types);
284 : bool has_custom_types = !species_to_types.empty();
285 :
286 : std::vector<AtomNumber> all_atoms;
287 14 : this->parseAtomList("SPECIES", all_atoms);
288 :
289 : size_t n_species = 0;
290 7 : if (all_atoms.empty()) {
291 : // first parse each of the 'SPECIES' entry
292 : std::vector<std::vector<AtomNumber>> atoms_per_species;
293 : int i = 0;
294 : while (true) {
295 24 : i += 1;
296 : auto atoms = std::vector<AtomNumber>();
297 48 : this->parseAtomList("SPECIES", i, atoms);
298 :
299 24 : if (atoms.empty()) {
300 : break;
301 : }
302 :
303 : int32_t type = i;
304 17 : if (has_custom_types) {
305 17 : if (species_to_types.size() < static_cast<size_t>(i)) {
306 0 : this->error(
307 : "SPECIES_TO_TYPES is too small, it should have one entry "
308 0 : "for each species (we have at least " + std::to_string(i) +
309 0 : " species and " + std::to_string(species_to_types.size()) +
310 : "entries in SPECIES_TO_TYPES)"
311 : );
312 : }
313 :
314 17 : type = species_to_types[static_cast<size_t>(i - 1)];
315 : }
316 :
317 17 : log.printf(" atoms with type %d are: ", type);
318 1285 : for(unsigned j=0; j<atoms.size(); j++) {
319 1268 : log.printf("%d ", atoms[j]);
320 : }
321 17 : log.printf("\n");
322 :
323 17 : n_species += 1;
324 17 : atoms_per_species.emplace_back(std::move(atoms));
325 : }
326 :
327 : size_t n_atoms = 0;
328 24 : for (const auto& atoms: atoms_per_species) {
329 17 : n_atoms += atoms.size();
330 : }
331 :
332 : // then fill the atomic_types as required
333 7 : atomic_types.resize(n_atoms, 0);
334 : i = 0;
335 24 : for (const auto& atoms: atoms_per_species) {
336 17 : i += 1;
337 :
338 : int32_t type = i;
339 17 : if (has_custom_types) {
340 17 : type = species_to_types[static_cast<size_t>(i - 1)];
341 : }
342 :
343 1285 : for (const auto& atom: atoms) {
344 1268 : atomic_types[atom.index()] = type;
345 : }
346 : }
347 7 : } else {
348 : n_species = 1;
349 :
350 0 : int32_t type = 1;
351 0 : if (has_custom_types) {
352 0 : type = species_to_types[0];
353 : }
354 0 : atomic_types.resize(all_atoms.size(), type);
355 : }
356 :
357 7 : if (has_custom_types && species_to_types.size() != n_species) {
358 0 : this->warning(
359 0 : "SPECIES_TO_TYPES contains more entries (" +
360 0 : std::to_string(species_to_types.size()) +
361 0 : ") than there where species (" + std::to_string(n_species) + ")"
362 : );
363 : }
364 :
365 : // request atoms in order
366 : all_atoms.clear();
367 1275 : for (size_t i=0; i<atomic_types.size(); i++) {
368 1268 : all_atoms.push_back(AtomNumber::index(i));
369 : }
370 7 : this->requestAtoms(all_atoms);
371 :
372 14 : this->atomic_types_ = torch::tensor(std::move(atomic_types));
373 :
374 7 : this->check_consistency_ = false;
375 7 : this->parseFlag("CHECK_CONSISTENCY", this->check_consistency_);
376 7 : if (this->check_consistency_) {
377 0 : log.printf(" checking for internal consistency of the model\n");
378 : }
379 :
380 : // create evaluation options for the model. These won't change during the
381 : // simulation, so we initialize them once here.
382 7 : evaluations_options_ = torch::make_intrusive<metatomic_torch::ModelEvaluationOptionsHolder>();
383 14 : evaluations_options_->set_length_unit(getUnits().getLengthString());
384 :
385 : auto outputs = this->capabilities_->outputs();
386 14 : if (!outputs.contains("features")) {
387 : auto existing_outputs = std::vector<std::string>();
388 0 : for (const auto& it: this->capabilities_->outputs()) {
389 0 : existing_outputs.push_back(it.key());
390 : }
391 :
392 0 : this->error(
393 : "expected a 'features' output in the capabilities of the model, "
394 0 : "could not find it. the following outputs exist: " + torch::str(existing_outputs)
395 : );
396 0 : }
397 :
398 : auto output = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
399 : // this output has no quantity or unit to set
400 :
401 7 : output->per_atom = this->capabilities_->outputs().at("features")->per_atom;
402 : // we are using torch autograd system to compute gradients,
403 : // so we don't need any explicit gradients.
404 7 : output->explicit_gradients = {};
405 7 : evaluations_options_->outputs.insert("features", output);
406 :
407 : // Determine which device we should use based on user input, what the model
408 : // supports and what's available
409 : auto available_devices = std::vector<torch::Device>();
410 22 : for (const auto& device: this->capabilities_->supported_devices) {
411 15 : if (device == "cpu") {
412 7 : available_devices.push_back(torch::kCPU);
413 8 : } else if (device == "cuda") {
414 4 : if (torch::cuda::is_available()) {
415 0 : available_devices.push_back(torch::Device("cuda"));
416 : }
417 4 : } else if (device == "mps") {
418 : #if TORCH_VERSION_MAJOR >= 2
419 4 : if (torch::mps::is_available()) {
420 0 : available_devices.push_back(torch::Device("mps"));
421 : }
422 : #endif
423 : } else {
424 0 : this->warning(
425 0 : "the model declared support for unknown device '" + device +
426 : "', it will be ignored"
427 : );
428 : }
429 : }
430 :
431 7 : if (available_devices.empty()) {
432 0 : this->error(
433 0 : "failed to find a valid device for the model at '" + model_path + "': "
434 0 : "the model supports " + torch::str(this->capabilities_->supported_devices) +
435 : ", none of these where available"
436 : );
437 : }
438 :
439 : std::string requested_device;
440 14 : this->parse("DEVICE", requested_device);
441 7 : if (requested_device.empty()) {
442 : // no user request, pick the device the model prefers
443 3 : this->device_ = available_devices[0];
444 : } else {
445 : bool found_requested_device = false;
446 4 : for (const auto& device: available_devices) {
447 8 : if (device.is_cpu() && requested_device == "cpu") {
448 4 : this->device_ = device;
449 : found_requested_device = true;
450 : break;
451 0 : } else if (device.is_cuda() && requested_device == "cuda") {
452 0 : this->device_ = device;
453 : found_requested_device = true;
454 : break;
455 0 : } else if (device.is_mps() && requested_device == "mps") {
456 0 : this->device_ = device;
457 : found_requested_device = true;
458 : break;
459 : }
460 : }
461 :
462 : if (!found_requested_device) {
463 0 : this->error(
464 0 : "failed to find requested device (" + requested_device + "): it is either "
465 : "not supported by this model or not available on this machine"
466 : );
467 : }
468 : }
469 :
470 7 : this->model_.to(this->device_);
471 7 : this->atomic_types_ = this->atomic_types_.to(this->device_);
472 :
473 7 : log.printf(
474 : " running model on %s device with %s data\n",
475 14 : this->device_.str().c_str(),
476 : this->capabilities_->dtype().c_str()
477 : );
478 :
479 7 : if (this->capabilities_->dtype() == "float64") {
480 3 : this->dtype_ = torch::kFloat64;
481 4 : } else if (this->capabilities_->dtype() == "float32") {
482 4 : this->dtype_ = torch::kFloat32;
483 : } else {
484 0 : this->error(
485 0 : "the model requested an unsupported dtype '" + this->capabilities_->dtype() + "'"
486 : );
487 : }
488 :
489 7 : auto tensor_options = torch::TensorOptions().dtype(this->dtype_).device(this->device_);
490 14 : this->strain_ = torch::eye(3, tensor_options.requires_grad(true));
491 :
492 : // determine how many properties there will be in the output by running the
493 : // model once on a dummy system
494 : auto dummy_system = torch::make_intrusive<metatomic_torch::SystemHolder>(
495 14 : /*types = */ torch::zeros({0}, tensor_options.dtype(torch::kInt32)),
496 14 : /*positions = */ torch::zeros({0, 3}, tensor_options),
497 14 : /*cell = */ torch::zeros({3, 3}, tensor_options),
498 7 : /*pbc = */ torch::zeros({3}, tensor_options.dtype(torch::kBool))
499 : );
500 :
501 7 : log.printf(" the following neighbor lists have been requested:\n");
502 7 : auto length_unit = this->getUnits().getLengthString();
503 7 : auto model_length_unit = this->capabilities_->length_unit();
504 14 : for (auto request: this->nl_requests_) {
505 10 : log.printf(" - %s list, %g %s cutoff (requested %g %s)\n",
506 : request->full_list() ? "full" : "half",
507 : request->engine_cutoff(length_unit),
508 : length_unit.c_str(),
509 : request->cutoff(),
510 : model_length_unit.c_str()
511 : );
512 :
513 : auto neighbors = this->computeNeighbors(
514 : request,
515 : {PLMD::Vector(0, 0, 0)},
516 7 : PLMD::Tensor(0, 0, 0, 0, 0, 0, 0, 0, 0),
517 : false
518 21 : );
519 21 : metatomic_torch::register_autograd_neighbors(dummy_system, neighbors, this->check_consistency_);
520 14 : dummy_system->add_neighbor_list(request, neighbors);
521 : }
522 :
523 7 : this->n_properties_ = static_cast<unsigned>(
524 14 : this->executeModel(dummy_system)->properties()->count()
525 : );
526 :
527 : // parse and handle atom sub-selection. This is done AFTER determining the
528 : // output size, since the selection might not be valid for the dummy system
529 : std::vector<int32_t> selected_atoms;
530 14 : this->parseVector("SELECTED_ATOMS", selected_atoms);
531 7 : if (!selected_atoms.empty()) {
532 : auto selection_value = torch::zeros(
533 : {static_cast<int64_t>(selected_atoms.size()), 2},
534 2 : torch::TensorOptions().dtype(torch::kInt32).device(this->device_)
535 2 : );
536 :
537 9 : for (unsigned i=0; i<selected_atoms.size(); i++) {
538 7 : auto n_atoms = static_cast<int32_t>(this->atomic_types_.size(0));
539 7 : if (selected_atoms[i] <= 0 || selected_atoms[i] > n_atoms) {
540 0 : this->error(
541 : "Values in metatomic's SELECTED_ATOMS should be between 1 "
542 0 : "and the number of atoms (" + std::to_string(n_atoms) + "), "
543 0 : "got " + std::to_string(selected_atoms[i]));
544 : }
545 : // PLUMED input uses 1-based indexes, but metatomic wants 0-based
546 14 : selection_value[i][1] = selected_atoms[i] - 1;
547 : }
548 :
549 4 : evaluations_options_->set_selected_atoms(
550 2 : torch::make_intrusive<metatensor_torch::LabelsHolder>(
551 8 : std::vector<std::string>{"system", "atom"}, selection_value
552 : )
553 : );
554 : }
555 :
556 : // Now that we now both n_samples and n_properties, we can setup the
557 : // PLUMED-side storage for the computed CV
558 7 : if (output->per_atom) {
559 4 : if (selected_atoms.empty()) {
560 2 : this->n_samples_ = static_cast<unsigned>(this->atomic_types_.size(0));
561 : } else {
562 2 : this->n_samples_ = static_cast<unsigned>(selected_atoms.size());
563 : }
564 : } else {
565 3 : this->n_samples_ = 1;
566 : }
567 :
568 7 : if (n_samples_ == 1 && n_properties_ == 1) {
569 2 : log.printf(" the output of this model is a scalar\n");
570 :
571 4 : this->addValue();
572 5 : } else if (n_samples_ == 1) {
573 1 : log.printf(" the output of this model is 1x%d vector\n", n_properties_);
574 :
575 1 : this->addValue({this->n_properties_});
576 1 : this->getPntrToComponent(0)->buildDataStore();
577 4 : } else if (n_properties_ == 1) {
578 1 : log.printf(" the output of this model is %dx1 vector\n", n_samples_);
579 :
580 1 : this->addValue({this->n_samples_});
581 1 : this->getPntrToComponent(0)->buildDataStore();
582 : } else {
583 3 : log.printf(" the output of this model is a %dx%d matrix\n", n_samples_, n_properties_);
584 :
585 3 : this->addValue({this->n_samples_, this->n_properties_});
586 3 : this->getPntrToComponent(0)->buildDataStore();
587 3 : this->getPntrToComponent(0)->reshapeMatrixStore(n_properties_);
588 : }
589 :
590 7 : this->setNotPeriodic();
591 7 : }
592 :
593 12 : unsigned MetatomicPlumedAction::getNumberOfDerivatives() {
594 : // gradients w.r.t. positions (3 x N values) + gradients w.r.t. strain (9 values)
595 12 : return 3 * this->getNumberOfAtoms() + 9;
596 : }
597 :
598 :
599 8 : void MetatomicPlumedAction::createSystem() {
600 16 : if (this->getTotAtoms() != static_cast<unsigned>(this->atomic_types_.size(0))) {
601 0 : std::ostringstream oss;
602 0 : oss << "METATOMIC action needs to know about all atoms in the system. ";
603 0 : oss << "There are " << this->getTotAtoms() << " atoms overall, ";
604 0 : oss << "but we only have atomic types for " << this->atomic_types_.size(0) << " of them.";
605 0 : plumed_merror(oss.str());
606 0 : }
607 :
608 : // this->getTotAtoms()
609 :
610 8 : const auto& cell = this->getPbc().getBox();
611 :
612 8 : auto cpu_f64_tensor = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU);
613 8 : auto torch_cell = torch::zeros({3, 3}, cpu_f64_tensor);
614 :
615 16 : torch_cell[0][0] = cell(0, 0);
616 16 : torch_cell[0][1] = cell(0, 1);
617 16 : torch_cell[0][2] = cell(0, 2);
618 :
619 16 : torch_cell[1][0] = cell(1, 0);
620 16 : torch_cell[1][1] = cell(1, 1);
621 16 : torch_cell[1][2] = cell(1, 2);
622 :
623 16 : torch_cell[2][0] = cell(2, 0);
624 16 : torch_cell[2][1] = cell(2, 1);
625 16 : torch_cell[2][2] = cell(2, 2);
626 :
627 : using torch::indexing::Slice;
628 :
629 32 : auto pbc_a = torch_cell.index({0, Slice()}).norm().abs().item<double>() > 1e-9;
630 32 : auto pbc_b = torch_cell.index({1, Slice()}).norm().abs().item<double>() > 1e-9;
631 32 : auto pbc_c = torch_cell.index({2, Slice()}).norm().abs().item<double>() > 1e-9;
632 48 : auto torch_pbc = torch::tensor({pbc_a, pbc_b, pbc_c});
633 :
634 : const auto& positions = this->getPositions();
635 :
636 : auto torch_positions = torch::from_blob(
637 : const_cast<PLMD::Vector*>(positions.data()),
638 : {static_cast<int64_t>(positions.size()), 3},
639 : cpu_f64_tensor
640 8 : );
641 :
642 16 : torch_positions = torch_positions.to(this->dtype_).to(this->device_);
643 16 : torch_cell = torch_cell.to(this->dtype_).to(this->device_);
644 8 : torch_pbc = torch_pbc.to(this->device_);
645 :
646 : // setup torch's automatic gradient tracking
647 8 : if (!this->doNotCalculateDerivatives()) {
648 : torch_positions.requires_grad_(true);
649 :
650 : // pretend to scale positions/cell by the strain so that it enters the
651 : // computational graph.
652 7 : torch_positions = torch_positions.matmul(this->strain_);
653 7 : torch_positions.retain_grad();
654 :
655 7 : torch_cell = torch_cell.matmul(this->strain_);
656 : }
657 :
658 8 : this->system_ = torch::make_intrusive<metatomic_torch::SystemHolder>(
659 8 : this->atomic_types_,
660 : torch_positions,
661 : torch_cell,
662 : torch_pbc
663 : );
664 :
665 8 : auto periodic = torch::all(torch_pbc).item<bool>();
666 8 : if (!periodic && torch::any(torch_pbc).item<bool>()) {
667 : std::string periodic_directions;
668 : std::string non_periodic_directions;
669 0 : if (pbc_a) {
670 : periodic_directions += "A";
671 : } else {
672 : non_periodic_directions += "A";
673 : }
674 :
675 0 : if (pbc_b) {
676 : periodic_directions += "B";
677 : } else {
678 : non_periodic_directions += "B";
679 : }
680 :
681 0 : if (pbc_c) {
682 : periodic_directions += "C";
683 : } else {
684 : non_periodic_directions += "C";
685 : }
686 :
687 0 : plumed_merror(
688 : "mixed periodic boundary conditions are not supported, this system "
689 : "is periodic along the " + periodic_directions + " cell vector(s), "
690 : "but not along the " + non_periodic_directions + " cell vector(s)."
691 : );
692 : }
693 :
694 : // compute the neighbors list requested by the model, and register them with
695 : // the system
696 16 : for (auto request: this->nl_requests_) {
697 8 : auto neighbors = this->computeNeighbors(request, positions, cell, periodic);
698 24 : metatomic_torch::register_autograd_neighbors(this->system_, neighbors, this->check_consistency_);
699 16 : this->system_->add_neighbor_list(request, neighbors);
700 : }
701 8 : }
702 :
703 :
704 15 : metatensor_torch::TensorBlock MetatomicPlumedAction::computeNeighbors(
705 : metatomic_torch::NeighborListOptions request,
706 : const std::vector<PLMD::Vector>& positions,
707 : const PLMD::Tensor& cell,
708 : bool periodic
709 : ) {
710 15 : auto labels_options = torch::TensorOptions().dtype(torch::kInt32).device(this->device_);
711 : auto neighbor_component = torch::make_intrusive<metatensor_torch::LabelsHolder>(
712 : "xyz",
713 75 : torch::tensor({0, 1, 2}, labels_options).reshape({3, 1})
714 : );
715 : auto neighbor_properties = torch::make_intrusive<metatensor_torch::LabelsHolder>(
716 30 : "distance", torch::zeros({1, 1}, labels_options)
717 : );
718 :
719 15 : auto cutoff = request->engine_cutoff(this->getUnits().getLengthString());
720 :
721 : // use https://github.com/Luthaf/vesin to compute the requested neighbor
722 : // lists since we can not get these from PLUMED
723 : vesin::VesinOptions options;
724 15 : options.cutoff = cutoff;
725 15 : options.full = request->full_list();
726 15 : options.return_shifts = true;
727 15 : options.return_distances = false;
728 15 : options.return_vectors = true;
729 :
730 15 : vesin::VesinNeighborList* vesin_neighbor_list = new vesin::VesinNeighborList();
731 : memset(vesin_neighbor_list, 0, sizeof(vesin::VesinNeighborList));
732 :
733 15 : const char* error_message = NULL;
734 15 : int status = vesin_neighbors(
735 : reinterpret_cast<const double (*)[3]>(positions.data()),
736 : positions.size(),
737 15 : reinterpret_cast<const double (*)[3]>(&cell(0, 0)),
738 : periodic,
739 : vesin::VesinCPU,
740 : options,
741 : vesin_neighbor_list,
742 : &error_message
743 : );
744 :
745 15 : if (status != EXIT_SUCCESS) {
746 0 : plumed_merror(
747 : "failed to compute neighbor list (cutoff=" + std::to_string(cutoff) +
748 : ", full=" + (request->full_list() ? "true" : "false") + "): " + error_message
749 : );
750 : }
751 :
752 : // transform from vesin to metatomic format
753 15 : auto n_pairs = static_cast<int64_t>(vesin_neighbor_list->length);
754 :
755 : auto pair_vectors = torch::from_blob(
756 15 : vesin_neighbor_list->vectors,
757 : {n_pairs, 3, 1},
758 15 : /*deleter*/ [=](void*) {
759 15 : vesin_free(vesin_neighbor_list);
760 15 : delete vesin_neighbor_list;
761 15 : },
762 15 : torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU)
763 15 : );
764 :
765 15 : auto pair_samples_values = torch::empty({n_pairs, 5}, labels_options.device(torch::kCPU));
766 15 : auto pair_samples_values_ptr = pair_samples_values.accessor<int32_t, 2>();
767 15189 : for (unsigned i=0; i<n_pairs; i++) {
768 15174 : pair_samples_values_ptr[i][0] = static_cast<int32_t>(vesin_neighbor_list->pairs[i][0]);
769 15174 : pair_samples_values_ptr[i][1] = static_cast<int32_t>(vesin_neighbor_list->pairs[i][1]);
770 15174 : pair_samples_values_ptr[i][2] = vesin_neighbor_list->shifts[i][0];
771 15174 : pair_samples_values_ptr[i][3] = vesin_neighbor_list->shifts[i][1];
772 15174 : pair_samples_values_ptr[i][4] = vesin_neighbor_list->shifts[i][2];
773 : }
774 :
775 : auto neighbor_samples = torch::make_intrusive<metatensor_torch::LabelsHolder>(
776 120 : std::vector<std::string>{"first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"},
777 30 : pair_samples_values.to(this->device_)
778 : );
779 :
780 : auto neighbors = torch::make_intrusive<metatensor_torch::TensorBlockHolder>(
781 30 : pair_vectors.to(this->dtype_).to(this->device_),
782 : neighbor_samples,
783 60 : std::vector<metatensor_torch::Labels>{neighbor_component},
784 : neighbor_properties
785 : );
786 :
787 15 : return neighbors;
788 : }
789 :
790 15 : metatensor_torch::TensorBlock MetatomicPlumedAction::executeModel(metatomic_torch::System system) {
791 : try {
792 45 : auto ivalue_output = this->model_.forward({
793 60 : std::vector<metatomic_torch::System>{system},
794 : evaluations_options_,
795 15 : this->check_consistency_,
796 75 : });
797 :
798 15 : auto dict_output = ivalue_output.toGenericDict();
799 15 : auto cv = dict_output.at("features");
800 30 : this->output_ = cv.toCustomClass<metatensor_torch::TensorMapHolder>();
801 0 : } catch (const std::exception& e) {
802 0 : plumed_merror("failed to evaluate the model: " + std::string(e.what()));
803 0 : }
804 :
805 30 : plumed_massert(this->output_->keys()->count() == 1, "output should have a single block");
806 30 : auto block = metatensor_torch::TensorMapHolder::block_by_id(this->output_, 0);
807 15 : plumed_massert(block->components().empty(), "components are not yet supported in the output");
808 :
809 15 : return block;
810 : }
811 :
812 :
813 8 : void MetatomicPlumedAction::calculate() {
814 8 : this->createSystem();
815 :
816 16 : auto block = this->executeModel(this->system_);
817 24 : auto torch_values = block->values().to(torch::kCPU).to(torch::kFloat64);
818 :
819 8 : if (static_cast<unsigned>(torch_values.size(0)) != this->n_samples_) {
820 0 : plumed_merror(
821 : "expected the model to return a TensorBlock with " +
822 : std::to_string(this->n_samples_) + " samples, got " +
823 : std::to_string(torch_values.size(0)) + " instead"
824 : );
825 8 : } else if (static_cast<unsigned>(torch_values.size(1)) != this->n_properties_) {
826 0 : plumed_merror(
827 : "expected the model to return a TensorBlock with " +
828 : std::to_string(this->n_properties_) + " properties, got " +
829 : std::to_string(torch_values.size(1)) + " instead"
830 : );
831 : }
832 :
833 8 : Value* value = this->getPntrToComponent(0);
834 : // reshape the plumed `Value` to hold the data returned by the model
835 8 : if (n_samples_ == 1) {
836 4 : if (n_properties_ == 1) {
837 3 : value->set(torch_values.item<double>());
838 : } else {
839 : // we have multiple CV describing a single thing (atom or full system)
840 3 : for (unsigned i=0; i<n_properties_; i++) {
841 2 : value->set(i, torch_values[0][i].item<double>());
842 : }
843 : }
844 : } else {
845 : auto samples = block->samples();
846 16 : plumed_assert((samples->names() == std::vector<std::string>{"system", "atom"}));
847 :
848 4 : auto samples_values = samples->values().to(torch::kCPU);
849 : auto selected_atoms = this->evaluations_options_->get_selected_atoms();
850 :
851 : // handle the possibility that samples are returned in
852 : // a non-sorted order.
853 92 : auto get_output_location = [&](unsigned i) {
854 92 : if (selected_atoms.has_value()) {
855 : // If the users picked some selected atoms, then we store the
856 : // output in the same order as the selection was given
857 28 : auto sample = samples_values.index({static_cast<int64_t>(i), torch::indexing::Slice()});
858 7 : auto position = selected_atoms.value()->position(sample);
859 7 : plumed_assert(position.has_value());
860 7 : return static_cast<unsigned>(position.value());
861 : } else {
862 170 : return static_cast<unsigned>(samples_values[i][1].item<int32_t>());
863 : }
864 4 : };
865 :
866 4 : if (n_properties_ == 1) {
867 : // we have a single CV describing multiple things (i.e. atoms)
868 5 : for (unsigned i=0; i<n_samples_; i++) {
869 4 : auto output_i = get_output_location(i);
870 8 : value->set(output_i, torch_values[i][0].item<double>());
871 : }
872 : } else {
873 : // the CV is a matrix
874 91 : for (unsigned i=0; i<n_samples_; i++) {
875 88 : auto output_i = get_output_location(i);
876 343 : for (unsigned j=0; j<n_properties_; j++) {
877 255 : value->set(output_i * n_properties_ + j, torch_values[i][j].item<double>());
878 : }
879 : }
880 : }
881 : }
882 8 : }
883 :
884 :
885 8 : void MetatomicPlumedAction::apply() {
886 8 : const auto* value = this->getPntrToComponent(0);
887 8 : if (!value->forcesWereAdded()) {
888 1 : return;
889 : }
890 :
891 14 : auto block = metatensor_torch::TensorMapHolder::block_by_id(this->output_, 0);
892 21 : auto torch_values = block->values().to(torch::kCPU).to(torch::kFloat64);
893 :
894 7 : if (!torch_values.requires_grad()) {
895 0 : this->warning(
896 : "the output of the model does not requires gradients, this might "
897 : "indicate a problem"
898 : );
899 : return;
900 : }
901 :
902 7 : auto output_grad = torch::zeros_like(torch_values);
903 7 : if (n_samples_ == 1) {
904 4 : if (n_properties_ == 1) {
905 3 : output_grad[0][0] = value->getForce();
906 : } else {
907 3 : for (unsigned i=0; i<n_properties_; i++) {
908 4 : output_grad[0][i] = value->getForce(i);
909 : }
910 : }
911 : } else {
912 : auto samples = block->samples();
913 12 : plumed_assert((samples->names() == std::vector<std::string>{"system", "atom"}));
914 :
915 3 : auto samples_values = samples->values().to(torch::kCPU);
916 : auto selected_atoms = this->evaluations_options_->get_selected_atoms();
917 :
918 : // see above for an explanation of why we use this function
919 89 : auto get_output_location = [&](unsigned i) {
920 89 : if (selected_atoms.has_value()) {
921 16 : auto sample = samples_values.index({static_cast<int64_t>(i), torch::indexing::Slice()});
922 4 : auto position = selected_atoms.value()->position(sample);
923 4 : plumed_assert(position.has_value());
924 4 : return static_cast<unsigned>(position.value());
925 : } else {
926 170 : return static_cast<unsigned>(samples_values[i][1].item<int32_t>());
927 : }
928 3 : };
929 :
930 3 : if (n_properties_ == 1) {
931 5 : for (unsigned i=0; i<n_samples_; i++) {
932 4 : auto output_i = get_output_location(i);
933 8 : output_grad[i][0] = value->getForce(output_i);
934 : }
935 : } else {
936 87 : for (unsigned i=0; i<n_samples_; i++) {
937 85 : auto output_i = get_output_location(i);
938 331 : for (unsigned j=0; j<n_properties_; j++) {
939 492 : output_grad[i][j] = value->getForce(output_i * n_properties_ + j);
940 : }
941 : }
942 : }
943 : }
944 :
945 7 : this->system_->positions().mutable_grad() = torch::Tensor();
946 7 : this->strain_.mutable_grad() = torch::Tensor();
947 :
948 7 : torch_values.backward(output_grad);
949 7 : auto positions_grad = this->system_->positions().grad();
950 7 : auto strain_grad = this->strain_.grad();
951 :
952 21 : positions_grad = positions_grad.to(torch::kCPU).to(torch::kFloat64);
953 21 : strain_grad = strain_grad.to(torch::kCPU).to(torch::kFloat64);
954 :
955 7 : plumed_assert(positions_grad.sizes().size() == 2);
956 7 : plumed_assert(positions_grad.is_contiguous());
957 :
958 7 : plumed_assert(strain_grad.sizes().size() == 2);
959 7 : plumed_assert(strain_grad.is_contiguous());
960 :
961 : auto derivatives = std::vector<double>(
962 : positions_grad.data_ptr<double>(),
963 7 : positions_grad.data_ptr<double>() + 3 * this->system_->size()
964 7 : );
965 :
966 : // add virials to the derivatives
967 7 : derivatives.push_back(-strain_grad[0][0].item<double>());
968 7 : derivatives.push_back(-strain_grad[0][1].item<double>());
969 7 : derivatives.push_back(-strain_grad[0][2].item<double>());
970 :
971 7 : derivatives.push_back(-strain_grad[1][0].item<double>());
972 7 : derivatives.push_back(-strain_grad[1][1].item<double>());
973 7 : derivatives.push_back(-strain_grad[1][2].item<double>());
974 :
975 7 : derivatives.push_back(-strain_grad[2][0].item<double>());
976 7 : derivatives.push_back(-strain_grad[2][1].item<double>());
977 7 : derivatives.push_back(-strain_grad[2][2].item<double>());
978 :
979 7 : unsigned index = 0;
980 7 : this->setForcesOnAtoms(derivatives, index);
981 : }
982 :
983 : } // namespace metatomic
984 : } // namespace PLMD
985 :
986 :
987 : #endif
988 :
989 :
990 : namespace PLMD {
991 : namespace metatomic {
992 :
993 : // use the same implementation for both the actual action and the dummy one
994 : // (when libtorch and libmetatomic could not be found).
995 9 : void MetatomicPlumedAction::registerKeywords(Keywords& keys) {
996 9 : Action::registerKeywords(keys);
997 9 : ActionAtomistic::registerKeywords(keys);
998 9 : ActionWithValue::registerKeywords(keys);
999 :
1000 18 : keys.add("compulsory", "MODEL", "path to the exported metatomic model");
1001 18 : keys.add("optional", "EXTENSIONS_DIRECTORY", "path to the directory containing TorchScript extensions to load");
1002 18 : keys.add("optional", "DEVICE", "Torch device to use for the calculations");
1003 :
1004 18 : keys.addFlag("CHECK_CONSISTENCY", false, "should we enable internal consistency checks when executing the model");
1005 :
1006 18 : keys.add("numbered", "SPECIES", "the indices of atoms in each PLUMED species");
1007 18 : keys.reset_style("SPECIES", "atoms");
1008 :
1009 18 : keys.add("optional", "SELECTED_ATOMS", "subset of atoms that should be used for the calculation");
1010 18 : keys.reset_style("SELECTED_ATOMS", "atoms");
1011 :
1012 18 : keys.add("optional", "SPECIES_TO_TYPES", "mapping from PLUMED SPECIES to metatomic's atom types");
1013 :
1014 18 : keys.addOutputComponent("outputs", "default", "collective variable created by the metatomic model");
1015 :
1016 9 : keys.setValueDescription("collective variable created by the metatomic model");
1017 9 : }
1018 :
1019 : PLUMED_REGISTER_ACTION(MetatomicPlumedAction, "METATOMIC")
1020 :
1021 : } // namespace metatomic
1022 : } // namespace PLMD
|