[mlir][linalg] Fix crash in tile_reduction when output map has constant exprs (#189166)
`generateInitialTensorForPartialReduction` and the `getInitSliceInfo*` helpers unconditionally cast every result expression of the partial result AffineMap to `AffineDimExpr`. When the original output indexing map contains a constant (e.g. `affine_map<(d0,d1,d2)->(d0,0,d2)>`), the constant expression propagates into the partial map and the cast triggers an assertion. Fixes #173025 Assisted-by: Claude Code
This commit is contained in:
parent
273e8d85fe
commit
c2ec012098
@ -434,13 +434,22 @@ struct InitSliceInfo {
|
||||
static InitSliceInfo getInitSliceInfoForOuterReduction(
|
||||
MLIRContext *context, ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
|
||||
ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
|
||||
ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap,
|
||||
ArrayRef<OpFoldResult> initOperandShape) {
|
||||
int64_t initRank = partialReductionMap.getNumResults();
|
||||
SmallVector<OpFoldResult> initOffsets, initSizes;
|
||||
Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
|
||||
Attribute one = IntegerAttr::get(IndexType::get(context), 1);
|
||||
SmallVector<OpFoldResult> initStrides(initRank, one);
|
||||
for (AffineExpr dimExpr : partialReductionMap.getResults()) {
|
||||
for (auto [resultIdx, dimExpr] :
|
||||
llvm::enumerate(partialReductionMap.getResults())) {
|
||||
if (isa<AffineConstantExpr>(dimExpr)) {
|
||||
// A constant index in the output map accesses a fixed position; keep
|
||||
// the full output dimension to match the original output operand shape.
|
||||
initOffsets.push_back(zero);
|
||||
initSizes.push_back(initOperandShape[resultIdx]);
|
||||
continue;
|
||||
}
|
||||
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
|
||||
if (reductionDims.contains(dim)) {
|
||||
initOffsets.push_back(zero);
|
||||
@ -460,13 +469,24 @@ static InitSliceInfo getInitSliceInfoForOuterReduction(
|
||||
static InitSliceInfo getInitSliceInfoForOuterParallel(
|
||||
MLIRContext *context, ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes, const SetVector<unsigned> &reductionDims,
|
||||
ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap) {
|
||||
ArrayRef<OpFoldResult> splitReductionIvs, AffineMap partialReductionMap,
|
||||
ArrayRef<OpFoldResult> initOperandShape) {
|
||||
int64_t initRank = partialReductionMap.getNumResults();
|
||||
SmallVector<OpFoldResult> initOffsets, initSizes;
|
||||
Attribute zero = IntegerAttr::get(IndexType::get(context), 0);
|
||||
Attribute one = IntegerAttr::get(IndexType::get(context), 1);
|
||||
SmallVector<OpFoldResult> initStrides(initRank, one);
|
||||
SmallVector<OpFoldResult> resultShape;
|
||||
for (AffineExpr dimExpr : partialReductionMap.getResults()) {
|
||||
for (auto [resultIdx, dimExpr] :
|
||||
llvm::enumerate(partialReductionMap.getResults())) {
|
||||
if (isa<AffineConstantExpr>(dimExpr)) {
|
||||
// A constant index accesses a fixed position; keep the full output
|
||||
// dimension to match the original output operand shape.
|
||||
initOffsets.push_back(zero);
|
||||
initSizes.push_back(initOperandShape[resultIdx]);
|
||||
resultShape.push_back(initOperandShape[resultIdx]);
|
||||
continue;
|
||||
}
|
||||
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
|
||||
if (std::optional<unsigned> dimPos = getPositionIn(reductionDims, dim)) {
|
||||
initOffsets.push_back(splitReductionIvs[dimPos.value()]);
|
||||
@ -490,17 +510,18 @@ static InitSliceInfo getInitSliceInfo(MLIRContext *context,
|
||||
ArrayRef<OpFoldResult> sizes,
|
||||
const SetVector<unsigned> &reductionDims,
|
||||
ArrayRef<OpFoldResult> splitReductionIvs,
|
||||
AffineMap partialReductionMap) {
|
||||
AffineMap partialReductionMap,
|
||||
ArrayRef<OpFoldResult> initOperandShape) {
|
||||
if (strategy == ReductionTilingStrategy::PartialReductionOuterReduction) {
|
||||
return getInitSliceInfoForOuterReduction(context, offsets, sizes,
|
||||
reductionDims, splitReductionIvs,
|
||||
partialReductionMap);
|
||||
return getInitSliceInfoForOuterReduction(
|
||||
context, offsets, sizes, reductionDims, splitReductionIvs,
|
||||
partialReductionMap, initOperandShape);
|
||||
}
|
||||
assert(strategy == ReductionTilingStrategy::PartialReductionOuterParallel &&
|
||||
"unexpected ReductionTilingStrategy");
|
||||
return getInitSliceInfoForOuterParallel(context, offsets, sizes,
|
||||
reductionDims, splitReductionIvs,
|
||||
partialReductionMap);
|
||||
return getInitSliceInfoForOuterParallel(
|
||||
context, offsets, sizes, reductionDims, splitReductionIvs,
|
||||
partialReductionMap, initOperandShape);
|
||||
}
|
||||
|
||||
/// External model implementation of PartialReductionInterface for
|
||||
@ -538,7 +559,17 @@ struct LinalgOpPartialReductionInterface
|
||||
|
||||
// Append the new partial result dimensions.
|
||||
SmallVector<OpFoldResult> partialResultShape;
|
||||
for (AffineExpr dimExpr : partialMap.getResults()) {
|
||||
Value initValue = linalgOp.getDpsInits()[initIdx];
|
||||
SmallVector<OpFoldResult> initShape =
|
||||
tensor::getMixedSizes(b, loc, initValue);
|
||||
for (auto [resultIdx, dimExpr] :
|
||||
llvm::enumerate(partialMap.getResults())) {
|
||||
if (isa<AffineConstantExpr>(dimExpr)) {
|
||||
// A constant index in the output map accesses a fixed position; use
|
||||
// the actual output dimension size (not a hardcoded 1).
|
||||
partialResultShape.push_back(initShape[resultIdx]);
|
||||
continue;
|
||||
}
|
||||
auto dim = cast<AffineDimExpr>(dimExpr);
|
||||
partialResultShape.push_back(sizes[dim.getPosition()]);
|
||||
}
|
||||
@ -591,11 +622,15 @@ struct LinalgOpPartialReductionInterface
|
||||
|
||||
// Step 2b: Extract a slice of the init operands.
|
||||
SmallVector<Value, 1> tiledInits;
|
||||
for (auto [partialReductionMap, valueToTile] :
|
||||
llvm::zip_equal(partialReductionMaps, init)) {
|
||||
for (auto [partialReductionMap, valueToTile, initOperandValue] :
|
||||
llvm::zip_equal(partialReductionMaps, init, linalgOp.getDpsInits())) {
|
||||
// Compute the actual shape of the original init operand for handling
|
||||
// constant expressions in the partial reduction map.
|
||||
SmallVector<OpFoldResult> initOperandShape =
|
||||
tensor::getMixedSizes(b, loc, initOperandValue);
|
||||
InitSliceInfo sliceInfo = getInitSliceInfo(
|
||||
b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
|
||||
splitReductionIvs, partialReductionMap);
|
||||
splitReductionIvs, partialReductionMap, initOperandShape);
|
||||
auto valueToTileType = cast<RankedTensorType>(valueToTile.getType());
|
||||
RankedTensorType sliceResultType = RankedTensorType::get(
|
||||
sliceInfo.resultShape, valueToTileType.getElementType(),
|
||||
@ -670,6 +705,8 @@ struct LinalgOpPartialReductionInterface
|
||||
SmallVector<int64_t> partialReductionDims;
|
||||
for (auto [resultNum, dimExpr] :
|
||||
llvm::enumerate(partialMap.getResults())) {
|
||||
if (isa<AffineConstantExpr>(dimExpr))
|
||||
continue; // Constant dims are never reduction dims.
|
||||
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
|
||||
if (llvm::is_contained(reductionDims, dim)) {
|
||||
partialReductionDims.push_back(resultNum);
|
||||
@ -707,9 +744,16 @@ struct LinalgOpPartialReductionInterface
|
||||
auto linalgOp = cast<LinalgOp>(op);
|
||||
SmallVector<AffineMap> partialReductionMaps =
|
||||
getPartialResultAffineMaps(linalgOp, reductionDims);
|
||||
InitSliceInfo sliceInfo = getInitSliceInfo(
|
||||
b.getContext(), tilingStrategy, offsets, sizes, reductionDims,
|
||||
splitReductionIvs, partialReductionMaps[resultNumber]);
|
||||
// Compute the actual shape of the init operand for handling constant
|
||||
// expressions in the partial reduction map.
|
||||
Value initOperandValue = linalgOp.getDpsInits()[resultNumber];
|
||||
Location loc = op->getLoc();
|
||||
SmallVector<OpFoldResult> initOperandShape =
|
||||
tensor::getMixedSizes(b, loc, initOperandValue);
|
||||
InitSliceInfo sliceInfo =
|
||||
getInitSliceInfo(b.getContext(), tilingStrategy, offsets, sizes,
|
||||
reductionDims, splitReductionIvs,
|
||||
partialReductionMaps[resultNumber], initOperandShape);
|
||||
std::swap(resultOffsets, sliceInfo.offsets);
|
||||
std::swap(resultSizes, sliceInfo.sizes);
|
||||
|
||||
|
||||
@ -728,3 +728,304 @@ module attributes {transform.with_named_sequence} {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: %[[LOCAL_IDX:.+]] = linalg.index 1 : index
|
||||
// CHECK: affine.apply #[[$INDEX_MAP]](%[[IV]])[%[[LOCAL_IDX]]]
|
||||
|
||||
// -----
|
||||
|
||||
// Verify that tile_reduction_using_forall handles output indexing maps that
|
||||
// contain constant expressions (e.g. `affine_map<(d0,d1,d2)->(d0,0,d2)>`)
|
||||
// without crashing. Previously, generateInitialTensorForPartialReduction
|
||||
// unconditionally cast every map result to AffineDimExpr, triggering an
|
||||
// assertion when a constant expression was present (issue #173025).
|
||||
|
||||
func.func @reduction_tile_with_constant_in_output_map(
|
||||
%arg0: tensor<1x4096x64xf32>,
|
||||
%arg1: tensor<1x1x64xf32>) -> tensor<1x1x64xf32> {
|
||||
%0 = linalg.generic {
|
||||
indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, 0, d2)>
|
||||
],
|
||||
iterator_types = ["parallel", "reduction", "parallel"]
|
||||
} ins(%arg0 : tensor<1x4096x64xf32>) outs(%arg1 : tensor<1x1x64xf32>) {
|
||||
^bb0(%in: f32, %out: f32):
|
||||
%1 = arith.addf %in, %out : f32
|
||||
linalg.yield %1 : f32
|
||||
} -> tensor<1x1x64xf32>
|
||||
return %0 : tensor<1x1x64xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["linalg.generic"]} in %arg
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
%1:4 = transform.structured.tile_reduction_using_forall %0
|
||||
by tile_sizes = [0, 4, 0]
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op,
|
||||
!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @reduction_tile_with_constant_in_output_map
|
||||
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: %[[E:.*]] = tensor.empty() : tensor<1x1x64x1024xf32>
|
||||
// CHECK: %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<1x1x64x1024xf32>) -> tensor<1x1x64x1024xf32>
|
||||
// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (4096) step (4) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<1x1x64x1024xf32>) {
|
||||
// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %arg0[0, %[[IV]], 0] [1, 4, 64] [1, 1, 1]
|
||||
// CHECK: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ARG3]][0, 0, 0, {{.*}}] [1, 1, 64, 1] [1, 1, 1, 1]
|
||||
// CHECK: %[[PARTIAL:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[IN_SLICE]] :
|
||||
// CHECK-SAME: outs(%[[INIT_SLICE]] :
|
||||
// CHECK: scf.forall.in_parallel {
|
||||
// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, 0, {{.*}}] [1, 1, 64, 1] [1, 1, 1, 1]
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: linalg.reduce ins(%[[L]] : tensor<1x1x64x1024xf32>) outs(%arg1 : tensor<1x1x64xf32>) dimensions = [3]
|
||||
// CHECK: return %{{.*}} : tensor<1x1x64xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// Verify tile_reduction_using_forall with a constant in the output map when
|
||||
// the constant-indexed dimension size is greater than 1 (K=3). The partial
|
||||
// init tensor must use the actual output dim size (3), not a hardcoded 1.
|
||||
|
||||
func.func @reduction_tile_forall_constant_dim_k_gt_1(
|
||||
%arg0: tensor<1x4096x64xf32>,
|
||||
%arg1: tensor<1x3x64xf32>) -> tensor<1x3x64xf32> {
|
||||
%0 = linalg.generic {
|
||||
indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, 0, d2)>
|
||||
],
|
||||
iterator_types = ["parallel", "reduction", "parallel"]
|
||||
} ins(%arg0 : tensor<1x4096x64xf32>) outs(%arg1 : tensor<1x3x64xf32>) {
|
||||
^bb0(%in: f32, %out: f32):
|
||||
%1 = arith.addf %in, %out : f32
|
||||
linalg.yield %1 : f32
|
||||
} -> tensor<1x3x64xf32>
|
||||
return %0 : tensor<1x3x64xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["linalg.generic"]} in %arg
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
%1:4 = transform.structured.tile_reduction_using_forall %0
|
||||
by tile_sizes = [0, 4, 0]
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op,
|
||||
!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @reduction_tile_forall_constant_dim_k_gt_1
|
||||
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: %[[E:.*]] = tensor.empty() : tensor<1x3x64x1024xf32>
|
||||
// CHECK: %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<1x3x64x1024xf32>) -> tensor<1x3x64x1024xf32>
|
||||
// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (4096) step (4) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<1x3x64x1024xf32>) {
|
||||
// CHECK: %[[IN_SLICE:.+]] = tensor.extract_slice %arg0[0, %[[IV]], 0] [1, 4, 64] [1, 1, 1]
|
||||
// CHECK: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ARG3]][0, 0, 0, {{.*}}] [1, 3, 64, 1] [1, 1, 1, 1]
|
||||
// CHECK: %[[PARTIAL:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[IN_SLICE]] :
|
||||
// CHECK-SAME: outs(%[[INIT_SLICE]] :
|
||||
// CHECK: scf.forall.in_parallel {
|
||||
// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, 0, {{.*}}] [1, 3, 64, 1] [1, 1, 1, 1]
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: linalg.reduce ins(%[[L]] : tensor<1x3x64x1024xf32>) outs(%arg1 : tensor<1x3x64xf32>) dimensions = [3]
|
||||
// CHECK: return %{{.*}} : tensor<1x3x64xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// Verify tile_reduction_using_for with a constant in the output map when
|
||||
// the constant-indexed dimension size is greater than 1 (K=3). The partial
|
||||
// init tensor must use the actual output dim size (3), not a hardcoded 1.
|
||||
|
||||
func.func @reduction_tile_for_constant_dim_k_gt_1(
|
||||
%arg0: tensor<1x4096x64xf32>,
|
||||
%arg1: tensor<1x3x64xf32>) -> tensor<1x3x64xf32> {
|
||||
%0 = linalg.generic {
|
||||
indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, 0, d2)>
|
||||
],
|
||||
iterator_types = ["parallel", "reduction", "parallel"]
|
||||
} ins(%arg0 : tensor<1x4096x64xf32>) outs(%arg1 : tensor<1x3x64xf32>) {
|
||||
^bb0(%in: f32, %out: f32):
|
||||
%1 = arith.addf %in, %out : f32
|
||||
linalg.yield %1 : f32
|
||||
} -> tensor<1x3x64xf32>
|
||||
return %0 : tensor<1x3x64xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["linalg.generic"]} in %arg
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
%1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
|
||||
by tile_sizes = [0, 4, 0]
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op,
|
||||
!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @reduction_tile_for_constant_dim_k_gt_1
|
||||
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: %[[E:.*]] = tensor.empty() : tensor<1x3x64x4xf32>
|
||||
// CHECK: %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<1x3x64x4xf32>) -> tensor<1x3x64x4xf32>
|
||||
// CHECK: %[[L:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG3:.+]] = %[[F]]) -> (tensor<1x3x64x4xf32>) {
|
||||
// CHECK: %[[PARTIAL:.+]] = linalg.generic
|
||||
// CHECK-SAME: outs(%[[ARG3]] :
|
||||
// CHECK: scf.yield %[[PARTIAL]]
|
||||
// CHECK: }
|
||||
// CHECK: linalg.reduce ins(%[[L]] : tensor<1x3x64x4xf32>) outs(%arg1 : tensor<1x3x64xf32>) dimensions = [3]
|
||||
// CHECK: return %{{.*}} : tensor<1x3x64xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// Verify tile_reduction_using_forall handles dynamic output shapes combined
|
||||
// with a constant expression in the output map. The partial init tensor must
|
||||
// use tensor.dim to query the dynamic dimension at the constant-indexed
|
||||
// position rather than a hardcoded 1.
|
||||
|
||||
// CHECK-LABEL: func @reduction_tile_dynamic_constant_map
|
||||
func.func @reduction_tile_dynamic_constant_map(
|
||||
%arg0: tensor<?x4096x?xf32>,
|
||||
%arg1: tensor<?x3x?xf32>) -> tensor<?x3x?xf32> {
|
||||
%0 = linalg.generic {
|
||||
indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, 0, d2)>
|
||||
],
|
||||
iterator_types = ["parallel", "reduction", "parallel"]
|
||||
} ins(%arg0 : tensor<?x4096x?xf32>) outs(%arg1 : tensor<?x3x?xf32>) {
|
||||
^bb0(%in: f32, %out: f32):
|
||||
%1 = arith.addf %in, %out : f32
|
||||
linalg.yield %1 : f32
|
||||
} -> tensor<?x3x?xf32>
|
||||
return %0 : tensor<?x3x?xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["linalg.generic"]} in %arg
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
%1:4 = transform.structured.tile_reduction_using_forall %0
|
||||
by tile_sizes = [0, 4, 0]
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op,
|
||||
!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that the partial init tensor uses the correct dynamic shape. The
|
||||
// constant-indexed dim 1 is static (size 3) so the partial tensor is
|
||||
// tensor<?x3x?x1024xf32>. The extract_slice within the forall body uses
|
||||
// size 3 at the constant-indexed position, not a hardcoded 1.
|
||||
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: %[[E:.*]] = tensor.empty({{.*}}) : tensor<?x3x?x1024xf32>
|
||||
// CHECK: %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<?x3x?x1024xf32>)
|
||||
// CHECK: scf.forall
|
||||
// CHECK: tensor.extract_slice {{.*}} [{{.*}}, 3, {{.*}}, 1]
|
||||
// CHECK: linalg.reduce ins({{.*}} : tensor<?x3x?x1024xf32>) {{.*}} dimensions = [3]
|
||||
// CHECK: return %{{.*}}
|
||||
|
||||
// -----
|
||||
|
||||
// Verify tile_reduction_using_forall handles two consecutive constant
|
||||
// expressions in the same output map (e.g. `(d0,d1,d2)->(0,0,d2)`).
|
||||
// Both constant-indexed dimensions must use the actual output dim size.
|
||||
|
||||
// CHECK-LABEL: func @reduction_tile_two_constants_in_map
|
||||
func.func @reduction_tile_two_constants_in_map(
|
||||
%arg0: tensor<1x4096x64xf32>,
|
||||
%arg1: tensor<3x5x64xf32>) -> tensor<3x5x64xf32> {
|
||||
%0 = linalg.generic {
|
||||
indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (0, 0, d2)>
|
||||
],
|
||||
iterator_types = ["reduction", "reduction", "parallel"]
|
||||
} ins(%arg0 : tensor<1x4096x64xf32>) outs(%arg1 : tensor<3x5x64xf32>) {
|
||||
^bb0(%in: f32, %out: f32):
|
||||
%1 = arith.addf %in, %out : f32
|
||||
linalg.yield %1 : f32
|
||||
} -> tensor<3x5x64xf32>
|
||||
return %0 : tensor<3x5x64xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["linalg.generic"]} in %arg
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
%1:4 = transform.structured.tile_reduction_using_forall %0
|
||||
by tile_sizes = [0, 4, 0]
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op,
|
||||
!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: %[[E:.*]] = tensor.empty() : tensor<3x5x64x1024xf32>
|
||||
// CHECK: %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<3x5x64x1024xf32>) -> tensor<3x5x64x1024xf32>
|
||||
// CHECK: %[[L:.*]] = scf.forall (%[[IV:.+]]) = (0) to (4096) step (4) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor<3x5x64x1024xf32>) {
|
||||
// CHECK: tensor.extract_slice %arg0[0, %[[IV]], 0] [1, 4, 64] [1, 1, 1]
|
||||
// CHECK: %[[INIT_SLICE:.+]] = tensor.extract_slice %[[ARG3]][0, 0, 0, {{.*}}] [3, 5, 64, 1] [1, 1, 1, 1]
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: scf.forall.in_parallel {
|
||||
// CHECK: tensor.parallel_insert_slice {{.*}} into %[[ARG3]][0, 0, 0, {{.*}}] [3, 5, 64, 1] [1, 1, 1, 1]
|
||||
// CHECK: }
|
||||
// CHECK: }
|
||||
// CHECK: linalg.reduce ins(%[[L]] : tensor<3x5x64x1024xf32>) outs(%arg1 : tensor<3x5x64xf32>) dimensions = [3]
|
||||
// CHECK: return %{{.*}} : tensor<3x5x64xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// Verify tile_reduction_using_for handles dynamic output shapes combined with
|
||||
// a constant expression in the output map. The partial init tensor must use
|
||||
// tensor.dim to query the dynamic dimensions rather than hardcoding 1.
|
||||
|
||||
// CHECK-LABEL: func @reduction_tile_for_dynamic_constant_map
|
||||
func.func @reduction_tile_for_dynamic_constant_map(
|
||||
%arg0: tensor<?x4096x?xf32>,
|
||||
%arg1: tensor<?x3x?xf32>) -> tensor<?x3x?xf32> {
|
||||
%0 = linalg.generic {
|
||||
indexing_maps = [
|
||||
affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
affine_map<(d0, d1, d2) -> (d0, 0, d2)>
|
||||
],
|
||||
iterator_types = ["parallel", "reduction", "parallel"]
|
||||
} ins(%arg0 : tensor<?x4096x?xf32>) outs(%arg1 : tensor<?x3x?xf32>) {
|
||||
^bb0(%in: f32, %out: f32):
|
||||
%1 = arith.addf %in, %out : f32
|
||||
linalg.yield %1 : f32
|
||||
} -> tensor<?x3x?xf32>
|
||||
return %0 : tensor<?x3x?xf32>
|
||||
}
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg: !transform.any_op {transform.readonly}) {
|
||||
%0 = transform.structured.match ops{["linalg.generic"]} in %arg
|
||||
: (!transform.any_op) -> !transform.any_op
|
||||
%1, %2, %3, %loop = transform.structured.tile_reduction_using_for %0
|
||||
by tile_sizes = [0, 4, 0]
|
||||
: (!transform.any_op) -> (!transform.any_op, !transform.any_op,
|
||||
!transform.any_op, !transform.any_op)
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the partial init tensor uses tensor.dim for dynamic dims and the
|
||||
// static constant-indexed position uses size 3. The extract_slice inside
|
||||
// the for loop body must use size 3 at the constant-indexed position.
|
||||
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: %[[E:.*]] = tensor.empty({{.*}}) : tensor<?x3x?x4xf32>
|
||||
// CHECK: %[[F:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[E]] : tensor<?x3x?x4xf32>)
|
||||
// CHECK: %[[L:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG3:.+]] = %[[F]]) -> (tensor<?x3x?x4xf32>) {
|
||||
// CHECK: tensor.extract_slice %arg0[{{.*}}] [{{.*}}, 4, {{.*}}] [1, 1, 1]
|
||||
// CHECK: tensor.extract_slice %[[ARG3]][{{.*}}] [{{.*}}, 3, {{.*}}, 4] [1, 1, 1, 1]
|
||||
// CHECK: }
|
||||
// CHECK: linalg.reduce ins(%[[L]] : tensor<?x3x?x4xf32>) outs(%arg1 : tensor<?x3x?xf32>) dimensions = [3]
|
||||
// CHECK: return %{{.*}}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user