River Riddle 3655069234 [mlir] Move the Builtin FuncOp to the Func dialect
This commit moves FuncOp out of the builtin dialect, and into the Func
dialect. This move has been planned in some capacity from the moment
we made FuncOp an operation (years ago). This commit handles the
functional aspects of the move, but various aspects are left untouched
to ease migration: func::FuncOp is re-exported into mlir to reduce
the actual API churn, the assembly format still accepts the unqualified
`func`. These temporary measures will remain for a little while to
simplify migration before being removed.

Differential Revision: https://reviews.llvm.org/D121266
2022-03-16 17:07:03 -07:00

279 lines
9.3 KiB
C++

//===- Invoke.cpp ------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/ExecutionEngine/CRunnerUtils.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/MemRefUtils.h"
#include "mlir/ExecutionEngine/RunnerUtils.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Export.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "gmock/gmock.h"
using namespace mlir;
// The JIT isn't supported on Windows at that time
#ifndef _WIN32
static struct LLVMInitializer {
LLVMInitializer() {
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
}
} initializer;
/// Simple conversion pipeline for the purpose of testing sources written in
/// dialects lowering to LLVM Dialect.
static LogicalResult lowerToLLVMDialect(ModuleOp module) {
PassManager pm(module.getContext());
pm.addPass(mlir::createMemRefToLLVMPass());
pm.addNestedPass<FuncOp>(mlir::arith::createConvertArithmeticToLLVMPass());
pm.addPass(mlir::createConvertFuncToLLVMPass());
pm.addPass(mlir::createReconcileUnrealizedCastsPass());
return pm.run(module);
}
TEST(MLIRExecutionEngine, AddInteger) {
std::string moduleStr = R"mlir(
func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {
%res = arith.addi %arg0, %arg0 : i32
return %res : i32
}
)mlir";
DialectRegistry registry;
registerAllDialects(registry);
registerLLVMDialectTranslation(registry);
MLIRContext context(registry);
OwningOpRef<ModuleOp> module =
parseSourceString<ModuleOp>(moduleStr, &context);
ASSERT_TRUE(!!module);
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
auto jitOrError = ExecutionEngine::create(*module);
ASSERT_TRUE(!!jitOrError);
std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
// The result of the function must be passed as output argument.
int result = 0;
llvm::Error error =
jit->invoke("foo", 42, ExecutionEngine::Result<int>(result));
ASSERT_TRUE(!error);
ASSERT_EQ(result, 42 + 42);
}
TEST(MLIRExecutionEngine, SubtractFloat) {
std::string moduleStr = R"mlir(
func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } {
%res = arith.subf %arg0, %arg1 : f32
return %res : f32
}
)mlir";
DialectRegistry registry;
registerAllDialects(registry);
registerLLVMDialectTranslation(registry);
MLIRContext context(registry);
OwningOpRef<ModuleOp> module =
parseSourceString<ModuleOp>(moduleStr, &context);
ASSERT_TRUE(!!module);
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
auto jitOrError = ExecutionEngine::create(*module);
ASSERT_TRUE(!!jitOrError);
std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
// The result of the function must be passed as output argument.
float result = -1;
llvm::Error error =
jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result));
ASSERT_TRUE(!error);
ASSERT_EQ(result, 42.f);
}
TEST(NativeMemRefJit, ZeroRankMemref) {
OwningMemRef<float, 0> a({});
a[{}] = 42.;
ASSERT_EQ(*a->data, 42);
a[{}] = 0;
std::string moduleStr = R"mlir(
func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } {
%cst42 = arith.constant 42.0 : f32
memref.store %cst42, %arg0[] : memref<f32>
return
}
)mlir";
DialectRegistry registry;
registerAllDialects(registry);
registerLLVMDialectTranslation(registry);
MLIRContext context(registry);
auto module = parseSourceString<ModuleOp>(moduleStr, &context);
ASSERT_TRUE(!!module);
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
auto jitOrError = ExecutionEngine::create(*module);
ASSERT_TRUE(!!jitOrError);
auto jit = std::move(jitOrError.get());
llvm::Error error = jit->invoke("zero_ranked", &*a);
ASSERT_TRUE(!error);
EXPECT_EQ((a[{}]), 42.);
for (float &elt : *a)
EXPECT_EQ(&elt, &(a[{}]));
}
TEST(NativeMemRefJit, RankOneMemref) {
int64_t shape[] = {9};
OwningMemRef<float, 1> a(shape);
int count = 1;
for (float &elt : *a) {
EXPECT_EQ(&elt, &(a[{count - 1}]));
elt = count++;
}
std::string moduleStr = R"mlir(
func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } {
%cst42 = arith.constant 42.0 : f32
%cst5 = arith.constant 5 : index
memref.store %cst42, %arg0[%cst5] : memref<?xf32>
return
}
)mlir";
DialectRegistry registry;
registerAllDialects(registry);
registerLLVMDialectTranslation(registry);
MLIRContext context(registry);
auto module = parseSourceString<ModuleOp>(moduleStr, &context);
ASSERT_TRUE(!!module);
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
auto jitOrError = ExecutionEngine::create(*module);
ASSERT_TRUE(!!jitOrError);
auto jit = std::move(jitOrError.get());
llvm::Error error = jit->invoke("one_ranked", &*a);
ASSERT_TRUE(!error);
count = 1;
for (float &elt : *a) {
if (count == 6)
EXPECT_EQ(elt, 42.);
else
EXPECT_EQ(elt, count);
count++;
}
}
TEST(NativeMemRefJit, BasicMemref) {
constexpr int k = 3;
constexpr int m = 7;
// Prepare arguments beforehand.
auto init = [=](float &elt, ArrayRef<int64_t> indices) {
assert(indices.size() == 2);
elt = m * indices[0] + indices[1];
};
int64_t shape[] = {k, m};
int64_t shapeAlloc[] = {k + 1, m + 1};
OwningMemRef<float, 2> a(shape, shapeAlloc, init);
ASSERT_EQ(a->sizes[0], k);
ASSERT_EQ(a->sizes[1], m);
ASSERT_EQ(a->strides[0], m + 1);
ASSERT_EQ(a->strides[1], 1);
for (int i = 0; i < k; ++i) {
for (int j = 0; j < m; ++j) {
EXPECT_EQ((a[{i, j}]), i * m + j);
EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j]));
}
}
std::string moduleStr = R"mlir(
func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
%x = arith.constant 2 : index
%y = arith.constant 1 : index
%cst42 = arith.constant 42.0 : f32
memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
return
}
)mlir";
DialectRegistry registry;
registerAllDialects(registry);
registerLLVMDialectTranslation(registry);
MLIRContext context(registry);
OwningOpRef<ModuleOp> module =
parseSourceString<ModuleOp>(moduleStr, &context);
ASSERT_TRUE(!!module);
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
auto jitOrError = ExecutionEngine::create(*module);
ASSERT_TRUE(!!jitOrError);
std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
llvm::Error error = jit->invoke("rank2_memref", &*a, &*a);
ASSERT_TRUE(!error);
EXPECT_EQ(((*a)[1][2]), 42.);
EXPECT_EQ((a[{2, 1}]), 42.);
}
// A helper function that will be called from the JIT
static void memrefMultiply(::StridedMemRefType<float, 2> *memref,
int32_t coefficient) {
for (float &elt : *memref)
elt *= coefficient;
}
TEST(NativeMemRefJit, JITCallback) {
constexpr int k = 2;
constexpr int m = 2;
int64_t shape[] = {k, m};
int64_t shapeAlloc[] = {k + 1, m + 1};
OwningMemRef<float, 2> a(shape, shapeAlloc);
int count = 1;
for (float &elt : *a)
elt = count++;
std::string moduleStr = R"mlir(
func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface }
func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } {
%unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32>
call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
return
}
)mlir";
DialectRegistry registry;
registerAllDialects(registry);
registerLLVMDialectTranslation(registry);
MLIRContext context(registry);
auto module = parseSourceString<ModuleOp>(moduleStr, &context);
ASSERT_TRUE(!!module);
ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
auto jitOrError = ExecutionEngine::create(*module);
ASSERT_TRUE(!!jitOrError);
auto jit = std::move(jitOrError.get());
// Define any extra symbols so they're available at runtime.
jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
llvm::orc::SymbolMap symbolMap;
symbolMap[interner("_mlir_ciface_callback")] =
llvm::JITEvaluatedSymbol::fromPointer(memrefMultiply);
return symbolMap;
});
int32_t coefficient = 3.;
llvm::Error error = jit->invoke("caller_for_callback", &*a, coefficient);
ASSERT_TRUE(!error);
count = 1;
for (float elt : *a)
ASSERT_EQ(elt, coefficient * count++);
}
#endif // _WIN32