Kaleidoscope 语言: JIT

在本章节中主要完成了2个工作

  1. 一些简单的优化
  2. JIT 化

编译优化 #

比如像这种

ready> def test(x) (1+2+x)*(x+(1+2));
ready> Read function definition:
define double @test(double %x) {
entry:
        %addtmp = fadd double 3.000000e+00, %x
        %addtmp1 = fadd double %x, 3.000000e+00
        %multmp = fmul double %addtmp, %addtmp1
        ret double %multmp
}

其实我们可以发现,这里面的 (1+2+x) 和 (x+(1+2)) 是一个意思,所以 LLVM 提供了优化的手段

void InitializeModuleAndPassManager () {
    // 新建一个 Context
    TheContext = std::make_unique<llvm::LLVMContext> ();
    TheModule  = std::make_unique<llvm::Module> ("My awesome JIT", *TheContext);
    TheModule->setDataLayout (TheJIT->getDataLayout ());
    // 创建一个 Builder
    Builder = std::make_unique<llvm::IRBuilder<>> (*TheContext);

    TheFPM = std::make_unique<llvm::legacy::FunctionPassManager> (TheModule.get ());
    // peephole 优化: https://zh.wikipedia.org/zh-sg/%E7%AA%A5%E5%AD%94%E4%BC%98%E5%8C%96 
    TheFPM->add (llvm::createInstructionCombiningPass ());
    // Reassociate 优化.
    TheFPM->add (llvm::createReassociatePass ());
    // Eliminate Common SubExpressions.
    TheFPM->add (llvm::createGVNPass ());
    // Simplify the control flow graph (deleting unreachable blocks, etc).
    TheFPM->add (llvm::createCFGSimplificationPass ());

    TheFPM->doInitialization ();
}

然后我们直接在生成 codegen 调用即可

if (Value *RetVal = Body->codegen()) {
  // Finish off the function.
  Builder.CreateRet(RetVal);

  // Validate the generated code, checking for consistency.
  verifyFunction(*TheFunction);

  // 优化我们的函数
  TheFPM->run(*TheFunction);

  return TheFunction;
}

结果如下

ready> def test(x) (1+2+x)*(x+(1+2));
ready> Read function definition:
define double @test(double %x) {
entry:
        %addtmp = fadd double %x, 3.000000e+00
        %multmp = fmul double %addtmp, %addtmp
        ret double %multmp
}

JIT 化 #

LLVM 提供了一个 JIT 文件 KaleidoscopeJIT.h

void InitJIT () {
    llvm::InitializeNativeTarget ();
    llvm::InitializeNativeTargetAsmPrinter ();
    llvm::InitializeNativeTargetAsmParser ();

    TheJIT = ExitOnErr (llvm::orc::KaleidoscopeJIT::Create ());
}


void InitializeModuleAndPassManager () {
    // Open a new context and module.
    TheContext = std::make_unique<llvm::LLVMContext> ();
    TheModule  = std::make_unique<llvm::Module> ("My awesome JIT", *TheContext);
    TheModule->setDataLayout (TheJIT->getDataLayout ());
    // Create a new builder for the module.
    Builder = std::make_unique<llvm::IRBuilder<>> (*TheContext);

    TheFPM = std::make_unique<llvm::legacy::FunctionPassManager> (TheModule.get ());
    // Do simple "peephole" optimizations and bit-twiddling optzns.
    TheFPM->add (llvm::createInstructionCombiningPass ());
    // Reassociate expressions.
    TheFPM->add (llvm::createReassociatePass ());
    // Eliminate Common SubExpressions.
    TheFPM->add (llvm::createGVNPass ());
    // Simplify the control flow graph (deleting unreachable blocks, etc).
    TheFPM->add (llvm::createCFGSimplificationPass ());

    TheFPM->doInitialization ();
}

我们再次执行

ready> 4+5;
Read top-level expression:
define double @0() {
entry:
  ret double 9.000000e+00
}

Evaluated to 9.000000

Extern Functions #

有点勘误,对于 LLVM-14 自定义函数需要如下处理 llvm-jit-symbols-not-found

void InitJIT () {
    llvm::InitializeNativeTarget ();
    llvm::InitializeNativeTargetAsmPrinter ();
    llvm::InitializeNativeTargetAsmParser ();

    TheJIT      = ExitOnErr (llvm::orc::KaleidoscopeJIT::Create ());
    auto &jd    = TheJIT->getMainJITDylib ();
    auto mangle = llvm::orc::MangleAndInterner (
    jd.getExecutionSession (), TheJIT->getDataLayout ());

    auto s = [] (llvm::orc::MangleAndInterner interner) {
        llvm::orc::SymbolMap symbolMap;
        symbolMap[interner ("putchard")] = {
            llvm::pointerToJITTargetAddress (&putchard),
            llvm::JITSymbolFlags (),
        };
        symbolMap[interner ("printd")] = {
            llvm::pointerToJITTargetAddress (&printd),
            llvm::JITSymbolFlags (),
        };
        return llvm::orc::absoluteSymbols (symbolMap);
    }(mangle);

    ExitOnErr (jd.define (s));
}

到这里,我们就可以调用 hostcall

ready> extern sin(x);
Read extern:
declare double @sin(double)

ready> extern cos(x);
Read extern:
declare double @cos(double)

ready> sin(1.0);
Read top-level expression:
define double @2() {
entry:
  ret double 0x3FEAED548F090CEE
}
comments powered by Disqus