[mlir][PDL] Support running pdl_interp.foreach on ranges of values and types (#173161)
The foreach execution only works for operation ranges, typically stemming from pdl_interp.get_users. Custom rewrites/constraints can return ranges of types and values as well, however. This pr adds support for executing `pdl_interp.foreach` in those cases.
This commit is contained in:
parent
2d1110a70f
commit
f3df4b9292
@ -1741,6 +1741,36 @@ void ByteCodeExecutor::executeForEach() {
|
||||
selectJump(size_t(0));
|
||||
return;
|
||||
}
|
||||
case PDLValue::Kind::Value: {
|
||||
unsigned &index = loopIndex[read()];
|
||||
ValueRange range = valueRangeMemory[rangeIndex];
|
||||
assert(index <= range.size() && "iterated past the end");
|
||||
if (index < range.size()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << " * Result: " << range[index] << "\n");
|
||||
value = range[index].getAsOpaquePointer();
|
||||
break;
|
||||
}
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << " * Done\n");
|
||||
index = 0;
|
||||
selectJump(size_t(0));
|
||||
return;
|
||||
}
|
||||
case PDLValue::Kind::Type: {
|
||||
unsigned &index = loopIndex[read()];
|
||||
TypeRange range = typeRangeMemory[rangeIndex];
|
||||
assert(index <= range.size() && "iterated past the end");
|
||||
if (index < range.size()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << " * Result: " << range[index] << "\n");
|
||||
value = range[index].getAsOpaquePointer();
|
||||
break;
|
||||
}
|
||||
|
||||
LLVM_DEBUG(llvm::dbgs() << " * Done\n");
|
||||
index = 0;
|
||||
selectJump(size_t(0));
|
||||
return;
|
||||
}
|
||||
default:
|
||||
llvm_unreachable("unexpected `ForEach` value kind");
|
||||
}
|
||||
|
||||
@ -956,6 +956,92 @@ module @ir attributes { test.foreach } {
|
||||
|
||||
// -----
|
||||
|
||||
// Test pdl_interp.foreach over a range of types.
|
||||
module @patterns {
|
||||
pdl_interp.func @matcher(%root : !pdl.operation) {
|
||||
pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
|
||||
|
||||
^pat:
|
||||
%results = pdl_interp.get_results of %root : !pdl.range<value>
|
||||
%types = pdl_interp.get_value_type of %results : !pdl.range<type>
|
||||
// Iterate over the types of the results of the root op
|
||||
pdl_interp.foreach %type : !pdl.type in %types {
|
||||
// Only match if the type is i64, verifying we introspect all types
|
||||
// but only trigger one rewrite
|
||||
pdl_interp.check_type %type is i64 -> ^record, ^cont
|
||||
^record:
|
||||
pdl_interp.record_match @rewriters::@success(%root, %type : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^cont
|
||||
^cont:
|
||||
pdl_interp.continue
|
||||
} -> ^end
|
||||
|
||||
^end:
|
||||
pdl_interp.finalize
|
||||
}
|
||||
|
||||
module @rewriters {
|
||||
pdl_interp.func @success(%root : !pdl.operation, %type : !pdl.type) {
|
||||
// Create an op for the matched i64 type
|
||||
pdl_interp.create_operation "test.type_found" -> (%type : !pdl.type)
|
||||
pdl_interp.erase %root
|
||||
pdl_interp.finalize
|
||||
}
|
||||
}
|
||||
}
|
||||
// CHECK-LABEL: test.foreach_type
|
||||
// CHECK: "test.type_found"() : () -> i64
|
||||
// CHECK-NOT: "test.type_found"
|
||||
module @ir attributes { test.foreach_type } {
|
||||
"test.success_op"() : () -> (i32, i64)
|
||||
}
|
||||
// -----
|
||||
|
||||
// Test pdl_interp.foreach over a range of values returned by native constraint.
|
||||
module @patterns {
|
||||
pdl_interp.func @matcher(%root : !pdl.operation) {
|
||||
pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
|
||||
|
||||
^pat:
|
||||
%values = pdl_interp.apply_constraint "op_constr_return_value_range"(%root : !pdl.operation) : !pdl.range<value> -> ^loop, ^end
|
||||
|
||||
^loop:
|
||||
pdl_interp.foreach %val : !pdl.value in %values {
|
||||
%type = pdl_interp.get_value_type of %val : !pdl.type
|
||||
// Only match if the type is f16, verifying we introspect all values
|
||||
// but only trigger one rewrite
|
||||
pdl_interp.check_type %type is f16 -> ^record, ^cont
|
||||
^record:
|
||||
pdl_interp.record_match @rewriters::@success(%root, %val : !pdl.operation, !pdl.value) : benefit(1), loc([%root]) -> ^cont
|
||||
^cont:
|
||||
pdl_interp.continue
|
||||
} -> ^end
|
||||
|
||||
^end:
|
||||
pdl_interp.finalize
|
||||
}
|
||||
|
||||
module @rewriters {
|
||||
pdl_interp.func @success(%root : !pdl.operation, %val : !pdl.value) {
|
||||
%type = pdl_interp.get_value_type of %val : !pdl.type
|
||||
pdl_interp.create_operation "test.value_found"(%val : !pdl.value) -> (%type : !pdl.type)
|
||||
pdl_interp.erase %root
|
||||
pdl_interp.finalize
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test.foreach_value
|
||||
// CHECK: %[[VAL1:.*]] = "test.input1"
|
||||
// CHECK: "test.value_found"(%[[VAL1]]) : (f16) -> f16
|
||||
// CHECK-NOT: "test.value_found"
|
||||
module @ir attributes { test.foreach_value } {
|
||||
%0 = "test.input0"() : () -> f32
|
||||
%1 = "test.input1"() : () -> f16
|
||||
"test.success_op"(%0, %1) : (f32, f16) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// pdl_interp::GetUsersOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -81,6 +81,19 @@ static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Custom constraint that returns a value range if the op is named
|
||||
// test.success_op
|
||||
static LogicalResult customValueRangeResultConstraint(PatternRewriter &rewriter,
|
||||
PDLResultList &results,
|
||||
ArrayRef<PDLValue> args) {
|
||||
auto *op = args[0].cast<Operation *>();
|
||||
if (op->getName().getStringRef() == "test.success_op") {
|
||||
results.push_back(op->getOperands()); // Returns ValueRange
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Custom creator invoked from PDL.
|
||||
static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
|
||||
return rewriter.create(OperationState(op->getLoc(), "test.success"));
|
||||
@ -161,6 +174,8 @@ struct TestPDLByteCodePass
|
||||
customConstraintFailure);
|
||||
pdlPattern.registerConstraintFunction("op_constr_return_type_range",
|
||||
customTypeRangeResultConstraint);
|
||||
pdlPattern.registerConstraintFunction("op_constr_return_value_range",
|
||||
customValueRangeResultConstraint);
|
||||
pdlPattern.registerRewriteFunction("creator", customCreate);
|
||||
pdlPattern.registerRewriteFunction("var_creator",
|
||||
customVariadicResultCreate);
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user