This commit moves FuncOp out of the builtin dialect, and into the Func dialect. This move has been planned in some capacity from the moment we made FuncOp an operation (years ago). This commit handles the functional aspects of the move, but various aspects are left untouched to ease migration: func::FuncOp is re-exported into mlir to reduce the actual API churn, the assembly format still accepts the unqualified `func`. These temporary measures will remain for a little while to simplify migration before being removed. Differential Revision: https://reviews.llvm.org/D121266
279 lines
9.3 KiB
C++
279 lines
9.3 KiB
C++
//===- Invoke.cpp ------------------------------------*- C++ -*-===//
|
|
//
|
|
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
|
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
|
|
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
|
|
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
|
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
|
|
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
|
|
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
#include "mlir/ExecutionEngine/CRunnerUtils.h"
|
|
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
|
#include "mlir/ExecutionEngine/MemRefUtils.h"
|
|
#include "mlir/ExecutionEngine/RunnerUtils.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/InitAllDialects.h"
|
|
#include "mlir/Parser/Parser.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
|
#include "mlir/Target/LLVMIR/Export.h"
|
|
#include "llvm/Support/TargetSelect.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
#include "gmock/gmock.h"
|
|
|
|
using namespace mlir;
|
|
|
|
// The JIT isn't supported on Windows at that time
|
|
#ifndef _WIN32
|
|
|
|
static struct LLVMInitializer {
|
|
LLVMInitializer() {
|
|
llvm::InitializeNativeTarget();
|
|
llvm::InitializeNativeTargetAsmPrinter();
|
|
}
|
|
} initializer;
|
|
|
|
/// Simple conversion pipeline for the purpose of testing sources written in
|
|
/// dialects lowering to LLVM Dialect.
|
|
static LogicalResult lowerToLLVMDialect(ModuleOp module) {
|
|
PassManager pm(module.getContext());
|
|
pm.addPass(mlir::createMemRefToLLVMPass());
|
|
pm.addNestedPass<FuncOp>(mlir::arith::createConvertArithmeticToLLVMPass());
|
|
pm.addPass(mlir::createConvertFuncToLLVMPass());
|
|
pm.addPass(mlir::createReconcileUnrealizedCastsPass());
|
|
return pm.run(module);
|
|
}
|
|
|
|
TEST(MLIRExecutionEngine, AddInteger) {
|
|
std::string moduleStr = R"mlir(
|
|
func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {
|
|
%res = arith.addi %arg0, %arg0 : i32
|
|
return %res : i32
|
|
}
|
|
)mlir";
|
|
DialectRegistry registry;
|
|
registerAllDialects(registry);
|
|
registerLLVMDialectTranslation(registry);
|
|
MLIRContext context(registry);
|
|
OwningOpRef<ModuleOp> module =
|
|
parseSourceString<ModuleOp>(moduleStr, &context);
|
|
ASSERT_TRUE(!!module);
|
|
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
|
|
auto jitOrError = ExecutionEngine::create(*module);
|
|
ASSERT_TRUE(!!jitOrError);
|
|
std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
|
|
// The result of the function must be passed as output argument.
|
|
int result = 0;
|
|
llvm::Error error =
|
|
jit->invoke("foo", 42, ExecutionEngine::Result<int>(result));
|
|
ASSERT_TRUE(!error);
|
|
ASSERT_EQ(result, 42 + 42);
|
|
}
|
|
|
|
TEST(MLIRExecutionEngine, SubtractFloat) {
|
|
std::string moduleStr = R"mlir(
|
|
func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } {
|
|
%res = arith.subf %arg0, %arg1 : f32
|
|
return %res : f32
|
|
}
|
|
)mlir";
|
|
DialectRegistry registry;
|
|
registerAllDialects(registry);
|
|
registerLLVMDialectTranslation(registry);
|
|
MLIRContext context(registry);
|
|
OwningOpRef<ModuleOp> module =
|
|
parseSourceString<ModuleOp>(moduleStr, &context);
|
|
ASSERT_TRUE(!!module);
|
|
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
|
|
auto jitOrError = ExecutionEngine::create(*module);
|
|
ASSERT_TRUE(!!jitOrError);
|
|
std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
|
|
// The result of the function must be passed as output argument.
|
|
float result = -1;
|
|
llvm::Error error =
|
|
jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result));
|
|
ASSERT_TRUE(!error);
|
|
ASSERT_EQ(result, 42.f);
|
|
}
|
|
|
|
TEST(NativeMemRefJit, ZeroRankMemref) {
|
|
OwningMemRef<float, 0> a({});
|
|
a[{}] = 42.;
|
|
ASSERT_EQ(*a->data, 42);
|
|
a[{}] = 0;
|
|
std::string moduleStr = R"mlir(
|
|
func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } {
|
|
%cst42 = arith.constant 42.0 : f32
|
|
memref.store %cst42, %arg0[] : memref<f32>
|
|
return
|
|
}
|
|
)mlir";
|
|
DialectRegistry registry;
|
|
registerAllDialects(registry);
|
|
registerLLVMDialectTranslation(registry);
|
|
MLIRContext context(registry);
|
|
auto module = parseSourceString<ModuleOp>(moduleStr, &context);
|
|
ASSERT_TRUE(!!module);
|
|
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
|
|
auto jitOrError = ExecutionEngine::create(*module);
|
|
ASSERT_TRUE(!!jitOrError);
|
|
auto jit = std::move(jitOrError.get());
|
|
|
|
llvm::Error error = jit->invoke("zero_ranked", &*a);
|
|
ASSERT_TRUE(!error);
|
|
EXPECT_EQ((a[{}]), 42.);
|
|
for (float &elt : *a)
|
|
EXPECT_EQ(&elt, &(a[{}]));
|
|
}
|
|
|
|
TEST(NativeMemRefJit, RankOneMemref) {
|
|
int64_t shape[] = {9};
|
|
OwningMemRef<float, 1> a(shape);
|
|
int count = 1;
|
|
for (float &elt : *a) {
|
|
EXPECT_EQ(&elt, &(a[{count - 1}]));
|
|
elt = count++;
|
|
}
|
|
|
|
std::string moduleStr = R"mlir(
|
|
func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } {
|
|
%cst42 = arith.constant 42.0 : f32
|
|
%cst5 = arith.constant 5 : index
|
|
memref.store %cst42, %arg0[%cst5] : memref<?xf32>
|
|
return
|
|
}
|
|
)mlir";
|
|
DialectRegistry registry;
|
|
registerAllDialects(registry);
|
|
registerLLVMDialectTranslation(registry);
|
|
MLIRContext context(registry);
|
|
auto module = parseSourceString<ModuleOp>(moduleStr, &context);
|
|
ASSERT_TRUE(!!module);
|
|
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
|
|
auto jitOrError = ExecutionEngine::create(*module);
|
|
ASSERT_TRUE(!!jitOrError);
|
|
auto jit = std::move(jitOrError.get());
|
|
|
|
llvm::Error error = jit->invoke("one_ranked", &*a);
|
|
ASSERT_TRUE(!error);
|
|
count = 1;
|
|
for (float &elt : *a) {
|
|
if (count == 6)
|
|
EXPECT_EQ(elt, 42.);
|
|
else
|
|
EXPECT_EQ(elt, count);
|
|
count++;
|
|
}
|
|
}
|
|
|
|
TEST(NativeMemRefJit, BasicMemref) {
|
|
constexpr int k = 3;
|
|
constexpr int m = 7;
|
|
// Prepare arguments beforehand.
|
|
auto init = [=](float &elt, ArrayRef<int64_t> indices) {
|
|
assert(indices.size() == 2);
|
|
elt = m * indices[0] + indices[1];
|
|
};
|
|
int64_t shape[] = {k, m};
|
|
int64_t shapeAlloc[] = {k + 1, m + 1};
|
|
OwningMemRef<float, 2> a(shape, shapeAlloc, init);
|
|
ASSERT_EQ(a->sizes[0], k);
|
|
ASSERT_EQ(a->sizes[1], m);
|
|
ASSERT_EQ(a->strides[0], m + 1);
|
|
ASSERT_EQ(a->strides[1], 1);
|
|
for (int i = 0; i < k; ++i) {
|
|
for (int j = 0; j < m; ++j) {
|
|
EXPECT_EQ((a[{i, j}]), i * m + j);
|
|
EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j]));
|
|
}
|
|
}
|
|
std::string moduleStr = R"mlir(
|
|
func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
|
|
%x = arith.constant 2 : index
|
|
%y = arith.constant 1 : index
|
|
%cst42 = arith.constant 42.0 : f32
|
|
memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
|
|
memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
|
|
return
|
|
}
|
|
)mlir";
|
|
DialectRegistry registry;
|
|
registerAllDialects(registry);
|
|
registerLLVMDialectTranslation(registry);
|
|
MLIRContext context(registry);
|
|
OwningOpRef<ModuleOp> module =
|
|
parseSourceString<ModuleOp>(moduleStr, &context);
|
|
ASSERT_TRUE(!!module);
|
|
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
|
|
auto jitOrError = ExecutionEngine::create(*module);
|
|
ASSERT_TRUE(!!jitOrError);
|
|
std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
|
|
|
|
llvm::Error error = jit->invoke("rank2_memref", &*a, &*a);
|
|
ASSERT_TRUE(!error);
|
|
EXPECT_EQ(((*a)[1][2]), 42.);
|
|
EXPECT_EQ((a[{2, 1}]), 42.);
|
|
}
|
|
|
|
// A helper function that will be called from the JIT
|
|
static void memrefMultiply(::StridedMemRefType<float, 2> *memref,
|
|
int32_t coefficient) {
|
|
for (float &elt : *memref)
|
|
elt *= coefficient;
|
|
}
|
|
|
|
TEST(NativeMemRefJit, JITCallback) {
|
|
constexpr int k = 2;
|
|
constexpr int m = 2;
|
|
int64_t shape[] = {k, m};
|
|
int64_t shapeAlloc[] = {k + 1, m + 1};
|
|
OwningMemRef<float, 2> a(shape, shapeAlloc);
|
|
int count = 1;
|
|
for (float &elt : *a)
|
|
elt = count++;
|
|
|
|
std::string moduleStr = R"mlir(
|
|
func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface }
|
|
func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } {
|
|
%unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32>
|
|
call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
|
|
return
|
|
}
|
|
)mlir";
|
|
DialectRegistry registry;
|
|
registerAllDialects(registry);
|
|
registerLLVMDialectTranslation(registry);
|
|
MLIRContext context(registry);
|
|
auto module = parseSourceString<ModuleOp>(moduleStr, &context);
|
|
ASSERT_TRUE(!!module);
|
|
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
|
|
auto jitOrError = ExecutionEngine::create(*module);
|
|
ASSERT_TRUE(!!jitOrError);
|
|
auto jit = std::move(jitOrError.get());
|
|
// Define any extra symbols so they're available at runtime.
|
|
jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
|
|
llvm::orc::SymbolMap symbolMap;
|
|
symbolMap[interner("_mlir_ciface_callback")] =
|
|
llvm::JITEvaluatedSymbol::fromPointer(memrefMultiply);
|
|
return symbolMap;
|
|
});
|
|
|
|
int32_t coefficient = 3.;
|
|
llvm::Error error = jit->invoke("caller_for_callback", &*a, coefficient);
|
|
ASSERT_TRUE(!error);
|
|
count = 1;
|
|
for (float elt : *a)
|
|
ASSERT_EQ(elt, coefficient * count++);
|
|
}
|
|
|
|
#endif // _WIN32
|