From af73aeaa19929127655d544b48a5145105e9e28c Mon Sep 17 00:00:00 2001 From: Nishant Patel Date: Wed, 19 Nov 2025 16:16:44 -0800 Subject: [PATCH] [MLIR][Vector] Add unroll pattern for vector.shape_cast (#167738) This PR adds pattern for unrolling shape_cast given a targetShape. This PR is a follow up of #164010 which was very general and was using inserts and extracts on each element (which is also LowerVectorShapeCast.cpp is doing). After doing some more research on use cases, we (me and @Jianhui-Li ) realized that the previous version in #164010 is unnecessarily generic and doesn't fit our performance needs. Our use case requires that targetShape is contiguous in both source and result vector. This pattern only applies when contiguous slices can be extracted from the source vector and inserted into the result vector such that each slice remains in vector form with targetShape (and not decompose to scalars). In these cases, the unrolling proceeds as: vector.extract_strided_slice -> vector.shape_cast (on the slice unrolled) -> vector.insert_strided_slice --- .../mlir/Dialect/Vector/IR/VectorOps.td | 1 + mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 4 + .../Vector/Transforms/VectorUnroll.cpp | 193 +++++++++++++++++- .../Dialect/Vector/vector-unroll-options.mlir | 79 +++++++ .../Dialect/Vector/TestVectorTransforms.cpp | 22 ++ 5 files changed, 297 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index f91d2b6404c9..43ebcaa03a47 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2424,6 +2424,7 @@ def Vector_CompressStoreOp : def Vector_ShapeCastOp : Vector_Op<"shape_cast", [Pure, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]>, Arguments<(ins AnyVectorOfAnyRank:$source)>, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a97d0cd7f755..2789f6355552 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6243,6 +6243,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef argRanges, setResultRanges(getResult(), argRanges.front()); } +std::optional> ShapeCastOp::getShapeForUnroll() { + return llvm::to_vector<4>(getResultVectorType().getShape()); +} + LogicalResult ShapeCastOp::verify() { VectorType sourceType = getSourceVectorType(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index fbae0989bed2..b60f80534bfb 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -1003,6 +1003,195 @@ private: vector::UnrollVectorOptions options; }; +/// Checks whether extractShape is a contiguous slice of shape. +/// For extractShape to be contiguous in shape: +/// 1) All but the leading dimension of extractShape and shape must match +/// exactly. 2) The total number of elements in shape must be evenly divisible +/// by +/// the total number of elements in extractShape. +/// Examples: +/// isContiguous([4, 4], [8, 4]) == true +/// isContiguous([2, 4], [8, 4]) == true +/// isContiguous([2, 2], [8, 4]) == false +/// Removes leading unit dimensions to handle cases like: +/// isContiguous([1, 16], [1, 32]) == true +static bool isContiguous(ArrayRef extractShape, + ArrayRef shape) { + + if (extractShape.size() > shape.size()) + return false; + + while (!extractShape.empty() && extractShape.front() == 1) { + extractShape = extractShape.drop_front(); + } + + while (!shape.empty() && shape.front() == 1) { + shape = shape.drop_front(); + } + + size_t rankDiff = shape.size() - extractShape.size(); + if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1))) + return false; + + int64_t extractElements = ShapedType::getNumElements(extractShape); + int64_t shapeElements = ShapedType::getNumElements(shape); + return shapeElements % extractElements == 0; +} + +/// Determines what shape to use with `vector.extract_strided_slice` to extract +/// a contiguous memory region from a source vector. The extraction must be +/// contiguous and contain exactly the specified number of elements. If such an +/// extraction shape cannot be determined, returns std::nullopt. +/// EXAMPLE 1: +/// sourceShape = [16], targetElements = 8 +/// Working right-to-left: +/// - Take min(8, 16) = 8 from only dim → extractShape = [8], +/// remaining = 8/8 = 1 +/// Result: [8] +/// +/// EXAMPLE 2: +/// sourceShape = [4, 4], targetElements = 8 +/// Working right-to-left: +/// - Take min(8, 4) = 4 from last dim → extractShape = [4], +/// remaining = 8/4 = 2 +/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4], +/// remaining = 2/2 = 1 +/// Result: [2, 4] +static std::optional> +calculateSourceExtractShape(ArrayRef sourceShape, + int64_t targetElements) { + SmallVector extractShape; + int64_t remainingElements = targetElements; + + // Build extract shape from innermost dimension outward to ensure contiguity. + for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) { + int64_t takeFromDim = std::min(remainingElements, sourceShape[i]); + extractShape.insert(extractShape.begin(), takeFromDim); + + if (remainingElements % takeFromDim != 0) + return std::nullopt; // Not evenly divisible. + remainingElements /= takeFromDim; + } + + // Fill remaining dimensions with 1. + while (extractShape.size() < sourceShape.size()) + extractShape.insert(extractShape.begin(), 1); + + if (ShapedType::getNumElements(extractShape) != targetElements) + return std::nullopt; + + return extractShape; +} + +// Convert result offsets to source offsets via linear position. +static SmallVector +calculateSourceOffsets(ArrayRef resultOffsets, + ArrayRef sourceShape, + ArrayRef resultShape) { + // Convert result offsets to linear position. + int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape)); + // Convert linear position to source offsets. + return delinearize(linearIndex, computeStrides(sourceShape)); +} + +/// This pattern unrolls `vector.shape_cast` operations according to the +/// provided target unroll shape. It unrolls a large shape cast into smaller +/// shape casts by extracting contiguous slices from the source vector, casting +/// each slice to the target shape, and assembling the result by inserting each +/// computed segment into the appropriate offset of the result vector. +/// +/// This pattern only applies when contiguous slices can be extracted from the +/// source vector and inserted into the result vector such that each slice +/// remains a valid vector (and not decompose to scalars). In these cases, the +/// unrolling proceeds as: +/// vector.extract_strided_slice -> vector.shape_cast (on the slice) -> +/// vector.insert_strided_slice. +/// +/// Example: +/// Given a shape cast operation: +/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32> +/// +/// and a target unroll shape of <2x4>, the pattern produces: +/// +/// %zero = arith.constant dense<0.0> : vector<4x4xf32> +/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1] +/// : vector<8x2xf32> to vector<4x2xf32> +/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32> +/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1] +/// : vector<2x4xf32> into vector<4x4xf32> +/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1] +/// : vector<8x2xf32> to vector<4x2xf32> +/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32> +/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1] +/// : vector<2x4xf32> into vector<4x4xf32> +/// +struct UnrollShapeCastPattern : public OpRewritePattern { + UnrollShapeCastPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, + PatternRewriter &rewriter) const override { + std::optional> targetShape = + getTargetShape(options, shapeCastOp); + if (!targetShape) + return failure(); + + VectorType sourceType = shapeCastOp.getSourceVectorType(); + VectorType resultType = shapeCastOp.getResultVectorType(); + ArrayRef sourceShape = sourceType.getShape(); + ArrayRef resultShape = resultType.getShape(); + + if (!isContiguous(*targetShape, resultShape)) + return rewriter.notifyMatchFailure( + shapeCastOp, "Only supports cases where target shape is " + "contiguous in result vector shape"); + + int64_t targetElements = ShapedType::getNumElements(*targetShape); + + // Calculate the shape to extract from source. + std::optional> extractShape = + calculateSourceExtractShape(sourceShape, targetElements); + if (!extractShape) + return rewriter.notifyMatchFailure( + shapeCastOp, + "cannot extract target number of elements contiguously from source"); + + Location loc = shapeCastOp.getLoc(); + + // Create result vector initialized to zero. + Value result = arith::ConstantOp::create(rewriter, loc, resultType, + rewriter.getZeroAttr(resultType)); + + VectorType targetType = + VectorType::get(*targetShape, sourceType.getElementType()); + + SmallVector extractStrides(extractShape->size(), 1); + SmallVector insertStrides(targetShape->size(), 1); + + for (SmallVector resultOffsets : + StaticTileOffsetRange(resultShape, *targetShape)) { + SmallVector sourceOffsets = + calculateSourceOffsets(resultOffsets, sourceShape, resultShape); + Value sourceChunk = rewriter.createOrFold( + loc, shapeCastOp.getSource(), sourceOffsets, *extractShape, + extractStrides); + Value targetChunk = rewriter.createOrFold( + loc, targetType, sourceChunk); + result = rewriter.createOrFold( + loc, targetChunk, result, resultOffsets, insertStrides); + } + + rewriter.replaceOp(shapeCastOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( @@ -1013,8 +1202,8 @@ void mlir::vector::populateVectorUnrollPatterns( UnrollReductionPattern, UnrollMultiReductionPattern, UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements, - UnrollToElements, UnrollStepPattern>(patterns.getContext(), - options, benefit); + UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>( + patterns.getContext(), options, benefit); } void mlir::vector::populateVectorToElementsUnrollPatterns( diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index e5a98b5c67f3..dec32e1c61a9 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -496,3 +496,82 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3 // CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32> // CHECK-NOT: arith.addf // CHECK: return + + +func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> { + %0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32> + return %0 : vector<2x2x4xf32> +} + +// CHECK-LABEL: func @shape_cast_1D +// CHECK-SAME: (%[[V:.*]]: vector<16xf32>) -> vector<2x2x4xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32> +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32> +// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<8xf32> to vector<2x4xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32> +// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<8xf32> to vector<2x4xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32> +// CHECK: return %[[I1]] : vector<2x2x4xf32> + + +func.func @shape_cast_2D(%v: vector<8x2xf32>) -> vector<4x4xf32> { + %0 = vector.shape_cast %v : vector<8x2xf32> to vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// CHECK-LABEL: func @shape_cast_2D +// CHECK-SAME: (%[[V:.*]]: vector<8x2xf32>) -> vector<4x4xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32> +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32> +// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<4x2xf32> to vector<2x4xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32> +// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<4x2xf32> to vector<2x4xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32> +// CHECK: return %[[I1]] : vector<4x4xf32> + + +// This is a negative test case to ensure that such shape casts are not unrolled +// because the targetShape (2x4) is not contiguous in result vector +func.func @negative_shape_cast_target_shape_not_contiguous(%v: vector<64xf32>) -> vector<8x8xf32> { + %0 = vector.shape_cast %v : vector<64xf32> to vector<8x8xf32> + return %0 : vector<8x8xf32> +} + +// CHECK-LABEL: func @negative_shape_cast_target_shape_not_contiguous +// CHECK-SAME: (%[[V:.*]]: vector<64xf32>) -> vector<8x8xf32> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<64xf32> to vector<8x8xf32> +// CHECK: return %[[SC]] : vector<8x8xf32> + + +// This is negative test case to ensure that such shape casts are not unrolled +// because it cannot determine the extractShape from source vector (8x3) +// to extract conitguous targetShape (2x4) +func.func @negative_shape_cast_source_shape_not_determinable(%v: vector<8x3xf32>) -> vector<6x4xf32> { + %0 = vector.shape_cast %v : vector<8x3xf32> to vector<6x4xf32> + return %0 : vector<6x4xf32> +} + +// CHECK-LABEL: func @negative_shape_cast_source_shape_not_determinable +// CHECK-SAME: (%[[V:.*]]: vector<8x3xf32>) -> vector<6x4xf32> { +// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<8x3xf32> to vector<6x4xf32> +// CHECK: return %[[SC]] : vector<6x4xf32> + + +// TargetShape is [1x16] +func.func @shape_cast_leading_unit_dim(%v: vector<32xf32>) -> vector<1x32xf32> { + %0 = vector.shape_cast %v : vector<32xf32> to vector<1x32xf32> + return %0 : vector<1x32xf32> +} + +// CHECK-LABEL: func @shape_cast_leading_unit_dim +// CHECK-SAME: (%[[V:.*]]: vector<32xf32>) -> vector<1x32xf32> { +// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32> +// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32> +// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<16xf32> to vector<1x16xf32> +// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32> +// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [16], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32> +// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<16xf32> to vector<1x16xf32> +// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [0, 16], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32> +// CHECK: return %[[I1]] : vector<1x32xf32> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 79bfc9bbcda7..e8ea0cc02d7f 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -178,6 +178,28 @@ struct TestVectorUnrollingPatterns .setFilterConstraint([](Operation *op) { return success(isa(op)); })); + populateVectorUnrollPatterns( + patterns, + UnrollVectorOptions() + .setNativeShapeFn( + [](Operation *op) -> std::optional> { + auto shapeCast = dyn_cast(op); + if (!shapeCast) + return std::nullopt; + + auto resultShape = shapeCast.getResultVectorType().getShape(); + // Special case with leading unit dims and different inner dim + // for result and target shape. + if (resultShape.size() == 2 && resultShape[0] == 1 && + resultShape[1] == 32) { + return SmallVector{1, 16}; + } + // Default case: [2,4] for all tests. + return SmallVector{2, 4}; + }) + .setFilterConstraint([](Operation *op) { + return success(isa(op)); + })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions() .setNativeShape(ArrayRef{1, 3, 4, 2})