[mlir][tensor] Fix getReassociationForCollapse for tensor/scalar re… (#144118)

…shapes

Commit 6e5a142 changed the behavior of the function when computing
reassociations between tensors (consisting of unit/dynamic dimensions)
and scalars/0d vectors. The IR representation for such reshapes actually
expects an empty reassociation, like so:
```
func.func @example(%arg0 : tensor<?x?x?xf32>) -> tensor<f32> {
  %0 = tensor.collapse_shape %arg0 [] : tensor<?x?x?xf32> into tensor<f32>
}
```

Restore the original behavior - the routine should resort to reporting
failures when compile time-known non-unit dimensions are part of the
attempted reassociation.

Signed-off-by: Artem Gindinson <gindinson@roofline.ai>
This commit is contained in:
Artem Gindinson 2025-06-13 20:03:24 +02:00 committed by GitHub
parent 0c7ce6883a
commit f82cf74420
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 10 deletions

View File

@ -299,19 +299,17 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
// this utility). // this utility).
if (numSourceDims <= numTargetDims) if (numSourceDims <= numTargetDims)
return std::nullopt; return std::nullopt;
// Early handling for scalar target types. // Early handling for scalar target types. We should report an invalid
// reassociation for non-unit static dimensions - no chance to collapse these
// into a scalar.
if (numTargetDims == 0) { if (numTargetDims == 0) {
ReassociationIndices allSourceIndices;
allSourceIndices.reserve(numSourceDims);
for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims; for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
++sourceDimIdx) { ++sourceDimIdx) {
int64_t sourceSize = sourceShape[sourceDimIdx]; int64_t sourceSize = sourceShape[sourceDimIdx];
// All source dimensions must be unit or dynamic.
if (sourceSize != 1 && sourceSize != ShapedType::kDynamic) if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
return std::nullopt; return std::nullopt;
allSourceIndices.push_back(sourceDimIdx);
} }
return SmallVector<ReassociationIndices>{allSourceIndices}; return SmallVector<ReassociationIndices>{};
} }
// Collect source ranges by iterating over the target shape left-to-right. // Collect source ranges by iterating over the target shape left-to-right.

View File

@ -23,16 +23,16 @@ makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
TEST(ReassociationIndicesForCollapse, ScalarTest) { TEST(ReassociationIndicesForCollapse, ScalarTest) {
EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}), EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}),
makeOptionalIndices({{0}})); makeOptionalIndices({}));
EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}), EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}),
makeOptionalIndices({{0, 1}})); makeOptionalIndices({}));
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}), EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}),
makeOptionalIndices({{0}})); makeOptionalIndices({}));
EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic,
ShapedType::kDynamic, 1, ShapedType::kDynamic, 1,
ShapedType::kDynamic}, ShapedType::kDynamic},
{}), {}),
makeOptionalIndices({{0, 1, 2, 3, 4}})); makeOptionalIndices({}));
} }
TEST(ReassociationIndicesForCollapse, ScalarTestFailure) { TEST(ReassociationIndicesForCollapse, ScalarTestFailure) {