[mlir][PDL] Support running pdl_interp.foreach on ranges of values and types (#173161)

The foreach execution only works for operation ranges, typically
stemming from pdl_interp.get_users.
Custom rewrites/constraints can return ranges of types and values as
well, however.
This pr adds support for executing `pdl_interp.foreach` in those cases.
This commit is contained in:
jumerckx 2026-02-02 17:58:48 +01:00 committed by GitHub
parent 2d1110a70f
commit f3df4b9292
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 131 additions and 0 deletions

View File

@ -1741,6 +1741,36 @@ void ByteCodeExecutor::executeForEach() {
selectJump(size_t(0));
return;
}
case PDLValue::Kind::Value: {
unsigned &index = loopIndex[read()];
ValueRange range = valueRangeMemory[rangeIndex];
assert(index <= range.size() && "iterated past the end");
if (index < range.size()) {
LLVM_DEBUG(llvm::dbgs() << " * Result: " << range[index] << "\n");
value = range[index].getAsOpaquePointer();
break;
}
LLVM_DEBUG(llvm::dbgs() << " * Done\n");
index = 0;
selectJump(size_t(0));
return;
}
case PDLValue::Kind::Type: {
unsigned &index = loopIndex[read()];
TypeRange range = typeRangeMemory[rangeIndex];
assert(index <= range.size() && "iterated past the end");
if (index < range.size()) {
LLVM_DEBUG(llvm::dbgs() << " * Result: " << range[index] << "\n");
value = range[index].getAsOpaquePointer();
break;
}
LLVM_DEBUG(llvm::dbgs() << " * Done\n");
index = 0;
selectJump(size_t(0));
return;
}
default:
llvm_unreachable("unexpected `ForEach` value kind");
}

View File

@ -956,6 +956,92 @@ module @ir attributes { test.foreach } {
// -----
// Test pdl_interp.foreach over a range of types.
module @patterns {
pdl_interp.func @matcher(%root : !pdl.operation) {
pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
^pat:
%results = pdl_interp.get_results of %root : !pdl.range<value>
%types = pdl_interp.get_value_type of %results : !pdl.range<type>
// Iterate over the types of the results of the root op
pdl_interp.foreach %type : !pdl.type in %types {
// Only match if the type is i64, verifying we introspect all types
// but only trigger one rewrite
pdl_interp.check_type %type is i64 -> ^record, ^cont
^record:
pdl_interp.record_match @rewriters::@success(%root, %type : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^cont
^cont:
pdl_interp.continue
} -> ^end
^end:
pdl_interp.finalize
}
module @rewriters {
pdl_interp.func @success(%root : !pdl.operation, %type : !pdl.type) {
// Create an op for the matched i64 type
pdl_interp.create_operation "test.type_found" -> (%type : !pdl.type)
pdl_interp.erase %root
pdl_interp.finalize
}
}
}
// CHECK-LABEL: test.foreach_type
// CHECK: "test.type_found"() : () -> i64
// CHECK-NOT: "test.type_found"
module @ir attributes { test.foreach_type } {
"test.success_op"() : () -> (i32, i64)
}
// -----
// Test pdl_interp.foreach over a range of values returned by native constraint.
module @patterns {
pdl_interp.func @matcher(%root : !pdl.operation) {
pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end
^pat:
%values = pdl_interp.apply_constraint "op_constr_return_value_range"(%root : !pdl.operation) : !pdl.range<value> -> ^loop, ^end
^loop:
pdl_interp.foreach %val : !pdl.value in %values {
%type = pdl_interp.get_value_type of %val : !pdl.type
// Only match if the type is f16, verifying we introspect all values
// but only trigger one rewrite
pdl_interp.check_type %type is f16 -> ^record, ^cont
^record:
pdl_interp.record_match @rewriters::@success(%root, %val : !pdl.operation, !pdl.value) : benefit(1), loc([%root]) -> ^cont
^cont:
pdl_interp.continue
} -> ^end
^end:
pdl_interp.finalize
}
module @rewriters {
pdl_interp.func @success(%root : !pdl.operation, %val : !pdl.value) {
%type = pdl_interp.get_value_type of %val : !pdl.type
pdl_interp.create_operation "test.value_found"(%val : !pdl.value) -> (%type : !pdl.type)
pdl_interp.erase %root
pdl_interp.finalize
}
}
}
// CHECK-LABEL: test.foreach_value
// CHECK: %[[VAL1:.*]] = "test.input1"
// CHECK: "test.value_found"(%[[VAL1]]) : (f16) -> f16
// CHECK-NOT: "test.value_found"
module @ir attributes { test.foreach_value } {
%0 = "test.input0"() : () -> f32
%1 = "test.input1"() : () -> f16
"test.success_op"(%0, %1) : (f32, f16) -> ()
}
// -----
//===----------------------------------------------------------------------===//
// pdl_interp::GetUsersOp
//===----------------------------------------------------------------------===//

View File

@ -81,6 +81,19 @@ static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
return failure();
}
// Custom constraint that returns a value range if the op is named
// test.success_op
static LogicalResult customValueRangeResultConstraint(PatternRewriter &rewriter,
PDLResultList &results,
ArrayRef<PDLValue> args) {
auto *op = args[0].cast<Operation *>();
if (op->getName().getStringRef() == "test.success_op") {
results.push_back(op->getOperands()); // Returns ValueRange
return success();
}
return failure();
}
// Custom creator invoked from PDL.
static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
return rewriter.create(OperationState(op->getLoc(), "test.success"));
@ -161,6 +174,8 @@ struct TestPDLByteCodePass
customConstraintFailure);
pdlPattern.registerConstraintFunction("op_constr_return_type_range",
customTypeRangeResultConstraint);
pdlPattern.registerConstraintFunction("op_constr_return_value_range",
customValueRangeResultConstraint);
pdlPattern.registerRewriteFunction("creator", customCreate);
pdlPattern.registerRewriteFunction("var_creator",
customVariadicResultCreate);