LCOV - code coverage report
Current view: top level - metatomic - metatomic.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 323 397 81.4 %
Date: 2026-06-05 17:04:24 Functions: 12 15 80.0 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             : Copyright (c) 2024 The METATOMIC-PLUMED team
       3             : (see the PEOPLE-METATOMIC file at the root of this folder for a list of names)
       4             : 
       5             : See https://docs.metatensor.org/metatomic/ for more information about the
       6             : metatomic package that this module allows you to call from PLUMED.
       7             : 
       8             : This file is part of METATOMIC-PLUMED module.
       9             : 
      10             : The METATOMIC-PLUMED module is free software: you can redistribute it and/or modify
      11             : it under the terms of the GNU Lesser General Public License as published by
      12             : the Free Software Foundation, either version 3 of the License, or
      13             : (at your option) any later version.
      14             : 
      15             : The METATOMIC-PLUMED module is distributed in the hope that it will be useful,
      16             : but WITHOUT ANY WARRANTY; without even the implied warranty of
      17             : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      18             : GNU Lesser General Public License for more details.
      19             : 
      20             : You should have received a copy of the GNU Lesser General Public License
      21             : along with the METATOMIC-PLUMED module. If not, see <http://www.gnu.org/licenses/>.
      22             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      23             : 
      24             : #include "core/ActionAtomistic.h"
      25             : #include "core/ActionWithValue.h"
      26             : #include "core/ActionRegister.h"
      27             : #include "core/PlumedMain.h"
      28             : 
      29             : //+PLUMEDOC METATOMICMOD_COLVAR METATOMIC
      30             : /*
      31             : Use arbitrary machine learning models as collective variables.
      32             : 
      33             : \note This action requires the metatomic-torch library. Check the
      34             : instructions in the \ref METATOMICMOD page to enable this module.
      35             : 
      36             : This action enables the use of fully custom machine learning models — based on
      37             : the [metatomic] models interface — as collective variables in PLUMED. Such
      38             : machine learning model are typically written and customized using Python code,
      39             : and then exported to run within PLUMED as [TorchScript], which is a subset of
      40             : Python that can be executed by the C++ torch library.
      41             : 
      42             : Metatomic offers a way to define such models and pass data from PLUMED (or any
      43             : other simulation engine) to the model and back. For more information on how to
      44             : define such model, have a look at the [corresponding tutorials][mta_tutorials],
      45             : or at the code in `regtest/metatomic/`. Each of the Python scripts in this
      46             : directory defines a custom machine learning CV that can be used with PLUMED.
      47             : 
      48             : \par Examples
      49             : 
      50             : The following input shows how you can call metatomic and evaluate the model that
      51             : is described in the file `custom_cv.pt` from PLUMED.
      52             : 
      53             : \plumedfile metatomic_cv: METATOMIC ... MODEL=custom_cv.pt
      54             : 
      55             :     SPECIES1=1-26
      56             :     SPECIES2=27-62
      57             :     SPECIES3=63-76
      58             :     SPECIES_TO_TYPES=6,1,8
      59             : ...
      60             : \endplumedfile
      61             : 
      62             : The numbered `SPECIES` labels are used to indicate the list of atoms that belong
      63             : to each atomic species in the system. The `SPECIES_TO_TYPE` keyword then
      64             : provides information on the atom type for each species. The first number here is
      65             : the atomic type of the atoms that have been specified using the `SPECIES1` flag,
      66             : the second number is the atomic number of the atoms that have been specified
      67             : using the `SPECIES2` flag and so on.
      68             : 
      69             : `METATOMIC` action also accepts the following options:
      70             : 
      71             : - `EXTENSIONS_DIRECTORY` should be the path to a directory containing
      72             :   TorchScript extensions (as shared libraries) that are required to load and
      73             :   execute the model. This matches the `collect_extensions` argument to
      74             :   `AtomisticModel.export` in Python.
      75             : - `CHECK_CONSISTENCY` can be used to enable internal consistency checks;
      76             : - `SELECTED_ATOMS` can be used to signal the metatomic models that it should
      77             :   only run its calculation for the selected subset of atoms. The model still
      78             :   need to know about all the atoms in the system (through the `SPECIES`
      79             :   keyword); but this can be used to reduce the calculation cost. Note that the
      80             :   indices of the selected atoms should start at 1 in the PLUMED input file, but
      81             :   they will be translated to start at 0 when given to the model (i.e. in
      82             :   Python/TorchScript, the `forward` method will receive a `selected_atoms` which
      83             :   starts at 0)
      84             : 
      85             : Here is another example with all the possible keywords:
      86             : 
      87             : \plumedfile soap: METATOMIC ... MODEL=soap.pt EXTENSION_DIRECTORY=extensions
      88             : CHECK_CONSISTENCY
      89             : 
      90             :     SPECIES1=1-10
      91             :     SPECIES2=11-20
      92             :     SPECIES_TO_TYPES=8,13
      93             : 
      94             :     # only run the calculation for the Aluminium (type 13) atoms, but
      95             :     # include the Oxygen (type 8) as potential neighbors.
      96             :     SELECTED_ATOMS=11-20
      97             : ...
      98             : \endplumedfile
      99             : 
     100             : \par Collective variables and metatomic  models
     101             : 
     102             : PLUMED can use the [`"features"` output][features_output] of metatomic models as
     103             : a collective variables.
     104             : 
     105             : */ /*
     106             : 
     107             : [TorchScript]: https://pytorch.org/docs/stable/jit.html
     108             : [metatomic]: https://docs.metatensor.org/metatomic/
     109             : [mta_tutorials]: https://docs.metatensor.org/metatomic/latest/examples/
     110             : [features_output]: https://docs.metatensor.org/metatomic/latest/outputs/features.html
     111             : */
     112             : //+ENDPLUMEDOC
     113             : 
     114             : /*INDENT-OFF*/
     115             : #if !defined(__PLUMED_HAS_LIBMETATOMIC) || !defined(__PLUMED_HAS_LIBTORCH)
     116             : 
     117             : namespace PLMD { namespace metatomic {
     118             : class MetatomicPlumedAction: public ActionAtomistic, public ActionWithValue {
     119             : public:
     120             :     static void registerKeywords(Keywords& keys);
     121             :     explicit MetatomicPlumedAction(const ActionOptions& options):
     122             :         Action(options),
     123             :         ActionAtomistic(options),
     124             :         ActionWithValue(options)
     125             :     {
     126             :         throw std::runtime_error(
     127             :             "Can not use metatomic action without the corresponding libraries. \n"
     128             :             "Make sure to configure with `--enable-libmetatomic --enable-libtorch` "
     129             :             "and that the corresponding libraries are found"
     130             :         );
     131             :     }
     132             : 
     133             :     void calculate() override {}
     134             :     void apply() override {}
     135             :     unsigned getNumberOfDerivatives() override {return 0;}
     136             : };
     137             : 
     138             : }} // namespace PLMD::metatomic
     139             : 
     140             : #else
     141             : 
     142             : #include <type_traits>
     143             : 
     144             : #pragma GCC diagnostic push
     145             : #pragma GCC diagnostic ignored "-Wpedantic"
     146             : #pragma GCC diagnostic ignored "-Wunused-parameter"
     147             : #pragma GCC diagnostic ignored "-Wfloat-equal"
     148             : #pragma GCC diagnostic ignored "-Wfloat-conversion"
     149             : #pragma GCC diagnostic ignored "-Wimplicit-float-conversion"
     150             : #pragma GCC diagnostic ignored "-Wimplicit-int-conversion"
     151             : #pragma GCC diagnostic ignored "-Wshorten-64-to-32"
     152             : #pragma GCC diagnostic ignored "-Wsign-conversion"
     153             : #pragma GCC diagnostic ignored "-Wold-style-cast"
     154             : 
     155             : #include <torch/script.h>
     156             : #include <torch/version.h>
     157             : #include <torch/cuda.h>
     158             : #if TORCH_VERSION_MAJOR >= 2
     159             : #include <torch/mps.h>
     160             : #endif
     161             : 
     162             : #pragma GCC diagnostic pop
     163             : 
     164             : #include <metatensor/torch.hpp>
     165             : #include <metatomic/torch.hpp>
     166             : 
     167             : #include "vesin.h"
     168             : 
     169             : 
     170             : namespace PLMD {
     171             : namespace metatomic {
     172             : 
     173             : // We will cast Vector/Tensor to pointers to arrays and doubles, so let's make
     174             : // sure this is legal to do
     175             : static_assert(std::is_standard_layout<PLMD::Vector>::value);
     176             : static_assert(sizeof(PLMD::Vector) == sizeof(std::array<double, 3>));
     177             : static_assert(alignof(PLMD::Vector) == alignof(std::array<double, 3>));
     178             : 
     179             : static_assert(std::is_standard_layout<PLMD::Tensor>::value);
     180             : static_assert(sizeof(PLMD::Tensor) == sizeof(std::array<std::array<double, 3>, 3>));
     181             : static_assert(alignof(PLMD::Tensor) == alignof(std::array<std::array<double, 3>, 3>));
     182             : 
     183             : /// Small helper class to compute one neighbor list requested by the metatomc
     184             : /// model using vesin
     185             : class NeighborListCalculator {
     186             : public:
     187             :     NeighborListCalculator(
     188             :         metatomic_torch::NeighborListOptions options,
     189             :         const std::string& engine_length_unit
     190             :     );
     191             :     ~NeighborListCalculator();
     192             : 
     193             :     NeighborListCalculator(const NeighborListCalculator& other) = delete;
     194             :     NeighborListCalculator& operator=(const NeighborListCalculator& other) = delete;
     195             : 
     196             :     NeighborListCalculator(NeighborListCalculator&& other) noexcept;
     197             :     NeighborListCalculator& operator=(NeighborListCalculator&& other) noexcept;
     198             : 
     199             :     // compute the neighbor list following metatomic format, using data from PLUMED
     200             :     metatensor_torch::TensorBlock compute(
     201             :         const std::vector<PLMD::Vector>& positions,
     202             :         const PLMD::Tensor& cell,
     203             :         std::array<bool, 3> periodic,
     204             :         torch::ScalarType dtype,
     205             :         torch::Device device
     206             :     );
     207             : 
     208             :     metatomic_torch::NeighborListOptions options;
     209             : private:
     210             :     double engine_cutoff_;
     211             :     vesin::VesinNeighborList neighbors_;
     212             : };
     213             : 
     214           8 : NeighborListCalculator::NeighborListCalculator(
     215             :     metatomic_torch::NeighborListOptions options_,
     216             :     const std::string& engine_length_unit
     217           8 : ):
     218           8 :     options(options_),
     219           8 :     engine_cutoff_(options_->engine_cutoff(engine_length_unit))
     220             : {
     221             :     memset(&this->neighbors_, 0, sizeof(vesin::VesinNeighborList));
     222           8 : }
     223             : 
     224           0 : NeighborListCalculator::NeighborListCalculator(NeighborListCalculator&& other) noexcept {
     225           0 :     this->options = other.options;
     226           0 :     this->engine_cutoff_ = other.engine_cutoff_;
     227           0 :     this->neighbors_ = other.neighbors_;
     228             : 
     229           0 :     memset(&other.neighbors_, 0, sizeof(vesin::VesinNeighborList));
     230           0 : }
     231             : 
     232           0 : NeighborListCalculator& NeighborListCalculator::operator=(NeighborListCalculator&& other) noexcept {
     233           0 :     if (this != &other) {
     234           0 :         vesin::vesin_free(&this->neighbors_);
     235             : 
     236           0 :         this->options = other.options;
     237           0 :         this->engine_cutoff_ = other.engine_cutoff_;
     238           0 :         this->neighbors_ = other.neighbors_;
     239             : 
     240           0 :         memset(&other.neighbors_, 0, sizeof(vesin::VesinNeighborList));
     241             :     }
     242           0 :     return *this;
     243             : }
     244             : 
     245           8 : NeighborListCalculator::~NeighborListCalculator() {
     246           8 :     vesin::vesin_free(&this->neighbors_);
     247           8 : }
     248             : 
     249          17 : metatensor_torch::TensorBlock NeighborListCalculator::compute(
     250             :     const std::vector<PLMD::Vector>& positions,
     251             :     const PLMD::Tensor& cell,
     252             :     std::array<bool, 3> periodic,
     253             :     torch::ScalarType dtype,
     254             :     torch::Device device
     255             : ) {
     256          17 :     auto labels_options = torch::TensorOptions().dtype(torch::kInt32).device(device);
     257             :     auto neighbor_component = torch::make_intrusive<metatensor_torch::LabelsHolder>(
     258             :         "xyz",
     259         102 :         torch::tensor({0, 1, 2}, labels_options).reshape({3, 1})
     260             :     );
     261             :     auto neighbor_properties = torch::make_intrusive<metatensor_torch::LabelsHolder>(
     262          17 :         "distance", torch::zeros({1, 1}, labels_options)
     263             :     );
     264             : 
     265             :     // use https://github.com/Luthaf/vesin to compute the requested neighbor
     266             :     // lists since we can not get these from PLUMED
     267             :     vesin::VesinOptions vesin_options;
     268          17 :     vesin_options.cutoff = this->engine_cutoff_;
     269          17 :     vesin_options.full = this->options->full_list();
     270          17 :     vesin_options.return_shifts = true;
     271          17 :     vesin_options.return_distances = false;
     272          17 :     vesin_options.return_vectors = true;
     273          17 :     vesin_options.algorithm = vesin::VesinAutoAlgorithm;
     274             : 
     275          17 :     const char* error_message = nullptr;
     276          17 :     int status = vesin_neighbors(
     277             :         reinterpret_cast<const double (*)[3]>(positions.data()),
     278             :         positions.size(),
     279          17 :         reinterpret_cast<const double (*)[3]>(&cell(0, 0)),
     280             :         periodic.data(),
     281             :         {vesin::VesinCPU, 0},
     282             :         vesin_options,
     283             :         &this->neighbors_,
     284             :         &error_message
     285             :     );
     286             : 
     287          17 :     if (status != EXIT_SUCCESS) {
     288           0 :         plumed_merror(
     289             :             "failed to compute neighbor list (cutoff=" + std::to_string(this->engine_cutoff_) +
     290             :             ", full=" + (this->options->full_list() ? "true" : "false") + "): " + error_message
     291             :         );
     292             :     }
     293             : 
     294             :     // transform from vesin to metatomic format
     295          17 :     auto n_pairs = static_cast<int64_t>(this->neighbors_.length);
     296             : 
     297             :     auto pair_vectors = torch::from_blob(
     298          17 :         this->neighbors_.vectors,
     299             :         {n_pairs, 3, 1},
     300          34 :         torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU)
     301          17 :     );
     302             : 
     303          17 :     auto pair_samples_values = torch::empty({n_pairs, 5}, labels_options.device(torch::kCPU));
     304          17 :     auto pair_samples_values_ptr = pair_samples_values.accessor<int32_t, 2>();
     305       15407 :     for (unsigned i=0; i<n_pairs; i++) {
     306       15390 :         pair_samples_values_ptr[i][0] = static_cast<int32_t>(this->neighbors_.pairs[i][0]);
     307       15390 :         pair_samples_values_ptr[i][1] = static_cast<int32_t>(this->neighbors_.pairs[i][1]);
     308       15390 :         pair_samples_values_ptr[i][2] = this->neighbors_.shifts[i][0];
     309       15390 :         pair_samples_values_ptr[i][3] = this->neighbors_.shifts[i][1];
     310       15390 :         pair_samples_values_ptr[i][4] = this->neighbors_.shifts[i][2];
     311             :     }
     312             : 
     313             :     auto neighbor_samples = torch::make_intrusive<metatensor_torch::LabelsHolder>(
     314         136 :         std::vector<std::string>{"first_atom", "second_atom", "cell_shift_a", "cell_shift_b", "cell_shift_c"},
     315          17 :         pair_samples_values.to(device),
     316             :         // vesin should create unique pairs
     317          17 :         metatensor::assume_unique{}
     318             :     );
     319             : 
     320             :     auto neighbors = torch::make_intrusive<metatensor_torch::TensorBlockHolder>(
     321          51 :         pair_vectors.to(dtype).to(device),
     322             :         neighbor_samples,
     323          68 :         std::vector<metatensor_torch::Labels>{neighbor_component},
     324             :         neighbor_properties
     325             :     );
     326             : 
     327          17 :     return neighbors;
     328             : }
     329             : 
     330             : class MetatomicPlumedAction: public ActionAtomistic, public ActionWithValue {
     331             : public:
     332             :     static void registerKeywords(Keywords& keys);
     333             :     explicit MetatomicPlumedAction(const ActionOptions&);
     334             : 
     335             :     void calculate() override;
     336             :     void apply() override;
     337             :     unsigned getNumberOfDerivatives() override;
     338             : 
     339             : private:
     340             :     // fill this->system_ according to the current PLUMED data
     341             :     void createSystem();
     342             : 
     343             : 
     344             :     // execute the model for the given system
     345             :     metatensor_torch::TensorBlock executeModel(metatomic_torch::System system);
     346             : 
     347             :     metatensor_torch::Module model_;
     348             : 
     349             :     metatomic_torch::ModelCapabilities capabilities_;
     350             :     // name of the output we request
     351             :     std::string features_key;
     352             : 
     353             :     // neighbor lists requests made by the model and the corresponding data
     354             :     std::vector<NeighborListCalculator> neighbor_lists_;
     355             : 
     356             :     // dtype/device to use to execute the model
     357             :     torch::ScalarType dtype_;
     358             :     torch::Device device_;
     359             : 
     360             :     torch::Tensor atomic_types_;
     361             :     // store the strain to be able to compute the virial with autograd
     362             :     torch::Tensor strain_;
     363             : 
     364             :     metatomic_torch::System system_;
     365             :     metatomic_torch::ModelEvaluationOptions evaluations_options_;
     366             :     bool check_consistency_;
     367             : 
     368             :     metatensor_torch::TensorMap output_;
     369             :     // shape of the output of this model
     370             :     unsigned n_samples_;
     371             :     unsigned n_properties_;
     372             : };
     373             : 
     374           8 : MetatomicPlumedAction::MetatomicPlumedAction(const ActionOptions& options):
     375             :     Action(options),
     376             :     ActionAtomistic(options),
     377             :     ActionWithValue(options),
     378           8 :     model_(torch::jit::Module()),
     379          16 :     device_(torch::kCPU)
     380             : {
     381          16 :     if (metatomic_torch::version().find("0.1.") != 0) {
     382           0 :         this->error(
     383           0 :             "this code requires version 0.1.x of metatomic-torch, got version " +
     384           0 :             metatomic_torch::version()
     385             :         );
     386             :     }
     387             : 
     388             :     // first, load the model
     389             :     std::string extensions_directory_str;
     390          16 :     this->parse("EXTENSIONS_DIRECTORY", extensions_directory_str);
     391             : 
     392             :     torch::optional<std::string> extensions_directory = torch::nullopt;
     393           8 :     if (!extensions_directory_str.empty()) {
     394           3 :         extensions_directory = std::move(extensions_directory_str);
     395             :     }
     396             : 
     397             :     std::string model_path;
     398          16 :     this->parse("MODEL", model_path);
     399             : 
     400             :     try {
     401          16 :         this->model_ = metatomic_torch::load_atomistic_model(model_path, extensions_directory);
     402           0 :     } catch (const std::exception& e) {
     403           0 :         this->error("failed to load model at '" + model_path + "': " + e.what());
     404           0 :     }
     405             : 
     406             :     // extract information from the model
     407          16 :     auto metadata = this->model_.run_method("metadata").toCustomClass<metatomic_torch::ModelMetadataHolder>();
     408          16 :     this->capabilities_ = this->model_.run_method("capabilities").toCustomClass<metatomic_torch::ModelCapabilitiesHolder>();
     409           8 :     auto nl_requests_ivalue = this->model_.run_method("requested_neighbor_lists");
     410          16 :     for (auto nl_request_ivalue: nl_requests_ivalue.toList()) {
     411           8 :         auto nl_request = nl_request_ivalue.get().toCustomClass<metatomic_torch::NeighborListOptionsHolder>();
     412           8 :         this->neighbor_lists_.emplace_back(nl_request, this->getUnits().getLengthString());
     413             :     }
     414             : 
     415          16 :     auto extra_inputs = this->model_.run_method("requested_inputs").toGenericDict();
     416             :     auto standard_inputs = std::vector<std::string>{};
     417             :     auto custom_inputs = std::vector<std::string>{};
     418           8 :     for (const auto& item: extra_inputs) {
     419           0 :         auto key = item.key().toStringRef();
     420           0 :         if (key.find("::") != std::string::npos) {
     421           0 :             custom_inputs.push_back(key);
     422             :         } else {
     423           0 :             standard_inputs.push_back(key);
     424             :         }
     425             :     }
     426             : 
     427           8 :     if (!standard_inputs.empty()) {
     428           0 :         this->error(
     429             :             "The model requested extra inputs that are not yet supported in PLUMED. "
     430           0 :             "Please open an issue to request support for the following inputs: " +
     431           0 :             torch::str(standard_inputs)
     432             :         );
     433             :     }
     434             : 
     435           8 :     if (!custom_inputs.empty()) {
     436           0 :         this->error(
     437           0 :             "The model requested custom inputs (" + torch::str(custom_inputs) + ") "
     438             :             "that can not be provided by PLUMED. Please change your model to use "
     439             :             "standard inputs only."
     440             :         );
     441             :     }
     442             : 
     443          16 :     log.printf("\n%s\n", metadata->print().c_str());
     444             :     // add the model references to PLUMED citation handling mechanism
     445           8 :     for (const auto& it: metadata->references) {
     446           2 :         for (const auto& ref: it.value()) {
     447           2 :             this->cite(ref);
     448           1 :         }
     449             :     }
     450             : 
     451             :     // parse the atomic types from the input file
     452             :     std::vector<int32_t> atomic_types;
     453             :     std::vector<int32_t> species_to_types;
     454          16 :     this->parseVector("SPECIES_TO_TYPES", species_to_types);
     455             :     bool has_custom_types = !species_to_types.empty();
     456             : 
     457             :     std::vector<AtomNumber> all_atoms;
     458          16 :     this->parseAtomList("SPECIES", all_atoms);
     459             : 
     460             :     size_t n_species = 0;
     461           8 :     if (all_atoms.empty()) {
     462             :         // first parse each of the 'SPECIES' entry
     463             :         std::vector<std::vector<AtomNumber>> atoms_per_species;
     464             :         int i = 0;
     465             :         while (true) {
     466          27 :             i += 1;
     467             :             auto atoms = std::vector<AtomNumber>();
     468          54 :             this->parseAtomList("SPECIES", i, atoms);
     469             : 
     470          27 :             if (atoms.empty()) {
     471             :                 break;
     472             :             }
     473             : 
     474             :             int32_t type = i;
     475          19 :             if (has_custom_types) {
     476          19 :                 if (species_to_types.size() < static_cast<size_t>(i)) {
     477           0 :                     this->error(
     478             :                         "SPECIES_TO_TYPES is too small, it should have one entry "
     479           0 :                         "for each species (we have at least " + std::to_string(i) +
     480           0 :                         " species and " + std::to_string(species_to_types.size()) +
     481             :                         "entries in SPECIES_TO_TYPES)"
     482             :                     );
     483             :                 }
     484             : 
     485          19 :                 type = species_to_types[static_cast<size_t>(i - 1)];
     486             :             }
     487             : 
     488          19 :             log.printf("  atoms with type %d are: ", type);
     489        1296 :             for(unsigned j=0; j<atoms.size(); j++) {
     490        1277 :                 log.printf("%d ", atoms[j]);
     491             :             }
     492          19 :             log.printf("\n");
     493             : 
     494          19 :             n_species += 1;
     495          19 :             atoms_per_species.emplace_back(std::move(atoms));
     496             :         }
     497             : 
     498             :         size_t n_atoms = 0;
     499          27 :         for (const auto& atoms: atoms_per_species) {
     500          19 :             n_atoms += atoms.size();
     501             :         }
     502             : 
     503             :         // then fill the atomic_types as required
     504           8 :         atomic_types.resize(n_atoms, 0);
     505             :         i = 0;
     506          27 :         for (const auto& atoms: atoms_per_species) {
     507          19 :             i += 1;
     508             : 
     509             :             int32_t type = i;
     510          19 :             if (has_custom_types) {
     511          19 :                 type = species_to_types[static_cast<size_t>(i - 1)];
     512             :             }
     513             : 
     514        1296 :             for (const auto& atom: atoms) {
     515        1277 :                 atomic_types[atom.index()] = type;
     516             :             }
     517             :         }
     518           8 :     } else {
     519             :         n_species = 1;
     520             : 
     521           0 :         int32_t type = 1;
     522           0 :         if (has_custom_types) {
     523           0 :             type = species_to_types[0];
     524             :         }
     525           0 :         atomic_types.resize(all_atoms.size(), type);
     526             :     }
     527             : 
     528           8 :     if (has_custom_types && species_to_types.size() != n_species) {
     529           0 :         this->warning(
     530           0 :             "SPECIES_TO_TYPES contains more entries (" +
     531           0 :             std::to_string(species_to_types.size()) +
     532           0 :             ") than there where species (" + std::to_string(n_species) + ")"
     533             :         );
     534             :     }
     535             : 
     536             :     // request atoms in order
     537             :     all_atoms.clear();
     538        1285 :     for (size_t i=0; i<atomic_types.size(); i++) {
     539        1277 :         all_atoms.push_back(AtomNumber::index(i));
     540             :     }
     541           8 :     this->requestAtoms(all_atoms);
     542             : 
     543          16 :     this->atomic_types_ = torch::tensor(atomic_types);
     544             : 
     545           8 :     this->check_consistency_ = false;
     546           8 :     this->parseFlag("CHECK_CONSISTENCY", this->check_consistency_);
     547           8 :     if (this->check_consistency_) {
     548           0 :         log.printf("  checking for internal consistency of the model\n");
     549             :     }
     550             : 
     551             :     // create evaluation options for the model. These won't change during the
     552             :     // simulation, so we initialize them once here.
     553           8 :     evaluations_options_ = torch::make_intrusive<metatomic_torch::ModelEvaluationOptionsHolder>();
     554          16 :     evaluations_options_->set_length_unit(getUnits().getLengthString());
     555             : 
     556             :     auto outputs = this->capabilities_->outputs();
     557             : 
     558             :     std::string requested_variant;
     559             :     torch::optional<std::string> requested_variant_opt = torch::nullopt;
     560          16 :     this->parse("VARIANT", requested_variant);
     561           8 :     if (!requested_variant.empty()) {
     562           1 :         requested_variant_opt = requested_variant;
     563             :     }
     564             : 
     565          25 :     this->features_key = metatomic_torch::pick_output("features", outputs, requested_variant_opt);
     566             : 
     567             :     auto output = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
     568             :     // this output has no quantity or unit to set
     569             : 
     570          24 :     output->set_sample_kind(this->capabilities_->outputs().at(this->features_key)->sample_kind());
     571             :     // we are using torch autograd system to compute gradients,
     572             :     // so we don't need any explicit gradients.
     573           8 :     output->explicit_gradients = {};
     574           8 :     evaluations_options_->outputs.insert(this->features_key, output);
     575             : 
     576             :     std::string requested_device;
     577             :     torch::optional<std::string> requested_device_opt = torch::nullopt;
     578          16 :     this->parse("DEVICE", requested_device);
     579           8 :     if (!requested_device.empty()) {
     580           5 :         requested_device_opt = requested_device;
     581             :     }
     582             : 
     583           8 :     this->device_ = torch::Device(
     584           8 :         metatomic_torch::pick_device(this->capabilities_->supported_devices, requested_device_opt),
     585             :         /*index=*/ 0
     586             :     );
     587             : 
     588           8 :     this->model_.to(this->device_);
     589           8 :     this->atomic_types_ = this->atomic_types_.to(this->device_);
     590             : 
     591           8 :     log.printf(
     592             :         "  running model on %s device with %s data\n",
     593          16 :         this->device_.str().c_str(),
     594             :         this->capabilities_->dtype().c_str()
     595             :     );
     596             : 
     597           8 :     if (this->capabilities_->dtype() == "float64") {
     598           3 :         this->dtype_ = torch::kFloat64;
     599           5 :     } else if (this->capabilities_->dtype() == "float32") {
     600           5 :         this->dtype_ = torch::kFloat32;
     601             :     } else {
     602           0 :         this->error(
     603           0 :             "the model requested an unsupported dtype '" + this->capabilities_->dtype() + "'"
     604             :         );
     605             :     }
     606             : 
     607           8 :     auto tensor_options = torch::TensorOptions().dtype(this->dtype_).device(this->device_);
     608          16 :     this->strain_ = torch::eye(3, tensor_options.requires_grad(true));
     609             : 
     610             :     // determine how many properties there will be in the output by running the
     611             :     // model once on a dummy system
     612             :     auto dummy_system = torch::make_intrusive<metatomic_torch::SystemHolder>(
     613          16 :         /*types = */ torch::zeros({0}, tensor_options.dtype(torch::kInt32)),
     614          16 :         /*positions = */ torch::zeros({0, 3}, tensor_options),
     615          16 :         /*cell = */ torch::zeros({3, 3}, tensor_options),
     616           8 :         /*pbc = */ torch::zeros({3}, tensor_options.dtype(torch::kBool))
     617             :     );
     618             : 
     619           8 :     log.printf("  the following neighbor lists have been requested:\n");
     620           8 :     auto length_unit = this->getUnits().getLengthString();
     621           8 :     auto model_length_unit = this->capabilities_->length_unit();
     622          16 :     for (auto& nl: this->neighbor_lists_) {
     623          11 :         log.printf("    - %s list, %g %s cutoff (requested %g %s)\n",
     624             :             nl.options->full_list() ? "full" : "half",
     625             :             nl.options->engine_cutoff(length_unit),
     626             :             length_unit.c_str(),
     627             :             nl.options->cutoff(),
     628             :             model_length_unit.c_str()
     629             :         );
     630             : 
     631             :         auto neighbors = nl.compute(
     632             :             {PLMD::Vector(0, 0, 0)},
     633          16 :             PLMD::Tensor(0, 0, 0, 0, 0, 0, 0, 0, 0),
     634             :             {false, false, false},
     635             :             this->dtype_,
     636             :             this->device_
     637           8 :         );
     638          32 :         metatomic_torch::register_autograd_neighbors(dummy_system, neighbors, this->check_consistency_);
     639          16 :         dummy_system->add_neighbor_list(nl.options, neighbors);
     640             :     }
     641             : 
     642           8 :     this->n_properties_ = static_cast<unsigned>(
     643          32 :         this->executeModel(dummy_system)->properties()->count()
     644             :     );
     645             : 
     646             :     // parse and handle atom sub-selection. This is done AFTER determining the
     647             :     // output size, since the selection might not be valid for the dummy system
     648             :     std::vector<AtomNumber> selected_atoms;
     649          16 :     this->parseAtomList("SELECTED_ATOMS", selected_atoms);
     650           8 :     if (!selected_atoms.empty()) {
     651             :         auto selection_value = torch::zeros(
     652             :             {static_cast<int64_t>(selected_atoms.size()), 2},
     653           2 :             torch::TensorOptions().dtype(torch::kInt32).device(this->device_)
     654           2 :         );
     655             : 
     656           9 :         for (unsigned i=0; i<selected_atoms.size(); i++) {
     657             :             auto n_atoms = this->atomic_types_.size(0);
     658           7 :             if (selected_atoms[i].index() > n_atoms) {
     659           0 :                 this->error(
     660             :                     "Values in metatomic's SELECTED_ATOMS should be between 1 "
     661           0 :                     "and the number of atoms (" + std::to_string(n_atoms) + "), "
     662           0 :                     "got " + std::to_string(selected_atoms[i].serial()));
     663             :             }
     664          21 :             selection_value[i][1] = static_cast<int32_t>(selected_atoms[i].index());
     665             :         }
     666             : 
     667           4 :         evaluations_options_->set_selected_atoms(
     668           2 :             torch::make_intrusive<metatensor_torch::LabelsHolder>(
     669           8 :                 std::vector<std::string>{"system", "atom"}, selection_value
     670             :             )
     671             :         );
     672             :     }
     673             : 
     674             :     // Now that we now both n_samples and n_properties, we can setup the
     675             :     // PLUMED-side storage for the computed CV
     676          16 :     if (output->sample_kind() == "atom") {
     677           4 :         if (selected_atoms.empty()) {
     678           2 :             this->n_samples_ = static_cast<unsigned>(this->atomic_types_.size(0));
     679             :         } else {
     680           2 :             this->n_samples_ = static_cast<unsigned>(selected_atoms.size());
     681             :         }
     682             :     } else {
     683             :         assert(output->sample_kind() == "system");
     684           4 :         this->n_samples_ = 1;
     685             :     }
     686             : 
     687           8 :     if (n_samples_ == 1 && n_properties_ == 1) {
     688           3 :         log.printf("  the output of this model is a scalar\n");
     689             : 
     690           6 :         this->addValue();
     691           5 :     } else if (n_samples_ == 1) {
     692           1 :         log.printf("  the output of this model is 1x%d vector\n", n_properties_);
     693             : 
     694           1 :         this->addValue({this->n_properties_});
     695           1 :         this->getPntrToComponent(0)->buildDataStore();
     696           4 :     } else if (n_properties_ == 1) {
     697           1 :         log.printf("  the output of this model is %dx1 vector\n", n_samples_);
     698             : 
     699           1 :         this->addValue({this->n_samples_});
     700           1 :         this->getPntrToComponent(0)->buildDataStore();
     701             :     } else {
     702           3 :         log.printf("  the output of this model is a %dx%d matrix\n", n_samples_, n_properties_);
     703             : 
     704           3 :         this->addValue({this->n_samples_, this->n_properties_});
     705           3 :         this->getPntrToComponent(0)->buildDataStore();
     706           3 :         this->getPntrToComponent(0)->reshapeMatrixStore(n_properties_);
     707             :     }
     708             : 
     709           8 :     this->setNotPeriodic();
     710          16 : }
     711             : 
     712          14 : unsigned MetatomicPlumedAction::getNumberOfDerivatives() {
     713             :     // gradients w.r.t. positions (3 x N values) + gradients w.r.t. strain (9 values)
     714          14 :     return 3 * this->getNumberOfAtoms() + 9;
     715             : }
     716             : 
     717             : 
     718           9 : void MetatomicPlumedAction::createSystem() {
     719          18 :     if (this->getTotAtoms() != static_cast<unsigned>(this->atomic_types_.size(0))) {
     720           0 :         std::ostringstream oss;
     721           0 :         oss << "METATOMIC action needs to know about all atoms in the system. ";
     722           0 :         oss << "There are " << this->getTotAtoms() << " atoms overall, ";
     723           0 :         oss << "but we only have atomic types for " << this->atomic_types_.size(0) << " of them.";
     724           0 :         plumed_merror(oss.str());
     725           0 :     }
     726             : 
     727             :     // this->getTotAtoms()
     728             : 
     729           9 :     const auto& cell = this->getPbc().getBox();
     730             : 
     731           9 :     auto cpu_f64_tensor = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU);
     732           9 :     auto torch_cell = torch::zeros({3, 3}, cpu_f64_tensor);
     733             : 
     734          27 :     torch_cell[0][0] = cell(0, 0);
     735          27 :     torch_cell[0][1] = cell(0, 1);
     736          27 :     torch_cell[0][2] = cell(0, 2);
     737             : 
     738          27 :     torch_cell[1][0] = cell(1, 0);
     739          27 :     torch_cell[1][1] = cell(1, 1);
     740          27 :     torch_cell[1][2] = cell(1, 2);
     741             : 
     742          27 :     torch_cell[2][0] = cell(2, 0);
     743          27 :     torch_cell[2][1] = cell(2, 1);
     744          27 :     torch_cell[2][2] = cell(2, 2);
     745             : 
     746             :     using torch::indexing::Slice;
     747             : 
     748          63 :     auto norm_a = torch_cell.index({0, Slice()}).norm().abs().item<double>();
     749          63 :     auto norm_b = torch_cell.index({1, Slice()}).norm().abs().item<double>();
     750          63 :     auto norm_c = torch_cell.index({2, Slice()}).norm().abs().item<double>();
     751             : 
     752           9 :     auto periodic = std::array<bool, 3>{true, true, true};
     753             : 
     754             :     // make sure the cell and pbc argument agree with each other
     755           9 :     if (norm_a < 1e-9) {
     756           0 :         if (norm_a > 1e-30) {
     757           0 :             this->warning(
     758           0 :                 "the cell vector A has a very small norm (" + std::to_string(norm_a) + "), "
     759             :                 "this direction will be treated as non periodic. If this is intentional, ensure "
     760             :                 "the vector is exactly zero to silence this warning"
     761             :             );
     762             :         }
     763           0 :         torch_cell.index({0, Slice()}).fill_(0);
     764           0 :         periodic[0] = false;
     765             :     }
     766             : 
     767           9 :     if (norm_b < 1e-9) {
     768           0 :         if (norm_b > 1e-30) {
     769           0 :             this->warning(
     770           0 :                 "the cell vector B has a very small norm (" + std::to_string(norm_b) + "), "
     771             :                 "this direction will be treated as non periodic. If this is intentional, ensure "
     772             :                 "the vector is exactly zero to silence this warning"
     773             :             );
     774             :         }
     775           0 :         torch_cell.index({1, Slice()}).fill_(0);
     776           0 :         periodic[1] = false;
     777             :     }
     778             : 
     779           9 :     if (norm_c < 1e-9) {
     780           0 :         if (norm_c > 1e-30) {
     781           0 :             this->warning(
     782           0 :                 "the cell vector C has a very small norm (" + std::to_string(norm_c) + "), "
     783             :                 "this direction will be treated as non periodic. If this is intentional, ensure "
     784             :                 "the vector is exactly zero to silence this warning"
     785             :             );
     786             :         }
     787           0 :         torch_cell.index({2, Slice()}).fill_(0);
     788           0 :         periodic[2] = false;
     789             :     }
     790             : 
     791           9 :     auto torch_pbc = torch::zeros({3}, torch::TensorOptions().dtype(torch::kBool).device(torch::kCPU));
     792          36 :     for (unsigned i=0; i<3; i++) {
     793          54 :         torch_pbc[i] = periodic[i];
     794             :     }
     795             : 
     796             :     const auto& positions = this->getPositions();
     797             : 
     798             :     auto torch_positions = torch::from_blob(
     799             :         const_cast<PLMD::Vector*>(positions.data()),
     800             :         {static_cast<int64_t>(positions.size()), 3},
     801             :         cpu_f64_tensor
     802           9 :     );
     803             : 
     804          18 :     torch_positions = torch_positions.to(this->dtype_).to(this->device_);
     805          27 :     torch_cell = torch_cell.to(this->dtype_).to(this->device_);
     806           9 :     torch_pbc = torch_pbc.to(this->device_);
     807             : 
     808             :     // setup torch's automatic gradient tracking
     809           9 :     if (!this->doNotCalculateDerivatives()) {
     810             :         torch_positions.requires_grad_(true);
     811             : 
     812             :         // pretend to scale positions/cell by the strain so that it enters the
     813             :         // computational graph.
     814           8 :         torch_positions = torch_positions.matmul(this->strain_);
     815           8 :         torch_positions.retain_grad();
     816             : 
     817           8 :         torch_cell = torch_cell.matmul(this->strain_);
     818             :     }
     819             : 
     820           9 :     this->system_ = torch::make_intrusive<metatomic_torch::SystemHolder>(
     821           9 :         this->atomic_types_,
     822             :         torch_positions,
     823             :         torch_cell,
     824             :         torch_pbc
     825             :     );
     826             : 
     827             :     // compute the neighbors list requested by the model, and register them with
     828             :     // the system
     829          18 :     for (auto& nl: this->neighbor_lists_) {
     830           9 :         auto neighbors = nl.compute(positions, cell, periodic, this->dtype_, this->device_);
     831          36 :         metatomic_torch::register_autograd_neighbors(this->system_, neighbors, this->check_consistency_);
     832          18 :         this->system_->add_neighbor_list(nl.options, neighbors);
     833             :     }
     834           9 : }
     835             : 
     836             : 
     837          17 : metatensor_torch::TensorBlock MetatomicPlumedAction::executeModel(metatomic_torch::System system) {
     838             :     try {
     839          51 :         auto ivalue_output = this->model_.forward({
     840          68 :             std::vector<metatomic_torch::System>{system},
     841             :             evaluations_options_,
     842          17 :             this->check_consistency_,
     843          85 :         });
     844             : 
     845          17 :         auto dict_output = ivalue_output.toGenericDict();
     846          34 :         auto cv = dict_output.at(this->features_key);
     847          34 :         this->output_ = cv.toCustomClass<metatensor_torch::TensorMapHolder>();
     848           0 :     } catch (const std::exception& e) {
     849           0 :         plumed_merror("failed to evaluate the model: " + std::string(e.what()));
     850           0 :     }
     851             : 
     852          34 :     plumed_massert(this->output_->keys()->count() == 1, "output should have a single block");
     853          34 :     auto block = metatensor_torch::TensorMapHolder::block_by_id(this->output_, 0);
     854          17 :     plumed_massert(block->components().empty(), "components are not yet supported in the output");
     855             : 
     856          17 :     return block;
     857             : }
     858             : 
     859             : 
     860           9 : void MetatomicPlumedAction::calculate() {
     861           9 :     this->createSystem();
     862             : 
     863          18 :     auto block = this->executeModel(this->system_);
     864          45 :     auto torch_values = block->values().to(torch::kCPU).to(torch::kFloat64);
     865             : 
     866           9 :     if (static_cast<unsigned>(torch_values.size(0)) != this->n_samples_) {
     867           0 :         plumed_merror(
     868             :             "expected the model to return a TensorBlock with " +
     869             :             std::to_string(this->n_samples_) + " samples, got " +
     870             :             std::to_string(torch_values.size(0)) + " instead"
     871             :         );
     872           9 :     } else if (static_cast<unsigned>(torch_values.size(1)) != this->n_properties_) {
     873           0 :         plumed_merror(
     874             :             "expected the model to return a TensorBlock with " +
     875             :             std::to_string(this->n_properties_) + " properties, got " +
     876             :             std::to_string(torch_values.size(1)) + " instead"
     877             :         );
     878             :     }
     879             : 
     880           9 :     Value* value = this->getPntrToComponent(0);
     881             :     // reshape the plumed `Value` to hold the data returned by the model
     882           9 :     if (n_samples_ == 1) {
     883           5 :         if (n_properties_ == 1) {
     884           4 :             value->set(torch_values.item<double>());
     885             :         } else {
     886             :             // we have multiple CV describing a single thing (atom or full system)
     887           3 :             for (unsigned i=0; i<n_properties_; i++) {
     888           6 :                 value->set(i, torch_values[0][i].item<double>());
     889             :             }
     890             :         }
     891             :     } else {
     892             :         auto samples = block->samples();
     893          16 :         plumed_assert((samples->names() == std::vector<std::string>{"system", "atom"}));
     894             : 
     895           8 :         auto samples_values = samples->values().to(torch::kCPU);
     896             :         auto selected_atoms = this->evaluations_options_->get_selected_atoms();
     897             : 
     898             :         // handle the possibility that samples are returned in
     899             :         // a non-sorted order.
     900          92 :         auto get_output_location = [&](unsigned i) {
     901          92 :             if (selected_atoms.has_value()) {
     902             :                 // If the users picked some selected atoms, then we store the
     903             :                 // output in the same order as the selection was given
     904          28 :                 auto sample = samples_values.index({static_cast<int64_t>(i), torch::indexing::Slice()});
     905           7 :                 auto position = selected_atoms.value()->position(sample);
     906           7 :                 plumed_assert(position.has_value());
     907           7 :                 return static_cast<unsigned>(position.value());
     908             :             } else {
     909         340 :                 return static_cast<unsigned>(samples_values[i][1].item<int32_t>());
     910             :             }
     911           4 :         };
     912             : 
     913           4 :         if (n_properties_ == 1) {
     914             :             // we have a single CV describing multiple things (i.e. atoms)
     915           5 :             for (unsigned i=0; i<n_samples_; i++) {
     916           4 :                 auto output_i = get_output_location(i);
     917          16 :                 value->set(output_i, torch_values[i][0].item<double>());
     918             :             }
     919             :         } else {
     920             :             // the CV is a matrix
     921          91 :             for (unsigned i=0; i<n_samples_; i++) {
     922          88 :                 auto output_i = get_output_location(i);
     923         343 :                 for (unsigned j=0; j<n_properties_; j++) {
     924         765 :                     value->set(output_i * n_properties_ + j, torch_values[i][j].item<double>());
     925             :                 }
     926             :             }
     927             :         }
     928             :     }
     929           9 : }
     930             : 
     931             : 
     932           9 : void MetatomicPlumedAction::apply() {
     933           9 :     const auto* value = this->getPntrToComponent(0);
     934           9 :     if (!value->forcesWereAdded()) {
     935           1 :         return;
     936             :     }
     937             : 
     938          16 :     auto block = metatensor_torch::TensorMapHolder::block_by_id(this->output_, 0);
     939          40 :     auto torch_values = block->values().to(torch::kCPU).to(torch::kFloat64);
     940             : 
     941           8 :     if (!torch_values.requires_grad()) {
     942           0 :         this->warning(
     943             :             "the output of the model does not requires gradients, this might "
     944             :             "indicate a problem"
     945             :         );
     946             :         return;
     947             :     }
     948             : 
     949           8 :     auto output_grad = torch::zeros_like(torch_values);
     950           8 :     if (n_samples_ == 1) {
     951           5 :         if (n_properties_ == 1) {
     952           8 :             output_grad[0][0] = value->getForce();
     953             :         } else {
     954           3 :             for (unsigned i=0; i<n_properties_; i++) {
     955           6 :                 output_grad[0][i] = value->getForce(i);
     956             :             }
     957             :         }
     958             :     } else {
     959             :         auto samples = block->samples();
     960          12 :         plumed_assert((samples->names() == std::vector<std::string>{"system", "atom"}));
     961             : 
     962           6 :         auto samples_values = samples->values().to(torch::kCPU);
     963             :         auto selected_atoms = this->evaluations_options_->get_selected_atoms();
     964             : 
     965             :         // see above for an explanation of why we use this function
     966          89 :         auto get_output_location = [&](unsigned i) {
     967          89 :             if (selected_atoms.has_value()) {
     968          16 :                 auto sample = samples_values.index({static_cast<int64_t>(i), torch::indexing::Slice()});
     969           4 :                 auto position = selected_atoms.value()->position(sample);
     970           4 :                 plumed_assert(position.has_value());
     971           4 :                 return static_cast<unsigned>(position.value());
     972             :             } else {
     973         340 :                 return static_cast<unsigned>(samples_values[i][1].item<int32_t>());
     974             :             }
     975           3 :         };
     976             : 
     977           3 :         if (n_properties_ == 1) {
     978           5 :             for (unsigned i=0; i<n_samples_; i++) {
     979           4 :                 auto output_i = get_output_location(i);
     980          12 :                 output_grad[i][0] = value->getForce(output_i);
     981             :             }
     982             :         } else {
     983          87 :             for (unsigned i=0; i<n_samples_; i++) {
     984          85 :                 auto output_i = get_output_location(i);
     985         331 :                 for (unsigned j=0; j<n_properties_; j++) {
     986         738 :                     output_grad[i][j] = value->getForce(output_i * n_properties_ + j);
     987             :                 }
     988             :             }
     989             :         }
     990             :     }
     991             : 
     992           8 :     this->system_->positions().mutable_grad() = torch::Tensor();
     993             :     this->strain_.mutable_grad() = torch::Tensor();
     994             : 
     995           8 :     torch_values.backward(output_grad);
     996           8 :     auto positions_grad = this->system_->positions().grad();
     997           8 :     auto strain_grad = this->strain_.grad();
     998             : 
     999          32 :     positions_grad = positions_grad.to(torch::kCPU).to(torch::kFloat64);
    1000          32 :     strain_grad = strain_grad.to(torch::kCPU).to(torch::kFloat64);
    1001             : 
    1002           8 :     plumed_assert(positions_grad.sizes().size() == 2);
    1003           8 :     plumed_assert(positions_grad.is_contiguous());
    1004             : 
    1005           8 :     plumed_assert(strain_grad.sizes().size() == 2);
    1006           8 :     plumed_assert(strain_grad.is_contiguous());
    1007             : 
    1008             :     auto derivatives = std::vector<double>(
    1009             :         positions_grad.data_ptr<double>(),
    1010           8 :         positions_grad.data_ptr<double>() + 3 * this->system_->size()
    1011           8 :     );
    1012             : 
    1013             :     // add virials to the derivatives
    1014          24 :     derivatives.push_back(-strain_grad[0][0].item<double>());
    1015          24 :     derivatives.push_back(-strain_grad[0][1].item<double>());
    1016          24 :     derivatives.push_back(-strain_grad[0][2].item<double>());
    1017             : 
    1018          24 :     derivatives.push_back(-strain_grad[1][0].item<double>());
    1019          24 :     derivatives.push_back(-strain_grad[1][1].item<double>());
    1020          24 :     derivatives.push_back(-strain_grad[1][2].item<double>());
    1021             : 
    1022          24 :     derivatives.push_back(-strain_grad[2][0].item<double>());
    1023          24 :     derivatives.push_back(-strain_grad[2][1].item<double>());
    1024          16 :     derivatives.push_back(-strain_grad[2][2].item<double>());
    1025             : 
    1026           8 :     unsigned index = 0;
    1027           8 :     this->setForcesOnAtoms(derivatives, index);
    1028             : }
    1029             : 
    1030             : } // namespace metatomic
    1031             : } // namespace PLMD
    1032             : 
    1033             : 
    1034             : #endif
    1035             : 
    1036             : 
    1037             : namespace PLMD {
    1038             : namespace metatomic {
    1039             : 
    1040             : // use the same implementation for both the actual action and the dummy one
    1041             : // (when libtorch and libmetatomic could not be found).
    1042          12 : void MetatomicPlumedAction::registerKeywords(Keywords& keys) {
    1043          12 :     Action::registerKeywords(keys);
    1044          12 :     ActionAtomistic::registerKeywords(keys);
    1045          12 :     ActionWithValue::registerKeywords(keys);
    1046             : 
    1047          24 :     keys.add("compulsory", "MODEL", "path to the exported metatomic model");
    1048          24 :     keys.add("optional", "EXTENSIONS_DIRECTORY", "path to the directory containing TorchScript extensions to load");
    1049          24 :     keys.add("optional", "DEVICE", "Torch device to use for the calculations");
    1050             : 
    1051          24 :     keys.addFlag("CHECK_CONSISTENCY", false, "should we enable internal consistency checks when executing the model");
    1052             : 
    1053          24 :     keys.add("numbered", "SPECIES", "the indices of atoms in each PLUMED species");
    1054          24 :     keys.reset_style("SPECIES", "atoms");
    1055             : 
    1056          24 :     keys.add("optional", "SELECTED_ATOMS", "subset of atoms that should be used for the calculation");
    1057          24 :     keys.reset_style("SELECTED_ATOMS", "atoms");
    1058             : 
    1059          24 :     keys.add("optional", "SPECIES_TO_TYPES", "mapping from PLUMED SPECIES to metatomic's atom types");
    1060             : 
    1061          24 :     keys.add("optional", "VARIANT", "which variant of the 'features' output to pick");
    1062             : 
    1063          24 :     keys.addOutputComponent("outputs", "default", "collective variable created by the metatomic model");
    1064             : 
    1065          12 :     keys.setValueDescription("collective variable created by the metatomic model");
    1066          12 : }
    1067             : 
    1068             : PLUMED_REGISTER_ACTION(MetatomicPlumedAction, "METATOMIC")
    1069             : 
    1070             : } // namespace metatomic
    1071             : } // namespace PLMD

Generated by: LCOV version 1.16