River Riddle faf42264e5 [PDLL] Add support for user defined constraint and rewrite functions
These functions allow for defining pattern fragments usable within the `match` and `rewrite` sections of a pattern. The main structure of Constraints and Rewrites functions are the same, and are similar to functions in other languages; they contain a signature (i.e. name, argument list, result list) and a body:

```pdll
// Constraint that takes a value as an input, and produces a value:
Constraint Cst(arg: Value) -> Value { ... }

// Constraint that returns multiple values:
Constraint Cst() -> (result1: Value, result2: ValueRange);
```

When returning multiple results, each result can be optionally be named (the result of a Constraint/Rewrite in the case of multiple results is a tuple).

These body of a Constraint/Rewrite functions can be specified in several ways:

* Externally
In this case we are importing an external function (registered by the user outside of PDLL):

```pdll
Constraint Foo(op: Op);
Rewrite Bar();
```

* In PDLL (using PDLL constructs)
In this case, the body is defined using PDLL constructs:

```pdll
Rewrite BuildFooOp() {
  // The result type of the Rewrite is inferred from the return.
  return op<my_dialect.foo>;
}
// Constraints/Rewrites can also implement a lambda/expression
// body for simple one line bodies.
Rewrite BuildFooOp() => op<my_dialect.foo>;
```

* In PDLL (using a native/C++ code block)
In this case the body is specified using a C++(or potentially other language at some point) code block. When building PDLL in AOT mode this will generate a native constraint/rewrite and register it with the PDL bytecode.

```pdll
Rewrite BuildFooOp() -> Op<my_dialect.foo> [{
  return rewriter.create<my_dialect::FooOp>(...);
}];
```

Differential Revision: https://reviews.llvm.org/D115836
2022-02-10 12:48:59 -08:00

410 lines
17 KiB
C++

