[mlir][linalg-transform] dyn_cast DestinationStyleOpInterface and early return (#166299)
Use `dyn_cast` instead of `cast` and early return if op does not
implement the `DestinationStyleOpInterface`. Before the change the
following IR would cause a segfault when the transform interpreter is
run, where `myop.a` and `myop.b` implement the `TilingInterface` and not
the `DestinationStyleOpInterface`. Tried looking for ops in the upstream
dialect that implement the `TilingInterface` and not the
`DestinationStyleOpInterface` to add a test but could not find any.
```mlir
module {
func.func @fuse(%arg0: tensor<4x4x4xf32>, %arg1: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> {
%mul = "myop.a"(%arg0, %arg1) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
%add = "myop.b"(%mul, %mul) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
return %add : tensor<4x4x4xf32>
}
transform.sequence failures(propagate) {
^bb0(%func: !transform.any_op):
%mul = transform.structured.match ops{["myop.a"]} in %func : (!transform.any_op) -> !transform.any_op
%add = transform.structured.match ops{["myop.b"]} in %func : (!transform.any_op) -> !transform.any_op
%loop, %tiled = transform.structured.tile_using_forall %add tile_sizes [1, 2, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
%mul_fused, %mul_containing = transform.structured.fuse_into_containing_op %mul into %tiled : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
}
}
```
This commit is contained in:
parent
1a34007f5f
commit
a257a063c6
@ -997,8 +997,11 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
|
||||
// Iterate over the outputs of the producer and over the loop bbArgs and
|
||||
// check if any bbArg points to the same value as the producer output. In
|
||||
// such case, make the producer output point to the bbArg directly.
|
||||
for (OpOperand &initOperandPtr :
|
||||
cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) {
|
||||
auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(clone);
|
||||
if (!dpsInterface)
|
||||
return;
|
||||
|
||||
for (OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) {
|
||||
Value producerOperand =
|
||||
clone->getOperand(initOperandPtr.getOperandNumber());
|
||||
for (BlockArgument containerIterArg :
|
||||
@ -1060,7 +1063,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
|
||||
resultNumber, offsets, sizes);
|
||||
|
||||
// Cleanup clone.
|
||||
if (dyn_cast<LoopLikeOpInterface>(containingOp))
|
||||
if (isa<LoopLikeOpInterface>(containingOp))
|
||||
rewriter.eraseOp(tileableProducer);
|
||||
|
||||
return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);
|
||||
|
||||
@ -253,6 +253,40 @@ module {
|
||||
|
||||
// -----
|
||||
|
||||
#map = affine_map<(d0) -> (d0 * 2)>
|
||||
#map1 = affine_map<(d0) -> (d0 * 4)>
|
||||
module {
|
||||
// CHECK-LABEL: func.func @fuse_tileable_op_no_dps
|
||||
func.func @fuse_tileable_op_no_dps(%arg0: tensor<4x4x4xf32>, %arg1: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> {
|
||||
%0 = "test.tiling_no_dps_op"(%arg0, %arg1) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
|
||||
%1 = tensor.empty() : tensor<4x4x4xf32>
|
||||
// CHECK: scf.forall
|
||||
%2 = scf.forall (%arg2, %arg3, %arg4) in (4, 2, 1) shared_outs(%arg5 = %1) -> (tensor<4x4x4xf32>) {
|
||||
%3 = affine.apply #map(%arg3)
|
||||
%4 = affine.apply #map1(%arg4)
|
||||
// CHECK: "test.tiling_no_dps_op"
|
||||
// CHECK: "test.unregistered_op"
|
||||
%extracted_slice = tensor.extract_slice %0[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<4x4x4xf32> to tensor<1x2x4xf32>
|
||||
%5 = "test.unregistered_op"(%extracted_slice, %extracted_slice) : (tensor<1x2x4xf32>, tensor<1x2x4xf32>) -> tensor<1x2x4xf32>
|
||||
scf.forall.in_parallel {
|
||||
tensor.parallel_insert_slice %5 into %arg5[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<1x2x4xf32> into tensor<4x4x4xf32>
|
||||
}
|
||||
}
|
||||
return %2 : tensor<4x4x4xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
|
||||
%op = transform.structured.match ops{["test.tiling_no_dps_op"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%forall = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
|
||||
%fused, %new_containing = transform.structured.fuse_into_containing_op %op into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
// CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout_nested
|
||||
// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<?x?x?xf32>
|
||||
|
||||
@ -1051,6 +1051,32 @@ LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TilingNoDpsOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SmallVector<Range> TilingNoDpsOp::getIterationDomain(OpBuilder &builder) {
|
||||
return {};
|
||||
}
|
||||
|
||||
SmallVector<utils::IteratorType> TilingNoDpsOp::getLoopIteratorTypes() {
|
||||
return {};
|
||||
}
|
||||
|
||||
FailureOr<TilingResult>
|
||||
TilingNoDpsOp::getTiledImplementation(OpBuilder &builder,
|
||||
ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult TilingNoDpsOp::getResultTilePosition(
|
||||
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
|
||||
SmallVector<OpFoldResult> &resultSizes) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpWithShapedTypeInferTypeAdaptorInterfaceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -30,6 +30,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/LoopLikeInterface.td"
|
||||
include "mlir/Interfaces/MemorySlotInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/TilingInterface.td"
|
||||
include "mlir/Interfaces/ValueBoundsOpInterface.td"
|
||||
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
|
||||
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
|
||||
@ -2887,6 +2888,20 @@ def TestLinalgFillOp :
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test TilingInterface.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Test_TilingNoDpsOp : TEST_Op<"tiling_no_dps_op",
|
||||
[Pure, DeclareOpInterfaceMethods<TilingInterface,
|
||||
["getIterationDomain",
|
||||
"getLoopIteratorTypes",
|
||||
"getResultTilePosition",
|
||||
"getTiledImplementation"]>]> {
|
||||
let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs);
|
||||
let results = (outs AnyRankedTensor:$result);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test NVVM RequiresSM trait.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user