[mlir][linalg-transform] dyn_cast DestinationStyleOpInterface and early return (#166299)

Use `dyn_cast` instead of `cast` and early return if op does not
implement the `DestinationStyleOpInterface`. Before the change the
following IR would cause a segfault when the transform interpreter is
run, where `myop.a` and `myop.b` implement the `TilingInterface` and not
the `DestinationStyleOpInterface`. Tried looking for ops in the upstream
dialect that implement the `TilingInterface` and not the
`DestinationStyleOpInterface` to add a test but could not find any.

```mlir
module {
func.func @fuse(%arg0: tensor<4x4x4xf32>, %arg1: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> {
  %mul = "myop.a"(%arg0, %arg1) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
  %add = "myop.b"(%mul, %mul) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
  return %add : tensor<4x4x4xf32>
}

transform.sequence failures(propagate) {
^bb0(%func: !transform.any_op):
  %mul = transform.structured.match ops{["myop.a"]} in %func : (!transform.any_op) -> !transform.any_op
  %add = transform.structured.match ops{["myop.b"]} in %func : (!transform.any_op) -> !transform.any_op
  %loop, %tiled = transform.structured.tile_using_forall %add tile_sizes [1, 2, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
  %mul_fused, %mul_containing = transform.structured.fuse_into_containing_op %mul into %tiled : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
}
}
```
This commit is contained in:
Hsiang-Chieh Tsou 2025-11-07 00:32:39 -08:00 committed by GitHub
parent 1a34007f5f
commit a257a063c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 81 additions and 3 deletions

View File

@ -997,8 +997,11 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
// Iterate over the outputs of the producer and over the loop bbArgs and
// check if any bbArg points to the same value as the producer output. In
// such case, make the producer output point to the bbArg directly.
for (OpOperand &initOperandPtr :
cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) {
auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(clone);
if (!dpsInterface)
return;
for (OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) {
Value producerOperand =
clone->getOperand(initOperandPtr.getOperandNumber());
for (BlockArgument containerIterArg :
@ -1060,7 +1063,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
resultNumber, offsets, sizes);
// Cleanup clone.
if (dyn_cast<LoopLikeOpInterface>(containingOp))
if (isa<LoopLikeOpInterface>(containingOp))
rewriter.eraseOp(tileableProducer);
return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);

View File

@ -253,6 +253,40 @@ module {
// -----
#map = affine_map<(d0) -> (d0 * 2)>
#map1 = affine_map<(d0) -> (d0 * 4)>
module {
// CHECK-LABEL: func.func @fuse_tileable_op_no_dps
func.func @fuse_tileable_op_no_dps(%arg0: tensor<4x4x4xf32>, %arg1: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> {
%0 = "test.tiling_no_dps_op"(%arg0, %arg1) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
%1 = tensor.empty() : tensor<4x4x4xf32>
// CHECK: scf.forall
%2 = scf.forall (%arg2, %arg3, %arg4) in (4, 2, 1) shared_outs(%arg5 = %1) -> (tensor<4x4x4xf32>) {
%3 = affine.apply #map(%arg3)
%4 = affine.apply #map1(%arg4)
// CHECK: "test.tiling_no_dps_op"
// CHECK: "test.unregistered_op"
%extracted_slice = tensor.extract_slice %0[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<4x4x4xf32> to tensor<1x2x4xf32>
%5 = "test.unregistered_op"(%extracted_slice, %extracted_slice) : (tensor<1x2x4xf32>, tensor<1x2x4xf32>) -> tensor<1x2x4xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %5 into %arg5[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<1x2x4xf32> into tensor<4x4x4xf32>
}
}
return %2 : tensor<4x4x4xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%op = transform.structured.match ops{["test.tiling_no_dps_op"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%forall = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%fused, %new_containing = transform.structured.fuse_into_containing_op %op into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
}
// -----
module {
// CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout_nested
// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<?x?x?xf32>

View File

@ -1051,6 +1051,32 @@ LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
return success();
}
//===----------------------------------------------------------------------===//
// TilingNoDpsOp
//===----------------------------------------------------------------------===//
SmallVector<Range> TilingNoDpsOp::getIterationDomain(OpBuilder &builder) {
return {};
}
SmallVector<utils::IteratorType> TilingNoDpsOp::getLoopIteratorTypes() {
return {};
}
FailureOr<TilingResult>
TilingNoDpsOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
return failure();
}
LogicalResult TilingNoDpsOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
return failure();
}
//===----------------------------------------------------------------------===//
// OpWithShapedTypeInferTypeAdaptorInterfaceOp
//===----------------------------------------------------------------------===//

View File

@ -30,6 +30,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/TilingInterface.td"
include "mlir/Interfaces/ValueBoundsOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
@ -2887,6 +2888,20 @@ def TestLinalgFillOp :
}];
}
//===----------------------------------------------------------------------===//
// Test TilingInterface.
//===----------------------------------------------------------------------===//
def Test_TilingNoDpsOp : TEST_Op<"tiling_no_dps_op",
[Pure, DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs);
let results = (outs AnyRankedTensor:$result);
}
//===----------------------------------------------------------------------===//
// Test NVVM RequiresSM trait.
//===----------------------------------------------------------------------===//