Reapply "[mlir][PDL] Add support for native constraints with results (#82760)"
with a small stack-use-after-scope fix in getConstraintPredicates() This reverts commit c80e6edba4a9593f0587e27fa0ac825ebe174afd.
This commit is contained in:
parent
da591d390e
commit
8ec28af8ea
@ -35,20 +35,25 @@ def PDL_ApplyNativeConstraintOp
|
|||||||
let description = [{
|
let description = [{
|
||||||
`pdl.apply_native_constraint` operations apply a native C++ constraint, that
|
`pdl.apply_native_constraint` operations apply a native C++ constraint, that
|
||||||
has been registered externally with the consumer of PDL, to a given set of
|
has been registered externally with the consumer of PDL, to a given set of
|
||||||
entities.
|
entities and optionally return a number of values.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```mlir
|
```mlir
|
||||||
// Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`.
|
// Apply `myConstraint` to the entities defined by `input`, `attr`, and `op`.
|
||||||
pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
|
pdl.apply_native_constraint "myConstraint"(%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation)
|
||||||
|
// Apply constraint `with_result` to `root`. This constraint returns an attribute.
|
||||||
|
%attr = pdl.apply_native_constraint "with_result"(%root : !pdl.operation) : !pdl.attribute
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins StrAttr:$name,
|
let arguments = (ins StrAttr:$name,
|
||||||
Variadic<PDL_AnyType>:$args,
|
Variadic<PDL_AnyType>:$args,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
|
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
|
||||||
let assemblyFormat = "$name `(` $args `:` type($args) `)` attr-dict";
|
let results = (outs Variadic<PDL_AnyType>:$results);
|
||||||
|
let assemblyFormat = [{
|
||||||
|
$name `(` $args `:` type($args) `)` (`:` type($results)^ )? attr-dict
|
||||||
|
}];
|
||||||
let hasVerifier = 1;
|
let hasVerifier = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -88,7 +88,9 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
|
|||||||
let description = [{
|
let description = [{
|
||||||
`pdl_interp.apply_constraint` operations apply a generic constraint, that
|
`pdl_interp.apply_constraint` operations apply a generic constraint, that
|
||||||
has been registered with the interpreter, with a given set of positional
|
has been registered with the interpreter, with a given set of positional
|
||||||
values. On success, this operation branches to the true destination,
|
values.
|
||||||
|
The constraint function may return any number of results.
|
||||||
|
On success, this operation branches to the true destination,
|
||||||
otherwise the false destination is taken. This behavior can be reversed
|
otherwise the false destination is taken. This behavior can be reversed
|
||||||
by setting the attribute `isNegated` to true.
|
by setting the attribute `isNegated` to true.
|
||||||
|
|
||||||
@ -104,8 +106,10 @@ def PDLInterp_ApplyConstraintOp : PDLInterp_PredicateOp<"apply_constraint"> {
|
|||||||
let arguments = (ins StrAttr:$name,
|
let arguments = (ins StrAttr:$name,
|
||||||
Variadic<PDL_AnyType>:$args,
|
Variadic<PDL_AnyType>:$args,
|
||||||
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
|
DefaultValuedAttr<BoolAttr, "false">:$isNegated);
|
||||||
|
let results = (outs Variadic<PDL_AnyType>:$results);
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$name `(` $args `:` type($args) `)` attr-dict `->` successors
|
$name `(` $args `:` type($args) `)` (`:` type($results)^)? attr-dict
|
||||||
|
`->` successors
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -318,8 +318,9 @@ protected:
|
|||||||
/// A generic PDL pattern constraint function. This function applies a
|
/// A generic PDL pattern constraint function. This function applies a
|
||||||
/// constraint to a given set of opaque PDLValue entities. Returns success if
|
/// constraint to a given set of opaque PDLValue entities. Returns success if
|
||||||
/// the constraint successfully held, failure otherwise.
|
/// the constraint successfully held, failure otherwise.
|
||||||
using PDLConstraintFunction =
|
using PDLConstraintFunction = std::function<LogicalResult(
|
||||||
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
|
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
|
||||||
|
|
||||||
/// A native PDL rewrite function. This function performs a rewrite on the
|
/// A native PDL rewrite function. This function performs a rewrite on the
|
||||||
/// given set of values. Any results from this rewrite that should be passed
|
/// given set of values. Any results from this rewrite that should be passed
|
||||||
/// back to PDL should be added to the provided result list. This method is only
|
/// back to PDL should be added to the provided result list. This method is only
|
||||||
@ -726,7 +727,7 @@ std::enable_if_t<
|
|||||||
PDLConstraintFunction>
|
PDLConstraintFunction>
|
||||||
buildConstraintFn(ConstraintFnT &&constraintFn) {
|
buildConstraintFn(ConstraintFnT &&constraintFn) {
|
||||||
return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
|
return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
|
||||||
PatternRewriter &rewriter,
|
PatternRewriter &rewriter, PDLResultList &,
|
||||||
ArrayRef<PDLValue> values) -> LogicalResult {
|
ArrayRef<PDLValue> values) -> LogicalResult {
|
||||||
auto argIndices = std::make_index_sequence<
|
auto argIndices = std::make_index_sequence<
|
||||||
llvm::function_traits<ConstraintFnT>::num_args - 1>();
|
llvm::function_traits<ConstraintFnT>::num_args - 1>();
|
||||||
@ -842,10 +843,13 @@ public:
|
|||||||
/// Register a constraint function with PDL. A constraint function may be
|
/// Register a constraint function with PDL. A constraint function may be
|
||||||
/// specified in one of two ways:
|
/// specified in one of two ways:
|
||||||
///
|
///
|
||||||
/// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
|
/// * `LogicalResult (PatternRewriter &,
|
||||||
|
/// PDLResultList &,
|
||||||
|
/// ArrayRef<PDLValue>)`
|
||||||
///
|
///
|
||||||
/// In this overload the arguments of the constraint function are passed via
|
/// In this overload the arguments of the constraint function are passed via
|
||||||
/// the low-level PDLValue form.
|
/// the low-level PDLValue form, and the results are manually appended to
|
||||||
|
/// the given result list.
|
||||||
///
|
///
|
||||||
/// * `LogicalResult (PatternRewriter &, ValueTs... values)`
|
/// * `LogicalResult (PatternRewriter &, ValueTs... values)`
|
||||||
///
|
///
|
||||||
@ -960,8 +964,8 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
class PDLResultList {};
|
class PDLResultList {};
|
||||||
using PDLConstraintFunction =
|
using PDLConstraintFunction = std::function<LogicalResult(
|
||||||
std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
|
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
|
||||||
using PDLRewriteFunction = std::function<LogicalResult(
|
using PDLRewriteFunction = std::function<LogicalResult(
|
||||||
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
|
PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
|
||||||
|
|
||||||
|
@ -50,7 +50,8 @@ private:
|
|||||||
|
|
||||||
/// Generate interpreter operations for the tree rooted at the given matcher
|
/// Generate interpreter operations for the tree rooted at the given matcher
|
||||||
/// node, in the specified region.
|
/// node, in the specified region.
|
||||||
Block *generateMatcher(MatcherNode &node, Region ®ion);
|
Block *generateMatcher(MatcherNode &node, Region ®ion,
|
||||||
|
Block *block = nullptr);
|
||||||
|
|
||||||
/// Get or create an access to the provided positional value in the current
|
/// Get or create an access to the provided positional value in the current
|
||||||
/// block. This operation may mutate the provided block pointer if nested
|
/// block. This operation may mutate the provided block pointer if nested
|
||||||
@ -148,6 +149,10 @@ private:
|
|||||||
/// A mapping between pattern operations and the corresponding configuration
|
/// A mapping between pattern operations and the corresponding configuration
|
||||||
/// set.
|
/// set.
|
||||||
DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
|
DenseMap<Operation *, PDLPatternConfigSet *> *configMap;
|
||||||
|
|
||||||
|
/// A mapping from a constraint question to the ApplyConstraintOp
|
||||||
|
/// that implements it.
|
||||||
|
DenseMap<ConstraintQuestion *, pdl_interp::ApplyConstraintOp> constraintOpMap;
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -182,9 +187,11 @@ void PatternLowering::lower(ModuleOp module) {
|
|||||||
firstMatcherBlock->erase();
|
firstMatcherBlock->erase();
|
||||||
}
|
}
|
||||||
|
|
||||||
Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) {
|
Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion,
|
||||||
|
Block *block) {
|
||||||
// Push a new scope for the values used by this matcher.
|
// Push a new scope for the values used by this matcher.
|
||||||
Block *block = ®ion.emplaceBlock();
|
if (!block)
|
||||||
|
block = ®ion.emplaceBlock();
|
||||||
ValueMapScope scope(values);
|
ValueMapScope scope(values);
|
||||||
|
|
||||||
// If this is the return node, simply insert the corresponding interpreter
|
// If this is the return node, simply insert the corresponding interpreter
|
||||||
@ -364,6 +371,15 @@ Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
|
|||||||
loc, cast<ArrayAttr>(rawTypeAttr));
|
loc, cast<ArrayAttr>(rawTypeAttr));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case Predicates::ConstraintResultPos: {
|
||||||
|
// Due to the order of traversal, the ApplyConstraintOp has already been
|
||||||
|
// created and we can find it in constraintOpMap.
|
||||||
|
auto *constrResPos = cast<ConstraintPosition>(pos);
|
||||||
|
auto i = constraintOpMap.find(constrResPos->getQuestion());
|
||||||
|
assert(i != constraintOpMap.end());
|
||||||
|
value = i->second->getResult(constrResPos->getIndex());
|
||||||
|
break;
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
llvm_unreachable("Generating unknown Position getter");
|
llvm_unreachable("Generating unknown Position getter");
|
||||||
break;
|
break;
|
||||||
@ -390,12 +406,11 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
|
|||||||
args.push_back(getValueAt(currentBlock, position));
|
args.push_back(getValueAt(currentBlock, position));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate the matcher in the current (potentially nested) region
|
// Generate a new block as success successor and get the failure successor.
|
||||||
// and get the failure successor.
|
Block *success = ®ion->emplaceBlock();
|
||||||
Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
|
|
||||||
Block *failure = failureBlockStack.back();
|
Block *failure = failureBlockStack.back();
|
||||||
|
|
||||||
// Finally, create the predicate.
|
// Create the predicate.
|
||||||
builder.setInsertionPointToEnd(currentBlock);
|
builder.setInsertionPointToEnd(currentBlock);
|
||||||
Predicates::Kind kind = question->getKind();
|
Predicates::Kind kind = question->getKind();
|
||||||
switch (kind) {
|
switch (kind) {
|
||||||
@ -447,14 +462,20 @@ void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
|
|||||||
}
|
}
|
||||||
case Predicates::ConstraintQuestion: {
|
case Predicates::ConstraintQuestion: {
|
||||||
auto *cstQuestion = cast<ConstraintQuestion>(question);
|
auto *cstQuestion = cast<ConstraintQuestion>(question);
|
||||||
builder.create<pdl_interp::ApplyConstraintOp>(
|
auto applyConstraintOp = builder.create<pdl_interp::ApplyConstraintOp>(
|
||||||
loc, cstQuestion->getName(), args, cstQuestion->getIsNegated(), success,
|
loc, cstQuestion->getResultTypes(), cstQuestion->getName(), args,
|
||||||
failure);
|
cstQuestion->getIsNegated(), success, failure);
|
||||||
|
|
||||||
|
constraintOpMap.insert({cstQuestion, applyConstraintOp});
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
llvm_unreachable("Generating unknown Predicate operation");
|
llvm_unreachable("Generating unknown Predicate operation");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generate the matcher in the current (potentially nested) region.
|
||||||
|
// This might use the results of the current predicate.
|
||||||
|
generateMatcher(*boolNode->getSuccessNode(), *region, success);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
|
template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
|
||||||
|
@ -47,6 +47,7 @@ enum Kind : unsigned {
|
|||||||
OperandPos,
|
OperandPos,
|
||||||
OperandGroupPos,
|
OperandGroupPos,
|
||||||
AttributePos,
|
AttributePos,
|
||||||
|
ConstraintResultPos,
|
||||||
ResultPos,
|
ResultPos,
|
||||||
ResultGroupPos,
|
ResultGroupPos,
|
||||||
TypePos,
|
TypePos,
|
||||||
@ -279,6 +280,28 @@ struct OperationPosition : public PredicateBase<OperationPosition, Position,
|
|||||||
bool isOperandDefiningOp() const;
|
bool isOperandDefiningOp() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// ConstraintPosition
|
||||||
|
|
||||||
|
struct ConstraintQuestion;
|
||||||
|
|
||||||
|
/// A position describing the result of a native constraint. It saves the
|
||||||
|
/// corresponding ConstraintQuestion and result index to enable referring
|
||||||
|
/// back to them
|
||||||
|
struct ConstraintPosition
|
||||||
|
: public PredicateBase<ConstraintPosition, Position,
|
||||||
|
std::pair<ConstraintQuestion *, unsigned>,
|
||||||
|
Predicates::ConstraintResultPos> {
|
||||||
|
using PredicateBase::PredicateBase;
|
||||||
|
|
||||||
|
/// Returns the ConstraintQuestion to enable keeping track of the native
|
||||||
|
/// constraint this position stems from.
|
||||||
|
ConstraintQuestion *getQuestion() const { return key.first; }
|
||||||
|
|
||||||
|
// Returns the result index of this position
|
||||||
|
unsigned getIndex() const { return key.second; }
|
||||||
|
};
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ResultPosition
|
// ResultPosition
|
||||||
|
|
||||||
@ -447,10 +470,12 @@ struct AttributeQuestion
|
|||||||
: public PredicateBase<AttributeQuestion, Qualifier, void,
|
: public PredicateBase<AttributeQuestion, Qualifier, void,
|
||||||
Predicates::AttributeQuestion> {};
|
Predicates::AttributeQuestion> {};
|
||||||
|
|
||||||
/// Apply a parameterized constraint to multiple position values.
|
/// Apply a parameterized constraint to multiple position values and possibly
|
||||||
|
/// produce results.
|
||||||
struct ConstraintQuestion
|
struct ConstraintQuestion
|
||||||
: public PredicateBase<ConstraintQuestion, Qualifier,
|
: public PredicateBase<
|
||||||
std::tuple<StringRef, ArrayRef<Position *>, bool>,
|
ConstraintQuestion, Qualifier,
|
||||||
|
std::tuple<StringRef, ArrayRef<Position *>, ArrayRef<Type>, bool>,
|
||||||
Predicates::ConstraintQuestion> {
|
Predicates::ConstraintQuestion> {
|
||||||
using Base::Base;
|
using Base::Base;
|
||||||
|
|
||||||
@ -460,15 +485,19 @@ struct ConstraintQuestion
|
|||||||
/// Return the arguments of the constraint.
|
/// Return the arguments of the constraint.
|
||||||
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
|
ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
|
||||||
|
|
||||||
|
/// Return the result types of the constraint.
|
||||||
|
ArrayRef<Type> getResultTypes() const { return std::get<2>(key); }
|
||||||
|
|
||||||
/// Return the negation status of the constraint.
|
/// Return the negation status of the constraint.
|
||||||
bool getIsNegated() const { return std::get<2>(key); }
|
bool getIsNegated() const { return std::get<3>(key); }
|
||||||
|
|
||||||
/// Construct an instance with the given storage allocator.
|
/// Construct an instance with the given storage allocator.
|
||||||
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
|
static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
|
||||||
KeyTy key) {
|
KeyTy key) {
|
||||||
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
|
return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
|
||||||
alloc.copyInto(std::get<1>(key)),
|
alloc.copyInto(std::get<1>(key)),
|
||||||
std::get<2>(key)});
|
alloc.copyInto(std::get<2>(key)),
|
||||||
|
std::get<3>(key)});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a hash suitable for the given keytype.
|
/// Returns a hash suitable for the given keytype.
|
||||||
@ -526,6 +555,7 @@ public:
|
|||||||
// Register the types of Positions with the uniquer.
|
// Register the types of Positions with the uniquer.
|
||||||
registerParametricStorageType<AttributePosition>();
|
registerParametricStorageType<AttributePosition>();
|
||||||
registerParametricStorageType<AttributeLiteralPosition>();
|
registerParametricStorageType<AttributeLiteralPosition>();
|
||||||
|
registerParametricStorageType<ConstraintPosition>();
|
||||||
registerParametricStorageType<ForEachPosition>();
|
registerParametricStorageType<ForEachPosition>();
|
||||||
registerParametricStorageType<OperandPosition>();
|
registerParametricStorageType<OperandPosition>();
|
||||||
registerParametricStorageType<OperandGroupPosition>();
|
registerParametricStorageType<OperandGroupPosition>();
|
||||||
@ -588,6 +618,12 @@ public:
|
|||||||
return OperationPosition::get(uniquer, p);
|
return OperationPosition::get(uniquer, p);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns a position for a new value created by a constraint.
|
||||||
|
ConstraintPosition *getConstraintPosition(ConstraintQuestion *q,
|
||||||
|
unsigned index) {
|
||||||
|
return ConstraintPosition::get(uniquer, std::make_pair(q, index));
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns an attribute position for an attribute of the given operation.
|
/// Returns an attribute position for an attribute of the given operation.
|
||||||
Position *getAttribute(OperationPosition *p, StringRef name) {
|
Position *getAttribute(OperationPosition *p, StringRef name) {
|
||||||
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
|
return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
|
||||||
@ -673,10 +709,10 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Create a predicate that applies a generic constraint.
|
/// Create a predicate that applies a generic constraint.
|
||||||
Predicate getConstraint(StringRef name, ArrayRef<Position *> pos,
|
Predicate getConstraint(StringRef name, ArrayRef<Position *> args,
|
||||||
bool isNegated) {
|
ArrayRef<Type> resultTypes, bool isNegated) {
|
||||||
return {
|
return {ConstraintQuestion::get(
|
||||||
ConstraintQuestion::get(uniquer, std::make_tuple(name, pos, isNegated)),
|
uniquer, std::make_tuple(name, args, resultTypes, isNegated)),
|
||||||
TrueAnswer::get(uniquer)};
|
TrueAnswer::get(uniquer)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
#include "mlir/IR/BuiltinOps.h"
|
#include "mlir/IR/BuiltinOps.h"
|
||||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||||
#include "llvm/ADT/MapVector.h"
|
#include "llvm/ADT/MapVector.h"
|
||||||
|
#include "llvm/ADT/SmallPtrSet.h"
|
||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
#include "llvm/Support/Debug.h"
|
#include "llvm/Support/Debug.h"
|
||||||
#include <queue>
|
#include <queue>
|
||||||
@ -49,15 +50,16 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList,
|
|||||||
DenseMap<Value, Position *> &inputs,
|
DenseMap<Value, Position *> &inputs,
|
||||||
AttributePosition *pos) {
|
AttributePosition *pos) {
|
||||||
assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
|
assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type");
|
||||||
pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
|
|
||||||
predList.emplace_back(pos, builder.getIsNotNull());
|
predList.emplace_back(pos, builder.getIsNotNull());
|
||||||
|
|
||||||
|
if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) {
|
||||||
// If the attribute has a type or value, add a constraint.
|
// If the attribute has a type or value, add a constraint.
|
||||||
if (Value type = attr.getValueType())
|
if (Value type = attr.getValueType())
|
||||||
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
|
getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
|
||||||
else if (Attribute value = attr.getValueAttr())
|
else if (Attribute value = attr.getValueAttr())
|
||||||
predList.emplace_back(pos, builder.getAttributeConstraint(value));
|
predList.emplace_back(pos, builder.getAttributeConstraint(value));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Collect all of the predicates for the given operand position.
|
/// Collect all of the predicates for the given operand position.
|
||||||
static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
|
static void getOperandTreePredicates(std::vector<PositionalPredicate> &predList,
|
||||||
@ -272,8 +274,27 @@ static void getConstraintPredicates(pdl::ApplyNativeConstraintOp op,
|
|||||||
// Push the constraint to the furthest position.
|
// Push the constraint to the furthest position.
|
||||||
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
|
Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
|
||||||
comparePosDepth);
|
comparePosDepth);
|
||||||
PredicateBuilder::Predicate pred =
|
ResultRange results = op.getResults();
|
||||||
builder.getConstraint(op.getName(), allPositions, op.getIsNegated());
|
PredicateBuilder::Predicate pred = builder.getConstraint(
|
||||||
|
op.getName(), allPositions, SmallVector<Type>(results.getTypes()),
|
||||||
|
op.getIsNegated());
|
||||||
|
|
||||||
|
// For each result register a position so it can be used later
|
||||||
|
for (auto [i, result] : llvm::enumerate(results)) {
|
||||||
|
ConstraintQuestion *q = cast<ConstraintQuestion>(pred.first);
|
||||||
|
ConstraintPosition *pos = builder.getConstraintPosition(q, i);
|
||||||
|
auto [it, inserted] = inputs.try_emplace(result, pos);
|
||||||
|
// If this is an input value that has been visited in the tree, add a
|
||||||
|
// constraint to ensure that both instances refer to the same value.
|
||||||
|
if (!inserted) {
|
||||||
|
Position *first = pos;
|
||||||
|
Position *second = it->second;
|
||||||
|
if (comparePosDepth(second, first))
|
||||||
|
std::tie(second, first) = std::make_pair(first, second);
|
||||||
|
|
||||||
|
predList.emplace_back(second, builder.getEqualTo(first));
|
||||||
|
}
|
||||||
|
}
|
||||||
predList.emplace_back(pos, pred);
|
predList.emplace_back(pos, pred);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -875,6 +896,49 @@ static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
|
|||||||
*root = std::make_unique<ExitNode>();
|
*root = std::make_unique<ExitNode>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Sorts the range begin/end with the partial order given by cmp.
|
||||||
|
template <typename Iterator, typename Compare>
|
||||||
|
static void stableTopologicalSort(Iterator begin, Iterator end, Compare cmp) {
|
||||||
|
while (begin != end) {
|
||||||
|
// Cannot compute sortBeforeOthers in the predicate of stable_partition
|
||||||
|
// because stable_partition will not keep the [begin, end) range intact
|
||||||
|
// while it runs.
|
||||||
|
llvm::SmallPtrSet<typename Iterator::value_type, 16> sortBeforeOthers;
|
||||||
|
for (auto i = begin; i != end; ++i) {
|
||||||
|
if (std::none_of(begin, end, [&](auto const &b) { return cmp(b, *i); }))
|
||||||
|
sortBeforeOthers.insert(*i);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const next = std::stable_partition(begin, end, [&](auto const &a) {
|
||||||
|
return sortBeforeOthers.contains(a);
|
||||||
|
});
|
||||||
|
assert(next != begin && "not a partial ordering");
|
||||||
|
begin = next;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if 'b' depends on a result of 'a'.
|
||||||
|
static bool dependsOn(OrderedPredicate *a, OrderedPredicate *b) {
|
||||||
|
auto *cqa = dyn_cast<ConstraintQuestion>(a->question);
|
||||||
|
if (!cqa)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto positionDependsOnA = [&](Position *p) {
|
||||||
|
auto *cp = dyn_cast<ConstraintPosition>(p);
|
||||||
|
return cp && cp->getQuestion() == cqa;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (auto *cqb = dyn_cast<ConstraintQuestion>(b->question)) {
|
||||||
|
// Does any argument of b use a?
|
||||||
|
return llvm::any_of(cqb->getArgs(), positionDependsOnA);
|
||||||
|
}
|
||||||
|
if (auto *equalTo = dyn_cast<EqualToQuestion>(b->question)) {
|
||||||
|
return positionDependsOnA(b->position) ||
|
||||||
|
positionDependsOnA(equalTo->getValue());
|
||||||
|
}
|
||||||
|
return positionDependsOnA(b->position);
|
||||||
|
}
|
||||||
|
|
||||||
/// Given a module containing PDL pattern operations, generate a matcher tree
|
/// Given a module containing PDL pattern operations, generate a matcher tree
|
||||||
/// using the patterns within the given module and return the root matcher node.
|
/// using the patterns within the given module and return the root matcher node.
|
||||||
std::unique_ptr<MatcherNode>
|
std::unique_ptr<MatcherNode>
|
||||||
@ -955,6 +1019,10 @@ MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
|
|||||||
return *lhs < *rhs;
|
return *lhs < *rhs;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Mostly keep the now established order, but also ensure that
|
||||||
|
// ConstraintQuestions come after the results they use.
|
||||||
|
stableTopologicalSort(ordered.begin(), ordered.end(), dependsOn);
|
||||||
|
|
||||||
// Build the matchers for each of the pattern predicate lists.
|
// Build the matchers for each of the pattern predicate lists.
|
||||||
std::unique_ptr<MatcherNode> root;
|
std::unique_ptr<MatcherNode> root;
|
||||||
for (OrderedPredicateList &list : lists)
|
for (OrderedPredicateList &list : lists)
|
||||||
|
@ -94,6 +94,12 @@ static void visit(Operation *op, DenseSet<Operation *> &visited) {
|
|||||||
LogicalResult ApplyNativeConstraintOp::verify() {
|
LogicalResult ApplyNativeConstraintOp::verify() {
|
||||||
if (getNumOperands() == 0)
|
if (getNumOperands() == 0)
|
||||||
return emitOpError("expected at least one argument");
|
return emitOpError("expected at least one argument");
|
||||||
|
if (llvm::any_of(getResults(), [](OpResult result) {
|
||||||
|
return isa<OperationType>(result.getType());
|
||||||
|
})) {
|
||||||
|
return emitOpError(
|
||||||
|
"returning an operation from a constraint is not supported");
|
||||||
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -769,11 +769,25 @@ void Generator::generate(Operation *op, ByteCodeWriter &writer) {
|
|||||||
|
|
||||||
void Generator::generate(pdl_interp::ApplyConstraintOp op,
|
void Generator::generate(pdl_interp::ApplyConstraintOp op,
|
||||||
ByteCodeWriter &writer) {
|
ByteCodeWriter &writer) {
|
||||||
assert(constraintToMemIndex.count(op.getName()) &&
|
// Constraints that should return a value have to be registered as rewrites.
|
||||||
"expected index for constraint function");
|
// If a constraint and a rewrite of similar name are registered the
|
||||||
|
// constraint takes precedence
|
||||||
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
|
writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
|
||||||
writer.appendPDLValueList(op.getArgs());
|
writer.appendPDLValueList(op.getArgs());
|
||||||
writer.append(ByteCodeField(op.getIsNegated()));
|
writer.append(ByteCodeField(op.getIsNegated()));
|
||||||
|
ResultRange results = op.getResults();
|
||||||
|
writer.append(ByteCodeField(results.size()));
|
||||||
|
for (Value result : results) {
|
||||||
|
// We record the expected kind of the result, so that we can provide extra
|
||||||
|
// verification of the native rewrite function and handle the failure case
|
||||||
|
// of constraints accordingly.
|
||||||
|
writer.appendPDLValueKind(result);
|
||||||
|
|
||||||
|
// Range results also need to append the range storage index.
|
||||||
|
if (isa<pdl::RangeType>(result.getType()))
|
||||||
|
writer.append(getRangeStorageIndex(result));
|
||||||
|
writer.append(result);
|
||||||
|
}
|
||||||
writer.append(op.getSuccessors());
|
writer.append(op.getSuccessors());
|
||||||
}
|
}
|
||||||
void Generator::generate(pdl_interp::ApplyRewriteOp op,
|
void Generator::generate(pdl_interp::ApplyRewriteOp op,
|
||||||
@ -786,11 +800,9 @@ void Generator::generate(pdl_interp::ApplyRewriteOp op,
|
|||||||
ResultRange results = op.getResults();
|
ResultRange results = op.getResults();
|
||||||
writer.append(ByteCodeField(results.size()));
|
writer.append(ByteCodeField(results.size()));
|
||||||
for (Value result : results) {
|
for (Value result : results) {
|
||||||
// In debug mode we also record the expected kind of the result, so that we
|
// We record the expected kind of the result, so that we
|
||||||
// can provide extra verification of the native rewrite function.
|
// can provide extra verification of the native rewrite function.
|
||||||
#ifndef NDEBUG
|
|
||||||
writer.appendPDLValueKind(result);
|
writer.appendPDLValueKind(result);
|
||||||
#endif
|
|
||||||
|
|
||||||
// Range results also need to append the range storage index.
|
// Range results also need to append the range storage index.
|
||||||
if (isa<pdl::RangeType>(result.getType()))
|
if (isa<pdl::RangeType>(result.getType()))
|
||||||
@ -1076,6 +1088,28 @@ void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
|
|||||||
// ByteCode Execution
|
// ByteCode Execution
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
/// This class is an instantiation of the PDLResultList that provides access to
|
||||||
|
/// the returned results. This API is not on `PDLResultList` to avoid
|
||||||
|
/// overexposing access to information specific solely to the ByteCode.
|
||||||
|
class ByteCodeRewriteResultList : public PDLResultList {
|
||||||
|
public:
|
||||||
|
ByteCodeRewriteResultList(unsigned maxNumResults)
|
||||||
|
: PDLResultList(maxNumResults) {}
|
||||||
|
|
||||||
|
/// Return the list of PDL results.
|
||||||
|
MutableArrayRef<PDLValue> getResults() { return results; }
|
||||||
|
|
||||||
|
/// Return the type ranges allocated by this list.
|
||||||
|
MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
|
||||||
|
return allocatedTypeRanges;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the value ranges allocated by this list.
|
||||||
|
MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
|
||||||
|
return allocatedValueRanges;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// This class provides support for executing a bytecode stream.
|
/// This class provides support for executing a bytecode stream.
|
||||||
class ByteCodeExecutor {
|
class ByteCodeExecutor {
|
||||||
public:
|
public:
|
||||||
@ -1152,6 +1186,9 @@ private:
|
|||||||
void executeSwitchResultCount();
|
void executeSwitchResultCount();
|
||||||
void executeSwitchType();
|
void executeSwitchType();
|
||||||
void executeSwitchTypes();
|
void executeSwitchTypes();
|
||||||
|
void processNativeFunResults(ByteCodeRewriteResultList &results,
|
||||||
|
unsigned numResults,
|
||||||
|
LogicalResult &rewriteResult);
|
||||||
|
|
||||||
/// Pushes a code iterator to the stack.
|
/// Pushes a code iterator to the stack.
|
||||||
void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
|
void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
|
||||||
@ -1225,6 +1262,8 @@ private:
|
|||||||
return T::getFromOpaquePointer(pointer);
|
return T::getFromOpaquePointer(pointer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void skip(size_t skipN) { curCodeIt += skipN; }
|
||||||
|
|
||||||
/// Jump to a specific successor based on a predicate value.
|
/// Jump to a specific successor based on a predicate value.
|
||||||
void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
|
void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
|
||||||
/// Jump to a specific successor based on a destination index.
|
/// Jump to a specific successor based on a destination index.
|
||||||
@ -1381,33 +1420,11 @@ private:
|
|||||||
ArrayRef<PDLConstraintFunction> constraintFunctions;
|
ArrayRef<PDLConstraintFunction> constraintFunctions;
|
||||||
ArrayRef<PDLRewriteFunction> rewriteFunctions;
|
ArrayRef<PDLRewriteFunction> rewriteFunctions;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// This class is an instantiation of the PDLResultList that provides access to
|
|
||||||
/// the returned results. This API is not on `PDLResultList` to avoid
|
|
||||||
/// overexposing access to information specific solely to the ByteCode.
|
|
||||||
class ByteCodeRewriteResultList : public PDLResultList {
|
|
||||||
public:
|
|
||||||
ByteCodeRewriteResultList(unsigned maxNumResults)
|
|
||||||
: PDLResultList(maxNumResults) {}
|
|
||||||
|
|
||||||
/// Return the list of PDL results.
|
|
||||||
MutableArrayRef<PDLValue> getResults() { return results; }
|
|
||||||
|
|
||||||
/// Return the type ranges allocated by this list.
|
|
||||||
MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
|
|
||||||
return allocatedTypeRanges;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return the value ranges allocated by this list.
|
|
||||||
MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
|
|
||||||
return allocatedValueRanges;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
|
void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
|
||||||
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
|
LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
|
||||||
const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
|
ByteCodeField fun_idx = read();
|
||||||
SmallVector<PDLValue, 16> args;
|
SmallVector<PDLValue, 16> args;
|
||||||
readList<PDLValue>(args);
|
readList<PDLValue>(args);
|
||||||
|
|
||||||
@ -1422,8 +1439,29 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
|
|||||||
llvm::dbgs() << " * isNegated: " << isNegated << "\n";
|
llvm::dbgs() << " * isNegated: " << isNegated << "\n";
|
||||||
llvm::interleaveComma(args, llvm::dbgs());
|
llvm::interleaveComma(args, llvm::dbgs());
|
||||||
});
|
});
|
||||||
// Invoke the constraint and jump to the proper destination.
|
|
||||||
selectJump(isNegated != succeeded(constraintFn(rewriter, args)));
|
ByteCodeField numResults = read();
|
||||||
|
const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
|
||||||
|
ByteCodeRewriteResultList results(numResults);
|
||||||
|
LogicalResult rewriteResult = constraintFn(rewriter, results, args);
|
||||||
|
ArrayRef<PDLValue> constraintResults = results.getResults();
|
||||||
|
LLVM_DEBUG({
|
||||||
|
if (succeeded(rewriteResult)) {
|
||||||
|
llvm::dbgs() << " * Constraint succeeded\n";
|
||||||
|
llvm::dbgs() << " * Results: ";
|
||||||
|
llvm::interleaveComma(constraintResults, llvm::dbgs());
|
||||||
|
llvm::dbgs() << "\n";
|
||||||
|
} else {
|
||||||
|
llvm::dbgs() << " * Constraint failed\n";
|
||||||
|
}
|
||||||
|
});
|
||||||
|
assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
|
||||||
|
"native PDL rewrite function succeeded but returned "
|
||||||
|
"unexpected number of results");
|
||||||
|
processNativeFunResults(results, numResults, rewriteResult);
|
||||||
|
|
||||||
|
// Depending on the constraint jump to the proper destination.
|
||||||
|
selectJump(isNegated != succeeded(rewriteResult));
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
|
LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
|
||||||
@ -1445,16 +1483,39 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
|
|||||||
assert(results.getResults().size() == numResults &&
|
assert(results.getResults().size() == numResults &&
|
||||||
"native PDL rewrite function returned unexpected number of results");
|
"native PDL rewrite function returned unexpected number of results");
|
||||||
|
|
||||||
// Store the results in the bytecode memory.
|
processNativeFunResults(results, numResults, rewriteResult);
|
||||||
for (PDLValue &result : results.getResults()) {
|
|
||||||
|
if (failed(rewriteResult)) {
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << " - Failed");
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ByteCodeExecutor::processNativeFunResults(
|
||||||
|
ByteCodeRewriteResultList &results, unsigned numResults,
|
||||||
|
LogicalResult &rewriteResult) {
|
||||||
|
// Store the results in the bytecode memory or handle missing results on
|
||||||
|
// failure.
|
||||||
|
for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
|
||||||
|
PDLValue::Kind resultKind = read<PDLValue::Kind>();
|
||||||
|
|
||||||
|
// Skip the according number of values on the buffer on failure and exit
|
||||||
|
// early as there are no results to process.
|
||||||
|
if (failed(rewriteResult)) {
|
||||||
|
if (resultKind == PDLValue::Kind::TypeRange ||
|
||||||
|
resultKind == PDLValue::Kind::ValueRange) {
|
||||||
|
skip(2);
|
||||||
|
} else {
|
||||||
|
skip(1);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
PDLValue result = results.getResults()[resultIdx];
|
||||||
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
|
LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
|
||||||
|
assert(result.getKind() == resultKind &&
|
||||||
// In debug mode we also verify the expected kind of the result.
|
"native PDL rewrite function returned an unexpected type of "
|
||||||
#ifndef NDEBUG
|
"result");
|
||||||
assert(result.getKind() == read<PDLValue::Kind>() &&
|
|
||||||
"native PDL rewrite function returned an unexpected type of result");
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// If the result is a range, we need to copy it over to the bytecodes
|
// If the result is a range, we need to copy it over to the bytecodes
|
||||||
// range memory.
|
// range memory.
|
||||||
if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
|
if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
|
||||||
@ -1476,13 +1537,6 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
|
|||||||
allocatedTypeRangeMemory.push_back(std::move(it));
|
allocatedTypeRangeMemory.push_back(std::move(it));
|
||||||
for (auto &it : results.getAllocatedValueRanges())
|
for (auto &it : results.getAllocatedValueRanges())
|
||||||
allocatedValueRangeMemory.push_back(std::move(it));
|
allocatedValueRangeMemory.push_back(std::move(it));
|
||||||
|
|
||||||
// Process the result of the rewrite.
|
|
||||||
if (failed(rewriteResult)) {
|
|
||||||
LLVM_DEBUG(llvm::dbgs() << " - Failed");
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ByteCodeExecutor::executeAreEqual() {
|
void ByteCodeExecutor::executeAreEqual() {
|
||||||
|
@ -1362,12 +1362,6 @@ FailureOr<T *> Parser::parseUserNativeConstraintOrRewriteDecl(
|
|||||||
if (failed(parseToken(Token::semicolon,
|
if (failed(parseToken(Token::semicolon,
|
||||||
"expected `;` after native declaration")))
|
"expected `;` after native declaration")))
|
||||||
return failure();
|
return failure();
|
||||||
// TODO: PDL should be able to support constraint results in certain
|
|
||||||
// situations, we should revise this.
|
|
||||||
if (std::is_same<ast::UserConstraintDecl, T>::value && !results.empty()) {
|
|
||||||
return emitError(
|
|
||||||
"native Constraints currently do not support returning results");
|
|
||||||
}
|
|
||||||
return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
|
return T::createNative(ctx, name, arguments, results, optCodeStr, resultType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,6 +79,57 @@ module @constraints {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: module @constraint_with_result
|
||||||
|
module @constraint_with_result {
|
||||||
|
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
|
||||||
|
// CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
|
||||||
|
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
|
||||||
|
pdl.pattern : benefit(1) {
|
||||||
|
%root = operation
|
||||||
|
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
|
||||||
|
rewrite %root with "rewriter"(%attr : !pdl.attribute)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: module @constraint_with_unused_result
|
||||||
|
module @constraint_with_unused_result {
|
||||||
|
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
|
||||||
|
// CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
|
||||||
|
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]] : !pdl.operation)
|
||||||
|
pdl.pattern : benefit(1) {
|
||||||
|
%root = operation
|
||||||
|
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
|
||||||
|
rewrite %root with "rewriter"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: module @constraint_with_result_multiple
|
||||||
|
module @constraint_with_result_multiple {
|
||||||
|
// check that native constraints work as expected even when multiple identical constraints are fused
|
||||||
|
|
||||||
|
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
|
||||||
|
// CHECK: %[[ATTR:.*]] = pdl_interp.apply_constraint "check_op_and_get_attr_constr"(%[[ROOT]]
|
||||||
|
// CHECK-NOT: pdl_interp.apply_constraint "check_op_and_get_attr_constr"
|
||||||
|
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter_0(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
|
||||||
|
// CHECK: pdl_interp.record_match @rewriters::@pdl_generated_rewriter(%[[ROOT]], %[[ATTR]] : !pdl.operation, !pdl.attribute)
|
||||||
|
pdl.pattern : benefit(1) {
|
||||||
|
%root = operation
|
||||||
|
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
|
||||||
|
rewrite %root with "rewriter"(%attr : !pdl.attribute)
|
||||||
|
}
|
||||||
|
pdl.pattern : benefit(1) {
|
||||||
|
%root = operation
|
||||||
|
%attr = pdl.apply_native_constraint "check_op_and_get_attr_constr"(%root : !pdl.operation) : !pdl.attribute
|
||||||
|
rewrite %root with "rewriter"(%attr : !pdl.attribute)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: module @negated_constraint
|
// CHECK-LABEL: module @negated_constraint
|
||||||
module @negated_constraint {
|
module @negated_constraint {
|
||||||
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
|
// CHECK: func @matcher(%[[ROOT:.*]]: !pdl.operation)
|
||||||
|
@ -0,0 +1,77 @@
|
|||||||
|
// RUN: mlir-opt -split-input-file -convert-pdl-to-pdl-interp %s | FileCheck %s
|
||||||
|
|
||||||
|
// Ensuse that the dependency between add & less
|
||||||
|
// causes them to be in the correct order.
|
||||||
|
// CHECK-LABEL: matcher
|
||||||
|
// CHECK: apply_constraint "return_attr_constraint"
|
||||||
|
// CHECK: apply_constraint "use_attr_constraint"
|
||||||
|
|
||||||
|
module {
|
||||||
|
pdl.pattern : benefit(1) {
|
||||||
|
%0 = attribute
|
||||||
|
%1 = types
|
||||||
|
%2 = operation "tosa.mul" {"shift" = %0} -> (%1 : !pdl.range<type>)
|
||||||
|
%3 = attribute = 0 : i32
|
||||||
|
%4 = attribute = 1 : i32
|
||||||
|
%5 = apply_native_constraint "return_attr_constraint"(%3, %4 : !pdl.attribute, !pdl.attribute) : !pdl.attribute
|
||||||
|
apply_native_constraint "use_attr_constraint"(%0, %5 : !pdl.attribute, !pdl.attribute)
|
||||||
|
rewrite %2 with "rewriter"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: matcher
|
||||||
|
// CHECK: %[[ATTR:.*]] = pdl_interp.get_attribute "attr" of
|
||||||
|
// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_attr_constraint"
|
||||||
|
// CHECK: pdl_interp.are_equal %[[ATTR:.*]], %[[CONSTRAINT:.*]]
|
||||||
|
|
||||||
|
pdl.pattern : benefit(1) {
|
||||||
|
%inputOp = operation
|
||||||
|
%result = result 0 of %inputOp
|
||||||
|
%attr = pdl.apply_native_constraint "return_attr_constraint"(%inputOp : !pdl.operation) : !pdl.attribute
|
||||||
|
%root = operation(%result : !pdl.value) {"attr" = %attr}
|
||||||
|
rewrite %root with "rewriter"(%attr : !pdl.attribute)
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: matcher
|
||||||
|
// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_value_constr"
|
||||||
|
// CHECK: %[[VALUE:.*]] = pdl_interp.get_operand 0
|
||||||
|
// CHECK: pdl_interp.are_equal %[[VALUE:.*]], %[[CONSTRAINT:.*]]
|
||||||
|
|
||||||
|
pdl.pattern : benefit(1) {
|
||||||
|
%attr = attribute = 10
|
||||||
|
%value = pdl.apply_native_constraint "return_value_constr"(%attr: !pdl.attribute) : !pdl.value
|
||||||
|
%root = operation(%value : !pdl.value)
|
||||||
|
rewrite %root with "rewriter"
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: matcher
|
||||||
|
// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_type_constr"
|
||||||
|
// CHECK: %[[TYPE:.*]] = pdl_interp.get_value_type of
|
||||||
|
// CHECK: pdl_interp.are_equal %[[TYPE:.*]], %[[CONSTRAINT:.*]]
|
||||||
|
|
||||||
|
pdl.pattern : benefit(1) {
|
||||||
|
%attr = attribute = 10
|
||||||
|
%type = pdl.apply_native_constraint "return_type_constr"(%attr: !pdl.attribute) : !pdl.type
|
||||||
|
%root = operation -> (%type : !pdl.type)
|
||||||
|
rewrite %root with "rewriter"
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: matcher
|
||||||
|
// CHECK: %[[CONSTRAINT:.*]] = pdl_interp.apply_constraint "return_type_range_constr"
|
||||||
|
// CHECK: %[[TYPE:.*]] = pdl_interp.get_value_type of
|
||||||
|
// CHECK: pdl_interp.are_equal %[[TYPE:.*]], %[[CONSTRAINT:.*]]
|
||||||
|
|
||||||
|
pdl.pattern : benefit(1) {
|
||||||
|
%attr = attribute = 10
|
||||||
|
%types = pdl.apply_native_constraint "return_type_range_constr"(%attr: !pdl.attribute) : !pdl.range<type>
|
||||||
|
%root = operation -> (%types : !pdl.range<type>)
|
||||||
|
rewrite %root with "rewriter"
|
||||||
|
}
|
@ -134,6 +134,24 @@ pdl.pattern @apply_rewrite_with_no_results : benefit(1) {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
pdl.pattern @apply_constraint_with_no_results : benefit(1) {
|
||||||
|
%root = operation
|
||||||
|
apply_native_constraint "NativeConstraint"(%root : !pdl.operation)
|
||||||
|
rewrite %root with "rewriter"
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
pdl.pattern @apply_constraint_with_results : benefit(1) {
|
||||||
|
%root = operation
|
||||||
|
%attr = apply_native_constraint "NativeConstraint"(%root : !pdl.operation) : !pdl.attribute
|
||||||
|
rewrite %root {
|
||||||
|
apply_native_rewrite "NativeRewrite"(%attr : !pdl.attribute)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
pdl.pattern @attribute_with_dict : benefit(1) {
|
pdl.pattern @attribute_with_dict : benefit(1) {
|
||||||
%root = operation
|
%root = operation
|
||||||
rewrite %root {
|
rewrite %root {
|
||||||
|
@ -109,6 +109,74 @@ module @ir attributes { test.apply_constraint_3 } {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// Test returning a type from a native constraint.
|
||||||
|
module @patterns {
|
||||||
|
pdl_interp.func @matcher(%root : !pdl.operation) {
|
||||||
|
pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
|
||||||
|
|
||||||
|
^pat:
|
||||||
|
%new_type = pdl_interp.apply_constraint "op_constr_return_type"(%root : !pdl.operation) : !pdl.type -> ^pat2, ^end
|
||||||
|
|
||||||
|
^pat2:
|
||||||
|
pdl_interp.record_match @rewriters::@success(%root, %new_type : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^end
|
||||||
|
|
||||||
|
^end:
|
||||||
|
pdl_interp.finalize
|
||||||
|
}
|
||||||
|
|
||||||
|
module @rewriters {
|
||||||
|
pdl_interp.func @success(%root : !pdl.operation, %new_type : !pdl.type) {
|
||||||
|
%op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%new_type : !pdl.type)
|
||||||
|
pdl_interp.erase %root
|
||||||
|
pdl_interp.finalize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: test.apply_constraint_4
|
||||||
|
// CHECK-NOT: "test.replaced_by_pattern"
|
||||||
|
// CHECK: "test.replaced_by_pattern"() : () -> f32
|
||||||
|
module @ir attributes { test.apply_constraint_4 } {
|
||||||
|
"test.failure_op"() : () -> ()
|
||||||
|
"test.success_op"() : () -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Test success and failure cases of native constraints with pdl.range results.
|
||||||
|
module @patterns {
|
||||||
|
pdl_interp.func @matcher(%root : !pdl.operation) {
|
||||||
|
pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
|
||||||
|
|
||||||
|
^pat:
|
||||||
|
%num_results = pdl_interp.create_attribute 2 : i32
|
||||||
|
%types = pdl_interp.apply_constraint "op_constr_return_type_range"(%root, %num_results : !pdl.operation, !pdl.attribute) : !pdl.range<type> -> ^pat1, ^end
|
||||||
|
|
||||||
|
^pat1:
|
||||||
|
pdl_interp.record_match @rewriters::@success(%root, %types : !pdl.operation, !pdl.range<type>) : benefit(1), loc([%root]) -> ^end
|
||||||
|
|
||||||
|
^end:
|
||||||
|
pdl_interp.finalize
|
||||||
|
}
|
||||||
|
|
||||||
|
module @rewriters {
|
||||||
|
pdl_interp.func @success(%root : !pdl.operation, %types : !pdl.range<type>) {
|
||||||
|
%op = pdl_interp.create_operation "test.replaced_by_pattern" -> (%types : !pdl.range<type>)
|
||||||
|
pdl_interp.erase %root
|
||||||
|
pdl_interp.finalize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: test.apply_constraint_5
|
||||||
|
// CHECK-NOT: "test.replaced_by_pattern"
|
||||||
|
// CHECK: "test.replaced_by_pattern"() : () -> (f32, f32)
|
||||||
|
module @ir attributes { test.apply_constraint_5 } {
|
||||||
|
"test.failure_op"() : () -> ()
|
||||||
|
"test.success_op"() : () -> ()
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// pdl_interp::ApplyRewriteOp
|
// pdl_interp::ApplyRewriteOp
|
||||||
|
@ -887,7 +887,7 @@ public:
|
|||||||
#include "TestTransformDialectExtensionTypes.cpp.inc"
|
#include "TestTransformDialectExtensionTypes.cpp.inc"
|
||||||
>();
|
>();
|
||||||
|
|
||||||
auto verboseConstraint = [](PatternRewriter &rewriter,
|
auto verboseConstraint = [](PatternRewriter &rewriter, PDLResultList &,
|
||||||
ArrayRef<PDLValue> pdlValues) {
|
ArrayRef<PDLValue> pdlValues) {
|
||||||
for (const PDLValue &pdlValue : pdlValues) {
|
for (const PDLValue &pdlValue : pdlValues) {
|
||||||
if (Operation *op = pdlValue.dyn_cast<Operation *>()) {
|
if (Operation *op = pdlValue.dyn_cast<Operation *>()) {
|
||||||
|
@ -30,6 +30,50 @@ static LogicalResult customMultiEntityVariadicConstraint(
|
|||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Custom constraint that returns a value if the op is named test.success_op
|
||||||
|
static LogicalResult customValueResultConstraint(PatternRewriter &rewriter,
|
||||||
|
PDLResultList &results,
|
||||||
|
ArrayRef<PDLValue> args) {
|
||||||
|
auto *op = args[0].cast<Operation *>();
|
||||||
|
if (op->getName().getStringRef() == "test.success_op") {
|
||||||
|
StringAttr customAttr = rewriter.getStringAttr("test.success");
|
||||||
|
results.push_back(customAttr);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom constraint that returns a type if the op is named test.success_op
|
||||||
|
static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
|
||||||
|
PDLResultList &results,
|
||||||
|
ArrayRef<PDLValue> args) {
|
||||||
|
auto *op = args[0].cast<Operation *>();
|
||||||
|
if (op->getName().getStringRef() == "test.success_op") {
|
||||||
|
results.push_back(rewriter.getF32Type());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Custom constraint that returns a type range of variable length if the op is
|
||||||
|
// named test.success_op
|
||||||
|
static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
|
||||||
|
PDLResultList &results,
|
||||||
|
ArrayRef<PDLValue> args) {
|
||||||
|
auto *op = args[0].cast<Operation *>();
|
||||||
|
int numTypes = args[1].cast<Attribute>().cast<IntegerAttr>().getInt();
|
||||||
|
|
||||||
|
if (op->getName().getStringRef() == "test.success_op") {
|
||||||
|
SmallVector<Type> types;
|
||||||
|
for (int i = 0; i < numTypes; i++) {
|
||||||
|
types.push_back(rewriter.getF32Type());
|
||||||
|
}
|
||||||
|
results.push_back(TypeRange(types));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
// Custom creator invoked from PDL.
|
// Custom creator invoked from PDL.
|
||||||
static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
|
static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
|
||||||
return rewriter.create(OperationState(op->getLoc(), "test.success"));
|
return rewriter.create(OperationState(op->getLoc(), "test.success"));
|
||||||
@ -102,6 +146,12 @@ struct TestPDLByteCodePass
|
|||||||
customMultiEntityConstraint);
|
customMultiEntityConstraint);
|
||||||
pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
|
pdlPattern.registerConstraintFunction("multi_entity_var_constraint",
|
||||||
customMultiEntityVariadicConstraint);
|
customMultiEntityVariadicConstraint);
|
||||||
|
pdlPattern.registerConstraintFunction("op_constr_return_attr",
|
||||||
|
customValueResultConstraint);
|
||||||
|
pdlPattern.registerConstraintFunction("op_constr_return_type",
|
||||||
|
customTypeResultConstraint);
|
||||||
|
pdlPattern.registerConstraintFunction("op_constr_return_type_range",
|
||||||
|
customTypeRangeResultConstraint);
|
||||||
pdlPattern.registerRewriteFunction("creator", customCreate);
|
pdlPattern.registerRewriteFunction("creator", customCreate);
|
||||||
pdlPattern.registerRewriteFunction("var_creator",
|
pdlPattern.registerRewriteFunction("var_creator",
|
||||||
customVariadicResultCreate);
|
customVariadicResultCreate);
|
||||||
|
@ -158,8 +158,3 @@ Pattern {
|
|||||||
|
|
||||||
// CHECK: expected `;` after native declaration
|
// CHECK: expected `;` after native declaration
|
||||||
Constraint Foo() [{}]
|
Constraint Foo() [{}]
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
// CHECK: native Constraints currently do not support returning results
|
|
||||||
Constraint Foo() -> Op;
|
|
||||||
|
@ -12,6 +12,14 @@ Constraint Foo() [{ /* Native Code */ }];
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// Test that native constraints support returning results.
|
||||||
|
|
||||||
|
// CHECK: Module
|
||||||
|
// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Attr>
|
||||||
|
Constraint Foo() -> Attr;
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK: Module
|
// CHECK: Module
|
||||||
// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Value>
|
// CHECK: `-UserConstraintDecl {{.*}} Name<Foo> ResultType<Value>
|
||||||
// CHECK: `Inputs`
|
// CHECK: `Inputs`
|
||||||
|
@ -298,6 +298,6 @@ def test_apply_native_constraint():
|
|||||||
pattern = PatternOp(1)
|
pattern = PatternOp(1)
|
||||||
with InsertionPoint(pattern.body):
|
with InsertionPoint(pattern.body):
|
||||||
resultType = TypeOp()
|
resultType = TypeOp()
|
||||||
ApplyNativeConstraintOp("typeConstraint", args=[resultType])
|
ApplyNativeConstraintOp([], "typeConstraint", args=[resultType])
|
||||||
root = OperationOp(types=[resultType])
|
root = OperationOp(types=[resultType])
|
||||||
RewriteOp(root, name="rewrite")
|
RewriteOp(root, name="rewrite")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user