Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 : Copyright (c) 2022-2023 of Luigi Bonati. 3 : 4 : The pytorch module is free software: you can redistribute it and/or modify 5 : it under the terms of the GNU Lesser General Public License as published by 6 : the Free Software Foundation, either version 3 of the License, or 7 : (at your option) any later version. 8 : 9 : The pytorch module is distributed in the hope that it will be useful, 10 : but WITHOUT ANY WARRANTY; without even the implied warranty of 11 : MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 : GNU Lesser General Public License for more details. 13 : 14 : You should have received a copy of the GNU Lesser General Public License 15 : along with plumed. If not, see <http://www.gnu.org/licenses/>. 16 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */ 17 : 18 : #ifdef __PLUMED_HAS_LIBTORCH 19 : // convert LibTorch version to string 20 : //#define STRINGIFY(x) #x 21 : //#define TOSTR(x) STRINGIFY(x) 22 : //#define LIBTORCH_VERSION TO_STR(TORCH_VERSION_MAJOR) "." TO_STR(TORCH_VERSION_MINOR) "." TO_STR(TORCH_VERSION_PATCH) 23 : 24 : #include "core/PlumedMain.h" 25 : #include "function/Function.h" 26 : #include "function/ActionRegister.h" 27 : 28 : #include <torch/torch.h> 29 : #include <torch/script.h> 30 : 31 : #include <fstream> 32 : #include <cmath> 33 : 34 : using namespace std; 35 : 36 : namespace PLMD { 37 : namespace function { 38 : namespace pytorch { 39 : 40 : //+PLUMEDOC PYTORCH_FUNCTION PYTORCH_MODEL 41 : /* 42 : Load a PyTorch model compiled with TorchScript. 43 : 44 : This can be a function defined in Python or a more complex model, such as a neural network optimized on a set of data. In both cases the derivatives of the outputs with respect to the inputs are computed using the automatic differentiation (autograd) feature of Pytorch. 45 : 46 : By default it is assumed that the model is saved as: `model.ptc`, unless otherwise indicated by the `FILE` keyword. The function automatically checks for the number of output dimensions and creates a component for each of them. The outputs are called node-i with i between 0 and N-1 for N outputs. 47 : 48 : Note that this function is active only if LibTorch is correctly linked against PLUMED. Please check the instructions in the \ref PYTORCH page. 49 : 50 : \par Examples 51 : Load a model called `torch_model.ptc` that takes as input two dihedral angles and returns two outputs. 52 : 53 : \plumedfile 54 : #SETTINGS AUXFILE=regtest/pytorch/rt-pytorch_model_2d/torch_model.ptc 55 : phi: TORSION ATOMS=5,7,9,15 56 : psi: TORSION ATOMS=7,9,15,17 57 : model: PYTORCH_MODEL FILE=torch_model.ptc ARG=phi,psi 58 : PRINT FILE=COLVAR ARG=model.node-0,model.node-1 59 : \endplumedfile 60 : 61 : */ 62 : //+ENDPLUMEDOC 63 : 64 : 65 : class PytorchModel : 66 : public Function { 67 : unsigned _n_in; 68 : unsigned _n_out; 69 : torch::jit::script::Module _model; 70 : 71 : public: 72 : explicit PytorchModel(const ActionOptions&); 73 : void calculate(); 74 : static void registerKeywords(Keywords& keys); 75 : 76 : std::vector<float> tensor_to_vector(const torch::Tensor& x); 77 : }; 78 : 79 13793 : PLUMED_REGISTER_ACTION(PytorchModel,"PYTORCH_MODEL") 80 : 81 8 : void PytorchModel::registerKeywords(Keywords& keys) { 82 8 : Function::registerKeywords(keys); 83 8 : keys.use("ARG"); 84 16 : keys.add("optional","FILE","Filename of the PyTorch compiled model"); 85 16 : keys.addOutputComponent("node", "default", "Model outputs"); 86 8 : } 87 : 88 : // Auxiliary function to transform torch tensors in std vectors 89 103 : std::vector<float> PytorchModel::tensor_to_vector(const torch::Tensor& x) { 90 206 : return std::vector<float>(x.data_ptr<float>(), x.data_ptr<float>() + x.numel()); 91 : } 92 : 93 4 : PytorchModel::PytorchModel(const ActionOptions&ao): 94 : Action(ao), 95 4 : Function(ao) { 96 : //print pytorch version 97 : 98 : //number of inputs of the model 99 4 : _n_in=getNumberOfArguments(); 100 : 101 : //parse model name 102 4 : std::string fname="model.ptc"; 103 8 : parse("FILE",fname); 104 : 105 : //deserialize the model from file 106 : try { 107 4 : _model = torch::jit::load(fname); 108 : } 109 : //if an error is thrown check if the file exists or not 110 0 : catch (const c10::Error& e) { 111 0 : std::ifstream infile(fname); 112 : bool exist = infile.good(); 113 0 : infile.close(); 114 0 : if (exist) { 115 : // print libtorch version 116 0 : std::stringstream ss; 117 0 : ss << TORCH_VERSION_MAJOR << "." << TORCH_VERSION_MINOR << "." << TORCH_VERSION_PATCH; 118 : std::string version; 119 0 : ss >> version; // extract into the string. 120 0 : plumed_merror("Cannot load FILE: '"+fname+"'. Please check that it is a Pytorch compiled model (exported with 'torch.jit.trace' or 'torch.jit.script') and that the Pytorch version matches the LibTorch one ("+version+")."); 121 0 : } else { 122 0 : plumed_merror("The FILE: '"+fname+"' does not exist."); 123 : } 124 0 : } 125 : 126 4 : checkRead(); 127 : 128 : //check the dimension of the output 129 4 : log.printf("Checking output dimension:\n"); 130 4 : std::vector<float> input_test (_n_in); 131 4 : torch::Tensor single_input = torch::tensor(input_test).view({1,_n_in}); 132 : std::vector<torch::jit::IValue> inputs; 133 4 : inputs.push_back( single_input ); 134 8 : torch::Tensor output = _model.forward( inputs ).toTensor(); 135 4 : vector<float> cvs = this->tensor_to_vector (output); 136 4 : _n_out=cvs.size(); 137 : 138 : //create components 139 9 : for(unsigned j=0; j<_n_out; j++) { 140 5 : string name_comp = "node-"+std::to_string(j); 141 5 : addComponentWithDerivatives( name_comp ); 142 5 : componentIsNotPeriodic( name_comp ); 143 : } 144 : 145 : //print log 146 : //log.printf("Pytorch Model Loaded: %s \n",fname); 147 4 : log.printf("Number of input: %d \n",_n_in); 148 4 : log.printf("Number of outputs: %d \n",_n_out); 149 4 : log.printf(" Bibliography: "); 150 8 : log<<plumed.cite("Bonati, Rizzi and Parrinello, J. Phys. Chem. Lett. 11, 2998-3004 (2020)"); 151 4 : log.printf("\n"); 152 : 153 8 : } 154 : 155 44 : void PytorchModel::calculate() { 156 : 157 : //retrieve arguments 158 44 : vector<float> current_S(_n_in); 159 99 : for(unsigned i=0; i<_n_in; i++) { 160 55 : current_S[i]=getArgument(i); 161 : } 162 : //convert to tensor 163 44 : torch::Tensor input_S = torch::tensor(current_S).view({1,_n_in}); 164 : input_S.set_requires_grad(true); 165 : //convert to Ivalue 166 : std::vector<torch::jit::IValue> inputs; 167 44 : inputs.push_back( input_S ); 168 : //calculate output 169 88 : torch::Tensor output = _model.forward( inputs ).toTensor(); 170 : //set CV values 171 44 : vector<float> cvs = this->tensor_to_vector (output); 172 99 : for(unsigned j=0; j<_n_out; j++) { 173 55 : string name_comp = "node-"+std::to_string(j); 174 55 : getPntrToComponent(name_comp)->set(cvs[j]); 175 : } 176 : //derivatives 177 99 : for(unsigned j=0; j<_n_out; j++) { 178 : // expand dim to have shape (1,_n_out) 179 : int batch_size = 1; 180 55 : auto grad_output = torch::ones({1}).expand({batch_size, 1}); 181 : // calculate derivatives with automatic differentiation 182 220 : auto gradient = torch::autograd::grad({output.slice(/*dim=*/1, /*start=*/j, /*end=*/j+1)}, 183 : {input_S}, 184 : /*grad_outputs=*/ {grad_output}, 185 : /*retain_graph=*/true, 186 330 : /*create_graph=*/false); 187 : // add dimension 188 : auto grad = gradient[0].unsqueeze(/*dim=*/1); 189 : //convert to vector 190 55 : vector<float> der = this->tensor_to_vector ( grad ); 191 : 192 55 : string name_comp = "node-"+std::to_string(j); 193 : //set derivatives of component j 194 132 : for(unsigned i=0; i<_n_in; i++) { 195 77 : setDerivative( getPntrToComponent(name_comp),i, der[i] ); 196 : } 197 55 : } 198 88 : } 199 : } 200 : } 201 : } 202 : 203 : #endif