LCOV - code coverage report
Current view: top level - metatomic - metatomic.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 313 371 84.4 %
Date: 2025-12-04 11:19:34 Functions: 11 12 91.7 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             : Copyright (c) 2024 The METATOMIC-PLUMED team
       3             : (see the PEOPLE-METATOMIC file at the root of this folder for a list of names)
       4             : 
       5             : See https://docs.metatensor.org/metatomic/ for more information about the
       6             : metatomic package that this module allows you to call from PLUMED.
       7             : 
       8             : This file is part of METATOMIC-PLUMED module.
       9             : 
      10             : The METATOMIC-PLUMED module is free software: you can redistribute it and/or modify
      11             : it under the terms of the GNU Lesser General Public License as published by
      12             : the Free Software Foundation, either version 3 of the License, or
      13             : (at your option) any later version.
      14             : 
      15             : The METATOMIC-PLUMED module is distributed in the hope that it will be useful,
      16             : but WITHOUT ANY WARRANTY; without even the implied warranty of
      17             : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      18             : GNU Lesser General Public License for more details.
      19             : 
      20             : You should have received a copy of the GNU Lesser General Public License
      21             : along with the METATOMIC-PLUMED module. If not, see <http://www.gnu.org/licenses/>.
      22             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      23             : 
      24             : #include "core/ActionAtomistic.h"
      25             : #include "core/ActionWithValue.h"
      26             : #include "core/ActionRegister.h"
      27             : #include "core/PlumedMain.h"
      28             : 
      29             : //+PLUMEDOC COLVAR METATOMIC
      30             : /*
      31             : Use arbitrary machine learning models as collective variables.
      32             : 
      33             : !!! note
      34             : 
      35             :     This action requires `libmetatomic` to be enabled and found. Check the
      36             :     instructions in [the module page](module_metatomic.md) for more information
      37             :     on how to enable this module.
      38             : 
      39             : This action enables the use of fully custom machine learning models — based on
      40             : the [metatomic] models interface — as collective variables in PLUMED. Such
      41             : machine learning model are typically written and customized using Python code,
      42             : and then exported to run within PLUMED as [TorchScript], which is a subset of
      43             : Python that can be executed by the C++ torch library.
      44             : 
      45             : Metatomic offers a way to define such models and pass data from PLUMED (or any
      46             : other simulation engine) to the model and back. For more information on how to
      47             : define such model, have a look at the [corresponding tutorials][mta_tutorials],
      48             : or at the code in `regtest/metatomic/`. Each of the Python scripts in this
      49             : directory defines a custom machine learning CV that can be used with PLUMED.
      50             : 
      51             : ## Examples
      52             : 
      53             : The following input shows how you can call metatomic and evaluate the model that
      54             : is described in the file `custom_cv.pt` from PLUMED.
      55             : 
      56             : ```plumed
      57             : metatomic_cv: METATOMIC ...
      58             :     MODEL=custom_cv.pt
      59             : 
      60             :     SPECIES1=1-26
      61             :     SPECIES2=27-62
      62             :     SPECIES3=63-76
      63             :     SPECIES_TO_TYPES=6,1,8
      64             : ...
      65             : ```
      66             : 
      67             : The numbered `SPECIES` labels are used to indicate the list of atoms that belong
      68             : to each atomic species in the system. The `SPECIES_TO_TYPE` keyword then
      69             : provides information on the atom type for each species. The first number here is
      70             : the atomic type of the atoms that have been specified using the `SPECIES1` flag,
      71             : the second number is the atomic number of the atoms that have been specified
      72             : using the `SPECIES2` flag and so on.
      73             : 
      74             : `METATOMIC` action also accepts the following options:
      75             : 
      76             : - `EXTENSIONS_DIRECTORY` should be the path to a directory containing
      77             :   TorchScript extensions (as shared libraries) that are required to load and
      78             :   execute the model. This matches the `collect_extensions` argument to
      79             :   `AtomisticModel.export` in Python.
      80             : - `CHECK_CONSISTENCY` can be used to enable internal consistency checks;
      81             : - `SELECTED_ATOMS` can be used to signal the metatomic models that it should
      82             :   only run its calculation for the selected subset of atoms. The model still
      83             :   need to know about all the atoms in the system (through the `SPECIES`
      84             :   keyword); but this can be used to reduce the calculation cost. Note that the
      85             :   indices of the selected atoms should start at 1 in the PLUMED input file, but
      86             :   they will be translated to start at 0 when given to the model (i.e. in
      87             :   Python/TorchScript, the `forward` method will receive a `selected_atoms` which
      88             :   starts at 0)
      89             : 
      90             : Here is another example with all the possible keywords:
      91             : 
      92             : ```plumed
      93             : soap: METATOMIC ...
      94             :     MODEL=path/to/model.pt
      95             :     EXTENSIONS_DIRECTORY=path/to/extensions/
      96             :     DEVICE=cuda
      97             :     CHECK_CONSISTENCY
      98             : 
      99             :     SPECIES1=1-10
     100             :     SPECIES2=11-20
     101             :     SPECIES_TO_TYPES=8,13
     102             : 
     103             :     # only run the calculation for the Aluminium (type 13) atoms,
     104             :     # still including all other atoms as potential neighbors.
     105             :     SELECTED_ATOMS=11-20
     106             : ...
     107             : ```
     108             : 
     109             : [TorchScript]: https://pytorch.org/docs/stable/jit.html [metatomic]:
     110             : https://docs.metatensor.org/metatomic/ [mta_tutorials]:
     111             : https://docs.metatensor.org/metatomic/latest/examples/ [features_output]:
     112             : https://docs.metatensor.org/metatomic/latest/outputs/features.html
     113             : 
     114             : */
     115             : //+ENDPLUMEDOC
     116             : 
     117             : /*INDENT-OFF*/
     118             : #if !defined(__PLUMED_HAS_LIBMETATOMIC) || !defined(__PLUMED_HAS_LIBTORCH)
     119             : 
     120             : namespace PLMD { namespace metatomic {
     121             : class MetatomicPlumedAction: public ActionAtomistic, public ActionWithValue {
     122             : public:
     123             :     static void registerKeywords(Keywords& keys);
     124             :     explicit MetatomicPlumedAction(const ActionOptions& options):
     125             :         Action(options),
     126             :         ActionAtomistic(options),
     127             :         ActionWithValue(options)
     128             :     {
     129             :         throw std::runtime_error(
     130             :             "Can not use metatomic action without the corresponding libraries. \n"
     131             :             "Make sure to configure with `--enable-libmetatomic --enable-libtorch` "
     132             :             "and that the corresponding libraries are found"
     133             :         );
     134             :     }
     135             : 
     136             :     void calculate() override {}
     137             :     void apply() override {}
     138             :     unsigned getNumberOfDerivatives() override {return 0;}
     139             : };
     140             : 
     141             : }} // namespace PLMD::metatomic
     142             : 
     143             : #else
     144             : 
     145             : #include <type_traits>
     146             : 
     147             : #pragma GCC diagnostic push
     148             : #pragma GCC diagnostic ignored "-Wpedantic"
     149             : #pragma GCC diagnostic ignored "-Wunused-parameter"
     150             : #pragma GCC diagnostic ignored "-Wfloat-equal"
     151             : #pragma GCC diagnostic ignored "-Wfloat-conversion"
     152             : #pragma GCC diagnostic ignored "-Wimplicit-float-conversion"
     153             : #pragma GCC diagnostic ignored "-Wimplicit-int-conversion"
     154             : #pragma GCC diagnostic ignored "-Wshorten-64-to-32"
     155             : #pragma GCC diagnostic ignored "-Wsign-conversion"
     156             : #pragma GCC diagnostic ignored "-Wold-style-cast"
     157             : 
     158             : #include <torch/script.h>
     159             : #include <torch/version.h>
     160             : #include <torch/cuda.h>
     161             : #if TORCH_VERSION_MAJOR >= 2
     162             : #include <torch/mps.h>
     163             : #endif
     164             : 
     165             : #pragma GCC diagnostic pop
     166             : 
     167             : #include <metatensor/torch.hpp>
     168             : #include <metatomic/torch.hpp>
     169             : 
     170             : #include "vesin.h"
     171             : 
     172             : 
     173             : namespace PLMD {
     174             : namespace metatomic {
     175             : 
     176             : // We will cast Vector/Tensor to pointers to arrays and doubles, so let's make
     177             : // sure this is legal to do
     178             : static_assert(std::is_standard_layout<PLMD::Vector>::value);
     179             : static_assert(sizeof(PLMD::Vector) == sizeof(std::array<double, 3>));
     180             : static_assert(alignof(PLMD::Vector) == alignof(std::array<double, 3>));
     181             : 
     182             : static_assert(std::is_standard_layout<PLMD::Tensor>::value);
     183             : static_assert(sizeof(PLMD::Tensor) == sizeof(std::array<std::array<double, 3>, 3>));
     184             : static_assert(alignof(PLMD::Tensor) == alignof(std::array<std::array<double, 3>, 3>));
     185             : 
     186             : class MetatomicPlumedAction: public ActionAtomistic, public ActionWithValue {
     187             : public:
     188             :     static void registerKeywords(Keywords& keys);
     189             :     explicit MetatomicPlumedAction(const ActionOptions&);
     190             : 
     191             :     void calculate() override;
     192             :     void apply() override;
     193             :     unsigned getNumberOfDerivatives() override;
     194             : 
     195             : private:
     196             :     // fill this->system_ according to the current PLUMED data
     197             :     void createSystem();
     198             :     // compute a neighbor list following metatomic format, using data from PLUMED
     199             :     metatensor_torch::TensorBlock computeNeighbors(
     200             :         metatomic_torch::NeighborListOptions request,
     201             :         const std::vector<PLMD::Vector>& positions,
     202             :         const PLMD::Tensor& cell,
     203             :         bool periodic
     204             :     );
     205             : 
     206             :     // execute the model for the given system
     207             :     metatensor_torch::TensorBlock executeModel(metatomic_torch::System system);
     208             : 
     209             :     torch::jit::Module model_;
     210             : 
     211             :     metatomic_torch::ModelCapabilities capabilities_;
     212             : 
     213             :     // neighbor lists requests made by the model
     214             :     std::vector<metatomic_torch::NeighborListOptions> nl_requests_;
     215             : 
     216             :     // dtype/device to use to execute the model
     217             :     torch::ScalarType dtype_;
     218             :     torch::Device device_;
     219             : 
     220             :     torch::Tensor atomic_types_;
     221             :     // store the strain to be able to compute the virial with autograd
     222             :     torch::Tensor strain_;
     223             : 
     224             :     metatomic_torch::System system_;
     225             :     metatomic_torch::ModelEvaluationOptions evaluations_options_;
     226             :     bool check_consistency_;
     227             : 
     228             :     metatensor_torch::TensorMap output_;
     229             :     // shape of the output of this model
     230             :     unsigned n_samples_;
     231             :     unsigned n_properties_;
     232             : };
     233             : 
     234             : 
     235           7 : MetatomicPlumedAction::MetatomicPlumedAction(const ActionOptions& options):
     236             :     Action(options),
     237             :     ActionAtomistic(options),
     238             :     ActionWithValue(options),
     239           7 :     device_(torch::kCPU)
     240             : {
     241          14 :     if (metatomic_torch::version().find("0.1.") != 0) {
     242           0 :         this->error(
     243           0 :             "this code requires version 0.1.x of metatomic-torch, got version " +
     244           0 :             metatomic_torch::version()
     245             :         );
     246             :     }
     247             : 
     248             :     // first, load the model
     249             :     std::string extensions_directory_str;
     250          14 :     this->parse("EXTENSIONS_DIRECTORY", extensions_directory_str);
     251             : 
     252             :     torch::optional<std::string> extensions_directory = torch::nullopt;
     253           7 :     if (!extensions_directory_str.empty()) {
     254           3 :         extensions_directory = std::move(extensions_directory_str);
     255             :     }
     256             : 
     257             :     std::string model_path;
     258          14 :     this->parse("MODEL", model_path);
     259             : 
     260             :     try {
     261           7 :         this->model_ = metatomic_torch::load_atomistic_model(model_path, extensions_directory);
     262           0 :     } catch (const std::exception& e) {
     263           0 :         this->error("failed to load model at '" + model_path + "': " + e.what());
     264           0 :     }
     265             : 
     266             :     // extract information from the model
     267          14 :     auto metadata = this->model_.run_method("metadata").toCustomClass<metatomic_torch::ModelMetadataHolder>();
     268          14 :     this->capabilities_ = this->model_.run_method("capabilities").toCustomClass<metatomic_torch::ModelCapabilitiesHolder>();
     269           7 :     auto requests_ivalue = this->model_.run_method("requested_neighbor_lists");
     270          14 :     for (auto request_ivalue: requests_ivalue.toList()) {
     271           7 :         auto request = request_ivalue.get().toCustomClass<metatomic_torch::NeighborListOptionsHolder>();
     272           7 :         this->nl_requests_.push_back(request);
     273             :     }
     274             : 
     275          14 :     log.printf("\n%s\n", metadata->print().c_str());
     276             :     // add the model references to PLUMED citation handling mechanism
     277           8 :     for (const auto& it: metadata->references) {
     278           2 :         for (const auto& ref: it.value()) {
     279           2 :             this->cite(ref);
     280           1 :         }
     281             :     }
     282             : 
     283             :     // parse the atomic types from the input file
     284             :     std::vector<int32_t> atomic_types;
     285             :     std::vector<int32_t> species_to_types;
     286          14 :     this->parseVector("SPECIES_TO_TYPES", species_to_types);
     287             :     bool has_custom_types = !species_to_types.empty();
     288             : 
     289             :     std::vector<AtomNumber> all_atoms;
     290          14 :     this->parseAtomList("SPECIES", all_atoms);
     291             : 
     292             :     size_t n_species = 0;
     293           7 :     if (all_atoms.empty()) {
     294             :         // first parse each of the 'SPECIES' entry
     295             :         std::vector<std::vector<AtomNumber>> atoms_per_species;
     296             :         int i = 0;
     297             :         while (true) {
     298          24 :             i += 1;
     299             :             auto atoms = std::vector<AtomNumber>();
     300          48 :             this->parseAtomList("SPECIES", i, atoms);
     301             : 
     302          24 :             if (atoms.empty()) {
     303             :                 break;
     304             :             }
     305             : 
     306             :             int32_t type = i;
     307          17 :             if (has_custom_types) {
     308          17 :                 if (species_to_types.size() < static_cast<size_t>(i)) {
     309           0 :                     this->error(
     310             :                         "SPECIES_TO_TYPES is too small, it should have one entry "
     311           0 :                         "for each species (we have at least " + std::to_string(i) +
     312           0 :                         " species and " + std::to_string(species_to_types.size()) +
     313             :                         "entries in SPECIES_TO_TYPES)"
     314             :                     );
     315             :                 }
     316             : 
     317          17 :                 type = species_to_types[static_cast<size_t>(i - 1)];
     318             :             }
     319             : 
     320          17 :             log.printf("  atoms with type %d are: ", type);
     321        1285 :             for(unsigned j=0; j<atoms.size(); j++) {
     322        1268 :                 log.printf("%d ", atoms[j]);
     323             :             }
     324          17 :             log.printf("\n");
     325             : 
     326          17 :             n_species += 1;
     327          17 :             atoms_per_species.emplace_back(std::move(atoms));
     328             :         }
     329             : 
     330             :         size_t n_atoms = 0;
     331          24 :         for (const auto& atoms: atoms_per_species) {
     332          17 :             n_atoms += atoms.size();
     333             :         }
     334             : 
     335             :         // then fill the atomic_types as required
     336           7 :         atomic_types.resize(n_atoms, 0);
     337             :         i = 0;
     338          24 :         for (const auto& atoms: atoms_per_species) {
     339          17 :             i += 1;
     340             : 
     341             :             int32_t type = i;
     342          17 :             if (has_custom_types) {
     343          17 :                 type = species_to_types[static_cast<size_t>(i - 1)];
     344             :             }
     345             : 
     346        1285 :             for (const auto& atom: atoms) {
     347        1268 :                 atomic_types[atom.index()] = type;
     348             :             }
     349             :         }
     350           7 :     } else {
     351             :         n_species = 1;
     352             : 
     353           0 :         int32_t type = 1;
     354           0 :         if (has_custom_types) {
     355           0 :             type = species_to_types[0];
     356             :         }
     357           0 :         atomic_types.resize(all_atoms.size(), type);
     358             :     }
     359             : 
     360           7 :     if (has_custom_types && species_to_types.size() != n_species) {
     361           0 :         this->warning(
     362           0 :             "SPECIES_TO_TYPES contains more entries (" +
     363           0 :             std::to_string(species_to_types.size()) +
     364           0 :             ") than there where species (" + std::to_string(n_species) + ")"
     365             :         );
     366             :     }
     367             : 
     368             :     // request atoms in order
     369             :     all_atoms.clear();
     370        1275 :     for (size_t i=0; i<atomic_types.size(); i++) {
     371        1268 :         all_atoms.push_back(AtomNumber::index(i));
     372             :     }
     373           7 :     this->requestAtoms(all_atoms);
     374             : 
     375          14 :     this->atomic_types_ = torch::tensor(std::move(atomic_types));
     376             : 
     377           7 :     this->check_consistency_ = false;
     378           7 :     this->parseFlag("CHECK_CONSISTENCY", this->check_consistency_);
     379           7 :     if (this->check_consistency_) {
     380           0 :         log.printf("  checking for internal consistency of the model\n");
     381             :     }
     382             : 
     383             :     // create evaluation options for the model. These won't change during the
     384             :     // simulation, so we initialize them once here.
     385           7 :     evaluations_options_ = torch::make_intrusive<metatomic_torch::ModelEvaluationOptionsHolder>();
     386          14 :     evaluations_options_->set_length_unit(getUnits().getLengthString());
     387             : 
     388             :     auto outputs = this->capabilities_->outputs();
     389          14 :     if (!outputs.contains("features")) {
     390             :         auto existing_outputs = std::vector<std::string>();
     391           0 :         for (const auto& it: this->capabilities_->outputs()) {
     392           0 :             existing_outputs.push_back(it.key());
     393             :         }
     394             : 
     395           0 :         this->error(
     396             :             "expected a 'features' output in the capabilities of the model, "
     397           0 :             "could not find it. the following outputs exist: " + torch::str(existing_outputs)
     398             :         );
     399           0 :     }
     400             : 
     401             :     auto output = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
     402             :     // this output has no quantity or unit to set
     403             : 
     404           7 :     output->per_atom = this->capabilities_->outputs().at("features")->per_atom;
     405             :     // we are using torch autograd system to compute gradients,
     406             :     // so we don't need any explicit gradients.
     407           7 :     output->explicit_gradients = {};
     408           7 :     evaluations_options_->outputs.insert("features", output);
     409             : 
     410             :     // Determine which device we should use based on user input, what the model
     411             :     // supports and what's available
     412             :     auto available_devices = std::vector<torch::Device>();
     413          22 :     for (const auto& device: this->capabilities_->supported_devices) {
     414          15 :         if (device == "cpu") {
     415           7 :             available_devices.push_back(torch::kCPU);
     416           8 :         } else if (device == "cuda") {
     417           4 :             if (torch::cuda::is_available()) {
     418           0 :                 available_devices.push_back(torch::Device("cuda"));
     419             :             }
     420           4 :         } else if (device == "mps") {
     421             :             #if TORCH_VERSION_MAJOR >= 2
     422           4 :             if (torch::mps::is_available()) {
     423           0 :                 available_devices.push_back(torch::Device("mps"));
     424             :             }
     425             :             #endif
     426             :         } else {
     427           0 :             this->warning(
     428           0 :                 "the model declared support for unknown device '" + device +
     429             :                 "', it will be ignored"
     430             :             );
     431             :         }
     432             :     }
     433             : 
     434           7 :     if (available_devices.empty()) {
     435           0 :         this->error(
     436           0 :             "failed to find a valid device for the model at '" + model_path + "': "
     437           0 :             "the model supports " + torch::str(this->capabilities_->supported_devices) +
     438             :             ", none of these where available"
     439             :         );
     440             :     }
     441             : 
     442             :     std::string requested_device;
     443          14 :     this->parse("DEVICE", requested_device);
     444           7 :     if (requested_device.empty()) {
     445             :         // no user request, pick the device the model prefers
     446           3 :         this->device_ = available_devices[0];
     447             :     } else {
     448             :         bool found_requested_device = false;
     449           4 :         for (const auto& device: available_devices) {
     450           8 :             if (device.is_cpu() && requested_device == "cpu") {
     451           4 :                 this->device_ = device;
     452             :                 found_requested_device = true;
     453             :                 break;
     454           0 :             } else if (device.is_cuda() && requested_device == "cuda") {
     455           0 :                 this->device_ = device;
     456             :                 found_requested_device = true;
     457             :                 break;
     458           0 :             } else if (device.is_mps() && requested_device == "mps") {
     459           0 :                 this->device_ = device;
     460             :                 found_requested_device = true;
     461             :                 break;
     462             :             }
     463             :         }
     464             : 
     465             :         if (!found_requested_device) {
     466           0 :             this->error(
     467           0 :                 "failed to find requested device (" + requested_device + "): it is either "
     468             :                 "not supported by this model or not available on this machine"
     469             :             );
     470             :         }
     471             :     }
     472             : 
     473           7 :     this->model_.to(this->device_);
     474           7 :     this->atomic_types_ = this->atomic_types_.to(this->device_);
     475             : 
     476           7 :     log.printf(
     477             :         "  running model on %s device with %s data\n",
     478          14 :         this->device_.str().c_str(),
     479             :         this->capabilities_->dtype().c_str()
     480             :     );
     481             : 
     482           7 :     if (this->capabilities_->dtype() == "float64") {
     483           3 :         this->dtype_ = torch::kFloat64;
     484           4 :     } else if (this->capabilities_->dtype() == "float32") {
     485           4 :         this->dtype_ = torch::kFloat32;
     486             :     } else {
     487           0 :         this->error(
     488           0 :             "the model requested an unsupported dtype '" + this->capabilities_->dtype() + "'"
     489             :         );
     490             :     }
     491             : 
     492           7 :     auto tensor_options = torch::TensorOptions().dtype(this->dtype_).device(this->device_);
     493          14 :     this->strain_ = torch::eye(3, tensor_options.requires_grad(true));
     494             : 
     495             :     // determine how many properties there will be in the output by running the
     496             :     // model once on a dummy system
     497             :     auto dummy_system = torch::make_intrusive<metatomic_torch::SystemHolder>(
     498          14 :         /*types = */ torch::zeros({0}, tensor_options.dtype(torch::kInt32)),
     499          14 :         /*positions = */ torch::zeros({0, 3}, tensor_options),
     500          14 :         /*cell = */ torch::zeros({3, 3}, tensor_options),
     501           7 :         /*pbc = */ torch::zeros({3}, tensor_options.dtype(torch::kBool))
     502             :     );
     503             : 
     504           7 :     log.printf("  the following neighbor lists have been requested:\n");
     505           7 :     auto length_unit = this->getUnits().getLengthString();
     506           7 :     auto model_length_unit = this->capabilities_->length_unit();
     507          14 :     for (auto request: this->nl_requests_) {
     508          10 :         log.printf("    - %s list, %g %s cutoff (requested %g %s)\n",
     509             :             request->full_list() ? "full" : "half",
     510             :             request->engine_cutoff(length_unit),
     511             :             length_unit.c_str(),
     512             :             request->cutoff(),
     513             :             model_length_unit.c_str()
     514             :         );
     515             : 
     516             :         auto neighbors = this->computeNeighbors(
     517             :             request,
     518             :             {PLMD::Vector(0, 0, 0)},
     519           7 :             PLMD::Tensor(0, 0, 0, 0, 0, 0, 0, 0, 0),
     520             :             false
     521          21 :         );
     522          21 :         metatomic_torch::register_autograd_neighbors(dummy_system, neighbors, this->check_consistency_);
     523          14 :         dummy_system->add_neighbor_list(request, neighbors);
     524             :     }
     525             : 
     526           7 :     this->n_properties_ = static_cast<unsigned>(
     527          14 :         this->executeModel(dummy_system)->properties()->count()
     528             :     );
     529             : 
     530             :     // parse and handle atom sub-selection. This is done AFTER determining the
     531             :     // output size, since the selection might not be valid for the dummy system
     532             :     std::vector<int32_t> selected_atoms;
     533          14 :     this->parseVector("SELECTED_ATOMS", selected_atoms);
     534           7 :     if (!selected_atoms.empty()) {
     535             :         auto selection_value = torch::zeros(
     536             :             {static_cast<int64_t>(selected_atoms.size()), 2},
     537           2 :             torch::TensorOptions().dtype(torch::kInt32).device(this->device_)
     538           2 :         );
     539             : 
     540           9 :         for (unsigned i=0; i<selected_atoms.size(); i++) {
     541           7 :             auto n_atoms = static_cast<int32_t>(this->atomic_types_.size(0));
     542           7 :             if (selected_atoms[i] <= 0 || selected_atoms[i] > n_atoms) {
     543           0 :                 this->error(
     544             :                     "Values in metatomic's SELECTED_ATOMS should be between 1 "
     545           0 :                     "and the number of atoms (" + std::to_string(n_atoms) + "), "
     546           0 :                     "got " + std::to_string(selected_atoms[i]));
     547             :             }
     548             :             // PLUMED input uses 1-based indexes, but metatomic wants 0-based
     549          14 :             selection_value[i][1] = selected_atoms[i] - 1;
     550             :         }
     551             : 
     552           4 :         evaluations_options_->set_selected_atoms(
     553           2 :             torch::make_intrusive<metatensor_torch::LabelsHolder>(
     554           8 :                 std::vector<std::string>{"system", "atom"}, selection_value
     555             :             )
     556             :         );
     557             :     }
     558             : 
     559             :     // Now that we now both n_samples and n_properties, we can setup the
     560             :     // PLUMED-side storage for the computed CV
     561           7 :     if (output->per_atom) {
     562           4 :         if (selected_atoms.empty()) {
     563           2 :             this->n_samples_ = static_cast<unsigned>(this->atomic_types_.size(0));
     564             :         } else {
     565           2 :             this->n_samples_ = static_cast<unsigned>(selected_atoms.size());
     566             :         }
     567             :     } else {
     568           3 :         this->n_samples_ = 1;
     569             :     }
     570             : 
     571           7 :     if (n_samples_ == 1 && n_properties_ == 1) {
     572           2 :         log.printf("  the output of this model is a scalar\n");
     573             : 
     574           4 :         this->addValue();
     575           5 :     } else if (n_samples_ == 1) {
     576           1 :         log.printf("  the output of this model is 1x%d vector\n", n_properties_);
     577             : 
     578           2 :         this->addValue({this->n_properties_});
     579           4 :     } else if (n_properties_ == 1) {
     580           1 :         log.printf("  the output of this model is %dx1 vector\n", n_samples_);
     581             : 
     582           2 :         this->addValue({this->n_samples_});
     583             :     } else {
     584           3 :         log.printf("  the output of this model is a %dx%d matrix\n", n_samples_, n_properties_);
     585             : 
     586           3 :         this->addValue({this->n_samples_, this->n_properties_});
     587           3 :         this->getPntrToComponent(0)->reshapeMatrixStore(n_properties_);
     588             :     }
     589             : 
     590           7 :     this->setNotPeriodic();
     591           7 : }
     592             : 
     593          12 : unsigned MetatomicPlumedAction::getNumberOfDerivatives() {
     594             :     // gradients w.r.t. positions (3 x N values) + gradients w.r.t. strain (9 values)
     595          12 :     return 3 * this->getNumberOfAtoms() + 9;
     596             : }
     597             : 
     598             : 
     599           8 : void MetatomicPlumedAction::createSystem() {
     600          16 :     if (this->getTotAtoms() != static_cast<unsigned>(this->atomic_types_.size(0))) {
     601           0 :         std::ostringstream oss;
     602           0 :         oss << "METATOMIC action needs to know about all atoms in the system. ";
     603           0 :         oss << "There are " << this->getTotAtoms() << " atoms overall, ";
     604           0 :         oss << "but we only have atomic types for " << this->atomic_types_.size(0) << " of them.";
     605           0 :         plumed_merror(oss.str());
     606           0 :     }
     607             : 
     608             :     // this->getTotAtoms()
     609             : 
     610           8 :     const auto& cell = this->getPbc().getBox();
     611             : 
     612           8 :     auto cpu_f64_tensor = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU);
     613           8 :     auto torch_cell = torch::zeros({3, 3}, cpu_f64_tensor);
     614             : 
     615          16 :     torch_cell[0][0] = cell(0, 0);
     616          16 :     torch_cell[0][1] = cell(0, 1);
     617          16 :     torch_cell[0][2] = cell(0, 2);
     618             : 
     619          16 :     torch_cell[1][0] = cell(1, 0);
     620          16 :     torch_cell[1][1] = cell(1, 1);
     621          16 :     torch_cell[1][2] = cell(1, 2);
     622             : 
     623          16 :     torch_cell[2][0] = cell(2, 0);
     624          16 :     torch_cell[2][1] = cell(2, 1);
     625          16 :     torch_cell[2][2] = cell(2, 2);
     626             : 
     627             :     using torch::indexing::Slice;
     628             : 
     629          32 :     auto pbc_a = torch_cell.index({0, Slice()}).norm().abs().item<double>() > 1e-9;
     630          32 :     auto pbc_b = torch_cell.index({1, Slice()}).norm().abs().item<double>() > 1e-9;
     631          32 :     auto pbc_c = torch_cell.index({2, Slice()}).norm().abs().item<double>() > 1e-9;
     632          48 :     auto torch_pbc = torch::tensor({pbc_a, pbc_b, pbc_c});
     633             : 
     634             :     const auto& positions = this->getPositions();
     635             : 
     636             :     auto torch_positions = torch::from_blob(
     637             :         const_cast<PLMD::Vector*>(positions.data()),
     638             :         {static_cast<int64_t>(positions.size()), 3},
     639             :         cpu_f64_tensor
     640           8 :     );
     641             : 
     642          16 :     torch_positions = torch_positions.to(this->dtype_).to(this->device_);
     643          16 :     torch_cell = torch_cell.to(this->dtype_).to(this->device_);
     644           8 :     torch_pbc = torch_pbc.to(this->device_);
     645             : 
     646             :     // setup torch's automatic gradient tracking
     647           8 :     if (!this->doNotCalculateDerivatives()) {
     648             :         torch_positions.requires_grad_(true);
     649             : 
     650             :         // pretend to scale positions/cell by the strain so that it enters the
     651             :         // computational graph.
     652           7 :         torch_positions = torch_positions.matmul(this->strain_);
     653           7 :         torch_positions.retain_grad();
     654             : 
     655           7 :         torch_cell = torch_cell.matmul(this->strain_);
     656             :     }
     657             : 
     658           8 :     this->system_ = torch::make_intrusive<metatomic_torch::SystemHolder>(
     659           8 :         this->atomic_types_,
     660             :         torch_positions,
     661             :         torch_cell,
     662             :         torch_pbc
     663             :     );
     664             : 
     665           8 :     auto periodic = torch::all(torch_pbc).item<bool>();
     666           8 :     if (!periodic && torch::any(torch_pbc).item<bool>()) {
     667             :         std::string periodic_directions;
     668             :         std::string non_periodic_directions;
     669           0 :         if (pbc_a) {
     670             :             periodic_directions += "A";
     671             :         } else {
     672             :             non_periodic_directions += "A";
     673             :         }
     674             : 
     675           0 :         if (pbc_b) {
     676             :             periodic_directions += "B";
     677             :         } else {
     678             :             non_periodic_directions += "B";
     679             :         }
     680             : 
     681           0 :         if (pbc_c) {
     682             :             periodic_directions += "C";
     683             :         } else {
     684             :             non_periodic_directions += "C";
     685             :         }
     686             : 
     687           0 :         plumed_merror(
     688             :             "mixed periodic boundary conditions are not supported, this system "
     689             :             "is periodic along the " + periodic_directions + " cell vector(s), "
     690             :             "but not along the " + non_periodic_directions + " cell vector(s)."
     691             :         );
     692             :     }
     693             : 
     694             :     // compute the neighbors list requested by the model, and register them with
     695             :     // the system
     696          16 :     for (auto request: this->nl_requests_) {
     697           8 :         auto neighbors = this->computeNeighbors(request, positions, cell, periodic);
     698          24 :         metatomic_torch::register_autograd_neighbors(this->system_, neighbors, this->check_consistency_);
     699          16 :         this->system_->add_neighbor_list(request, neighbors);
     700             :     }
     701           8 : }
     702             : 
     703             : 
     704          15 : metatensor_torch::TensorBlock MetatomicPlumedAction::computeNeighbors(
     705             :     metatomic_torch::NeighborListOptions request,
     706             :     const std::vector<PLMD::Vector>& positions,
     707             :     const PLMD::Tensor& cell,
     708             :     bool periodic
     709             : ) {
     710          15 :     auto labels_options = torch::TensorOptions().dtype(torch::kInt32).device(this->device_);
     711             :     auto neighbor_component = torch::make_intrusive<metatensor_torch::LabelsHolder>(
     712             :         "xyz",
     713          75 :         torch::tensor({0, 1, 2}, labels_options).reshape({3, 1})
     714             :     );
     715             :     auto neighbor_properties = torch::make_intrusive<metatensor_torch::LabelsHolder>(
     716          30 :         "distance", torch::zeros({1, 1}, labels_options)
     717             :     );
     718             : 
     719          15 :     auto cutoff = request->engine_cutoff(this->getUnits().getLengthString());
     720             : 
     721             :     // use https://github.com/Luthaf/vesin to compute the requested neighbor
     722             :     // lists since we can not get these from PLUMED
     723             :     vesin::VesinOptions options;
     724          15 :     options.cutoff = cutoff;
     725          15 :     options.full = request->full_list();
     726          15 :     options.return_shifts = true;
     727          15 :     options.return_distances = false;
     728          15 :     options.return_vectors = true;
     729             : 
     730          15 :     vesin::VesinNeighborList* vesin_neighbor_list = new vesin::VesinNeighborList();
     731             :     memset(vesin_neighbor_list, 0, sizeof(vesin::VesinNeighborList));
     732             : 
     733          15 :     const char* error_message = NULL;
     734          15 :     int status = vesin_neighbors(
     735             :         reinterpret_cast<const double (*)[3]>(positions.data()),
     736             :         positions.size(),
     737             :         reinterpret_cast<const double (*)[3]>(&cell(0, 0)),
     738             :         periodic,
     739             :         vesin::VesinCPU,
     740             :         options,
     741             :         vesin_neighbor_list,
     742             :         &error_message
     743             :     );
     744             : 
     745          15 :     if (status != EXIT_SUCCESS) {
     746           0 :         plumed_merror(
     747             :             "failed to compute neighbor list (cutoff=" + std::to_string(cutoff) +
     748             :             ", full=" + (request->full_list() ? "true" : "false") + "): " + error_message
     749             :         );
     750             :     }
     751             : 
     752             :     // transform from vesin to metatomic format
     753          15 :     auto n_pairs = static_cast<int64_t>(vesin_neighbor_list->length);
     754             : 
     755             :     auto pair_vectors = torch::from_blob(
     756          15 :         vesin_neighbor_list->vectors,
     757             :         {n_pairs, 3, 1},
     758          15 :         /*deleter*/ [=](void*) {
     759          15 :             vesin_free(vesin_neighbor_list);
     760          15 :             delete vesin_neighbor_list;
     761          15 :         },
     762          15 :         torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU)
     763          15 :     );
     764             : 
     765          15 :     auto pair_samples_values = torch::empty({n_pairs, 5}, labels_options.device(torch::kCPU));
     766          15 :     auto pair_samples_values_ptr = pair_samples_values.accessor<int32_t, 2>();
     767       15189 :     for (unsigned i=0; i<n_pairs; i++) {
     768       15174 :         pair_samples_values_ptr[i][0] = static_cast<int32_t>(vesin_neighbor_list->pairs[i][0]);
     769       15174 :         pair_samples_values_ptr[i][1] = static_cast<int32_t>(vesin_neighbor_list->pairs[i][1]);
     770       15174 :         pair_samples_values_ptr[i][2] = vesin_neighbor_list->shifts[i][0];
     771       15174 :         pair_samples_values_ptr[i][3] = vesin_neighbor_list->shifts[i][1];
     772       15174 :         pair_samples_values_ptr[i][4] = vesin_neighbor_list->shifts[i][2];
     773             :     }
     774             : 
     775             :     auto neighbor_samples = torch::make_intrusive<metatensor_torch::LabelsHolder>(
     776         120 :         std::vector<std::string>{"first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"},
     777          30 :         pair_samples_values.to(this->device_)
     778             :     );
     779             : 
     780             :     auto neighbors = torch::make_intrusive<metatensor_torch::TensorBlockHolder>(
     781          30 :         pair_vectors.to(this->dtype_).to(this->device_),
     782             :         neighbor_samples,
     783          60 :         std::vector<metatensor_torch::Labels>{neighbor_component},
     784             :         neighbor_properties
     785             :     );
     786             : 
     787          15 :     return neighbors;
     788             : }
     789             : 
     790          15 : metatensor_torch::TensorBlock MetatomicPlumedAction::executeModel(metatomic_torch::System system) {
     791             :     try {
     792          45 :         auto ivalue_output = this->model_.forward({
     793          60 :             std::vector<metatomic_torch::System>{system},
     794             :             evaluations_options_,
     795          15 :             this->check_consistency_,
     796          75 :         });
     797             : 
     798          15 :         auto dict_output = ivalue_output.toGenericDict();
     799          15 :         auto cv = dict_output.at("features");
     800          30 :         this->output_ = cv.toCustomClass<metatensor_torch::TensorMapHolder>();
     801           0 :     } catch (const std::exception& e) {
     802           0 :         plumed_merror("failed to evaluate the model: " + std::string(e.what()));
     803           0 :     }
     804             : 
     805          30 :     plumed_massert(this->output_->keys()->count() == 1, "output should have a single block");
     806          30 :     auto block = metatensor_torch::TensorMapHolder::block_by_id(this->output_, 0);
     807          15 :     plumed_massert(block->components().empty(), "components are not yet supported in the output");
     808             : 
     809          15 :     return block;
     810             : }
     811             : 
     812             : 
     813           8 : void MetatomicPlumedAction::calculate() {
     814           8 :     this->createSystem();
     815             : 
     816          16 :     auto block = this->executeModel(this->system_);
     817          24 :     auto torch_values = block->values().to(torch::kCPU).to(torch::kFloat64);
     818             : 
     819           8 :     if (static_cast<unsigned>(torch_values.size(0)) != this->n_samples_) {
     820           0 :         plumed_merror(
     821             :             "expected the model to return a TensorBlock with " +
     822             :             std::to_string(this->n_samples_) + " samples, got " +
     823             :             std::to_string(torch_values.size(0)) + " instead"
     824             :         );
     825           8 :     } else if (static_cast<unsigned>(torch_values.size(1)) != this->n_properties_) {
     826           0 :         plumed_merror(
     827             :             "expected the model to return a TensorBlock with " +
     828             :             std::to_string(this->n_properties_) + " properties, got " +
     829             :             std::to_string(torch_values.size(1)) + " instead"
     830             :         );
     831             :     }
     832             : 
     833           8 :     Value* value = this->getPntrToComponent(0);
     834             :     // reshape the plumed `Value` to hold the data returned by the model
     835           8 :     if (n_samples_ == 1) {
     836           4 :         if (n_properties_ == 1) {
     837           3 :             value->set(torch_values.item<double>());
     838             :         } else {
     839             :             // we have multiple CV describing a single thing (atom or full system)
     840           3 :             for (unsigned i=0; i<n_properties_; i++) {
     841           2 :                 value->set(i, torch_values[0][i].item<double>());
     842             :             }
     843             :         }
     844             :     } else {
     845             :         auto samples = block->samples();
     846          16 :         plumed_assert((samples->names() == std::vector<std::string>{"system", "atom"}));
     847             : 
     848           4 :         auto samples_values = samples->values().to(torch::kCPU);
     849             :         auto selected_atoms = this->evaluations_options_->get_selected_atoms();
     850             : 
     851             :         // handle the possibility that samples are returned in
     852             :         // a non-sorted order.
     853          92 :         auto get_output_location = [&](unsigned i) {
     854          92 :             if (selected_atoms.has_value()) {
     855             :                 // If the users picked some selected atoms, then we store the
     856             :                 // output in the same order as the selection was given
     857          28 :                 auto sample = samples_values.index({static_cast<int64_t>(i), torch::indexing::Slice()});
     858           7 :                 auto position = selected_atoms.value()->position(sample);
     859           7 :                 plumed_assert(position.has_value());
     860           7 :                 return static_cast<unsigned>(position.value());
     861             :             } else {
     862         170 :                 return static_cast<unsigned>(samples_values[i][1].item<int32_t>());
     863             :             }
     864           4 :         };
     865             : 
     866           4 :         if (n_properties_ == 1) {
     867             :             // we have a single CV describing multiple things (i.e. atoms)
     868           5 :             for (unsigned i=0; i<n_samples_; i++) {
     869           4 :                 auto output_i = get_output_location(i);
     870           8 :                 value->set(output_i, torch_values[i][0].item<double>());
     871             :             }
     872             :         } else {
     873             :             // the CV is a matrix
     874          91 :             for (unsigned i=0; i<n_samples_; i++) {
     875          88 :                 auto output_i = get_output_location(i);
     876         343 :                 for (unsigned j=0; j<n_properties_; j++) {
     877         255 :                     value->set(output_i * n_properties_ + j, torch_values[i][j].item<double>());
     878             :                 }
     879             :             }
     880             :         }
     881             :     }
     882           8 : }
     883             : 
     884             : 
     885           8 : void MetatomicPlumedAction::apply() {
     886           8 :     const auto* value = this->getPntrToComponent(0);
     887           8 :     if (!value->forcesWereAdded()) {
     888           1 :         return;
     889             :     }
     890             : 
     891          14 :     auto block = metatensor_torch::TensorMapHolder::block_by_id(this->output_, 0);
     892          21 :     auto torch_values = block->values().to(torch::kCPU).to(torch::kFloat64);
     893             : 
     894           7 :     if (!torch_values.requires_grad()) {
     895           0 :         this->warning(
     896             :             "the output of the model does not requires gradients, this might "
     897             :             "indicate a problem"
     898             :         );
     899             :         return;
     900             :     }
     901             : 
     902           7 :     auto output_grad = torch::zeros_like(torch_values);
     903           7 :     if (n_samples_ == 1) {
     904           4 :         if (n_properties_ == 1) {
     905           3 :             output_grad[0][0] = value->getForce();
     906             :         } else {
     907           3 :             for (unsigned i=0; i<n_properties_; i++) {
     908           4 :                 output_grad[0][i] = value->getForce(i);
     909             :             }
     910             :         }
     911             :     } else {
     912             :         auto samples = block->samples();
     913          12 :         plumed_assert((samples->names() == std::vector<std::string>{"system", "atom"}));
     914             : 
     915           3 :         auto samples_values = samples->values().to(torch::kCPU);
     916             :         auto selected_atoms = this->evaluations_options_->get_selected_atoms();
     917             : 
     918             :         // see above for an explanation of why we use this function
     919          89 :         auto get_output_location = [&](unsigned i) {
     920          89 :             if (selected_atoms.has_value()) {
     921          16 :                 auto sample = samples_values.index({static_cast<int64_t>(i), torch::indexing::Slice()});
     922           4 :                 auto position = selected_atoms.value()->position(sample);
     923           4 :                 plumed_assert(position.has_value());
     924           4 :                 return static_cast<unsigned>(position.value());
     925             :             } else {
     926         170 :                 return static_cast<unsigned>(samples_values[i][1].item<int32_t>());
     927             :             }
     928           3 :         };
     929             : 
     930           3 :         if (n_properties_ == 1) {
     931           5 :             for (unsigned i=0; i<n_samples_; i++) {
     932           4 :                 auto output_i = get_output_location(i);
     933           8 :                 output_grad[i][0] = value->getForce(output_i);
     934             :             }
     935             :         } else {
     936          87 :             for (unsigned i=0; i<n_samples_; i++) {
     937          85 :                 auto output_i = get_output_location(i);
     938         331 :                 for (unsigned j=0; j<n_properties_; j++) {
     939         492 :                     output_grad[i][j] = value->getForce(output_i * n_properties_ + j);
     940             :                 }
     941             :             }
     942             :         }
     943             :     }
     944             : 
     945           7 :     this->system_->positions().mutable_grad() = torch::Tensor();
     946           7 :     this->strain_.mutable_grad() = torch::Tensor();
     947             : 
     948           7 :     torch_values.backward(output_grad);
     949           7 :     auto positions_grad = this->system_->positions().grad();
     950           7 :     auto strain_grad = this->strain_.grad();
     951             : 
     952          21 :     positions_grad = positions_grad.to(torch::kCPU).to(torch::kFloat64);
     953          21 :     strain_grad = strain_grad.to(torch::kCPU).to(torch::kFloat64);
     954             : 
     955           7 :     plumed_assert(positions_grad.sizes().size() == 2);
     956           7 :     plumed_assert(positions_grad.is_contiguous());
     957             : 
     958           7 :     plumed_assert(strain_grad.sizes().size() == 2);
     959           7 :     plumed_assert(strain_grad.is_contiguous());
     960             : 
     961             :     auto derivatives = std::vector<double>(
     962             :         positions_grad.data_ptr<double>(),
     963           7 :         positions_grad.data_ptr<double>() + 3 * this->system_->size()
     964           7 :     );
     965             : 
     966             :     // add virials to the derivatives
     967           7 :     derivatives.push_back(-strain_grad[0][0].item<double>());
     968           7 :     derivatives.push_back(-strain_grad[0][1].item<double>());
     969           7 :     derivatives.push_back(-strain_grad[0][2].item<double>());
     970             : 
     971           7 :     derivatives.push_back(-strain_grad[1][0].item<double>());
     972           7 :     derivatives.push_back(-strain_grad[1][1].item<double>());
     973           7 :     derivatives.push_back(-strain_grad[1][2].item<double>());
     974             : 
     975           7 :     derivatives.push_back(-strain_grad[2][0].item<double>());
     976           7 :     derivatives.push_back(-strain_grad[2][1].item<double>());
     977           7 :     derivatives.push_back(-strain_grad[2][2].item<double>());
     978             : 
     979           7 :     unsigned index = 0;
     980           7 :     this->setForcesOnAtoms(derivatives, index);
     981             : }
     982             : 
     983             : } // namespace metatomic
     984             : } // namespace PLMD
     985             : 
     986             : 
     987             : #endif
     988             : 
     989             : 
     990             : namespace PLMD {
     991             : namespace metatomic {
     992             : 
     993             : // use the same implementation for both the actual action and the dummy one
     994             : // (when libtorch and libmetatomic could not be found).
     995           9 : void MetatomicPlumedAction::registerKeywords(Keywords& keys) {
     996           9 :     Action::registerKeywords(keys);
     997           9 :     ActionAtomistic::registerKeywords(keys);
     998           9 :     ActionWithValue::registerKeywords(keys);
     999             : 
    1000           9 :     keys.add("compulsory", "MODEL", "path to the exported metatomic model");
    1001           9 :     keys.add("optional", "EXTENSIONS_DIRECTORY", "path to the directory containing TorchScript extensions to load");
    1002           9 :     keys.add("optional", "DEVICE", "Torch device to use for the calculations");
    1003             : 
    1004           9 :     keys.addFlag("CHECK_CONSISTENCY", false, "should we enable internal consistency checks when executing the model");
    1005             : 
    1006           9 :     keys.add("numbered", "SPECIES", "the indices of atoms in each PLUMED species");
    1007          18 :     keys.reset_style("SPECIES", "atoms");
    1008             : 
    1009           9 :     keys.add("optional", "SELECTED_ATOMS", "subset of atoms that should be used for the calculation");
    1010          18 :     keys.reset_style("SELECTED_ATOMS", "atoms");
    1011             : 
    1012           9 :     keys.add("optional", "SPECIES_TO_TYPES", "mapping from PLUMED SPECIES to metatomic's atom types");
    1013             : 
    1014          18 :     keys.addOutputComponent("outputs", "default", "scalar/vector/matrix", "collective variable created by the metatomic model");
    1015          18 :     keys.setValueDescription("scalar/vector/matrix", "collective variable created by the metatomic model");
    1016           9 : }
    1017             : 
    1018             : PLUMED_REGISTER_ACTION(MetatomicPlumedAction, "METATOMIC")
    1019             : 
    1020             : } // namespace metatomic
    1021             : } // namespace PLMD

Generated by: LCOV version 1.16