LCOV - code coverage report
Current view: top level - pytorch - PytorchModel.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 62 69 89.9 %
Date: 2026-03-30 11:13:23 Functions: 4 5 80.0 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             : Copyright (c) 2022-2023 of Luigi Bonati and Enrico Trizio.
       3             : 
       4             : The pytorch module is free software: you can redistribute it and/or modify
       5             : it under the terms of the GNU Lesser General Public License as published by
       6             : the Free Software Foundation, either version 3 of the License, or
       7             : (at your option) any later version.
       8             : 
       9             : The pytorch module is distributed in the hope that it will be useful,
      10             : but WITHOUT ANY WARRANTY; without even the implied warranty of
      11             : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      12             : GNU Lesser General Public License for more details.
      13             : 
      14             : You should have received a copy of the GNU Lesser General Public License
      15             : along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      16             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      17             : 
      18             : #ifdef __PLUMED_HAS_LIBTORCH
      19             : // convert LibTorch version to string
      20             : //#define STRINGIFY(x) #x
      21             : //#define TOSTR(x) STRINGIFY(x)
      22             : //#define LIBTORCH_VERSION TO_STR(TORCH_VERSION_MAJOR) "." TO_STR(TORCH_VERSION_MINOR) "." TO_STR(TORCH_VERSION_PATCH)
      23             : 
      24             : #include "core/PlumedMain.h"
      25             : #include "function/Function.h"
      26             : #include "core/ActionRegister.h"
      27             : 
      28             : #include <torch/torch.h>
      29             : #include <torch/script.h>
      30             : 
      31             : #include <fstream>
      32             : #include <cmath>
      33             : 
      34             : // Note: Freezing a ScriptModule (torch::jit::freeze) works only in >=1.11
      35             : // For 1.8 <= versions <=1.10 we need a hack
      36             : // (see https://discuss.pytorch.org/t/how-to-check-libtorch-version/77709/4 and also
      37             : // https://github.com/pytorch/pytorch/blob/dfbd030854359207cb3040b864614affeace11ce/torch/csrc/jit/api/module.cpp#L479)
      38             : // adapted from NequIP https://github.com/mir-group/nequip
      39             : #if ( TORCH_VERSION_MAJOR == 2 || TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR <= 10 )
      40             : #define DO_TORCH_FREEZE_HACK
      41             : // For the hack, need more headers:
      42             : #include <torch/csrc/jit/passes/freeze_module.h>
      43             : #include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
      44             : #endif
      45             : 
      46             : using namespace std;
      47             : 
      48             : namespace PLMD {
      49             : namespace function {
      50             : namespace pytorch {
      51             : 
      52             : //+PLUMEDOC PYTORCH_FUNCTION PYTORCH_MODEL
      53             : /*
      54             : Load a PyTorch model compiled with TorchScript.
      55             : 
      56             : This can be a function defined in Python or a more complex model, such as a neural network optimized on a set of data. In both cases the derivatives of the outputs with respect to the inputs are computed using the automatic differentiation (autograd) feature of Pytorch.
      57             : 
      58             : By default it is assumed that the model is saved as: `model.ptc`, unless otherwise indicated by the `FILE` keyword. The function automatically checks for the number of output dimensions and creates a component for each of them. The outputs are called node-i with i between 0 and N-1 for N outputs.
      59             : 
      60             : Note that this function requires \ref installation-libtorch LibTorch C++ library. Check the instructions in the \ref PYTORCH page to enable the module.
      61             : 
      62             : \par Examples
      63             : Load a model called `torch_model.ptc` that takes as input two dihedral angles and returns two outputs.
      64             : 
      65             : \plumedfile
      66             : #SETTINGS AUXFILE=regtest/pytorch/rt-pytorch_model_2d/torch_model.ptc
      67             : phi: TORSION ATOMS=5,7,9,15
      68             : psi: TORSION ATOMS=7,9,15,17
      69             : model: PYTORCH_MODEL FILE=torch_model.ptc ARG=phi,psi
      70             : PRINT FILE=COLVAR ARG=model.node-0,model.node-1
      71             : \endplumedfile
      72             : 
      73             : */
      74             : //+ENDPLUMEDOC
      75             : 
      76             : 
      77             : class PytorchModel :
      78             :   public Function {
      79             :   unsigned _n_in;
      80             :   unsigned _n_out;
      81             :   torch::jit::script::Module _model;
      82             :   torch::Device device = torch::kCPU;
      83             : 
      84             : public:
      85             :   explicit PytorchModel(const ActionOptions&);
      86             :   void calculate();
      87             :   static void registerKeywords(Keywords& keys);
      88             : 
      89             :   std::vector<float> tensor_to_vector(const torch::Tensor& x);
      90             : };
      91             : 
      92             : PLUMED_REGISTER_ACTION(PytorchModel,"PYTORCH_MODEL")
      93             : 
      94           8 : void PytorchModel::registerKeywords(Keywords& keys) {
      95           8 :   Function::registerKeywords(keys);
      96           8 :   keys.use("ARG");
      97          16 :   keys.add("optional","FILE","Filename of the PyTorch compiled model");
      98          16 :   keys.addOutputComponent("node", "default", "Model outputs");
      99           8 : }
     100             : 
     101             : // Auxiliary function to transform torch tensors in std vectors
     102         103 : std::vector<float> PytorchModel::tensor_to_vector(const torch::Tensor& x) {
     103         206 :   return std::vector<float>(x.data_ptr<float>(), x.data_ptr<float>() + x.numel());
     104             : }
     105             : 
     106           4 : PytorchModel::PytorchModel(const ActionOptions&ao):
     107             :   Action(ao),
     108           4 :   Function(ao) {
     109             :   // print libtorch version
     110           4 :   std::stringstream ss;
     111           4 :   ss << TORCH_VERSION_MAJOR << "." << TORCH_VERSION_MINOR << "." << TORCH_VERSION_PATCH;
     112             :   std::string version;
     113           4 :   ss >> version; // extract into the string.
     114           8 :   log.printf(("  LibTorch version: "+version+"\n").data());
     115             : 
     116             :   //number of inputs of the model
     117           4 :   _n_in=getNumberOfArguments();
     118             : 
     119             :   //parse model name
     120           4 :   std::string fname="model.ptc";
     121           8 :   parse("FILE",fname);
     122             : 
     123             :   //deserialize the model from file
     124             :   try {
     125           4 :     _model = torch::jit::load(fname, device);
     126             :   }
     127             : 
     128             :   //if an error is thrown check if the file exists or not
     129           0 :   catch (const c10::Error& e) {
     130           0 :     std::ifstream infile(fname);
     131             :     bool exist = infile.good();
     132           0 :     infile.close();
     133           0 :     if (exist) {
     134           0 :       plumed_merror("Cannot load FILE: '"+fname+"'. Please check that it is a Pytorch compiled model (exported with 'torch.jit.trace' or 'torch.jit.script').");
     135             :     } else {
     136           0 :       plumed_merror("The FILE: '"+fname+"' does not exist.");
     137             :     }
     138           0 :   }
     139           4 :   checkRead();
     140             : 
     141             : // Optimize model
     142             :   _model.eval();
     143             : #ifdef DO_TORCH_FREEZE_HACK
     144             :   // Do the hack
     145             :   // Copied from the implementation of torch::jit::freeze,
     146             :   // except without the broken check
     147             :   // See https://github.com/pytorch/pytorch/blob/dfbd030854359207cb3040b864614affeace11ce/torch/csrc/jit/api/module.cpp
     148             :   bool optimize_numerics = true;  // the default
     149             :   // the {} is preserved_attrs
     150             :   auto out_mod = torch::jit::freeze_module(
     151             :                    _model, {}
     152           4 :                  );
     153             :   // See 1.11 bugfix in https://github.com/pytorch/pytorch/pull/71436
     154           8 :   auto graph = out_mod.get_method("forward").graph();
     155           4 :   OptimizeFrozenGraph(graph, optimize_numerics);
     156           4 :   _model = out_mod;
     157             : #else
     158             :   // Do it normally
     159             :   _model = torch::jit::freeze(_model);
     160             : #endif
     161             : 
     162             : // Optimize model for inference
     163             : #if (TORCH_VERSION_MAJOR == 2 || TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 10)
     164           4 :   _model = torch::jit::optimize_for_inference(_model);
     165             : #endif
     166             : 
     167             :   //check the dimension of the output
     168           4 :   log.printf("  Checking output dimension:\n");
     169           4 :   std::vector<float> input_test (_n_in);
     170           4 :   torch::Tensor single_input = torch::tensor(input_test).view({1,_n_in});
     171           8 :   single_input = single_input.to(device);
     172             :   std::vector<torch::jit::IValue> inputs;
     173           4 :   inputs.push_back( single_input );
     174           8 :   torch::Tensor output = _model.forward( inputs ).toTensor();
     175           4 :   vector<float> cvs = this->tensor_to_vector (output);
     176           4 :   _n_out=cvs.size();
     177             : 
     178             :   //create components
     179           9 :   for(unsigned j=0; j<_n_out; j++) {
     180           5 :     string name_comp = "node-"+std::to_string(j);
     181           5 :     addComponentWithDerivatives( name_comp );
     182           5 :     componentIsNotPeriodic( name_comp );
     183             :   }
     184             : 
     185             :   //print log
     186           4 :   log.printf("  Number of input: %d \n",_n_in);
     187           4 :   log.printf("  Number of outputs: %d \n",_n_out);
     188           4 :   log.printf("  Bibliography: ");
     189           8 :   log<<plumed.cite("Bonati, Trizio, Rizzi and Parrinello, J. Chem. Phys. 159, 014801 (2023)");
     190           8 :   log<<plumed.cite("Bonati, Rizzi and Parrinello, J. Phys. Chem. Lett. 11, 2998-3004 (2020)");
     191           4 :   log.printf("\n");
     192             : 
     193          12 : }
     194             : 
     195             : 
     196          44 : void PytorchModel::calculate() {
     197             : 
     198             :   // retrieve arguments
     199          44 :   vector<float> current_S(_n_in);
     200          99 :   for(unsigned i=0; i<_n_in; i++) {
     201          55 :     current_S[i]=getArgument(i);
     202             :   }
     203             :   //convert to tensor
     204          44 :   torch::Tensor input_S = torch::tensor(current_S).view({1,_n_in}).to(device);
     205             :   input_S.set_requires_grad(true);
     206             :   //convert to Ivalue
     207             :   std::vector<torch::jit::IValue> inputs;
     208          44 :   inputs.push_back( input_S );
     209             :   //calculate output
     210          88 :   torch::Tensor output = _model.forward( inputs ).toTensor();
     211             : 
     212             : 
     213          99 :   for(unsigned j=0; j<_n_out; j++) {
     214          55 :     auto grad_output = torch::ones({1}).expand({1, 1}).to(device);
     215         440 :     auto gradient = torch::autograd::grad({output.slice(/*dim=*/1, /*start=*/j, /*end=*/j+1)},
     216             :     {input_S},
     217             :     /*grad_outputs=*/ {grad_output},
     218             :     /*retain_graph=*/true,
     219             :     /*create_graph=*/false)[0]; // the [0] is to get a tensor and not a vector<at::tensor>
     220             : 
     221          55 :     vector<float> der = this->tensor_to_vector ( gradient );
     222          55 :     string name_comp = "node-"+std::to_string(j);
     223             :     //set derivatives of component j
     224         132 :     for(unsigned i=0; i<_n_in; i++) {
     225          77 :       setDerivative( getPntrToComponent(name_comp),i, der[i] );
     226             :     }
     227             :   }
     228             : 
     229             :   //set CV values
     230          44 :   vector<float> cvs = this->tensor_to_vector (output);
     231          99 :   for(unsigned j=0; j<_n_out; j++) {
     232          55 :     string name_comp = "node-"+std::to_string(j);
     233          55 :     getPntrToComponent(name_comp)->set(cvs[j]);
     234             :   }
     235             : 
     236          88 : }
     237             : 
     238             : 
     239             : } //PLMD
     240             : } //function
     241             : } //pytorch
     242             : 
     243             : #endif //PLUMED_HAS_LIBTORCH

Generated by: LCOV version 1.16