This support has never really worked well, and is incredibly clunky to use (it effectively creates two argument APIs), and clunky to generate (it isn't clear how we should actually expose this from PDL frontends). Treating these as just attribute arguments is much much cleaner in every aspect of the stack. If we need to optimize lots of constant parameters, it would be better to investigate internal representation optimizations (e.g. batch attribute creation), that do not affect the user (we want a clean external API). Differential Revision: https://reviews.llvm.org/D121569
291 lines
9.3 KiB
Python
291 lines
9.3 KiB
Python
# RUN: %PYTHON %s | FileCheck %s
|
|
|
|
from mlir.ir import *
|
|
from mlir.dialects.pdl import *
|
|
|
|
|
|
def constructAndPrintInModule(f):
|
|
print("\nTEST:", f.__name__)
|
|
with Context(), Location.unknown():
|
|
module = Module.create()
|
|
with InsertionPoint(module.body):
|
|
f()
|
|
print(module)
|
|
return f
|
|
|
|
|
|
# CHECK: module {
|
|
# CHECK: pdl.pattern @operations : benefit(1) {
|
|
# CHECK: %0 = attribute
|
|
# CHECK: %1 = type
|
|
# CHECK: %2 = operation {"attr" = %0} -> (%1 : !pdl.type)
|
|
# CHECK: %3 = result 0 of %2
|
|
# CHECK: %4 = operand
|
|
# CHECK: %5 = operation(%3, %4 : !pdl.value, !pdl.value)
|
|
# CHECK: rewrite %5 with "rewriter"
|
|
# CHECK: }
|
|
# CHECK: }
|
|
@constructAndPrintInModule
|
|
def test_operations():
|
|
pattern = PatternOp(1, "operations")
|
|
with InsertionPoint(pattern.body):
|
|
attr = AttributeOp()
|
|
ty = TypeOp()
|
|
op0 = OperationOp(attributes={"attr": attr}, types=[ty])
|
|
op0_result = ResultOp(op0, 0)
|
|
input = OperandOp()
|
|
root = OperationOp(args=[op0_result, input])
|
|
RewriteOp(root, "rewriter")
|
|
|
|
|
|
# CHECK: module {
|
|
# CHECK: pdl.pattern @rewrite_with_args : benefit(1) {
|
|
# CHECK: %0 = operand
|
|
# CHECK: %1 = operation(%0 : !pdl.value)
|
|
# CHECK: rewrite %1 with "rewriter"(%0 : !pdl.value)
|
|
# CHECK: }
|
|
# CHECK: }
|
|
@constructAndPrintInModule
|
|
def test_rewrite_with_args():
|
|
pattern = PatternOp(1, "rewrite_with_args")
|
|
with InsertionPoint(pattern.body):
|
|
input = OperandOp()
|
|
root = OperationOp(args=[input])
|
|
RewriteOp(root, "rewriter", args=[input])
|
|
|
|
# CHECK: module {
|
|
# CHECK: pdl.pattern @rewrite_multi_root_optimal : benefit(1) {
|
|
# CHECK: %0 = operand
|
|
# CHECK: %1 = operand
|
|
# CHECK: %2 = type
|
|
# CHECK: %3 = operation(%0 : !pdl.value) -> (%2 : !pdl.type)
|
|
# CHECK: %4 = result 0 of %3
|
|
# CHECK: %5 = operation(%4 : !pdl.value)
|
|
# CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type)
|
|
# CHECK: %7 = result 0 of %6
|
|
# CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value)
|
|
# CHECK: rewrite with "rewriter"(%5, %8 : !pdl.operation, !pdl.operation)
|
|
# CHECK: }
|
|
# CHECK: }
|
|
@constructAndPrintInModule
|
|
def test_rewrite_multi_root_optimal():
|
|
pattern = PatternOp(1, "rewrite_multi_root_optimal")
|
|
with InsertionPoint(pattern.body):
|
|
input1 = OperandOp()
|
|
input2 = OperandOp()
|
|
ty = TypeOp()
|
|
op1 = OperationOp(args=[input1], types=[ty])
|
|
val1 = ResultOp(op1, 0)
|
|
root1 = OperationOp(args=[val1])
|
|
op2 = OperationOp(args=[input2], types=[ty])
|
|
val2 = ResultOp(op2, 0)
|
|
root2 = OperationOp(args=[val1, val2])
|
|
RewriteOp(name="rewriter", args=[root1, root2])
|
|
|
|
# CHECK: module {
|
|
# CHECK: pdl.pattern @rewrite_multi_root_forced : benefit(1) {
|
|
# CHECK: %0 = operand
|
|
# CHECK: %1 = operand
|
|
# CHECK: %2 = type
|
|
# CHECK: %3 = operation(%0 : !pdl.value) -> (%2 : !pdl.type)
|
|
# CHECK: %4 = result 0 of %3
|
|
# CHECK: %5 = operation(%4 : !pdl.value)
|
|
# CHECK: %6 = operation(%1 : !pdl.value) -> (%2 : !pdl.type)
|
|
# CHECK: %7 = result 0 of %6
|
|
# CHECK: %8 = operation(%4, %7 : !pdl.value, !pdl.value)
|
|
# CHECK: rewrite %5 with "rewriter"(%8 : !pdl.operation)
|
|
# CHECK: }
|
|
# CHECK: }
|
|
@constructAndPrintInModule
|
|
def test_rewrite_multi_root_forced():
|
|
pattern = PatternOp(1, "rewrite_multi_root_forced")
|
|
with InsertionPoint(pattern.body):
|
|
input1 = OperandOp()
|
|
input2 = OperandOp()
|
|
ty = TypeOp()
|
|
op1 = OperationOp(args=[input1], types=[ty])
|
|
val1 = ResultOp(op1, 0)
|
|
root1 = OperationOp(args=[val1])
|
|
op2 = OperationOp(args=[input2], types=[ty])
|
|
val2 = ResultOp(op2, 0)
|
|
root2 = OperationOp(args=[val1, val2])
|
|
RewriteOp(root1, name="rewriter", args=[root2])
|
|
|
|
# CHECK: module {
|
|
# CHECK: pdl.pattern @rewrite_add_body : benefit(1) {
|
|
# CHECK: %0 = type : i32
|
|
# CHECK: %1 = type
|
|
# CHECK: %2 = operation -> (%0, %1 : !pdl.type, !pdl.type)
|
|
# CHECK: rewrite %2 {
|
|
# CHECK: %3 = type
|
|
# CHECK: %4 = operation "foo.op" -> (%0, %3 : !pdl.type, !pdl.type)
|
|
# CHECK: replace %2 with %4
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
@constructAndPrintInModule
|
|
def test_rewrite_add_body():
|
|
pattern = PatternOp(1, "rewrite_add_body")
|
|
with InsertionPoint(pattern.body):
|
|
ty1 = TypeOp(IntegerType.get_signless(32))
|
|
ty2 = TypeOp()
|
|
root = OperationOp(types=[ty1, ty2])
|
|
rewrite = RewriteOp(root)
|
|
with InsertionPoint(rewrite.add_body()):
|
|
ty3 = TypeOp()
|
|
newOp = OperationOp(name="foo.op", types=[ty1, ty3])
|
|
ReplaceOp(root, with_op=newOp)
|
|
|
|
# CHECK: module {
|
|
# CHECK: pdl.pattern @rewrite_type : benefit(1) {
|
|
# CHECK: %0 = type : i32
|
|
# CHECK: %1 = type
|
|
# CHECK: %2 = operation -> (%0, %1 : !pdl.type, !pdl.type)
|
|
# CHECK: rewrite %2 {
|
|
# CHECK: %3 = operation "foo.op" -> (%0, %1 : !pdl.type, !pdl.type)
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
@constructAndPrintInModule
|
|
def test_rewrite_type():
|
|
pattern = PatternOp(1, "rewrite_type")
|
|
with InsertionPoint(pattern.body):
|
|
ty1 = TypeOp(IntegerType.get_signless(32))
|
|
ty2 = TypeOp()
|
|
root = OperationOp(types=[ty1, ty2])
|
|
rewrite = RewriteOp(root)
|
|
with InsertionPoint(rewrite.add_body()):
|
|
newOp = OperationOp(name="foo.op", types=[ty1, ty2])
|
|
|
|
# CHECK: module {
|
|
# CHECK: pdl.pattern @rewrite_types : benefit(1) {
|
|
# CHECK: %0 = types
|
|
# CHECK: %1 = operation -> (%0 : !pdl.range<type>)
|
|
# CHECK: rewrite %1 {
|
|
# CHECK: %2 = types : [i32, i64]
|
|
# CHECK: %3 = operation "foo.op" -> (%0, %2 : !pdl.range<type>, !pdl.range<type>)
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
@constructAndPrintInModule
|
|
def test_rewrite_types():
|
|
pattern = PatternOp(1, "rewrite_types")
|
|
with InsertionPoint(pattern.body):
|
|
types = TypesOp()
|
|
root = OperationOp(types=[types])
|
|
rewrite = RewriteOp(root)
|
|
with InsertionPoint(rewrite.add_body()):
|
|
otherTypes = TypesOp([IntegerType.get_signless(32), IntegerType.get_signless(64)])
|
|
newOp = OperationOp(name="foo.op", types=[types, otherTypes])
|
|
|
|
# CHECK: module {
|
|
# CHECK: pdl.pattern @rewrite_operands : benefit(1) {
|
|
# CHECK: %0 = types
|
|
# CHECK: %1 = operands : %0
|
|
# CHECK: %2 = operation(%1 : !pdl.range<value>)
|
|
# CHECK: rewrite %2 {
|
|
# CHECK: %3 = operation "foo.op" -> (%0 : !pdl.range<type>)
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
@constructAndPrintInModule
|
|
def test_rewrite_operands():
|
|
pattern = PatternOp(1, "rewrite_operands")
|
|
with InsertionPoint(pattern.body):
|
|
types = TypesOp()
|
|
operands = OperandsOp(types)
|
|
root = OperationOp(args=[operands])
|
|
rewrite = RewriteOp(root)
|
|
with InsertionPoint(rewrite.add_body()):
|
|
newOp = OperationOp(name="foo.op", types=[types])
|
|
|
|
# CHECK: module {
|
|
# CHECK: pdl.pattern @native_rewrite : benefit(1) {
|
|
# CHECK: %0 = operation
|
|
# CHECK: rewrite %0 {
|
|
# CHECK: apply_native_rewrite "NativeRewrite"(%0 : !pdl.operation)
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
@constructAndPrintInModule
|
|
def test_native_rewrite():
|
|
pattern = PatternOp(1, "native_rewrite")
|
|
with InsertionPoint(pattern.body):
|
|
root = OperationOp()
|
|
rewrite = RewriteOp(root)
|
|
with InsertionPoint(rewrite.add_body()):
|
|
ApplyNativeRewriteOp([], "NativeRewrite", args=[root])
|
|
|
|
# CHECK: module {
|
|
# CHECK: pdl.pattern @attribute_with_value : benefit(1) {
|
|
# CHECK: %0 = operation
|
|
# CHECK: rewrite %0 {
|
|
# CHECK: %1 = attribute "value"
|
|
# CHECK: apply_native_rewrite "NativeRewrite"(%1 : !pdl.attribute)
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
@constructAndPrintInModule
|
|
def test_attribute_with_value():
|
|
pattern = PatternOp(1, "attribute_with_value")
|
|
with InsertionPoint(pattern.body):
|
|
root = OperationOp()
|
|
rewrite = RewriteOp(root)
|
|
with InsertionPoint(rewrite.add_body()):
|
|
attr = AttributeOp(value=Attribute.parse('"value"'))
|
|
ApplyNativeRewriteOp([], "NativeRewrite", args=[attr])
|
|
|
|
# CHECK: module {
|
|
# CHECK: pdl.pattern @erase : benefit(1) {
|
|
# CHECK: %0 = operation
|
|
# CHECK: rewrite %0 {
|
|
# CHECK: erase %0
|
|
# CHECK: }
|
|
# CHECK: }
|
|
# CHECK: }
|
|
@constructAndPrintInModule
|
|
def test_erase():
|
|
pattern = PatternOp(1, "erase")
|
|
with InsertionPoint(pattern.body):
|
|
root = OperationOp()
|
|
rewrite = RewriteOp(root)
|
|
with InsertionPoint(rewrite.add_body()):
|
|
EraseOp(root)
|
|
|
|
# CHECK: module {
|
|
# CHECK: pdl.pattern @operation_results : benefit(1) {
|
|
# CHECK: %0 = types
|
|
# CHECK: %1 = operation -> (%0 : !pdl.range<type>)
|
|
# CHECK: %2 = results of %1
|
|
# CHECK: %3 = operation(%2 : !pdl.range<value>)
|
|
# CHECK: rewrite %3 with "rewriter"
|
|
# CHECK: }
|
|
# CHECK: }
|
|
@constructAndPrintInModule
|
|
def test_operation_results():
|
|
valueRange = RangeType.get(ValueType.get())
|
|
pattern = PatternOp(1, "operation_results")
|
|
with InsertionPoint(pattern.body):
|
|
types = TypesOp()
|
|
inputOp = OperationOp(types=[types])
|
|
results = ResultsOp(valueRange, inputOp)
|
|
root = OperationOp(args=[results])
|
|
RewriteOp(root, name="rewriter")
|
|
|
|
# CHECK: module {
|
|
# CHECK: pdl.pattern : benefit(1) {
|
|
# CHECK: %0 = type
|
|
# CHECK: apply_native_constraint "typeConstraint"(%0 : !pdl.type)
|
|
# CHECK: %1 = operation -> (%0 : !pdl.type)
|
|
# CHECK: rewrite %1 with "rewrite"
|
|
# CHECK: }
|
|
# CHECK: }
|
|
@constructAndPrintInModule
|
|
def test_apply_native_constraint():
|
|
pattern = PatternOp(1)
|
|
with InsertionPoint(pattern.body):
|
|
resultType = TypeOp()
|
|
ApplyNativeConstraintOp("typeConstraint", args=[resultType])
|
|
root = OperationOp(types=[resultType])
|
|
RewriteOp(root, name="rewrite")
|