233 lines
8.8 KiB
C++
233 lines
8.8 KiB
C++
//===- LowerTestPass.cpp - Test pass for lowering EDSC --------------------===//
|
|
//
|
|
// Copyright 2019 The MLIR Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
// =============================================================================
|
|
|
|
#include "mlir/EDSC/MLIREmitter.h"
|
|
#include "mlir/EDSC/Types.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/Module.h"
|
|
#include "mlir/IR/StandardTypes.h"
|
|
#include "mlir/IR/Types.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/StandardOps/StandardOps.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
// Testing pass to lower EDSC.
|
|
struct LowerEDSCTestPass : public FunctionPass {
|
|
LowerEDSCTestPass() : FunctionPass(&LowerEDSCTestPass::passID) {}
|
|
PassResult runOnFunction(Function *f) override;
|
|
|
|
constexpr static PassID passID = {};
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
#include "mlir/EDSC/reference-impl.inc"
|
|
|
|
PassResult LowerEDSCTestPass::runOnFunction(Function *f) {
|
|
// Inject a EDSC-constructed list of blocks.
|
|
if (f->getName().strref() == "blocks") {
|
|
using namespace edsc::op;
|
|
|
|
FuncBuilder builder(f);
|
|
edsc::ScopedEDSCContext context;
|
|
auto type = builder.getIntegerType(32);
|
|
edsc::Expr arg1(type), arg2(type), arg3(type), arg4(type);
|
|
|
|
auto b1 =
|
|
edsc::block({arg1, arg2}, {type, type}, {arg1 + arg2, edsc::Return()});
|
|
auto b2 =
|
|
edsc::block({arg3, arg4}, {type, type}, {arg3 - arg4, edsc::Return()});
|
|
|
|
edsc::MLIREmitter(&builder, f->getLoc()).emitBlock(b1).emitBlock(b2);
|
|
}
|
|
|
|
// Inject a EDSC-constructed `for` loop with bounds coming from function
|
|
// arguments.
|
|
if (f->getName().strref() == "dynamic_for_func_args") {
|
|
assert(!f->getBlocks().empty() && "dynamic_for should not be empty");
|
|
FuncBuilder builder(&f->getBlocks().front(),
|
|
f->getBlocks().front().begin());
|
|
assert(f->getNumArguments() == 2 && "dynamic_for expected 4 arguments");
|
|
for (const auto *arg : f->getArguments()) {
|
|
(void)arg;
|
|
assert(arg->getType().isIndex() &&
|
|
"dynamic_for expected index arguments");
|
|
}
|
|
|
|
using namespace edsc::op;
|
|
Type index = IndexType::get(f->getContext());
|
|
edsc::ScopedEDSCContext context;
|
|
edsc::Expr lb(index), ub(index), step(index);
|
|
step = edsc::constantInteger(index, 3);
|
|
auto loop = edsc::For(lb, ub, step, {lb * step + ub, step + lb});
|
|
edsc::MLIREmitter(&builder, f->getLoc())
|
|
.bind(edsc::Bindable(lb), f->getArgument(0))
|
|
.bind(edsc::Bindable(ub), f->getArgument(1))
|
|
.emitStmt(loop);
|
|
return success();
|
|
}
|
|
|
|
// Inject a EDSC-constructed `for` loop with non-constant bounds that are
|
|
// obtained from AffineApplyOp (also constructed using EDSC operator
|
|
// overloads).
|
|
if (f->getName().strref() == "dynamic_for") {
|
|
assert(!f->getBlocks().empty() && "dynamic_for should not be empty");
|
|
FuncBuilder builder(&f->getBlocks().front(),
|
|
f->getBlocks().front().begin());
|
|
assert(f->getNumArguments() == 4 && "dynamic_for expected 4 arguments");
|
|
for (const auto *arg : f->getArguments()) {
|
|
(void)arg;
|
|
assert(arg->getType().isIndex() &&
|
|
"dynamic_for expected index arguments");
|
|
}
|
|
|
|
Type index = IndexType::get(f->getContext());
|
|
edsc::ScopedEDSCContext context;
|
|
edsc::Expr lb1(index), lb2(index), ub1(index), ub2(index), step(index);
|
|
using namespace edsc::op;
|
|
auto lb = lb1 - lb2;
|
|
auto ub = ub1 + ub2;
|
|
auto loop = edsc::For(lb, ub, step, {});
|
|
edsc::MLIREmitter(&builder, f->getLoc())
|
|
.bind(edsc::Bindable(lb1), f->getArgument(0))
|
|
.bind(edsc::Bindable(lb2), f->getArgument(1))
|
|
.bind(edsc::Bindable(ub1), f->getArgument(2))
|
|
.bind(edsc::Bindable(ub2), f->getArgument(3))
|
|
.bindConstant<ConstantIndexOp>(edsc::Bindable(step), 2)
|
|
.emitStmt(loop);
|
|
|
|
return success();
|
|
}
|
|
if (f->getName().strref() == "max_min_for") {
|
|
assert(!f->getBlocks().empty() && "max_min_for should not be empty");
|
|
FuncBuilder builder(&f->getBlocks().front(),
|
|
f->getBlocks().front().begin());
|
|
assert(f->getNumArguments() == 4 && "max_min_for expected 4 arguments");
|
|
assert(std::all_of(f->args_begin(), f->args_end(),
|
|
[](const Value *s) { return s->getType().isIndex(); }) &&
|
|
"max_min_for expected index arguments");
|
|
|
|
edsc::ScopedEDSCContext context;
|
|
edsc::Expr lb1(f->getArgument(0)->getType());
|
|
edsc::Expr lb2(f->getArgument(1)->getType());
|
|
edsc::Expr ub1(f->getArgument(2)->getType());
|
|
edsc::Expr ub2(f->getArgument(3)->getType());
|
|
edsc::Expr iv(builder.getIndexType());
|
|
edsc::Expr step = edsc::constantInteger(builder.getIndexType(), 1);
|
|
auto loop =
|
|
edsc::MaxMinFor(edsc::Bindable(iv), {lb1, lb2}, {ub1, ub2}, step, {});
|
|
edsc::MLIREmitter(&builder, f->getLoc())
|
|
.bind(edsc::Bindable(lb1), f->getArgument(0))
|
|
.bind(edsc::Bindable(lb2), f->getArgument(1))
|
|
.bind(edsc::Bindable(ub1), f->getArgument(2))
|
|
.bind(edsc::Bindable(ub2), f->getArgument(3))
|
|
.emitStmt(loop);
|
|
|
|
return success();
|
|
}
|
|
if (f->getName().strref() == "call_indirect") {
|
|
assert(!f->getBlocks().empty() && "call_indirect should not be empty");
|
|
FuncBuilder builder(&f->getBlocks().front(),
|
|
f->getBlocks().front().begin());
|
|
Function *callee = f->getModule()->getNamedFunction("callee");
|
|
Function *calleeArgs = f->getModule()->getNamedFunction("callee_args");
|
|
Function *secondOrderCallee =
|
|
f->getModule()->getNamedFunction("second_order_callee");
|
|
assert(callee && calleeArgs && secondOrderCallee &&
|
|
"could not find required declarations");
|
|
|
|
auto funcRetIndexType = builder.getFunctionType({}, builder.getIndexType());
|
|
|
|
edsc::ScopedEDSCContext context;
|
|
edsc::Expr func(callee->getType()), funcArgs(calleeArgs->getType()),
|
|
secondOrderFunc(secondOrderCallee->getType());
|
|
auto stmt = edsc::call(func, {});
|
|
auto chainedCallResult =
|
|
edsc::call(edsc::call(secondOrderFunc, funcRetIndexType, {func}),
|
|
builder.getIndexType(), {});
|
|
auto argsCall =
|
|
edsc::call(funcArgs, {chainedCallResult, chainedCallResult});
|
|
edsc::MLIREmitter(&builder, f->getLoc())
|
|
.bindConstant<ConstantOp>(edsc::Bindable(func),
|
|
builder.getFunctionAttr(callee))
|
|
.bindConstant<ConstantOp>(edsc::Bindable(funcArgs),
|
|
builder.getFunctionAttr(calleeArgs))
|
|
.bindConstant<ConstantOp>(edsc::Bindable(secondOrderFunc),
|
|
builder.getFunctionAttr(secondOrderCallee))
|
|
.emitStmt(stmt)
|
|
.emitStmt(chainedCallResult)
|
|
.emitStmt(argsCall);
|
|
|
|
return success();
|
|
}
|
|
|
|
// Inject an EDSC-constructed computation that assigns Stmt and uses the LHS.
|
|
if (f->getName().strref().contains("assignments")) {
|
|
FuncBuilder builder(f);
|
|
edsc::ScopedEDSCContext context;
|
|
edsc::MLIREmitter emitter(&builder, f->getLoc());
|
|
|
|
edsc::Expr zero = emitter.zero();
|
|
edsc::Expr one = emitter.one();
|
|
auto args = emitter.makeBoundFunctionArguments(f);
|
|
auto views = emitter.makeBoundMemRefViews(args.begin(), args.end());
|
|
|
|
Type indexType = builder.getIndexType();
|
|
edsc::Expr i(indexType);
|
|
edsc::Expr A = args[0], B = args[1], C = args[2];
|
|
edsc::Expr M = views[0].dim(0);
|
|
// clang-format off
|
|
using namespace edsc::op;
|
|
edsc::Stmt scalarA, scalarB, tmp;
|
|
auto block = edsc::block({
|
|
For(i, zero, M, one, {
|
|
scalarA = load(A, {i}),
|
|
scalarB = load(B, {i}),
|
|
tmp = scalarA * scalarB,
|
|
store(tmp, C, {i})
|
|
}),
|
|
});
|
|
// clang-format on
|
|
|
|
emitter.emitStmts(block.getBody());
|
|
}
|
|
|
|
f->walk([](Instruction *op) {
|
|
if (op->getName().getStringRef() == "print") {
|
|
auto opName = op->getAttrOfType<StringAttr>("op");
|
|
if (!opName) {
|
|
op->emitOpError("no 'op' attribute provided for print");
|
|
return;
|
|
}
|
|
auto function = op->getAttrOfType<FunctionAttr>("fn");
|
|
if (!function) {
|
|
op->emitOpError("no 'fn' attribute provided for print");
|
|
return;
|
|
}
|
|
printRefImplementation(opName.getValue(), function.getValue());
|
|
}
|
|
});
|
|
return success();
|
|
}
|
|
|
|
static PassRegistration<LowerEDSCTestPass> pass("lower-edsc-test",
|
|
"Lower EDSC test pass");
|