From c2ec012098df051e698d57b3ec9b58c625761bf6 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 3 Apr 2026 13:09:26 +0200 Subject: [PATCH] [mlir][linalg] Fix crash in tile_reduction when output map has constant exprs (#189166) `generateInitialTensorForPartialReduction` and the `getInitSliceInfo*` helpers unconditionally cast every result expression of the partial result AffineMap to `AffineDimExpr`. When the original output indexing map contains a constant (e.g. `affine_map<(d0,d1,d2)->(d0,0,d2)>`), the constant expression propagates into the partial map and the cast triggers an assertion. Fixes #173025 Assisted-by: Claude Code --- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 80 +++-- .../Linalg/transform-tile-reduction.mlir | 301 ++++++++++++++++++ 2 files changed, 363 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 558ebdebd65c..7ed07e1ec9a0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -434,13 +434,22 @@ struct InitSliceInfo { static InitSliceInfo getInitSliceInfoForOuterReduction( MLIRContext *context, ArrayRef offsets, ArrayRef sizes, const SetVector &reductionDims, - ArrayRef splitReductionIvs, AffineMap partialReductionMap) { + ArrayRef splitReductionIvs, AffineMap partialReductionMap, + ArrayRef initOperandShape) { int64_t initRank = partialReductionMap.getNumResults(); SmallVector initOffsets, initSizes; Attribute zero = IntegerAttr::get(IndexType::get(context), 0); Attribute one = IntegerAttr::get(IndexType::get(context), 1); SmallVector initStrides(initRank, one); - for (AffineExpr dimExpr : partialReductionMap.getResults()) { + for (auto [resultIdx, dimExpr] : + llvm::enumerate(partialReductionMap.getResults())) { + if (isa(dimExpr)) { + // A constant index in the output map accesses a fixed position; keep + // the full output dimension to match the original output operand shape. + initOffsets.push_back(zero); + initSizes.push_back(initOperandShape[resultIdx]); + continue; + } unsigned dim = cast(dimExpr).getPosition(); if (reductionDims.contains(dim)) { initOffsets.push_back(zero); @@ -460,13 +469,24 @@ static InitSliceInfo getInitSliceInfoForOuterReduction( static InitSliceInfo getInitSliceInfoForOuterParallel( MLIRContext *context, ArrayRef offsets, ArrayRef sizes, const SetVector &reductionDims, - ArrayRef splitReductionIvs, AffineMap partialReductionMap) { + ArrayRef splitReductionIvs, AffineMap partialReductionMap, + ArrayRef initOperandShape) { int64_t initRank = partialReductionMap.getNumResults(); SmallVector initOffsets, initSizes; + Attribute zero = IntegerAttr::get(IndexType::get(context), 0); Attribute one = IntegerAttr::get(IndexType::get(context), 1); SmallVector initStrides(initRank, one); SmallVector resultShape; - for (AffineExpr dimExpr : partialReductionMap.getResults()) { + for (auto [resultIdx, dimExpr] : + llvm::enumerate(partialReductionMap.getResults())) { + if (isa(dimExpr)) { + // A constant index accesses a fixed position; keep the full output + // dimension to match the original output operand shape. + initOffsets.push_back(zero); + initSizes.push_back(initOperandShape[resultIdx]); + resultShape.push_back(initOperandShape[resultIdx]); + continue; + } unsigned dim = cast(dimExpr).getPosition(); if (std::optional dimPos = getPositionIn(reductionDims, dim)) { initOffsets.push_back(splitReductionIvs[dimPos.value()]); @@ -490,17 +510,18 @@ static InitSliceInfo getInitSliceInfo(MLIRContext *context, ArrayRef sizes, const SetVector &reductionDims, ArrayRef splitReductionIvs, - AffineMap partialReductionMap) { + AffineMap partialReductionMap, + ArrayRef initOperandShape) { if (strategy == ReductionTilingStrategy::PartialReductionOuterReduction) { - return getInitSliceInfoForOuterReduction(context, offsets, sizes, - reductionDims, splitReductionIvs, - partialReductionMap); + return getInitSliceInfoForOuterReduction( + context, offsets, sizes, reductionDims, splitReductionIvs, + partialReductionMap, initOperandShape); } assert(strategy == ReductionTilingStrategy::PartialReductionOuterParallel && "unexpected ReductionTilingStrategy"); - return getInitSliceInfoForOuterParallel(context, offsets, sizes, - reductionDims, splitReductionIvs, - partialReductionMap); + return getInitSliceInfoForOuterParallel( + context, offsets, sizes, reductionDims, splitReductionIvs, + partialReductionMap, initOperandShape); } /// External model implementation of PartialReductionInterface for @@ -538,7 +559,17 @@ struct LinalgOpPartialReductionInterface // Append the new partial result dimensions. SmallVector partialResultShape; - for (AffineExpr dimExpr : partialMap.getResults()) { + Value initValue = linalgOp.getDpsInits()[initIdx]; + SmallVector initShape = + tensor::getMixedSizes(b, loc, initValue); + for (auto [resultIdx, dimExpr] : + llvm::enumerate(partialMap.getResults())) { + if (isa(dimExpr)) { + // A constant index in the output map accesses a fixed position; use + // the actual output dimension size (not a hardcoded 1). + partialResultShape.push_back(initShape[resultIdx]); + continue; + } auto dim = cast(dimExpr); partialResultShape.push_back(sizes[dim.getPosition()]); } @@ -591,11 +622,15 @@ struct LinalgOpPartialReductionInterface // Step 2b: Extract a slice of the init operands. SmallVector tiledInits; - for (auto [partialReductionMap, valueToTile] : - llvm::zip_equal(partialReductionMaps, init)) { + for (auto [partialReductionMap, valueToTile, initOperandValue] : + llvm::zip_equal(partialReductionMaps, init, linalgOp.getDpsInits())) { + // Compute the actual shape of the original init operand for handling + // constant expressions in the partial reduction map. + SmallVector initOperandShape = + tensor::getMixedSizes(b, loc, initOperandValue); InitSliceInfo sliceInfo = getInitSliceInfo( b.getContext(), tilingStrategy, offsets, sizes, reductionDims, - splitReductionIvs, partialReductionMap); + splitReductionIvs, partialReductionMap, initOperandShape); auto valueToTileType = cast(valueToTile.getType()); RankedTensorType sliceResultType = RankedTensorType::get( sliceInfo.resultShape, valueToTileType.getElementType(), @@ -670,6 +705,8 @@ struct LinalgOpPartialReductionInterface SmallVector partialReductionDims; for (auto [resultNum, dimExpr] : llvm::enumerate(partialMap.getResults())) { + if (isa(dimExpr)) + continue; // Constant dims are never reduction dims. unsigned dim = cast(dimExpr).getPosition(); if (llvm::is_contained(reductionDims, dim)) { partialReductionDims.push_back(resultNum); @@ -707,9 +744,16 @@ struct LinalgOpPartialReductionInterface auto linalgOp = cast(op); SmallVector partialReductionMaps = getPartialResultAffineMaps(linalgOp, reductionDims); - InitSliceInfo sliceInfo = getInitSliceInfo( - b.getContext(), tilingStrategy, offsets, sizes, reductionDims, - splitReductionIvs, partialReductionMaps[resultNumber]); + // Compute the actual shape of the init operand for handling constant + // expressions in the partial reduction map. + Value initOperandValue = linalgOp.getDpsInits()[resultNumber]; + Location loc = op->getLoc(); + SmallVector initOperandShape = + tensor::getMixedSizes(b, loc, initOperandValue); + InitSliceInfo sliceInfo = + getInitSliceInfo(b.getContext(), tilingStrategy, offsets, sizes, + reductionDims, splitReductionIvs, + partialReductionMaps[resultNumber], initOperandShape); std::swap(resultOffsets, sliceInfo.offsets); std::swap(resultSizes, sliceInfo.sizes); diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir index e31d4f333557..25af6796d1f6 100644 --- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir @@ -728,3 +728,304 @@ module attributes {transform.with_named_sequence} { // CHECK: linalg.generic // CHECK: %[[LOCAL_IDX:.+]] = linalg.index 1 : index // CHECK: affine.apply #[[$INDEX_MAP]](%[[IV]])[%[[LOCAL_IDX]]] + +// ----- + +// Verify that tile_reduction_using_forall handles output indexing maps that +// contain constant expressions (e.g. `affine_map<(d0,d1,d2)->(d0,0,d2)>`) +// without crashing. Previously, generateInitialTensorForPartialReduction +// unconditionally cast every map result to AffineDimExpr, triggering an +// assertion when a constant expression was present (issue #173025). + +func.func @reduction_tile_with_constant_in_output_map( + %arg0: tensor<1x4096x64xf32>, + %arg1: tensor<1x1x64xf32>) -> tensor<1x1x64xf32> { + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, 0, d2)> + ], + iterator_types = ["parallel", "reduction", "parallel"] + } ins(%arg0 : tensor<1x4096x64xf32>) outs(%arg1 : tensor<1x1x64xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<1x1x64xf32> + return %0 : tensor<1x1x64xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg + : (!transform.any_op) -> !transform.any_op + %1:4 = transform.structured.tile_reduction_using_forall %0 + by tile_sizes = [0, 4, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, + !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK-LABEL: func @reduction_tile_with_constant_in_output_map +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[E:.*]] = tensor.empty() : tensor<1x1x64x1024xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<1x1x64x1024xf32>) -> tensor<1x1x64x1024xf32> +// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (4096) step (4) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<1x1x64x1024xf32>) { +// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %arg0[0, %[[IV]], 0] [1, 4, 64] [1, 1, 1] +// CHECK: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ARG3]][0, 0, 0, {{.*}}] [1, 1, 64, 1] [1, 1, 1, 1] +// CHECK: %[[PARTIAL:.+]] = linalg.generic +// CHECK-SAME: ins(%[[IN_SLICE]] : +// CHECK-SAME: outs(%[[INIT_SLICE]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, 0, {{.*}}] [1, 1, 64, 1] [1, 1, 1, 1] +// CHECK: } +// CHECK: } +// CHECK: linalg.reduce ins(%[[L]] : tensor<1x1x64x1024xf32>) outs(%arg1 : tensor<1x1x64xf32>) dimensions = [3] +// CHECK: return %{{.*}} : tensor<1x1x64xf32> + +// ----- + +// Verify tile_reduction_using_forall with a constant in the output map when +// the constant-indexed dimension size is greater than 1 (K=3). The partial +// init tensor must use the actual output dim size (3), not a hardcoded 1. + +func.func @reduction_tile_forall_constant_dim_k_gt_1( + %arg0: tensor<1x4096x64xf32>, + %arg1: tensor<1x3x64xf32>) -> tensor<1x3x64xf32> { + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, 0, d2)> + ], + iterator_types = ["parallel", "reduction", "parallel"] + } ins(%arg0 : tensor<1x4096x64xf32>) outs(%arg1 : tensor<1x3x64xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<1x3x64xf32> + return %0 : tensor<1x3x64xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg + : (!transform.any_op) -> !transform.any_op + %1:4 = transform.structured.tile_reduction_using_forall %0 + by tile_sizes = [0, 4, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, + !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK-LABEL: func @reduction_tile_forall_constant_dim_k_gt_1 +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[E:.*]] = tensor.empty() : tensor<1x3x64x1024xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<1x3x64x1024xf32>) -> tensor<1x3x64x1024xf32> +// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (4096) step (4) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<1x3x64x1024xf32>) { +// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %arg0[0, %[[IV]], 0] [1, 4, 64] [1, 1, 1] +// CHECK: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ARG3]][0, 0, 0, {{.*}}] [1, 3, 64, 1] [1, 1, 1, 1] +// CHECK: %[[PARTIAL:.+]] = linalg.generic +// CHECK-SAME: ins(%[[IN_SLICE]] : +// CHECK-SAME: outs(%[[INIT_SLICE]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, 0, {{.*}}] [1, 3, 64, 1] [1, 1, 1, 1] +// CHECK: } +// CHECK: } +// CHECK: linalg.reduce ins(%[[L]] : tensor<1x3x64x1024xf32>) outs(%arg1 : tensor<1x3x64xf32>) dimensions = [3] +// CHECK: return %{{.*}} : tensor<1x3x64xf32> + +// ----- + +// Verify tile_reduction_using_for with a constant in the output map when +// the constant-indexed dimension size is greater than 1 (K=3). The partial +// init tensor must use the actual output dim size (3), not a hardcoded 1. + +func.func @reduction_tile_for_constant_dim_k_gt_1( + %arg0: tensor<1x4096x64xf32>, + %arg1: tensor<1x3x64xf32>) -> tensor<1x3x64xf32> { + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, 0, d2)> + ], + iterator_types = ["parallel", "reduction", "parallel"] + } ins(%arg0 : tensor<1x4096x64xf32>) outs(%arg1 : tensor<1x3x64xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<1x3x64xf32> + return %0 : tensor<1x3x64xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg + : (!transform.any_op) -> !transform.any_op + %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0 + by tile_sizes = [0, 4, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, + !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK-LABEL: func @reduction_tile_for_constant_dim_k_gt_1 +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[E:.*]] = tensor.empty() : tensor<1x3x64x4xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<1x3x64x4xf32>) -> tensor<1x3x64x4xf32> +// CHECK: %[[L:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG3:.+]] = %[[F]]) -> (tensor<1x3x64x4xf32>) { +// CHECK: %[[PARTIAL:.+]] = linalg.generic +// CHECK-SAME: outs(%[[ARG3]] : +// CHECK: scf.yield %[[PARTIAL]] +// CHECK: } +// CHECK: linalg.reduce ins(%[[L]] : tensor<1x3x64x4xf32>) outs(%arg1 : tensor<1x3x64xf32>) dimensions = [3] +// CHECK: return %{{.*}} : tensor<1x3x64xf32> + +// ----- + +// Verify tile_reduction_using_forall handles dynamic output shapes combined +// with a constant expression in the output map. The partial init tensor must +// use tensor.dim to query the dynamic dimension at the constant-indexed +// position rather than a hardcoded 1. + +// CHECK-LABEL: func @reduction_tile_dynamic_constant_map +func.func @reduction_tile_dynamic_constant_map( + %arg0: tensor, + %arg1: tensor) -> tensor { + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, 0, d2)> + ], + iterator_types = ["parallel", "reduction", "parallel"] + } ins(%arg0 : tensor) outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg + : (!transform.any_op) -> !transform.any_op + %1:4 = transform.structured.tile_reduction_using_forall %0 + by tile_sizes = [0, 4, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, + !transform.any_op, !transform.any_op) + transform.yield + } +} + +// Verify that the partial init tensor uses the correct dynamic shape. The +// constant-indexed dim 1 is static (size 3) so the partial tensor is +// tensor. The extract_slice within the forall body uses +// size 3 at the constant-indexed position, not a hardcoded 1. +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[E:.*]] = tensor.empty({{.*}}) : tensor +// CHECK: %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor) +// CHECK: scf.forall +// CHECK: tensor.extract_slice {{.*}} [{{.*}}, 3, {{.*}}, 1] +// CHECK: linalg.reduce ins({{.*}} : tensor) {{.*}} dimensions = [3] +// CHECK: return %{{.*}} + +// ----- + +// Verify tile_reduction_using_forall handles two consecutive constant +// expressions in the same output map (e.g. `(d0,d1,d2)->(0,0,d2)`). +// Both constant-indexed dimensions must use the actual output dim size. + +// CHECK-LABEL: func @reduction_tile_two_constants_in_map +func.func @reduction_tile_two_constants_in_map( + %arg0: tensor<1x4096x64xf32>, + %arg1: tensor<3x5x64xf32>) -> tensor<3x5x64xf32> { + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (0, 0, d2)> + ], + iterator_types = ["reduction", "reduction", "parallel"] + } ins(%arg0 : tensor<1x4096x64xf32>) outs(%arg1 : tensor<3x5x64xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<3x5x64xf32> + return %0 : tensor<3x5x64xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg + : (!transform.any_op) -> !transform.any_op + %1:4 = transform.structured.tile_reduction_using_forall %0 + by tile_sizes = [0, 4, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, + !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[E:.*]] = tensor.empty() : tensor<3x5x64x1024xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<3x5x64x1024xf32>) -> tensor<3x5x64x1024xf32> +// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (4096) step (4) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<3x5x64x1024xf32>) { +// CHECK: tensor.extract_slice %arg0[0, %[[IV]], 0] [1, 4, 64] [1, 1, 1] +// CHECK: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ARG3]][0, 0, 0, {{.*}}] [3, 5, 64, 1] [1, 1, 1, 1] +// CHECK: linalg.generic +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice {{.*}} into %[[ARG3]][0, 0, 0, {{.*}}] [3, 5, 64, 1] [1, 1, 1, 1] +// CHECK: } +// CHECK: } +// CHECK: linalg.reduce ins(%[[L]] : tensor<3x5x64x1024xf32>) outs(%arg1 : tensor<3x5x64xf32>) dimensions = [3] +// CHECK: return %{{.*}} : tensor<3x5x64xf32> + +// ----- + +// Verify tile_reduction_using_for handles dynamic output shapes combined with +// a constant expression in the output map. The partial init tensor must use +// tensor.dim to query the dynamic dimensions rather than hardcoding 1. + +// CHECK-LABEL: func @reduction_tile_for_dynamic_constant_map +func.func @reduction_tile_for_dynamic_constant_map( + %arg0: tensor, + %arg1: tensor) -> tensor { + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, 0, d2)> + ], + iterator_types = ["parallel", "reduction", "parallel"] + } ins(%arg0 : tensor) outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg + : (!transform.any_op) -> !transform.any_op + %1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0 + by tile_sizes = [0, 4, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, + !transform.any_op, !transform.any_op) + transform.yield + } +} + +// Verify the partial init tensor uses tensor.dim for dynamic dims and the +// static constant-indexed position uses size 3. The extract_slice inside +// the for loop body must use size 3 at the constant-indexed position. +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[E:.*]] = tensor.empty({{.*}}) : tensor +// CHECK: %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor) +// CHECK: %[[L:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG3:.+]] = %[[F]]) -> (tensor) { +// CHECK: tensor.extract_slice %arg0[{{.*}}] [{{.*}}, 4, {{.*}}] [1, 1, 1] +// CHECK: tensor.extract_slice %[[ARG3]][{{.*}}] [{{.*}}, 3, {{.*}}, 4] [1, 1, 1, 1] +// CHECK: } +// CHECK: linalg.reduce ins(%[[L]] : tensor) outs(%arg1 : tensor) dimensions = [3] +// CHECK: return %{{.*}}