Line data Source code
1 : /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2 : * -------------------------------------------------------------------------- *
3 : * Lepton *
4 : * -------------------------------------------------------------------------- *
5 : * This is part of the Lepton expression parser originating from *
6 : * Simbios, the NIH National Center for Physics-Based Simulation of *
7 : * Biological Structures at Stanford, funded under the NIH Roadmap for *
8 : * Medical Research, grant U54 GM072970. See https://simtk.org. *
9 : * *
10 : * Portions copyright (c) 2013-2016 Stanford University and the Authors. *
11 : * Authors: Peter Eastman *
12 : * Contributors: *
13 : * *
14 : * Permission is hereby granted, free of charge, to any person obtaining a *
15 : * copy of this software and associated documentation files (the "Software"), *
16 : * to deal in the Software without restriction, including without limitation *
17 : * the rights to use, copy, modify, merge, publish, distribute, sublicense, *
18 : * and/or sell copies of the Software, and to permit persons to whom the *
19 : * Software is furnished to do so, subject to the following conditions: *
20 : * *
21 : * The above copyright notice and this permission notice shall be included in *
22 : * all copies or substantial portions of the Software. *
23 : * *
24 : * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
25 : * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
26 : * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
27 : * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
28 : * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
29 : * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
30 : * USE OR OTHER DEALINGS IN THE SOFTWARE. *
31 : * -------------------------------------------------------------------------- *
32 : +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ */
33 : /* -------------------------------------------------------------------------- *
34 : * lepton *
35 : * -------------------------------------------------------------------------- *
36 : * This is part of the lepton expression parser originating from *
37 : * Simbios, the NIH National Center for Physics-Based Simulation of *
38 : * Biological Structures at Stanford, funded under the NIH Roadmap for *
39 : * Medical Research, grant U54 GM072970. See https://simtk.org. *
40 : * *
41 : * Portions copyright (c) 2013-2016 Stanford University and the Authors. *
42 : * Authors: Peter Eastman *
43 : * Contributors: *
44 : * *
45 : * Permission is hereby granted, free of charge, to any person obtaining a *
46 : * copy of this software and associated documentation files (the "Software"), *
47 : * to deal in the Software without restriction, including without limitation *
48 : * the rights to use, copy, modify, merge, publish, distribute, sublicense, *
49 : * and/or sell copies of the Software, and to permit persons to whom the *
50 : * Software is furnished to do so, subject to the following conditions: *
51 : * *
52 : * The above copyright notice and this permission notice shall be included in *
53 : * all copies or substantial portions of the Software. *
54 : * *
55 : * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
56 : * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
57 : * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
58 : * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
59 : * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
60 : * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
61 : * USE OR OTHER DEALINGS IN THE SOFTWARE. *
62 : * -------------------------------------------------------------------------- */
63 :
64 : #include "CompiledExpression.h"
65 : #include "Operation.h"
66 : #include "ParsedExpression.h"
67 : #ifdef __PLUMED_HAS_ASMJIT
68 : #include "asmjit/asmjit.h"
69 : #endif
70 : #include <utility>
71 :
72 : namespace PLMD {
73 : using namespace lepton;
74 : using namespace std;
75 : #ifdef __PLUMED_HAS_ASMJIT
76 : using namespace asmjit;
77 : #endif
78 :
79 1664 : AsmJitRuntimePtr::AsmJitRuntimePtr()
80 : #ifdef __PLUMED_HAS_ASMJIT
81 : : ptr(new asmjit::JitRuntime)
82 : #endif
83 1664 : {}
84 :
85 1664 : AsmJitRuntimePtr::~AsmJitRuntimePtr()
86 : {
87 : #ifdef __PLUMED_HAS_ASMJIT
88 : delete static_cast<asmjit::JitRuntime*>(ptr);
89 : #endif
90 1664 : }
91 :
92 3384 : CompiledExpression::CompiledExpression() : jitCode(NULL) {
93 846 : }
94 :
95 4090 : CompiledExpression::CompiledExpression(const ParsedExpression& expression) : jitCode(NULL) {
96 818 : ParsedExpression expr = expression.optimize(); // Just in case it wasn't already optimized.
97 818 : vector<pair<ExpressionTreeNode, int> > temps;
98 818 : compileExpression(expr.getRootNode(), temps);
99 : int maxArguments = 1;
100 10740 : for (int i = 0; i < (int) operation.size(); i++)
101 9922 : if (operation[i]->getNumArguments() > maxArguments)
102 538 : maxArguments = operation[i]->getNumArguments();
103 818 : argValues.resize(maxArguments);
104 : #ifdef __PLUMED_HAS_ASMJIT
105 : generateJitCode();
106 : #endif
107 818 : }
108 :
109 6656 : CompiledExpression::~CompiledExpression() {
110 21508 : for (int i = 0; i < (int) operation.size(); i++)
111 19844 : if (operation[i] != NULL)
112 9922 : delete operation[i];
113 1664 : }
114 :
115 0 : CompiledExpression::CompiledExpression(const CompiledExpression& expression) : jitCode(NULL) {
116 0 : *this = expression;
117 0 : }
118 :
119 818 : CompiledExpression& CompiledExpression::operator=(const CompiledExpression& expression) {
120 818 : arguments = expression.arguments;
121 818 : target = expression.target;
122 : variableIndices = expression.variableIndices;
123 : variableNames = expression.variableNames;
124 1636 : workspace.resize(expression.workspace.size());
125 1636 : argValues.resize(expression.argValues.size());
126 1636 : operation.resize(expression.operation.size());
127 10740 : for (int i = 0; i < (int) operation.size(); i++)
128 14883 : operation[i] = expression.operation[i]->clone();
129 818 : setVariableLocations(variablePointers);
130 818 : return *this;
131 : }
132 :
133 7544 : void CompiledExpression::compileExpression(const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, int> >& temps) {
134 7544 : if (findTempIndex(node, temps) != -1)
135 1663 : return; // We have already processed a node identical to this one.
136 :
137 : // Process the child nodes.
138 :
139 : vector<int> args;
140 31940 : for (int i = 0; i < node.getChildren().size(); i++) {
141 13452 : compileExpression(node.getChildren()[i], temps);
142 20178 : args.push_back(findTempIndex(node.getChildren()[i], temps));
143 : }
144 :
145 : // Process this node.
146 :
147 5881 : if (node.getOperation().getId() == Operation::VARIABLE) {
148 1840 : variableIndices[node.getOperation().getName()] = (int) workspace.size();
149 1840 : variableNames.insert(node.getOperation().getName());
150 : }
151 : else {
152 4961 : int stepIndex = (int) arguments.size();
153 9922 : arguments.push_back(vector<int>());
154 14883 : target.push_back((int) workspace.size());
155 9922 : operation.push_back(node.getOperation().clone());
156 4961 : if (args.size() == 0)
157 231 : arguments[stepIndex].push_back(0); // The value won't actually be used. We just need something there.
158 : else {
159 : // If the arguments are sequential, we can just pass a pointer to the first one.
160 :
161 : bool sequential = true;
162 8568 : for (int i = 1; i < args.size(); i++)
163 3684 : if (args[i] != args[i-1]+1)
164 : sequential = false;
165 4884 : if (sequential)
166 7590 : arguments[stepIndex].push_back(args[0]);
167 : else
168 2178 : arguments[stepIndex] = args;
169 : }
170 : }
171 11762 : temps.push_back(make_pair(node, (int) workspace.size()));
172 11762 : workspace.push_back(0.0);
173 : }
174 :
175 14270 : int CompiledExpression::findTempIndex(const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, int> >& temps) {
176 95716 : for (int i = 0; i < (int) temps.size(); i++)
177 98224 : if (temps[i].first == node)
178 : return i;
179 : return -1;
180 : }
181 :
182 347 : const set<string>& CompiledExpression::getVariables() const {
183 347 : return variableNames;
184 : }
185 :
186 1389 : double& CompiledExpression::getVariableReference(const string& name) {
187 : map<string, double*>::iterator pointer = variablePointers.find(name);
188 1389 : if (pointer != variablePointers.end())
189 0 : return *pointer->second;
190 : map<string, int>::iterator index = variableIndices.find(name);
191 1389 : if (index == variableIndices.end())
192 1407 : throw Exception("getVariableReference: Unknown variable '"+name+"'");
193 1840 : return workspace[index->second];
194 : }
195 :
196 818 : void CompiledExpression::setVariableLocations(map<string, double*>& variableLocations) {
197 : variablePointers = variableLocations;
198 : #ifdef __PLUMED_HAS_ASMJIT
199 : // Rebuild the JIT code.
200 :
201 : if (workspace.size() > 0)
202 : generateJitCode();
203 : #else
204 : // Make a list of all variables we will need to copy before evaluating the expression.
205 :
206 818 : variablesToCopy.clear();
207 1738 : for (map<string, int>::const_iterator iter = variableIndices.begin(); iter != variableIndices.end(); ++iter) {
208 920 : map<string, double*>::iterator pointer = variablePointers.find(iter->first);
209 920 : if (pointer != variablePointers.end())
210 0 : variablesToCopy.push_back(make_pair(&workspace[iter->second], pointer->second));
211 : }
212 : #endif
213 818 : }
214 :
215 12891770 : double CompiledExpression::evaluate() const {
216 : #ifdef __PLUMED_HAS_ASMJIT
217 : return ((double (*)()) jitCode)();
218 : #else
219 25783540 : for (int i = 0; i < variablesToCopy.size(); i++)
220 0 : *variablesToCopy[i].first = *variablesToCopy[i].second;
221 :
222 : // Loop over the operations and evaluate each one.
223 :
224 208982041 : for (int step = 0; step < operation.size(); step++) {
225 : const vector<int>& args = arguments[step];
226 61066167 : if (args.size() == 1)
227 164141169 : workspace[target[step]] = operation[step]->evaluate(&workspace[args[0]], dummyVariables);
228 : else {
229 50819552 : for (int i = 0; i < args.size(); i++)
230 25409776 : argValues[i] = workspace[args[i]];
231 19057332 : workspace[target[step]] = operation[step]->evaluate(&argValues[0], dummyVariables);
232 : }
233 : }
234 25783540 : return workspace[workspace.size()-1];
235 : #endif
236 : }
237 :
238 : #ifdef __PLUMED_HAS_ASMJIT
239 : static double evaluateOperation(Operation* op, double* args) {
240 : static map<string, double> dummyVariables;
241 : return op->evaluate(args, dummyVariables);
242 : }
243 :
244 : static void generateSingleArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg, double (*function)(double));
245 :
246 : void CompiledExpression::generateJitCode() {
247 : CodeHolder code;
248 : auto & runtime(*static_cast<asmjit::JitRuntime*>(runtimeptr.get()));
249 : code.init(runtime.getCodeInfo());
250 : X86Assembler a(&code);
251 : X86Compiler c(&code);
252 : c.addFunc(FuncSignature0<double>());
253 : vector<X86Xmm> workspaceVar(workspace.size());
254 : for (int i = 0; i < (int) workspaceVar.size(); i++)
255 : workspaceVar[i] = c.newXmmSd();
256 : X86Gp argsPointer = c.newIntPtr();
257 : c.mov(argsPointer, imm_ptr(&argValues[0]));
258 :
259 : // Load the arguments into variables.
260 :
261 : for (set<string>::const_iterator iter = variableNames.begin(); iter != variableNames.end(); ++iter) {
262 : map<string, int>::iterator index = variableIndices.find(*iter);
263 : X86Gp variablePointer = c.newIntPtr();
264 : c.mov(variablePointer, imm_ptr(&getVariableReference(index->first)));
265 : c.movsd(workspaceVar[index->second], x86::ptr(variablePointer, 0, 0));
266 : }
267 :
268 : // Make a list of all constants that will be needed for evaluation.
269 :
270 : vector<int> operationConstantIndex(operation.size(), -1);
271 : for (int step = 0; step < (int) operation.size(); step++) {
272 : // Find the constant value (if any) used by this operation.
273 :
274 : Operation& op = *operation[step];
275 : double value;
276 : if (op.getId() == Operation::CONSTANT)
277 : value = dynamic_cast<Operation::Constant&>(op).getValue();
278 : else if (op.getId() == Operation::ADD_CONSTANT)
279 : value = dynamic_cast<Operation::AddConstant&>(op).getValue();
280 : else if (op.getId() == Operation::MULTIPLY_CONSTANT)
281 : value = dynamic_cast<Operation::MultiplyConstant&>(op).getValue();
282 : else if (op.getId() == Operation::RECIPROCAL)
283 : value = 1.0;
284 : else if (op.getId() == Operation::STEP)
285 : value = 1.0;
286 : else if (op.getId() == Operation::DELTA)
287 : value = 1.0/0.0;
288 : else
289 : continue;
290 :
291 : // See if we already have a variable for this constant.
292 :
293 : for (int i = 0; i < (int) constants.size(); i++)
294 : if (value == constants[i]) {
295 : operationConstantIndex[step] = i;
296 : break;
297 : }
298 : if (operationConstantIndex[step] == -1) {
299 : operationConstantIndex[step] = constants.size();
300 : constants.push_back(value);
301 : }
302 : }
303 :
304 : // Load constants into variables.
305 :
306 : vector<X86Xmm> constantVar(constants.size());
307 : if (constants.size() > 0) {
308 : X86Gp constantsPointer = c.newIntPtr();
309 : c.mov(constantsPointer, imm_ptr(&constants[0]));
310 : for (int i = 0; i < (int) constants.size(); i++) {
311 : constantVar[i] = c.newXmmSd();
312 : c.movsd(constantVar[i], x86::ptr(constantsPointer, 8*i, 0));
313 : }
314 : }
315 :
316 : // Evaluate the operations.
317 :
318 : for (int step = 0; step < (int) operation.size(); step++) {
319 : Operation& op = *operation[step];
320 : vector<int> args = arguments[step];
321 : if (args.size() == 1) {
322 : // One or more sequential arguments. Fill out the list.
323 :
324 : for (int i = 1; i < op.getNumArguments(); i++)
325 : args.push_back(args[0]+i);
326 : }
327 :
328 : // Generate instructions to execute this operation.
329 :
330 : switch (op.getId()) {
331 : case Operation::CONSTANT:
332 : c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
333 : break;
334 : case Operation::ADD:
335 : c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
336 : c.addsd(workspaceVar[target[step]], workspaceVar[args[1]]);
337 : break;
338 : case Operation::SUBTRACT:
339 : c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
340 : c.subsd(workspaceVar[target[step]], workspaceVar[args[1]]);
341 : break;
342 : case Operation::MULTIPLY:
343 : c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
344 : c.mulsd(workspaceVar[target[step]], workspaceVar[args[1]]);
345 : break;
346 : case Operation::DIVIDE:
347 : c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
348 : c.divsd(workspaceVar[target[step]], workspaceVar[args[1]]);
349 : break;
350 : case Operation::NEGATE:
351 : c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
352 : c.subsd(workspaceVar[target[step]], workspaceVar[args[0]]);
353 : break;
354 : case Operation::SQRT:
355 : c.sqrtsd(workspaceVar[target[step]], workspaceVar[args[0]]);
356 : break;
357 : case Operation::EXP:
358 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], exp);
359 : break;
360 : case Operation::LOG:
361 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], log);
362 : break;
363 : case Operation::SIN:
364 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sin);
365 : break;
366 : case Operation::COS:
367 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cos);
368 : break;
369 : case Operation::TAN:
370 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tan);
371 : break;
372 : case Operation::ASIN:
373 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], asin);
374 : break;
375 : case Operation::ACOS:
376 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], acos);
377 : break;
378 : case Operation::ATAN:
379 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], atan);
380 : break;
381 : case Operation::SINH:
382 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], sinh);
383 : break;
384 : case Operation::COSH:
385 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], cosh);
386 : break;
387 : case Operation::TANH:
388 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], tanh);
389 : break;
390 : case Operation::STEP:
391 : c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
392 : c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(18)); // Comparison mode is _CMP_LE_OQ = 18
393 : c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
394 : break;
395 : case Operation::DELTA:
396 : c.xorps(workspaceVar[target[step]], workspaceVar[target[step]]);
397 : c.cmpsd(workspaceVar[target[step]], workspaceVar[args[0]], imm(16)); // Comparison mode is _CMP_EQ_OS = 16
398 : c.andps(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
399 : break;
400 : case Operation::SQUARE:
401 : c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
402 : c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
403 : break;
404 : case Operation::CUBE:
405 : c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
406 : c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
407 : c.mulsd(workspaceVar[target[step]], workspaceVar[args[0]]);
408 : break;
409 : case Operation::RECIPROCAL:
410 : c.movsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
411 : c.divsd(workspaceVar[target[step]], workspaceVar[args[0]]);
412 : break;
413 : case Operation::ADD_CONSTANT:
414 : c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
415 : c.addsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
416 : break;
417 : case Operation::MULTIPLY_CONSTANT:
418 : c.movsd(workspaceVar[target[step]], workspaceVar[args[0]]);
419 : c.mulsd(workspaceVar[target[step]], constantVar[operationConstantIndex[step]]);
420 : break;
421 : case Operation::ABS:
422 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], fabs);
423 : break;
424 : case Operation::FLOOR:
425 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], floor);
426 : break;
427 : case Operation::CEIL:
428 : generateSingleArgCall(c, workspaceVar[target[step]], workspaceVar[args[0]], ceil);
429 : break;
430 : default:
431 : // Just invoke evaluateOperation().
432 :
433 : for (int i = 0; i < (int) args.size(); i++)
434 : c.movsd(x86::ptr(argsPointer, 8*i, 0), workspaceVar[args[i]]);
435 : X86Gp fn = c.newIntPtr();
436 : c.mov(fn, imm_ptr((void*) evaluateOperation));
437 : CCFuncCall* call = c.call(fn, FuncSignature2<double, Operation*, double*>(CallConv::kIdHost));
438 : call->setArg(0, imm_ptr(&op));
439 : call->setArg(1, imm_ptr(&argValues[0]));
440 : call->setRet(0, workspaceVar[target[step]]);
441 : }
442 : }
443 : c.ret(workspaceVar[workspace.size()-1]);
444 : c.endFunc();
445 : c.finalize();
446 : typedef double (*Func0)(void);
447 : Func0 func0;
448 : Error err = runtime.add(&func0,&code);
449 : if(err) return;
450 : jitCode = (void*) func0;
451 : }
452 :
453 : void generateSingleArgCall(X86Compiler& c, X86Xmm& dest, X86Xmm& arg, double (*function)(double)) {
454 : X86Gp fn = c.newIntPtr();
455 : c.mov(fn, imm_ptr((void*) function));
456 : CCFuncCall* call = c.call(fn, FuncSignature1<double, double>(CallConv::kIdHost));
457 : call->setArg(0, arg);
458 : call->setRet(0, dest);
459 : }
460 : #endif
461 : }
|