[mlir][tensor] Fix getReassociationForCollapse
for tensor/scalar re… (#144118)
…shapes Commit 6e5a142 changed the behavior of the function when computing reassociations between tensors (consisting of unit/dynamic dimensions) and scalars/0d vectors. The IR representation for such reshapes actually expects an empty reassociation, like so: ``` func.func @example(%arg0 : tensor<?x?x?xf32>) -> tensor<f32> { %0 = tensor.collapse_shape %arg0 [] : tensor<?x?x?xf32> into tensor<f32> } ``` Restore the original behavior - the routine should resort to reporting failures when compile time-known non-unit dimensions are part of the attempted reassociation. Signed-off-by: Artem Gindinson <gindinson@roofline.ai>
This commit is contained in:
parent
0c7ce6883a
commit
f82cf74420
@ -299,19 +299,17 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
|
||||
// this utility).
|
||||
if (numSourceDims <= numTargetDims)
|
||||
return std::nullopt;
|
||||
// Early handling for scalar target types.
|
||||
// Early handling for scalar target types. We should report an invalid
|
||||
// reassociation for non-unit static dimensions - no chance to collapse these
|
||||
// into a scalar.
|
||||
if (numTargetDims == 0) {
|
||||
ReassociationIndices allSourceIndices;
|
||||
allSourceIndices.reserve(numSourceDims);
|
||||
for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
|
||||
++sourceDimIdx) {
|
||||
int64_t sourceSize = sourceShape[sourceDimIdx];
|
||||
// All source dimensions must be unit or dynamic.
|
||||
if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
|
||||
return std::nullopt;
|
||||
allSourceIndices.push_back(sourceDimIdx);
|
||||
}
|
||||
return SmallVector<ReassociationIndices>{allSourceIndices};
|
||||
return SmallVector<ReassociationIndices>{};
|
||||
}
|
||||
|
||||
// Collect source ranges by iterating over the target shape left-to-right.
|
||||
|
@ -23,16 +23,16 @@ makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
|
||||
|
||||
TEST(ReassociationIndicesForCollapse, ScalarTest) {
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}),
|
||||
makeOptionalIndices({{0}}));
|
||||
makeOptionalIndices({}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}),
|
||||
makeOptionalIndices({{0, 1}}));
|
||||
makeOptionalIndices({}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}),
|
||||
makeOptionalIndices({{0}}));
|
||||
makeOptionalIndices({}));
|
||||
EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic,
|
||||
ShapedType::kDynamic, 1,
|
||||
ShapedType::kDynamic},
|
||||
{}),
|
||||
makeOptionalIndices({{0, 1, 2, 3, 4}}));
|
||||
makeOptionalIndices({}));
|
||||
}
|
||||
|
||||
TEST(ReassociationIndicesForCollapse, ScalarTestFailure) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user