From cfda27d0fbda65de4a7f482d46231933c9f2c678 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrzej=20Warzy=C5=84ski?= Date: Thu, 20 Nov 2025 13:39:52 +0000 Subject: [PATCH] [mlir][Vector] Add support for scalable vectors to `ScanToArithOps` (#123117) Note, scalable reductions dims are left as a TODO. --- .../Vector/Transforms/LowerVectorScan.cpp | 14 ++- .../Vector/vector-scan-transforms.mlir | 94 ++++++++++++++++++- 2 files changed, 105 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp index 258f2cbc7773..1af552362a26 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp @@ -111,7 +111,7 @@ struct ScanToArithOps : public OpRewritePattern { if (!isValidKind(isInt, scanOp.getKind())) return failure(); - VectorType resType = VectorType::get(destShape, elType); + VectorType resType = destType; Value result = arith::ConstantOp::create(rewriter, loc, resType, rewriter.getZeroAttr(resType)); int64_t reductionDim = scanOp.getReductionDim(); @@ -121,8 +121,18 @@ struct ScanToArithOps : public OpRewritePattern { int64_t initialValueRank = initialValueType.getRank(); SmallVector reductionShape(destShape); + SmallVector reductionScalableDims(destType.getScalableDims()); + + if (reductionScalableDims[reductionDim]) + return rewriter.notifyMatchFailure( + scanOp, "Trying to reduce scalable dimension - not yet supported!"); + + // The reduction dimension, after reducing, becomes 1. It's a fixed-width + // dimension - no need to touch the scalability flag. reductionShape[reductionDim] = 1; - VectorType reductionType = VectorType::get(reductionShape, elType); + VectorType reductionType = + VectorType::get(reductionShape, elType, reductionScalableDims); + SmallVector offsets(destRank, 0); SmallVector strides(destRank, 1); SmallVector sizes(destShape); diff --git a/mlir/test/Dialect/Vector/vector-scan-transforms.mlir b/mlir/test/Dialect/Vector/vector-scan-transforms.mlir index 1d8f440e0fb0..27a36538095e 100644 --- a/mlir/test/Dialect/Vector/vector-scan-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-scan-transforms.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s --test-vector-scan-lowering | FileCheck %s +// RUN: mlir-opt %s -split-input-file --test-vector-scan-lowering | FileCheck %s // CHECK-LABEL: func @scan1d_inc // CHECK-SAME: %[[ARG0:.*]]: vector<2xi32>, @@ -18,6 +18,20 @@ func.func @scan1d_inc(%arg0 : vector<2xi32>, %arg1 : vector) -> (vector<2xi return %0#0, %0#1 : vector<2xi32>, vector } +// ----- + +// Reducing scalable dims is not yet supported! + +// CHECK-LABEL: func @scan1d_inc_scalable +// CHECK: vector.scan +func.func @scan1d_inc_scalable(%arg0 : vector<[2]xi32>, %arg1 : vector) -> (vector<[2]xi32>, vector) { + %0:2 = vector.scan , %arg0, %arg1 {inclusive = true, reduction_dim = 0} : + vector<[2]xi32>, vector + return %0#0, %0#1 : vector<[2]xi32>, vector +} + +// ----- + // CHECK-LABEL: func @scan1d_exc // CHECK-SAME: %[[ARG0:.*]]: vector<2xi32>, // CHECK-SAME: %[[ARG1:.*]]: vector @@ -36,6 +50,20 @@ func.func @scan1d_exc(%arg0 : vector<2xi32>, %arg1 : vector) -> (vector<2xi return %0#0, %0#1 : vector<2xi32>, vector } +// ----- + +// Rducing scalable dims is not yet supported! + +// CHECK-LABEL: func @scan1d_exc_scalable +// CHECK: vector.scan +func.func @scan1d_exc_scalable(%arg0 : vector<[2]xi32>, %arg1 : vector) -> (vector<[2]xi32>, vector) { + %0:2 = vector.scan , %arg0, %arg1 {inclusive = false, reduction_dim = 0} : + vector<[2]xi32>, vector + return %0#0, %0#1 : vector<[2]xi32>, vector +} + +// ----- + // CHECK-LABEL: func @scan2d_mul_dim0 // CHECK-SAME: %[[ARG0:.*]]: vector<2x3xi32>, // CHECK-SAME: %[[ARG1:.*]]: vector<3xi32> @@ -53,6 +81,27 @@ func.func @scan2d_mul_dim0(%arg0 : vector<2x3xi32>, %arg1 : vector<3xi32>) -> (v return %0#0, %0#1 : vector<2x3xi32>, vector<3xi32> } +// ----- + +// CHECK-LABEL: func @scan2d_mul_dim0_scalable +// CHECK-SAME: %[[ARG0:.*]]: vector<2x[3]xi32>, +// CHECK-SAME: %[[ARG1:.*]]: vector<[3]xi32> +// CHECK: %[[A:.*]] = arith.constant dense<0> : vector<2x[3]xi32> +// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x[3]xi32> to vector<1x[3]xi32> +// CHECK: %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0, 0], strides = [1, 1]} : vector<1x[3]xi32> into vector<2x[3]xi32> +// CHECK: %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x[3]xi32> to vector<1x[3]xi32> +// CHECK: %[[E:.*]] = arith.muli %[[B]], %[[D]] : vector<1x[3]xi32> +// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [1, 0], strides = [1, 1]} : vector<1x[3]xi32> into vector<2x[3]xi32> +// CHECK: %[[G:.*]] = vector.shape_cast %[[E]] : vector<1x[3]xi32> to vector<[3]xi32> +// CHECK: return %[[F]], %[[G]] : vector<2x[3]xi32>, vector<[3]xi32> +func.func @scan2d_mul_dim0_scalable(%arg0 : vector<2x[3]xi32>, %arg1 : vector<[3]xi32>) -> (vector<2x[3]xi32>, vector<[3]xi32>) { + %0:2 = vector.scan , %arg0, %arg1 {inclusive = true, reduction_dim = 0} : + vector<2x[3]xi32>, vector<[3]xi32> + return %0#0, %0#1 : vector<2x[3]xi32>, vector<[3]xi32> +} + +// ----- + // CHECK-LABEL: func @scan2d_mul_dim1 // CHECK-SAME: %[[ARG0:.*]]: vector<2x3xi32>, // CHECK-SAME: %[[ARG1:.*]]: vector<2xi32> @@ -73,6 +122,30 @@ func.func @scan2d_mul_dim1(%arg0 : vector<2x3xi32>, %arg1 : vector<2xi32>) -> (v return %0#0, %0#1 : vector<2x3xi32>, vector<2xi32> } +// ----- + +// CHECK-LABEL: func @scan2d_mul_dim1_scalable +// CHECK-SAME: %[[ARG0:.*]]: vector<[2]x3xi32>, +// CHECK-SAME: %[[ARG1:.*]]: vector<[2]xi32> +// CHECK: %[[A:.*]] = arith.constant dense<0> : vector<[2]x3xi32> +// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} : vector<[2]x3xi32> to vector<[2]x1xi32> +// CHECK: %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0, 0], strides = [1, 1]} : vector<[2]x1xi32> into vector<[2]x3xi32> +// CHECK: %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : vector<[2]x3xi32> to vector<[2]x1xi32> +// CHECK: %[[E:.*]] = arith.muli %[[B]], %[[D]] : vector<[2]x1xi32> +// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [0, 1], strides = [1, 1]} : vector<[2]x1xi32> into vector<[2]x3xi32> +// CHECK: %[[G:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]} : vector<[2]x3xi32> to vector<[2]x1xi32> +// CHECK: %[[H:.*]] = arith.muli %[[E]], %[[G]] : vector<[2]x1xi32> +// CHECK: %[[I:.*]] = vector.insert_strided_slice %[[H]], %[[F]] {offsets = [0, 2], strides = [1, 1]} : vector<[2]x1xi32> into vector<[2]x3xi32> +// CHECK: %[[J:.*]] = vector.shape_cast %[[H]] : vector<[2]x1xi32> to vector<[2]xi32> +// CHECK: return %[[I]], %[[J]] : vector<[2]x3xi32>, vector<[2]xi32> +func.func @scan2d_mul_dim1_scalable(%arg0 : vector<[2]x3xi32>, %arg1 : vector<[2]xi32>) -> (vector<[2]x3xi32>, vector<[2]xi32>) { + %0:2 = vector.scan , %arg0, %arg1 {inclusive = true, reduction_dim = 1} : + vector<[2]x3xi32>, vector<[2]xi32> + return %0#0, %0#1 : vector<[2]x3xi32>, vector<[2]xi32> +} + +// ----- + // CHECK-LABEL: func @scan3d_mul_dim1 // CHECK-SAME: %[[ARG0:.*]]: vector<4x2x3xf32>, // CHECK-SAME: %[[ARG1:.*]]: vector<4x3xf32> @@ -89,3 +162,22 @@ func.func @scan3d_mul_dim1(%arg0 : vector<4x2x3xf32>, %arg1 : vector<4x3xf32>) - vector<4x2x3xf32>, vector<4x3xf32> return %0#0, %0#1 : vector<4x2x3xf32>, vector<4x3xf32> } + +// ----- + +// CHECK-LABEL: func @scan3d_mul_dim1_scalable +// CHECK-SAME: %[[ARG0:.*]]: vector<4x2x[3]xf32>, +// CHECK-SAME: %[[ARG1:.*]]: vector<4x[3]xf32> +// CHECK: %[[A:.*]] = arith.constant dense<0.000000e+00> : vector<4x2x[3]xf32> +// CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x[3]xf32> to vector<4x1x[3]xf32> +// CHECK: %[[C:.*]] = vector.shape_cast %[[ARG1]] : vector<4x[3]xf32> to vector<4x1x[3]xf32> +// CHECK: %[[D:.*]] = vector.insert_strided_slice %[[C]], %[[A]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x[3]xf32> into vector<4x2x[3]xf32> +// CHECK: %[[E:.*]] = arith.mulf %[[C]], %[[B]] : vector<4x1x[3]xf32> +// CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[D]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x[3]xf32> into vector<4x2x[3]xf32> +// CHECK: %[[G:.*]] = vector.shape_cast %[[E]] : vector<4x1x[3]xf32> to vector<4x[3]xf32> +// CHECK: return %[[F]], %[[G]] : vector<4x2x[3]xf32>, vector<4x[3]xf32> +func.func @scan3d_mul_dim1_scalable(%arg0 : vector<4x2x[3]xf32>, %arg1 : vector<4x[3]xf32>) -> (vector<4x2x[3]xf32>, vector<4x[3]xf32>) { + %0:2 = vector.scan , %arg0, %arg1 {inclusive = false, reduction_dim = 1} : + vector<4x2x[3]xf32>, vector<4x[3]xf32> + return %0#0, %0#1 : vector<4x2x[3]xf32>, vector<4x[3]xf32> +}