//===- Nodes.cpp ----------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Tools/PDLL/AST/Nodes.h"
#include "mlir/Tools/PDLL/AST/Context.h"
#include "llvm/ADT/SmallPtrSet.h"
using namespace mlir;
using namespace mlir::pdll::ast;
/// Copy a string reference into the context with a null terminator.
static StringRef copyStringWithNull(Context &ctx, StringRef str) {
if (str.empty())
return str;
char *data = ctx.getAllocator().Allocate<char>(str.size() + 1);
std::copy(str.begin(), str.end(), data);
data[str.size()] = 0;
return StringRef(data, str.size());
}
//===----------------------------------------------------------------------===//
// Name
//===----------------------------------------------------------------------===//
const Name &Name::create(Context &ctx, StringRef name, SMRange location) {
return *new (ctx.getAllocator().Allocate<Name>())
Name(copyStringWithNull(ctx, name), location);
}
//===----------------------------------------------------------------------===//
// DeclScope
//===----------------------------------------------------------------------===//
void DeclScope::add(Decl *decl) {
const Name *name = decl->getName();
assert(name && "expected a named decl");
assert(!decls.count(name->getName()) && "decl with this name already exists");
decls.try_emplace(name->getName(), decl);
}
Decl *DeclScope::lookup(StringRef name) {
if (Decl *decl = decls.lookup(name))
return decl;
return parent ? parent->lookup(name) : nullptr;
}
//===----------------------------------------------------------------------===//
// CompoundStmt
//===----------------------------------------------------------------------===//
CompoundStmt *CompoundStmt::create(Context &ctx, SMRange loc,
ArrayRef<Stmt *> children) {
unsigned allocSize = CompoundStmt::totalSizeToAlloc<Stmt *>(children.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CompoundStmt));
CompoundStmt *stmt = new (rawData) CompoundStmt(loc, children.size());
std::uninitialized_copy(children.begin(), children.end(),
stmt->getChildren().begin());
return stmt;
}
//===----------------------------------------------------------------------===//
// LetStmt
//===----------------------------------------------------------------------===//
LetStmt *LetStmt::create(Context &ctx, SMRange loc,
VariableDecl *varDecl) {
return new (ctx.getAllocator().Allocate<LetStmt>()) LetStmt(loc, varDecl);
}
//===----------------------------------------------------------------------===//
// OpRewriteStmt
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// EraseStmt
EraseStmt *EraseStmt::create(Context &ctx, SMRange loc, Expr *rootOp) {
return new (ctx.getAllocator().Allocate<EraseStmt>()) EraseStmt(loc, rootOp);
}
//===----------------------------------------------------------------------===//
// ReplaceStmt
ReplaceStmt *ReplaceStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
ArrayRef<Expr *> replExprs) {
unsigned allocSize = ReplaceStmt::totalSizeToAlloc<Expr *>(replExprs.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(ReplaceStmt));
ReplaceStmt *stmt = new (rawData) ReplaceStmt(loc, rootOp, replExprs.size());
std::uninitialized_copy(replExprs.begin(), replExprs.end(),
stmt->getReplExprs().begin());
return stmt;
}
//===----------------------------------------------------------------------===//
// RewriteStmt
RewriteStmt *RewriteStmt::create(Context &ctx, SMRange loc, Expr *rootOp,
CompoundStmt *rewriteBody) {
return new (ctx.getAllocator().Allocate<RewriteStmt>())
RewriteStmt(loc, rootOp, rewriteBody);
}
//===----------------------------------------------------------------------===//
// ReturnStmt
//===----------------------------------------------------------------------===//
ReturnStmt *ReturnStmt::create(Context &ctx, SMRange loc, Expr *resultExpr) {
return new (ctx.getAllocator().Allocate<ReturnStmt>())
ReturnStmt(loc, resultExpr);
}
//===----------------------------------------------------------------------===//
// AttributeExpr
//===----------------------------------------------------------------------===//
AttributeExpr *AttributeExpr::create(Context &ctx, SMRange loc,
StringRef value) {
return new (ctx.getAllocator().Allocate<AttributeExpr>())
AttributeExpr(ctx, loc, copyStringWithNull(ctx, value));
}
//===----------------------------------------------------------------------===//
// CallExpr
//===----------------------------------------------------------------------===//
CallExpr *CallExpr::create(Context &ctx, SMRange loc, Expr *callable,
ArrayRef<Expr *> arguments, Type resultType) {
unsigned allocSize = CallExpr::totalSizeToAlloc<Expr *>(arguments.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(CallExpr));
CallExpr *expr =
new (rawData) CallExpr(loc, resultType, callable, arguments.size());
std::uninitialized_copy(arguments.begin(), arguments.end(),
expr->getArguments().begin());
return expr;
}
//===----------------------------------------------------------------------===//
// DeclRefExpr
//===----------------------------------------------------------------------===//
DeclRefExpr *DeclRefExpr::create(Context &ctx, SMRange loc, Decl *decl,
Type type) {
return new (ctx.getAllocator().Allocate<DeclRefExpr>())
DeclRefExpr(loc, decl, type);
}
//===----------------------------------------------------------------------===//
// MemberAccessExpr
//===----------------------------------------------------------------------===//
MemberAccessExpr *MemberAccessExpr::create(Context &ctx, SMRange loc,
const Expr *parentExpr,
StringRef memberName, Type type) {
return new (ctx.getAllocator().Allocate<MemberAccessExpr>()) MemberAccessExpr(
loc, parentExpr, memberName.copy(ctx.getAllocator()), type);
}
//===----------------------------------------------------------------------===//
// OperationExpr
//===----------------------------------------------------------------------===//
OperationExpr *OperationExpr::create(
Context &ctx, SMRange loc, const OpNameDecl *name,
ArrayRef<Expr *> operands, ArrayRef<Expr *> resultTypes,
ArrayRef<NamedAttributeDecl *> attributes) {
unsigned allocSize =
OperationExpr::totalSizeToAlloc<Expr *, NamedAttributeDecl *>(
operands.size() + resultTypes.size(), attributes.size());
void *rawData =
ctx.getAllocator().Allocate(allocSize, alignof(OperationExpr));
Type resultType = OperationType::get(ctx, name->getName());
OperationExpr *opExpr = new (rawData)
OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(),
attributes.size(), name->getLoc());
std::uninitialized_copy(operands.begin(), operands.end(),
opExpr->getOperands().begin());
std::uninitialized_copy(resultTypes.begin(), resultTypes.end(),
opExpr->getResultTypes().begin());
std::uninitialized_copy(attributes.begin(), attributes.end(),
opExpr->getAttributes().begin());
return opExpr;
}
Optional<StringRef> OperationExpr::getName() const {
return getNameDecl()->getName();
}
//===----------------------------------------------------------------------===//
// TupleExpr
//===----------------------------------------------------------------------===//
TupleExpr *TupleExpr::create(Context &ctx, SMRange loc,
ArrayRef<Expr *> elements,
ArrayRef<StringRef> names) {
unsigned allocSize = TupleExpr::totalSizeToAlloc<Expr *>(elements.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(TupleExpr));
auto elementTypes = llvm::map_range(
elements, [](const Expr *expr) { return expr->getType(); });
TupleType type = TupleType::get(ctx, llvm::to_vector(elementTypes), names);
TupleExpr *expr = new (rawData) TupleExpr(loc, type);
std::uninitialized_copy(elements.begin(), elements.end(),
expr->getElements().begin());
return expr;
}
//===----------------------------------------------------------------------===//
// TypeExpr
//===----------------------------------------------------------------------===//
TypeExpr *TypeExpr::create(Context &ctx, SMRange loc, StringRef value) {
return new (ctx.getAllocator().Allocate<TypeExpr>())
TypeExpr(ctx, loc, copyStringWithNull(ctx, value));
}
//===----------------------------------------------------------------------===//
// AttrConstraintDecl
//===----------------------------------------------------------------------===//
AttrConstraintDecl *AttrConstraintDecl::create(Context &ctx, SMRange loc,
Expr *typeExpr) {
return new (ctx.getAllocator().Allocate<AttrConstraintDecl>())
AttrConstraintDecl(loc, typeExpr);
}
//===----------------------------------------------------------------------===//
// OpConstraintDecl
//===----------------------------------------------------------------------===//
OpConstraintDecl *OpConstraintDecl::create(Context &ctx, SMRange loc,
const OpNameDecl *nameDecl) {
if (!nameDecl)
nameDecl = OpNameDecl::create(ctx, SMRange());
return new (ctx.getAllocator().Allocate<OpConstraintDecl>())
OpConstraintDecl(loc, nameDecl);
}
Optional<StringRef> OpConstraintDecl::getName() const {
return getNameDecl()->getName();
}
//===----------------------------------------------------------------------===//
// TypeConstraintDecl
//===----------------------------------------------------------------------===//
TypeConstraintDecl *TypeConstraintDecl::create(Context &ctx,
SMRange loc) {
return new (ctx.getAllocator().Allocate<TypeConstraintDecl>())
TypeConstraintDecl(loc);
}
//===----------------------------------------------------------------------===//
// TypeRangeConstraintDecl
//===----------------------------------------------------------------------===//
TypeRangeConstraintDecl *TypeRangeConstraintDecl::create(Context &ctx,
SMRange loc) {
return new (ctx.getAllocator().Allocate<TypeRangeConstraintDecl>())
TypeRangeConstraintDecl(loc);
}
//===----------------------------------------------------------------------===//
// ValueConstraintDecl
//===----------------------------------------------------------------------===//
ValueConstraintDecl *
ValueConstraintDecl::create(Context &ctx, SMRange loc, Expr *typeExpr) {
return new (ctx.getAllocator().Allocate<ValueConstraintDecl>())
ValueConstraintDecl(loc, typeExpr);
}
//===----------------------------------------------------------------------===//
// ValueRangeConstraintDecl
//===----------------------------------------------------------------------===//
ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx,
SMRange loc,
Expr *typeExpr) {
return new (ctx.getAllocator().Allocate<ValueRangeConstraintDecl>())
ValueRangeConstraintDecl(loc, typeExpr);
}
//===----------------------------------------------------------------------===//
// UserConstraintDecl
//===----------------------------------------------------------------------===//
UserConstraintDecl *UserConstraintDecl::createImpl(
Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
const CompoundStmt *body, Type resultType) {
unsigned allocSize = UserConstraintDecl::totalSizeToAlloc<VariableDecl *>(
inputs.size() + results.size());
void *rawData =
ctx.getAllocator().Allocate(allocSize, alignof(UserConstraintDecl));
if (codeBlock)
codeBlock = codeBlock->copy(ctx.getAllocator());
UserConstraintDecl *decl = new (rawData) UserConstraintDecl(
name, inputs.size(), results.size(), codeBlock, body, resultType);
std::uninitialized_copy(inputs.begin(), inputs.end(),
decl->getInputs().begin());
std::uninitialized_copy(results.begin(), results.end(),
decl->getResults().begin());
return decl;
}
//===----------------------------------------------------------------------===//
// NamedAttributeDecl
//===----------------------------------------------------------------------===//
NamedAttributeDecl *NamedAttributeDecl::create(Context &ctx, const Name &name,
Expr *value) {
return new (ctx.getAllocator().Allocate<NamedAttributeDecl>())
NamedAttributeDecl(name, value);
}
//===----------------------------------------------------------------------===//
// OpNameDecl
//===----------------------------------------------------------------------===//
OpNameDecl *OpNameDecl::create(Context &ctx, const Name &name) {
return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(name);
}
OpNameDecl *OpNameDecl::create(Context &ctx, SMRange loc) {
return new (ctx.getAllocator().Allocate<OpNameDecl>()) OpNameDecl(loc);
}
//===----------------------------------------------------------------------===//
// PatternDecl
//===----------------------------------------------------------------------===//
PatternDecl *PatternDecl::create(Context &ctx, SMRange loc,
const Name *name, Optional<uint16_t> benefit,
bool hasBoundedRecursion,
const CompoundStmt *body) {
return new (ctx.getAllocator().Allocate<PatternDecl>())
PatternDecl(loc, name, benefit, hasBoundedRecursion, body);
}
//===----------------------------------------------------------------------===//
// UserRewriteDecl
//===----------------------------------------------------------------------===//
UserRewriteDecl *UserRewriteDecl::createImpl(Context &ctx, const Name &name,
ArrayRef<VariableDecl *> inputs,
ArrayRef<VariableDecl *> results,
Optional<StringRef> codeBlock,
const CompoundStmt *body,
Type resultType) {
unsigned allocSize = UserRewriteDecl::totalSizeToAlloc<VariableDecl *>(
inputs.size() + results.size());
void *rawData =
ctx.getAllocator().Allocate(allocSize, alignof(UserRewriteDecl));
if (codeBlock)
codeBlock = codeBlock->copy(ctx.getAllocator());
UserRewriteDecl *decl = new (rawData) UserRewriteDecl(
name, inputs.size(), results.size(), codeBlock, body, resultType);
std::uninitialized_copy(inputs.begin(), inputs.end(),
decl->getInputs().begin());
std::uninitialized_copy(results.begin(), results.end(),
decl->getResults().begin());
return decl;
}
//===----------------------------------------------------------------------===//
// VariableDecl
//===----------------------------------------------------------------------===//
VariableDecl *VariableDecl::create(Context &ctx, const Name &name, Type type,
Expr *initExpr,
ArrayRef<ConstraintRef> constraints) {
unsigned allocSize =
VariableDecl::totalSizeToAlloc<ConstraintRef>(constraints.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(VariableDecl));
VariableDecl *varDecl =
new (rawData) VariableDecl(name, type, initExpr, constraints.size());
std::uninitialized_copy(constraints.begin(), constraints.end(),
varDecl->getConstraints().begin());
return varDecl;
}
//===----------------------------------------------------------------------===//
// Module
//===----------------------------------------------------------------------===//
Module *Module::create(Context &ctx, SMLoc loc,
ArrayRef<Decl *> children) {
unsigned allocSize = Module::totalSizeToAlloc<Decl *>(children.size());
void *rawData = ctx.getAllocator().Allocate(allocSize, alignof(Module));
Module *module = new (rawData) Module(loc, children.size());
std::uninitialized_copy(children.begin(), children.end(),
module->getChildren().begin());
return module;
}