SeExpr
Evaluator.h
Go to the documentation of this file.
1 /*
2  Copyright Disney Enterprises, Inc. All rights reserved.
3 
4  Licensed under the Apache License, Version 2.0 (the "License");
5  you may not use this file except in compliance with the License
6  and the following modification to it: Section 6 Trademarks.
7  deleted and replaced with:
8 
9  6. Trademarks. This License does not grant permission to use the
10  trade names, trademarks, service marks, or product names of the
11  Licensor and its affiliates, except as required for reproducing
12  the content of the NOTICE file.
13 
14  You may obtain a copy of the License at
15  http://www.apache.org/licenses/LICENSE-2.0
16 */
17 
18 #include "ExprConfig.h"
19 #include "ExprLLVMAll.h"
20 #include "VarBlock.h"
21 
22 #ifdef SEEXPR_ENABLE_LLVM
23 #include <llvm/Config/llvm-config.h>
24 #include <llvm/Support/Compiler.h>
25 #endif
26 
27 extern "C" void SeExpr2LLVMEvalFPVarRef(SeExpr2::ExprVarRef *seVR, double *result);
28 extern "C" void SeExpr2LLVMEvalStrVarRef(SeExpr2::ExprVarRef *seVR, double *result);
29 extern "C" void SeExpr2LLVMEvalCustomFunction(int *opDataArg,
30  double *fpArg,
31  char **strArg,
32  void **funcdata,
33  const SeExpr2::ExprFuncNode *node);
34 
35 namespace SeExpr2 {
36 #ifdef SEEXPR_ENABLE_LLVM
37 
38 LLVM_VALUE promoteToDim(LLVM_VALUE val, unsigned dim, llvm::IRBuilder<> &Builder);
39 
40 class LLVMEvaluator {
41  // TODO: this seems needlessly complex, let's fix it
42  // TODO: let the dev code allocate memory?
43  // FP is the native function for this expression.
44  template <class T>
45  class LLVMEvaluationContext {
46  private:
47  typedef void (*FunctionPtr)(T *, char **, uint32_t);
48  typedef void (*FunctionPtrMultiple)(char **, uint32_t, uint32_t, uint32_t);
49  FunctionPtr functionPtr;
50  FunctionPtrMultiple functionPtrMultiple;
51  T *resultData;
52 
53  public:
54  LLVMEvaluationContext(const LLVMEvaluationContext &) = delete;
55  LLVMEvaluationContext &operator=(const LLVMEvaluationContext &) = delete;
56  ~LLVMEvaluationContext() { delete[] resultData; }
57  LLVMEvaluationContext() : functionPtr(nullptr), resultData(nullptr) {}
58  void init(void *fp, void *fpLoop, int dim) {
59  reset();
60  functionPtr = reinterpret_cast<FunctionPtr>(fp);
61  functionPtrMultiple = reinterpret_cast<FunctionPtrMultiple>(fpLoop);
62  resultData = new T[dim];
63  }
64  void reset() {
65  if (resultData) delete[] resultData;
66  functionPtr = nullptr;
67  resultData = nullptr;
68  }
69  const T *operator()(VarBlock *varBlock) {
70  assert(functionPtr && resultData);
71  functionPtr(resultData, varBlock ? varBlock->data() : nullptr, varBlock ? varBlock->indirectIndex : 0);
72  return resultData;
73  }
74  void operator()(VarBlock *varBlock, size_t outputVarBlockOffset, size_t rangeStart, size_t rangeEnd) {
75  assert(functionPtr && resultData);
76  functionPtrMultiple(varBlock ? varBlock->data() : nullptr, outputVarBlockOffset, rangeStart, rangeEnd);
77  }
78  };
79  std::unique_ptr<LLVMEvaluationContext<double>> _llvmEvalFP;
80  std::unique_ptr<LLVMEvaluationContext<char *>> _llvmEvalStr;
81 
82  std::unique_ptr<llvm::LLVMContext> _llvmContext;
83  std::unique_ptr<llvm::ExecutionEngine> TheExecutionEngine;
84 
85  public:
86  LLVMEvaluator() {}
87 
88  const char *evalStr(VarBlock *varBlock) { return *(*_llvmEvalStr)(varBlock); }
89  const double *evalFP(VarBlock *varBlock) { return (*_llvmEvalFP)(varBlock); }
90 
91  void evalMultiple(VarBlock *varBlock, uint32_t outputVarBlockOffset, uint32_t rangeStart, uint32_t rangeEnd) {
92  return (*_llvmEvalFP)(varBlock, outputVarBlockOffset, rangeStart, rangeEnd);
93  }
94 
95  void debugPrint() {
96  // TheModule->print(llvm::errs(), nullptr);
97  }
98 
99  bool prepLLVM(ExprNode *parseTree, ExprType desiredReturnType) {
100  using namespace llvm;
101  InitializeNativeTarget();
102  InitializeNativeTargetAsmPrinter();
103  InitializeNativeTargetAsmParser();
104 
105  std::string uniqueName = getUniqueName();
106 
107  // create Module
108  _llvmContext.reset(new LLVMContext());
109 
110  std::unique_ptr<Module> TheModule(new Module(uniqueName + "_module", *_llvmContext));
111 
112  // create all needed types
113  Type *i8PtrTy = Type::getInt8PtrTy(*_llvmContext); // char *
114  PointerType *i8PtrPtrTy = PointerType::getUnqual(i8PtrTy); // char **
115  PointerType *i8PtrPtrPtrTy = PointerType::getUnqual(i8PtrPtrTy); // char ***
116  Type *i32Ty = Type::getInt32Ty(*_llvmContext); // int
117  Type *i32PtrTy = Type::getInt32PtrTy(*_llvmContext); // int *
118  Type *i64Ty = Type::getInt64Ty(*_llvmContext); // int64 *
119  Type *doublePtrTy = Type::getDoublePtrTy(*_llvmContext); // double *
120  PointerType *doublePtrPtrTy = PointerType::getUnqual(doublePtrTy); // double **
121  Type *voidTy = Type::getVoidTy(*_llvmContext); // void
122 
123  // create bindings to helper functions for variables and fucntions
124  Function *SeExpr2LLVMEvalCustomFunctionFunc = nullptr;
125  Function *SeExpr2LLVMEvalFPVarRefFunc = nullptr;
126  Function *SeExpr2LLVMEvalStrVarRefFunc = nullptr;
127  Function *SeExpr2LLVMEvalstrlenFunc = nullptr;
128  Function *SeExpr2LLVMEvalmallocFunc = nullptr;
129  Function *SeExpr2LLVMEvalfreeFunc = nullptr;
130  Function *SeExpr2LLVMEvalmemsetFunc = nullptr;
131  Function *SeExpr2LLVMEvalstrcatFunc = nullptr;
132  {
133  {
134  FunctionType *FT = FunctionType::get(voidTy, {i32PtrTy, doublePtrTy, i8PtrPtrTy, i8PtrPtrTy, i64Ty}, false);
135  SeExpr2LLVMEvalCustomFunctionFunc = Function::Create(FT, GlobalValue::ExternalLinkage, "SeExpr2LLVMEvalCustomFunction", TheModule.get());
136  }
137  {
138  FunctionType *FT = FunctionType::get(voidTy, {i8PtrTy, doublePtrTy}, false);
139  SeExpr2LLVMEvalFPVarRefFunc = Function::Create(FT, GlobalValue::ExternalLinkage, "SeExpr2LLVMEvalFPVarRef", TheModule.get());
140  }
141  {
142  FunctionType *FT = FunctionType::get(voidTy, {i8PtrTy, i8PtrPtrTy}, false);
143  SeExpr2LLVMEvalStrVarRefFunc = Function::Create(FT, GlobalValue::ExternalLinkage, "SeExpr2LLVMEvalStrVarRef", TheModule.get());
144  }
145  {
146  FunctionType *FT = FunctionType::get(i32Ty, { i8PtrTy }, false);
147  SeExpr2LLVMEvalstrlenFunc = Function::Create(FT, Function::ExternalLinkage, "strlen", TheModule.get());
148  }
149  {
150  FunctionType *FT = FunctionType::get(i8PtrTy, { i32Ty }, false);
151  SeExpr2LLVMEvalmallocFunc = Function::Create(FT, Function::ExternalLinkage, "malloc", TheModule.get());
152  }
153  {
154  FunctionType *FT = FunctionType::get(voidTy, { i8PtrTy }, false);
155  SeExpr2LLVMEvalfreeFunc = Function::Create(FT, Function::ExternalLinkage, "free", TheModule.get());
156  }
157  {
158  FunctionType *FT = FunctionType::get(voidTy, { i8PtrTy, i32Ty, i32Ty }, false);
159  SeExpr2LLVMEvalmemsetFunc = Function::Create(FT, Function::ExternalLinkage, "memset", TheModule.get());
160  }
161  {
162  FunctionType *FT = FunctionType::get(i8PtrTy, { i8PtrTy, i8PtrTy }, false);
163  SeExpr2LLVMEvalstrcatFunc = Function::Create(FT, Function::ExternalLinkage, "strcat", TheModule.get());
164  }
165  }
166 
167  // create function and entry BB
168  bool desireFP = desiredReturnType.isFP();
169  Type *ParamTys[] = {
170  desireFP ? doublePtrTy : i8PtrPtrTy,
171  doublePtrPtrTy,
172  i32Ty
173  };
174  FunctionType *FT = FunctionType::get(voidTy, ParamTys, false);
175  Function *F = Function::Create(FT, Function::ExternalLinkage, uniqueName + "_func", TheModule.get());
176 #if LLVM_VERSION_MAJOR > 4
177  F->addAttribute(llvm::AttributeList::FunctionIndex, llvm::Attribute::AlwaysInline);
178 #else
179  F->addAttribute(llvm::AttributeSet::FunctionIndex, llvm::Attribute::AlwaysInline);
180 #endif
181  {
182  // label the function with names
183  const char *names[] = {"outputPointer", "dataBlock", "indirectIndex"};
184  int idx = 0;
185  for (auto &arg : F->args()) arg.setName(names[idx++]);
186  }
187 
188  unsigned int dimDesired = (unsigned)desiredReturnType.dim();
189  unsigned int dimGenerated = parseTree->type().dim();
190  {
191  BasicBlock *BB = BasicBlock::Create(*_llvmContext, "entry", F);
192  IRBuilder<> Builder(BB);
193 
194  // codegen
195  Value *lastVal = parseTree->codegen(Builder);
196 
197  // return values through parameter.
198  Value *firstArg = &*F->arg_begin();
199  if (desireFP) {
200  if (dimGenerated > 1) {
201  Value *newLastVal = promoteToDim(lastVal, dimDesired, Builder);
202  assert(newLastVal->getType()->getVectorNumElements() >= dimDesired);
203  for (unsigned i = 0; i < dimDesired; ++i) {
204  Value *idx = ConstantInt::get(Type::getInt64Ty(*_llvmContext), i);
205  Value *val = Builder.CreateExtractElement(newLastVal, idx);
206  Value *ptr = Builder.CreateInBoundsGEP(firstArg, idx);
207  Builder.CreateStore(val, ptr);
208  }
209  } else if (dimGenerated == 1) {
210  for (unsigned i = 0; i < dimDesired; ++i) {
211  Value *ptr = Builder.CreateConstInBoundsGEP1_32(nullptr, firstArg, i);
212  Builder.CreateStore(lastVal, ptr);
213  }
214  } else {
215  assert(false && "error. dim of FP is less than 1.");
216  }
217  } else {
218  Builder.CreateStore(lastVal, firstArg);
219  }
220 
221  Builder.CreateRetVoid();
222  }
223 
224  // write a new function
225  FunctionType *FTLOOP = FunctionType::get(voidTy, {i8PtrTy, i32Ty, i32Ty, i32Ty}, false);
226  Function *FLOOP = Function::Create(FTLOOP, Function::ExternalLinkage, uniqueName + "_loopfunc", TheModule.get());
227  {
228  // label the function with names
229  const char *names[] = {"dataBlock", "outputVarBlockOffset", "rangeStart", "rangeEnd"};
230  int idx = 0;
231  for (auto &arg : FLOOP->args()) {
232  arg.setName(names[idx++]);
233  }
234  }
235  {
236  // Local variables
237  Value *dimValue = ConstantInt::get(i32Ty, dimDesired);
238  Value *oneValue = ConstantInt::get(i32Ty, 1);
239 
240  // Basic blocks
241  BasicBlock *entryBlock = BasicBlock::Create(*_llvmContext, "entry", FLOOP);
242  BasicBlock *loopCmpBlock = BasicBlock::Create(*_llvmContext, "loopCmp", FLOOP);
243  BasicBlock *loopRepeatBlock = BasicBlock::Create(*_llvmContext, "loopRepeat", FLOOP);
244  BasicBlock *loopIncBlock = BasicBlock::Create(*_llvmContext, "loopInc", FLOOP);
245  BasicBlock *loopEndBlock = BasicBlock::Create(*_llvmContext, "loopEnd", FLOOP);
246  IRBuilder<> Builder(entryBlock);
247  Builder.SetInsertPoint(entryBlock);
248 
249  // Get arguments
250  Function::arg_iterator argIterator = FLOOP->arg_begin();
251  Value *varBlockCharPtrPtrArg = &*argIterator; ++argIterator;
252  Value *outputVarBlockOffsetArg = &*argIterator; ++argIterator;
253  Value *rangeStartArg = &*argIterator; ++argIterator;
254  Value *rangeEndArg = &*argIterator; ++argIterator;
255 
256  // Allocate Variables
257  Value *rangeStartVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue, "rangeStartVar");
258  Value *rangeEndVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue, "rangeEndVar");
259  Value *indexVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue, "indexVar");
260  Value *outputVarBlockOffsetVar = Builder.CreateAlloca(Type::getInt32Ty(*_llvmContext), oneValue, "outputVarBlockOffsetVar");
261  Value *varBlockDoublePtrPtrVar = Builder.CreateAlloca(doublePtrPtrTy, oneValue, "varBlockDoublePtrPtrVar");
262  Value *varBlockTPtrPtrVar = Builder.CreateAlloca(desireFP == true ? doublePtrPtrTy : i8PtrPtrPtrTy, oneValue, "varBlockTPtrPtrVar");
263 
264  // Copy variables from args
265  Builder.CreateStore(Builder.CreatePointerCast(varBlockCharPtrPtrArg, doublePtrPtrTy, "varBlockAsDoublePtrPtr"), varBlockDoublePtrPtrVar);
266  Builder.CreateStore(Builder.CreatePointerCast(varBlockCharPtrPtrArg, desireFP ? doublePtrPtrTy : i8PtrPtrPtrTy, "varBlockAsTPtrPtr"), varBlockTPtrPtrVar);
267  Builder.CreateStore(rangeStartArg, rangeStartVar);
268  Builder.CreateStore(rangeEndArg, rangeEndVar);
269  Builder.CreateStore(outputVarBlockOffsetArg, outputVarBlockOffsetVar);
270 
271  // Set output pointer
272  Value *outputBasePtrPtr = Builder.CreateGEP(nullptr, Builder.CreateLoad(varBlockTPtrPtrVar), outputVarBlockOffsetArg, "outputBasePtrPtr");
273  Value *outputBasePtr = Builder.CreateLoad(outputBasePtrPtr, "outputBasePtr");
274  Builder.CreateStore(Builder.CreateLoad(rangeStartVar), indexVar);
275 
276  Builder.CreateBr(loopCmpBlock);
277  Builder.SetInsertPoint(loopCmpBlock);
278  Value *cond = Builder.CreateICmpULT(Builder.CreateLoad(indexVar), Builder.CreateLoad(rangeEndVar));
279  Builder.CreateCondBr(cond, loopRepeatBlock, loopEndBlock);
280 
281  Builder.SetInsertPoint(loopRepeatBlock);
282  Value *myOutputPtr = Builder.CreateGEP(nullptr, outputBasePtr, Builder.CreateMul(dimValue, Builder.CreateLoad(indexVar)));
283  Builder.CreateCall(F, {myOutputPtr, Builder.CreateLoad(varBlockDoublePtrPtrVar), Builder.CreateLoad(indexVar)});
284 
285  Builder.CreateBr(loopIncBlock);
286 
287  Builder.SetInsertPoint(loopIncBlock);
288  Builder.CreateStore(Builder.CreateAdd(Builder.CreateLoad(indexVar), oneValue), indexVar);
289  Builder.CreateBr(loopCmpBlock);
290 
291  Builder.SetInsertPoint(loopEndBlock);
292  Builder.CreateRetVoid();
293  }
294 
295  if (Expression::debugging) {
296  #ifdef DEBUG
297  std::cerr << "Pre verified LLVM byte code " << std::endl;
298  TheModule->print(llvm::errs(), nullptr);
299  #endif
300  }
301 
302  // TODO: Find out if there is a new way to veirfy
303  // if (verifyModule(*TheModule)) {
304  // std::cerr << "Logic error in code generation of LLVM alert developers" << std::endl;
305  // TheModule->print(llvm::errs(), nullptr);
306  // }
307  Module *altModule = TheModule.get();
308  std::string ErrStr;
309  TheExecutionEngine.reset(EngineBuilder(std::move(TheModule))
310  .setErrorStr(&ErrStr)
311  // .setUseMCJIT(true)
312  .setOptLevel(CodeGenOpt::Aggressive)
313  .create());
314 
315  altModule->setDataLayout(TheExecutionEngine->getDataLayout());
316 
317  // Add bindings to C linkage helper functions
318  TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalFPVarRefFunc, (void *)SeExpr2LLVMEvalFPVarRef);
319  TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalStrVarRefFunc, (void *)SeExpr2LLVMEvalStrVarRef);
320  TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalCustomFunctionFunc, (void *)SeExpr2LLVMEvalCustomFunction);
321  TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalstrlenFunc, (void *)strlen);
322  TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalstrcatFunc, (void *)strcat);
323  TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalmemsetFunc, (void *)memset);
324  TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalmallocFunc, (void *)malloc);
325  TheExecutionEngine->addGlobalMapping(SeExpr2LLVMEvalfreeFunc, (void *)free);
326 
327  // [verify]
328  std::string errorStr;
329  llvm::raw_string_ostream raw(errorStr);
330  if (llvm::verifyModule(*altModule, &raw)) {
331  parseTree->addError(raw.str());
332  return false;
333  }
334 
335  // Setup optimization
336  llvm::PassManagerBuilder builder;
337  std::unique_ptr<llvm::legacy::PassManager> pm(new llvm::legacy::PassManager);
338  std::unique_ptr<llvm::legacy::FunctionPassManager> fpm(new llvm::legacy::FunctionPassManager(altModule));
339  builder.OptLevel = 3;
340 #if (LLVM_VERSION_MAJOR >= 4)
341  builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
342 #else
343  builder.Inliner = llvm::createAlwaysInlinerPass();
344 #endif
345  builder.populateModulePassManager(*pm);
346  // fpm->add(new llvm::DataLayoutPass());
347  builder.populateFunctionPassManager(*fpm);
348  fpm->run(*F);
349  fpm->run(*FLOOP);
350  pm->run(*altModule);
351 
352  // Create the JIT. This takes ownership of the module.
353 
354  if (!TheExecutionEngine) {
355  fprintf(stderr, "Could not create ExecutionEngine: %s\n", ErrStr.c_str());
356  exit(1);
357  }
358 
359  TheExecutionEngine->finalizeObject();
360  void *fp = TheExecutionEngine->getPointerToFunction(F);
361  void *fpLoop = TheExecutionEngine->getPointerToFunction(FLOOP);
362  if (desireFP) {
363  _llvmEvalFP.reset(new LLVMEvaluationContext<double>);
364  _llvmEvalFP->init(fp, fpLoop, dimDesired);
365  } else {
366  _llvmEvalStr.reset(new LLVMEvaluationContext<char *>);
367  _llvmEvalStr->init(fp, fpLoop, dimDesired);
368  }
369 
370  if (Expression::debugging) {
371  #ifdef DEBUG
372  std::cerr << "Pre verified LLVM byte code " << std::endl;
373  altModule->print(llvm::errs(), nullptr);
374  #endif
375  }
376 
377  return true;
378  }
379 
380  std::string getUniqueName() const {
381  std::ostringstream o;
382  o << std::setbase(16) << (uint64_t)(this);
383  return ("_" + o.str());
384  }
385 };
386 
387 #else // no LLVM support
389  public:
390  void unsupported() { throw std::runtime_error("LLVM is not enabled in build"); }
391  const char *evalStr(VarBlock *varBlock) {
392  unsupported();
393  return "";
394  }
395  const double *evalFP(VarBlock *varBlock) {
396  unsupported();
397  return 0;
398  }
399  bool prepLLVM(ExprNode *parseTree, ExprType desiredReturnType) {
400  unsupported();
401  return false;
402  }
403  void evalMultiple(VarBlock *varBlock, int outputVarBlockOffset, size_t rangeStart, size_t rangeEnd) {
404  unsupported();
405  }
406  void debugPrint() {}
407 };
408 #endif
409 
410 } // end namespace SeExpr2
SeExpr2::Expression::debugging
static bool debugging
Whether to debug expressions.
Definition: Expression.h:86
llvm
Definition: Expression.h:31
SeExpr2::LLVMEvaluator::evalFP
const double * evalFP(VarBlock *varBlock)
Definition: Evaluator.h:395
SeExpr2::VarBlock
A thread local evaluation context. Just allocate and fill in with data.
Definition: VarBlock.h:33
SeExpr2::LLVMEvaluator::debugPrint
void debugPrint()
Definition: Evaluator.h:406
SeExpr2::LLVMEvaluator::evalMultiple
void evalMultiple(VarBlock *varBlock, int outputVarBlockOffset, size_t rangeStart, size_t rangeEnd)
Definition: Evaluator.h:403
SeExpr2LLVMEvalStrVarRef
void SeExpr2LLVMEvalStrVarRef(SeExpr2::ExprVarRef *seVR, double *result)
SeExpr2::ExprNode
Definition: ExprNode.h:72
ExprLLVMAll.h
SeExpr2::ExprType
Definition: ExprType.h:39
SeExpr2::LLVMEvaluator::unsupported
void unsupported()
Definition: Evaluator.h:390
SeExpr2::ExprFuncNode
Node that calls a function.
Definition: ExprNode.h:517
LLVM_VALUE
double LLVM_VALUE
Definition: ExprLLVM.h:33
SeExpr2
Definition: Context.h:22
SeExpr2::LLVMEvaluator::prepLLVM
bool prepLLVM(ExprNode *parseTree, ExprType desiredReturnType)
Definition: Evaluator.h:399
SeExpr2::LLVMEvaluator
Definition: Evaluator.h:388
SeExpr2LLVMEvalFPVarRef
void SeExpr2LLVMEvalFPVarRef(SeExpr2::ExprVarRef *seVR, double *result)
SeExpr2LLVMEvalCustomFunction
void SeExpr2LLVMEvalCustomFunction(int *opDataArg, double *fpArg, char **strArg, void **funcdata, const SeExpr2::ExprFuncNode *node)
Definition: ExprFuncX.cpp:110
SeExpr2::ExprVarRef
abstract class for implementing variable references
Definition: Expression.h:45
SeExpr2::LLVMEvaluator::evalStr
const char * evalStr(VarBlock *varBlock)
Definition: Evaluator.h:391
VarBlock.h