[mlir][vector] Add support for linearizing Extract, ExtractStridedSlice, Shuffle VectorOps in VectorLinearize (#88204)
This PR adds support for converting `vector.extract_strided_slice` and `vector.extract` operations to equivalent `vector.shuffle` operations that operates on linearized (1-D) vectors. `vector.shuffle` operations operating on n-D (n > 1) are also converted to equivalent shuffle operations working on linearized vectors.
This commit is contained in:
parent
44713f15f9
commit
c577f91d26
@ -389,6 +389,13 @@ void populateVectorLinearizeTypeConversionsAndLegality(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, unsigned targetBitWidth);
|
||||
|
||||
/// Populates patterns for linearizing ND (N >= 2) vector operations to 1D
|
||||
/// vector shuffle operations.
|
||||
void populateVectorLinearizeShuffleLikeOpsPatterns(TypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
ConversionTarget &target,
|
||||
unsigned targetBitWidth);
|
||||
|
||||
} // namespace vector
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -13,9 +13,16 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include <cstdint>
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
@ -103,6 +110,251 @@ public:
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned targetVectorBitWidth;
|
||||
};
|
||||
|
||||
/// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
|
||||
/// on a linearized vector.
|
||||
/// Following,
|
||||
/// vector.extract_strided_slice %source
|
||||
/// { offsets = [..], strides = [..], sizes = [..] }
|
||||
/// is converted to :
|
||||
/// %source_1d = vector.shape_cast %source
|
||||
/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
|
||||
/// %out_nd = vector.shape_cast %out_1d
|
||||
/// `shuffle_indices_1d` is computed using the offsets and sizes of the
|
||||
/// extraction.
|
||||
struct LinearizeVectorExtractStridedSlice final
|
||||
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LinearizeVectorExtractStridedSlice(
|
||||
const TypeConverter &typeConverter, MLIRContext *context,
|
||||
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
|
||||
PatternBenefit benefit = 1)
|
||||
: OpConversionPattern(typeConverter, context, benefit),
|
||||
targetVectorBitWidth(targetVectBitWidth) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type dstType = getTypeConverter()->convertType(extractOp.getType());
|
||||
assert(!(extractOp.getVector().getType().isScalable() ||
|
||||
dstType.cast<VectorType>().isScalable()) &&
|
||||
"scalable vectors are not supported.");
|
||||
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
|
||||
return rewriter.notifyMatchFailure(
|
||||
extractOp, "Can't flatten since targetBitWidth <= OpSize");
|
||||
|
||||
ArrayAttr offsets = extractOp.getOffsets();
|
||||
ArrayAttr sizes = extractOp.getSizes();
|
||||
ArrayAttr strides = extractOp.getStrides();
|
||||
if (!isConstantIntValue(strides[0], 1))
|
||||
return rewriter.notifyMatchFailure(
|
||||
extractOp, "Strided slice with stride != 1 is not supported.");
|
||||
Value srcVector = adaptor.getVector();
|
||||
// If kD offsets are specified for nD source vector (n > k), the granularity
|
||||
// of the extraction is greater than 1. In this case last (n-k) dimensions
|
||||
// form the extraction granularity.
|
||||
// Example :
|
||||
// vector.extract_strided_slice %src {
|
||||
// offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
|
||||
// vector<4x8x8xf32> to vector<2x2x8xf32>
|
||||
// Here, extraction granularity is 8.
|
||||
int64_t extractGranularitySize = 1;
|
||||
int64_t nD = extractOp.getSourceVectorType().getRank();
|
||||
int64_t kD = (int64_t)offsets.size();
|
||||
int64_t k = kD;
|
||||
while (k < nD) {
|
||||
extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
|
||||
++k;
|
||||
}
|
||||
// Get total number of extracted slices.
|
||||
int64_t nExtractedSlices = 1;
|
||||
for (Attribute size : sizes) {
|
||||
nExtractedSlices *= size.cast<IntegerAttr>().getInt();
|
||||
}
|
||||
// Compute the strides of the source vector considering first k dimensions.
|
||||
llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
|
||||
for (int i = kD - 2; i >= 0; --i) {
|
||||
sourceStrides[i] = sourceStrides[i + 1] *
|
||||
extractOp.getSourceVectorType().getShape()[i + 1];
|
||||
}
|
||||
// Final shuffle indices has nExtractedSlices * extractGranularitySize
|
||||
// elements.
|
||||
llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
|
||||
extractGranularitySize);
|
||||
// Compute the strides of the extracted kD vector.
|
||||
llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
|
||||
// Compute extractedStrides.
|
||||
for (int i = kD - 2; i >= 0; --i) {
|
||||
extractedStrides[i] =
|
||||
extractedStrides[i + 1] * sizes[i + 1].cast<IntegerAttr>().getInt();
|
||||
}
|
||||
// Iterate over all extracted slices from 0 to nExtractedSlices - 1
|
||||
// and compute the multi-dimensional index and the corresponding linearized
|
||||
// index within the source vector.
|
||||
for (int64_t i = 0; i < nExtractedSlices; ++i) {
|
||||
int64_t index = i;
|
||||
// Compute the corresponding multi-dimensional index.
|
||||
llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
|
||||
for (int64_t j = 0; j < kD; ++j) {
|
||||
multiDimIndex[j] = (index / extractedStrides[j]);
|
||||
index -= multiDimIndex[j] * extractedStrides[j];
|
||||
}
|
||||
// Compute the corresponding linearized index in the source vector
|
||||
// i.e. shift the multiDimIndex by the offsets.
|
||||
int64_t linearizedIndex = 0;
|
||||
for (int64_t j = 0; j < kD; ++j) {
|
||||
linearizedIndex +=
|
||||
(offsets[j].cast<IntegerAttr>().getInt() + multiDimIndex[j]) *
|
||||
sourceStrides[j];
|
||||
}
|
||||
// Fill the indices array form linearizedIndex to linearizedIndex +
|
||||
// extractGranularitySize.
|
||||
for (int64_t j = 0; j < extractGranularitySize; ++j) {
|
||||
indices[i * extractGranularitySize + j] = linearizedIndex + j;
|
||||
}
|
||||
}
|
||||
// Perform a shuffle to extract the kD vector.
|
||||
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
|
||||
extractOp, dstType, srcVector, srcVector,
|
||||
rewriter.getI64ArrayAttr(indices));
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned targetVectorBitWidth;
|
||||
};
|
||||
|
||||
/// This pattern converts the ShuffleOp that works on nD (n > 1)
|
||||
/// vectors to a ShuffleOp that works on linearized vectors.
|
||||
/// Following,
|
||||
/// vector.shuffle %v1, %v2 [ shuffle_indices ]
|
||||
/// is converted to :
|
||||
/// %v1_1d = vector.shape_cast %v1
|
||||
/// %v2_1d = vector.shape_cast %v2
|
||||
/// %out_1d = vector.shuffle %v1_1d, %v2_1d [ shuffle_indices_1d ]
|
||||
/// %out_nd = vector.shape_cast %out_1d
|
||||
// `shuffle_indices_1d` is computed using the sizes and `shuffle_indices`
|
||||
/// of the original shuffle operation.
|
||||
struct LinearizeVectorShuffle final
|
||||
: public OpConversionPattern<vector::ShuffleOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LinearizeVectorShuffle(
|
||||
const TypeConverter &typeConverter, MLIRContext *context,
|
||||
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
|
||||
PatternBenefit benefit = 1)
|
||||
: OpConversionPattern(typeConverter, context, benefit),
|
||||
targetVectorBitWidth(targetVectBitWidth) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
|
||||
assert(!(shuffleOp.getV1VectorType().isScalable() ||
|
||||
shuffleOp.getV2VectorType().isScalable() ||
|
||||
dstType.cast<VectorType>().isScalable()) &&
|
||||
"scalable vectors are not supported.");
|
||||
if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
|
||||
return rewriter.notifyMatchFailure(
|
||||
shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
|
||||
|
||||
Value vec1 = adaptor.getV1();
|
||||
Value vec2 = adaptor.getV2();
|
||||
int shuffleSliceLen = 1;
|
||||
int rank = shuffleOp.getV1().getType().getRank();
|
||||
|
||||
// If rank > 1, we need to do the shuffle in the granularity of slices
|
||||
// instead of scalars. Size of the slice is equal to the rank-1 innermost
|
||||
// dims. Mask of the shuffle op specifies which slice to take from the
|
||||
// outermost dim.
|
||||
if (rank > 1) {
|
||||
llvm::ArrayRef<int64_t> shape = shuffleOp.getV1().getType().getShape();
|
||||
for (unsigned i = 1; i < shape.size(); ++i) {
|
||||
shuffleSliceLen *= shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
// For each value in the mask, we generate the indices of the source vectors
|
||||
// that needs to be shuffled to the destination vector. If shuffleSliceLen >
|
||||
// 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
|
||||
// elements) instead of scalars.
|
||||
ArrayAttr mask = shuffleOp.getMask();
|
||||
int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
|
||||
llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
|
||||
for (auto [i, value] :
|
||||
llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
|
||||
|
||||
int64_t v = value.getZExtValue();
|
||||
std::iota(indices.begin() + shuffleSliceLen * i,
|
||||
indices.begin() + shuffleSliceLen * (i + 1),
|
||||
shuffleSliceLen * v);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
|
||||
shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned targetVectorBitWidth;
|
||||
};
|
||||
|
||||
/// This pattern converts the ExtractOp to a ShuffleOp that works on a
|
||||
/// linearized vector.
|
||||
/// Following,
|
||||
/// vector.extract %source [ position ]
|
||||
/// is converted to :
|
||||
/// %source_1d = vector.shape_cast %source
|
||||
/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
|
||||
/// %out_nd = vector.shape_cast %out_1d
|
||||
/// `shuffle_indices_1d` is computed using the position of the original extract.
|
||||
struct LinearizeVectorExtract final
|
||||
: public OpConversionPattern<vector::ExtractOp> {
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LinearizeVectorExtract(
|
||||
const TypeConverter &typeConverter, MLIRContext *context,
|
||||
unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
|
||||
PatternBenefit benefit = 1)
|
||||
: OpConversionPattern(typeConverter, context, benefit),
|
||||
targetVectorBitWidth(targetVectBitWidth) {}
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
|
||||
assert(!(extractOp.getVector().getType().isScalable() ||
|
||||
dstTy.cast<VectorType>().isScalable()) &&
|
||||
"scalable vectors are not supported.");
|
||||
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
|
||||
return rewriter.notifyMatchFailure(
|
||||
extractOp, "Can't flatten since targetBitWidth <= OpSize");
|
||||
|
||||
// Dynamic position is not supported.
|
||||
if (extractOp.hasDynamicPosition())
|
||||
return rewriter.notifyMatchFailure(extractOp,
|
||||
"dynamic position is not supported.");
|
||||
|
||||
llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
|
||||
int64_t size = extractOp.getVector().getType().getNumElements();
|
||||
|
||||
// Compute linearized offset.
|
||||
int64_t linearizedOffset = 0;
|
||||
llvm::ArrayRef<int64_t> offsets = extractOp.getStaticPosition();
|
||||
for (auto [i, off] : llvm::enumerate(offsets)) {
|
||||
size /= shape[i];
|
||||
linearizedOffset += offsets[i] * size;
|
||||
}
|
||||
|
||||
llvm::SmallVector<int64_t, 2> indices(size);
|
||||
std::iota(indices.begin(), indices.end(), linearizedOffset);
|
||||
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
|
||||
extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
|
||||
rewriter.getI64ArrayAttr(indices));
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned targetVectorBitWidth;
|
||||
};
|
||||
@ -145,3 +397,21 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
|
||||
patterns.add<LinearizeConstant, LinearizeVectorizable>(
|
||||
typeConverter, patterns.getContext(), targetBitWidth);
|
||||
}
|
||||
|
||||
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
|
||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
ConversionTarget &target, unsigned int targetBitWidth) {
|
||||
target.addDynamicallyLegalOp<vector::ShuffleOp>(
|
||||
[=](vector::ShuffleOp shuffleOp) -> bool {
|
||||
return isLessThanTargetBitWidth(shuffleOp, targetBitWidth)
|
||||
? (typeConverter.isLegal(shuffleOp) &&
|
||||
shuffleOp.getResult()
|
||||
.getType()
|
||||
.cast<mlir::VectorType>()
|
||||
.getRank() == 1)
|
||||
: true;
|
||||
});
|
||||
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
|
||||
LinearizeVectorExtractStridedSlice>(
|
||||
typeConverter, patterns.getContext(), targetBitWidth);
|
||||
}
|
||||
|
@ -153,3 +153,95 @@ func.func @test_0d_vector() -> vector<f32> {
|
||||
// ALL: return %[[CST]]
|
||||
return %0 : vector<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// ALL-LABEL: test_extract_strided_slice_1
|
||||
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x8xf32>) -> vector<2x2xf32> {
|
||||
func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf32> {
|
||||
// DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
|
||||
// DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
|
||||
// DEFAULT-SAME: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
|
||||
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32>
|
||||
// DEFAULT: return %[[RES]] : vector<2x2xf32
|
||||
|
||||
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x8xf32> to vector<32xf32>
|
||||
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
|
||||
// BW-128-SAME: [4, 5, 12, 13] : vector<32xf32>, vector<32xf32>
|
||||
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<4xf32> to vector<2x2xf32>
|
||||
// BW-128: return %[[RES]] : vector<2x2xf32>
|
||||
|
||||
// BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ARG:.*]] {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
|
||||
// BW-0: return %[[RES]] : vector<2x2xf32>
|
||||
%0 = vector.extract_strided_slice %arg0 { sizes = [2, 2], strides = [1, 1], offsets = [0, 4]}
|
||||
: vector<4x8xf32> to vector<2x2xf32>
|
||||
return %0 : vector<2x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// ALL-LABEL: test_extract_strided_slice_2
|
||||
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<1x4x2xf32> {
|
||||
func.func @test_extract_strided_slice_2(%arg0 : vector<2x8x2xf32>) -> vector<1x4x2xf32> {
|
||||
// DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
|
||||
// DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
|
||||
// DEFAULT-SAME: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
|
||||
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32>
|
||||
// DEFAULT: return %[[RES]] : vector<1x4x2xf32>
|
||||
|
||||
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
|
||||
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
|
||||
// BW-128-SAME: [20, 21, 22, 23, 24, 25, 26, 27] : vector<32xf32>, vector<32xf32>
|
||||
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<8xf32> to vector<1x4x2xf32>
|
||||
// BW-128: return %[[RES]] : vector<1x4x2xf32>
|
||||
|
||||
// BW-0: %[[RES:.*]] = vector.extract_strided_slice %[[ORIG_ARG]] {offsets = [1, 2], sizes = [1, 4], strides = [1, 1]} : vector<2x8x2xf32> to vector<1x4x2xf32>
|
||||
// BW-0: return %[[RES]] : vector<1x4x2xf32>
|
||||
%0 = vector.extract_strided_slice %arg0 { offsets = [1, 2], strides = [1, 1], sizes = [1, 4] }
|
||||
: vector<2x8x2xf32> to vector<1x4x2xf32>
|
||||
return %0 : vector<1x4x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// ALL-LABEL: test_vector_shuffle
|
||||
// ALL-SAME: (%[[ORIG_ARG0:.*]]: vector<4x2xf32>, %[[ORIG_ARG1:.*]]: vector<4x2xf32>) -> vector<8x2xf32> {
|
||||
func.func @test_vector_shuffle(%arg0: vector<4x2xf32>, %arg1: vector<4x2xf32>) -> vector<8x2xf32> {
|
||||
// DEFAULT: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32>
|
||||
// DEFAULT: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32>
|
||||
// DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]]
|
||||
// DEFAULT-SAME: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
|
||||
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
|
||||
// DEFAULT: return %[[RES]] : vector<8x2xf32>
|
||||
|
||||
// BW-128: %[[ARG0:.*]] = vector.shape_cast %[[ORIG_ARG0]] : vector<4x2xf32> to vector<8xf32>
|
||||
// BW-128: %[[ARG1:.*]] = vector.shape_cast %[[ORIG_ARG1]] : vector<4x2xf32> to vector<8xf32>
|
||||
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG0]], %[[ARG1]]
|
||||
// BW-128-SAME: [0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15] : vector<8xf32>, vector<8xf32>
|
||||
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
|
||||
// BW-128: return %[[RES]] : vector<8x2xf32>
|
||||
|
||||
// BW-0: %[[RES:.*]] = vector.shuffle %[[ORIG_ARG0]], %[[ORIG_ARG1]] [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32>
|
||||
// BW-0: return %[[RES]] : vector<8x2xf32>
|
||||
%0 = vector.shuffle %arg0, %arg1 [0, 4, 1, 5, 2, 6, 3, 7] : vector<4x2xf32>, vector<4x2xf32>
|
||||
return %0 : vector<8x2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// ALL-LABEL: test_vector_extract
|
||||
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<2x8x2xf32>) -> vector<8x2xf32> {
|
||||
func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
|
||||
// DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
|
||||
// DEFAULT: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
|
||||
// DEFAULT-SAME: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
|
||||
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
|
||||
// DEFAULT: return %[[RES]] : vector<8x2xf32>
|
||||
|
||||
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8x2xf32> to vector<32xf32>
|
||||
// BW-128: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
|
||||
// BW-128-SAME: [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<32xf32>, vector<32xf32>
|
||||
// BW-128: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
|
||||
// BW-128: return %[[RES]] : vector<8x2xf32>
|
||||
|
||||
// BW-0: %[[RES:.*]] = vector.extract %[[ORIG_ARG]][1] : vector<8x2xf32> from vector<2x8x2xf32>
|
||||
// BW-0: return %[[RES]] : vector<8x2xf32>
|
||||
%0 = vector.extract %arg0[1]: vector<8x2xf32> from vector<2x8x2xf32>
|
||||
return %0 : vector<8x2xf32>
|
||||
}
|
||||
|
@ -867,6 +867,8 @@ struct TestVectorLinearize final
|
||||
|
||||
vector::populateVectorLinearizeTypeConversionsAndLegality(
|
||||
typeConverter, patterns, target, targetVectorBitwidth);
|
||||
vector::populateVectorLinearizeShuffleLikeOpsPatterns(
|
||||
typeConverter, patterns, target, targetVectorBitwidth);
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
Loading…
x
Reference in New Issue
Block a user