diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index ede7d8a4006f..cf0021628811 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -1741,6 +1741,36 @@ void ByteCodeExecutor::executeForEach() { selectJump(size_t(0)); return; } + case PDLValue::Kind::Value: { + unsigned &index = loopIndex[read()]; + ValueRange range = valueRangeMemory[rangeIndex]; + assert(index <= range.size() && "iterated past the end"); + if (index < range.size()) { + LLVM_DEBUG(llvm::dbgs() << " * Result: " << range[index] << "\n"); + value = range[index].getAsOpaquePointer(); + break; + } + + LLVM_DEBUG(llvm::dbgs() << " * Done\n"); + index = 0; + selectJump(size_t(0)); + return; + } + case PDLValue::Kind::Type: { + unsigned &index = loopIndex[read()]; + TypeRange range = typeRangeMemory[rangeIndex]; + assert(index <= range.size() && "iterated past the end"); + if (index < range.size()) { + LLVM_DEBUG(llvm::dbgs() << " * Result: " << range[index] << "\n"); + value = range[index].getAsOpaquePointer(); + break; + } + + LLVM_DEBUG(llvm::dbgs() << " * Done\n"); + index = 0; + selectJump(size_t(0)); + return; + } default: llvm_unreachable("unexpected `ForEach` value kind"); } diff --git a/mlir/test/Rewrite/pdl-bytecode.mlir b/mlir/test/Rewrite/pdl-bytecode.mlir index 8221f009a659..844f832cd22c 100644 --- a/mlir/test/Rewrite/pdl-bytecode.mlir +++ b/mlir/test/Rewrite/pdl-bytecode.mlir @@ -956,6 +956,92 @@ module @ir attributes { test.foreach } { // ----- +// Test pdl_interp.foreach over a range of types. +module @patterns { + pdl_interp.func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end + + ^pat: + %results = pdl_interp.get_results of %root : !pdl.range + %types = pdl_interp.get_value_type of %results : !pdl.range + // Iterate over the types of the results of the root op + pdl_interp.foreach %type : !pdl.type in %types { + // Only match if the type is i64, verifying we introspect all types + // but only trigger one rewrite + pdl_interp.check_type %type is i64 -> ^record, ^cont + ^record: + pdl_interp.record_match @rewriters::@success(%root, %type : !pdl.operation, !pdl.type) : benefit(1), loc([%root]) -> ^cont + ^cont: + pdl_interp.continue + } -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + pdl_interp.func @success(%root : !pdl.operation, %type : !pdl.type) { + // Create an op for the matched i64 type + pdl_interp.create_operation "test.type_found" -> (%type : !pdl.type) + pdl_interp.erase %root + pdl_interp.finalize + } + } +} +// CHECK-LABEL: test.foreach_type +// CHECK: "test.type_found"() : () -> i64 +// CHECK-NOT: "test.type_found" +module @ir attributes { test.foreach_type } { + "test.success_op"() : () -> (i32, i64) +} +// ----- + +// Test pdl_interp.foreach over a range of values returned by native constraint. +module @patterns { + pdl_interp.func @matcher(%root : !pdl.operation) { + pdl_interp.check_operation_name of %root is "test.success_op" -> ^pat, ^end + + ^pat: + %values = pdl_interp.apply_constraint "op_constr_return_value_range"(%root : !pdl.operation) : !pdl.range -> ^loop, ^end + + ^loop: + pdl_interp.foreach %val : !pdl.value in %values { + %type = pdl_interp.get_value_type of %val : !pdl.type + // Only match if the type is f16, verifying we introspect all values + // but only trigger one rewrite + pdl_interp.check_type %type is f16 -> ^record, ^cont + ^record: + pdl_interp.record_match @rewriters::@success(%root, %val : !pdl.operation, !pdl.value) : benefit(1), loc([%root]) -> ^cont + ^cont: + pdl_interp.continue + } -> ^end + + ^end: + pdl_interp.finalize + } + + module @rewriters { + pdl_interp.func @success(%root : !pdl.operation, %val : !pdl.value) { + %type = pdl_interp.get_value_type of %val : !pdl.type + pdl_interp.create_operation "test.value_found"(%val : !pdl.value) -> (%type : !pdl.type) + pdl_interp.erase %root + pdl_interp.finalize + } + } +} + +// CHECK-LABEL: test.foreach_value +// CHECK: %[[VAL1:.*]] = "test.input1" +// CHECK: "test.value_found"(%[[VAL1]]) : (f16) -> f16 +// CHECK-NOT: "test.value_found" +module @ir attributes { test.foreach_value } { + %0 = "test.input0"() : () -> f32 + %1 = "test.input1"() : () -> f16 + "test.success_op"(%0, %1) : (f32, f16) -> () +} + +// ----- + //===----------------------------------------------------------------------===// // pdl_interp::GetUsersOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp index e5783c96f44e..1e3d0186eb1c 100644 --- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp +++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp @@ -81,6 +81,19 @@ static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter, return failure(); } +// Custom constraint that returns a value range if the op is named +// test.success_op +static LogicalResult customValueRangeResultConstraint(PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef args) { + auto *op = args[0].cast(); + if (op->getName().getStringRef() == "test.success_op") { + results.push_back(op->getOperands()); // Returns ValueRange + return success(); + } + return failure(); +} + // Custom creator invoked from PDL. static Operation *customCreate(PatternRewriter &rewriter, Operation *op) { return rewriter.create(OperationState(op->getLoc(), "test.success")); @@ -161,6 +174,8 @@ struct TestPDLByteCodePass customConstraintFailure); pdlPattern.registerConstraintFunction("op_constr_return_type_range", customTypeRangeResultConstraint); + pdlPattern.registerConstraintFunction("op_constr_return_value_range", + customValueRangeResultConstraint); pdlPattern.registerRewriteFunction("creator", customCreate); pdlPattern.registerRewriteFunction("var_creator", customVariadicResultCreate);