LCOV - code coverage report
Current view: top level - pytorch - PytorchModel.cpp (source / functions) Hit Total Coverage
Test: plumed test coverage Lines: 54 65 83.1 %
Date: 2026-03-30 13:16:06 Functions: 7 8 87.5 %

          Line data    Source code
       1             : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
       2             : Copyright (c) 2022-2023 of Luigi Bonati.
       3             : 
       4             : The pytorch module is free software: you can redistribute it and/or modify
       5             : it under the terms of the GNU Lesser General Public License as published by
       6             : the Free Software Foundation, either version 3 of the License, or
       7             : (at your option) any later version.
       8             : 
       9             : The pytorch module is distributed in the hope that it will be useful,
      10             : but WITHOUT ANY WARRANTY; without even the implied warranty of
      11             : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
      12             : GNU Lesser General Public License for more details.
      13             : 
      14             : You should have received a copy of the GNU Lesser General Public License
      15             : along with plumed.  If not, see <http://www.gnu.org/licenses/>.
      16             : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
      17             : 
      18             : #ifdef __PLUMED_HAS_LIBTORCH
      19             : // convert LibTorch version to string
      20             : //#define STRINGIFY(x) #x
      21             : //#define TOSTR(x) STRINGIFY(x)
      22             : //#define LIBTORCH_VERSION TO_STR(TORCH_VERSION_MAJOR) "." TO_STR(TORCH_VERSION_MINOR) "." TO_STR(TORCH_VERSION_PATCH)
      23             : 
      24             : #include "core/PlumedMain.h"
      25             : #include "function/Function.h"
      26             : #include "function/ActionRegister.h"
      27             : 
      28             : #include <torch/torch.h>
      29             : #include <torch/script.h>
      30             : 
      31             : #include <fstream>
      32             : #include <cmath>
      33             : 
      34             : using namespace std;
      35             : 
      36             : namespace PLMD {
      37             : namespace function {
      38             : namespace pytorch {
      39             : 
      40             : //+PLUMEDOC PYTORCH_FUNCTION PYTORCH_MODEL
      41             : /*
      42             : Load a PyTorch model compiled with TorchScript.
      43             : 
      44             : This can be a function defined in Python or a more complex model, such as a neural network optimized on a set of data. In both cases the derivatives of the outputs with respect to the inputs are computed using the automatic differentiation (autograd) feature of Pytorch.
      45             : 
      46             : By default it is assumed that the model is saved as: `model.ptc`, unless otherwise indicated by the `FILE` keyword. The function automatically checks for the number of output dimensions and creates a component for each of them. The outputs are called node-i with i between 0 and N-1 for N outputs.
      47             : 
      48             : Note that this function is active only if LibTorch is correctly linked against PLUMED. Please check the instructions in the \ref PYTORCH page.
      49             : 
      50             : \par Examples
      51             : Load a model called `torch_model.ptc` that takes as input two dihedral angles and returns two outputs.
      52             : 
      53             : \plumedfile
      54             : #SETTINGS AUXFILE=regtest/pytorch/rt-pytorch_model_2d/torch_model.ptc
      55             : phi: TORSION ATOMS=5,7,9,15
      56             : psi: TORSION ATOMS=7,9,15,17
      57             : model: PYTORCH_MODEL FILE=torch_model.ptc ARG=phi,psi
      58             : PRINT FILE=COLVAR ARG=model.node-0,model.node-1
      59             : \endplumedfile
      60             : 
      61             : */
      62             : //+ENDPLUMEDOC
      63             : 
      64             : 
      65             : class PytorchModel :
      66             :   public Function {
      67             :   unsigned _n_in;
      68             :   unsigned _n_out;
      69             :   torch::jit::script::Module _model;
      70             : 
      71             : public:
      72             :   explicit PytorchModel(const ActionOptions&);
      73             :   void calculate();
      74             :   static void registerKeywords(Keywords& keys);
      75             : 
      76             :   std::vector<float> tensor_to_vector(const torch::Tensor& x);
      77             : };
      78             : 
      79       13793 : PLUMED_REGISTER_ACTION(PytorchModel,"PYTORCH_MODEL")
      80             : 
      81           8 : void PytorchModel::registerKeywords(Keywords& keys) {
      82           8 :   Function::registerKeywords(keys);
      83           8 :   keys.use("ARG");
      84          16 :   keys.add("optional","FILE","Filename of the PyTorch compiled model");
      85          16 :   keys.addOutputComponent("node", "default", "Model outputs");
      86           8 : }
      87             : 
      88             : // Auxiliary function to transform torch tensors in std vectors
      89         103 : std::vector<float> PytorchModel::tensor_to_vector(const torch::Tensor& x) {
      90         206 :   return std::vector<float>(x.data_ptr<float>(), x.data_ptr<float>() + x.numel());
      91             : }
      92             : 
      93           4 : PytorchModel::PytorchModel(const ActionOptions&ao):
      94             :   Action(ao),
      95           4 :   Function(ao) {
      96             :   //print pytorch version
      97             : 
      98             :   //number of inputs of the model
      99           4 :   _n_in=getNumberOfArguments();
     100             : 
     101             :   //parse model name
     102           4 :   std::string fname="model.ptc";
     103           8 :   parse("FILE",fname);
     104             : 
     105             :   //deserialize the model from file
     106             :   try {
     107           4 :     _model = torch::jit::load(fname);
     108             :   }
     109             :   //if an error is thrown check if the file exists or not
     110           0 :   catch (const c10::Error& e) {
     111           0 :     std::ifstream infile(fname);
     112             :     bool exist = infile.good();
     113           0 :     infile.close();
     114           0 :     if (exist) {
     115             :       // print libtorch version
     116           0 :       std::stringstream ss;
     117           0 :       ss << TORCH_VERSION_MAJOR << "." << TORCH_VERSION_MINOR << "." << TORCH_VERSION_PATCH;
     118             :       std::string version;
     119           0 :       ss >> version; // extract into the string.
     120           0 :       plumed_merror("Cannot load FILE: '"+fname+"'. Please check that it is a Pytorch compiled model (exported with 'torch.jit.trace' or 'torch.jit.script') and that the Pytorch version matches the LibTorch one ("+version+").");
     121           0 :     } else {
     122           0 :       plumed_merror("The FILE: '"+fname+"' does not exist.");
     123             :     }
     124           0 :   }
     125             : 
     126           4 :   checkRead();
     127             : 
     128             :   //check the dimension of the output
     129           4 :   log.printf("Checking output dimension:\n");
     130           4 :   std::vector<float> input_test (_n_in);
     131           4 :   torch::Tensor single_input = torch::tensor(input_test).view({1,_n_in});
     132             :   std::vector<torch::jit::IValue> inputs;
     133           4 :   inputs.push_back( single_input );
     134           8 :   torch::Tensor output = _model.forward( inputs ).toTensor();
     135           4 :   vector<float> cvs = this->tensor_to_vector (output);
     136           4 :   _n_out=cvs.size();
     137             : 
     138             :   //create components
     139           9 :   for(unsigned j=0; j<_n_out; j++) {
     140           5 :     string name_comp = "node-"+std::to_string(j);
     141           5 :     addComponentWithDerivatives( name_comp );
     142           5 :     componentIsNotPeriodic( name_comp );
     143             :   }
     144             : 
     145             :   //print log
     146             :   //log.printf("Pytorch Model Loaded: %s \n",fname);
     147           4 :   log.printf("Number of input: %d \n",_n_in);
     148           4 :   log.printf("Number of outputs: %d \n",_n_out);
     149           4 :   log.printf("  Bibliography: ");
     150           8 :   log<<plumed.cite("Bonati, Rizzi and Parrinello, J. Phys. Chem. Lett. 11, 2998-3004 (2020)");
     151           4 :   log.printf("\n");
     152             : 
     153           8 : }
     154             : 
     155          44 : void PytorchModel::calculate() {
     156             : 
     157             :   //retrieve arguments
     158          44 :   vector<float> current_S(_n_in);
     159          99 :   for(unsigned i=0; i<_n_in; i++) {
     160          55 :     current_S[i]=getArgument(i);
     161             :   }
     162             :   //convert to tensor
     163          44 :   torch::Tensor input_S = torch::tensor(current_S).view({1,_n_in});
     164             :   input_S.set_requires_grad(true);
     165             :   //convert to Ivalue
     166             :   std::vector<torch::jit::IValue> inputs;
     167          44 :   inputs.push_back( input_S );
     168             :   //calculate output
     169          88 :   torch::Tensor output = _model.forward( inputs ).toTensor();
     170             :   //set CV values
     171          44 :   vector<float> cvs = this->tensor_to_vector (output);
     172          99 :   for(unsigned j=0; j<_n_out; j++) {
     173          55 :     string name_comp = "node-"+std::to_string(j);
     174          55 :     getPntrToComponent(name_comp)->set(cvs[j]);
     175             :   }
     176             :   //derivatives
     177          99 :   for(unsigned j=0; j<_n_out; j++) {
     178             :     // expand dim to have shape (1,_n_out)
     179             :     int batch_size = 1;
     180          55 :     auto grad_output = torch::ones({1}).expand({batch_size, 1});
     181             :     // calculate derivatives with automatic differentiation
     182         220 :     auto gradient = torch::autograd::grad({output.slice(/*dim=*/1, /*start=*/j, /*end=*/j+1)},
     183             :     {input_S},
     184             :     /*grad_outputs=*/ {grad_output},
     185             :     /*retain_graph=*/true,
     186         330 :     /*create_graph=*/false);
     187             :     // add dimension
     188             :     auto grad = gradient[0].unsqueeze(/*dim=*/1);
     189             :     //convert to vector
     190          55 :     vector<float> der = this->tensor_to_vector ( grad );
     191             : 
     192          55 :     string name_comp = "node-"+std::to_string(j);
     193             :     //set derivatives of component j
     194         132 :     for(unsigned i=0; i<_n_in; i++) {
     195          77 :       setDerivative( getPntrToComponent(name_comp),i, der[i] );
     196             :     }
     197          55 :   }
     198          88 : }
     199             : }
     200             : }
     201             : }
     202             : 
     203             : #endif

Generated by: LCOV version 1.16