[MLIR][Vector] Remove vector.splat (#162167)
vector.splat has been deprecated (user: please use the very similar vector.broadcast instead) with the last PR landing about 6 weeks ago. The discourse discussion is at https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/1 The last PR was #152230 This PR completely removes vector.splat. In addition to removing vector.splat from VectorOps.td, it - Updates the few remaining places where vector::SplatOp is created (now vector::BroadcastOp is created) - Removes temporary patterns where vector.splat is replaced by vector.broadcast The only place 'vector.splat' appears is now the files https://github.com/llvm/llvm-project/blob/main/mlir/utils/tree-sitter-mlir/test/corpus/op.txt and https://github.com/llvm/llvm-project/blob/main/mlir/utils/tree-sitter-mlir/dialect/vector.js --------- Signed-off-by: James Newling <james.newling@gmail.com>
This commit is contained in:
parent
6ed18d8525
commit
ea291d0e8c
@ -2383,7 +2383,7 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
|
||||
auto context{builder.getContext()};
|
||||
auto argBases{getBasesForArgs(args)};
|
||||
|
||||
mlir::vector::SplatOp splatOp{nullptr};
|
||||
mlir::vector::BroadcastOp splatOp{nullptr};
|
||||
mlir::Type retTy{nullptr};
|
||||
switch (vop) {
|
||||
case VecOp::Splat: {
|
||||
@ -2391,9 +2391,9 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
|
||||
auto vecTyInfo{getVecTypeFromFir(argBases[0])};
|
||||
|
||||
auto extractOp{genVecExtract(resultType, args)};
|
||||
splatOp =
|
||||
mlir::vector::SplatOp::create(builder, loc, *(extractOp.getUnboxed()),
|
||||
vecTyInfo.toMlirVectorType(context));
|
||||
splatOp = mlir::vector::BroadcastOp::create(
|
||||
builder, loc, vecTyInfo.toMlirVectorType(context),
|
||||
*(extractOp.getUnboxed()));
|
||||
retTy = vecTyInfo.toFirVectorType();
|
||||
break;
|
||||
}
|
||||
@ -2401,8 +2401,8 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
|
||||
assert(args.size() == 1);
|
||||
auto vecTyInfo{getVecTypeFromEle(argBases[0])};
|
||||
|
||||
splatOp = mlir::vector::SplatOp::create(
|
||||
builder, loc, argBases[0], vecTyInfo.toMlirVectorType(context));
|
||||
splatOp = mlir::vector::BroadcastOp::create(
|
||||
builder, loc, vecTyInfo.toMlirVectorType(context), argBases[0]);
|
||||
retTy = vecTyInfo.toFirVectorType();
|
||||
break;
|
||||
}
|
||||
@ -2412,8 +2412,8 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
|
||||
auto intOp{builder.createConvert(loc, eleTy, argBases[0])};
|
||||
|
||||
// the intrinsic always returns vector(integer(4))
|
||||
splatOp = mlir::vector::SplatOp::create(builder, loc, intOp,
|
||||
mlir::VectorType::get(4, eleTy));
|
||||
splatOp = mlir::vector::BroadcastOp::create(
|
||||
builder, loc, mlir::VectorType::get(4, eleTy), intOp);
|
||||
retTy = fir::VectorType::get(4, eleTy);
|
||||
break;
|
||||
}
|
||||
@ -2444,7 +2444,8 @@ PPCIntrinsicLibrary::genVecXlds(mlir::Type resultType,
|
||||
auto addrConv{fir::ConvertOp::create(builder, loc, i64RefTy, addr)};
|
||||
|
||||
auto addrVal{fir::LoadOp::create(builder, loc, addrConv)};
|
||||
auto splatRes{mlir::vector::SplatOp::create(builder, loc, addrVal, i64VecTy)};
|
||||
auto splatRes{
|
||||
mlir::vector::BroadcastOp::create(builder, loc, i64VecTy, addrVal)};
|
||||
|
||||
mlir::Value result{nullptr};
|
||||
if (mlirTy != splatRes.getType()) {
|
||||
|
||||
@ -125,7 +125,7 @@ Some existing Arith and Vector Dialect on `n-D` `vector` types comprise:
|
||||
// Produces a vector<3x7x8xf32>
|
||||
%b = arith.mulf %0, %1 : vector<3x7x8xf32>
|
||||
// Produces a vector<3x7x8xf32>
|
||||
%c = vector.splat %1 : vector<3x7x8xf32>
|
||||
%c = vector.broadcast %1 : f32 to vector<3x7x8xf32>
|
||||
|
||||
%d = vector.extract %0[1]: vector<7x8xf32> from vector<3x7x8xf32>
|
||||
%e = vector.extract %0[1, 5]: vector<8xf32> from vector<3x7x8xf32>
|
||||
@ -176,8 +176,6 @@ infrastructure can apply iteratively.
|
||||
### Virtual Vector to Hardware Vector Lowering
|
||||
|
||||
For now, `VV -> HWV` are specified in C++ (see for instance the
|
||||
[SplatOpLowering for n-D vectors](https://github.com/tensorflow/mlir/commit/0a0c4867c6a6fcb0a2f17ef26a791c1d551fe33d)
|
||||
or the
|
||||
[VectorOuterProductOp lowering](https://github.com/tensorflow/mlir/commit/957b1ca9680b4aacabb3a480fbc4ebd2506334b8)).
|
||||
|
||||
Simple
|
||||
|
||||
@ -2881,53 +2881,6 @@ def Vector_PrintOp :
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SplatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Vector_SplatOp : Vector_Op<"splat", [
|
||||
Pure,
|
||||
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
|
||||
TypesMatchWith<"operand type matches element type of result",
|
||||
"aggregate", "input",
|
||||
"::llvm::cast<VectorType>($_self).getElementType()">
|
||||
]> {
|
||||
let summary = "vector splat or broadcast operation";
|
||||
let description = [{
|
||||
Note: This operation is deprecated. Please use vector.broadcast.
|
||||
|
||||
Broadcast the operand to all elements of the result vector. The type of the
|
||||
operand must match the element type of the vector type.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%s = arith.constant 10.1 : f32
|
||||
%t = vector.splat %s : vector<8x16xf32>
|
||||
```
|
||||
|
||||
This operation is deprecated, the preferred representation of the above is:
|
||||
|
||||
```mlir
|
||||
%s = arith.constant 10.1 : f32
|
||||
%t = vector.broadcast %s : f32 to vector<8x16xf32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyType:$input);
|
||||
let results = (outs AnyVectorOfAnyRank:$aggregate);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$element, "Type":$aggregateType),
|
||||
[{ build($_builder, $_state, aggregateType, element); }]>];
|
||||
let assemblyFormat = "$input attr-dict `:` type($aggregate)";
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
// vector.splat is deprecated, and vector.broadcast should be used instead.
|
||||
// Canonicalize vector.splat to vector.broadcast.
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorScaleOp
|
||||
|
||||
@ -432,10 +432,6 @@ static Value getOriginalVectorValue(Value value) {
|
||||
current = op.getSource();
|
||||
return false;
|
||||
})
|
||||
.Case<vector::SplatOp>([¤t](auto op) {
|
||||
current = op.getInput();
|
||||
return false;
|
||||
})
|
||||
.Default([](Operation *) { return false; });
|
||||
|
||||
if (!skipOp) {
|
||||
|
||||
@ -236,7 +236,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
|
||||
/// AFTER:
|
||||
/// ```mlir
|
||||
/// ...
|
||||
/// %pad_1d = vector.splat %pad : vector<[4]xi32>
|
||||
/// %pad_1d = vector.broadcast %pad : i32 to vector<[4]xi32>
|
||||
/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
|
||||
/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
|
||||
/// ...
|
||||
|
||||
@ -731,28 +731,14 @@ struct ExtractFromCreateMaskToPselLowering
|
||||
}
|
||||
};
|
||||
|
||||
// Convert all `vector.splat` to `vector.broadcast`. There is a path from
|
||||
// `vector.broadcast` to ArmSME via another pattern.
|
||||
struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
|
||||
using Base::Base;
|
||||
|
||||
LogicalResult matchAndRewrite(vector::SplatOp splatOp,
|
||||
PatternRewriter &rewriter) const final {
|
||||
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
|
||||
splatOp.getInput());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext &ctx) {
|
||||
patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
|
||||
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
|
||||
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
|
||||
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
|
||||
patterns.add<BroadcastOpToArmSMELowering, TransferReadToArmSMELowering,
|
||||
TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
|
||||
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
|
||||
VectorOuterProductToArmSMELowering,
|
||||
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
|
||||
VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
|
||||
ExtractFromCreateMaskToPselLowering>(&ctx);
|
||||
|
||||
@ -2161,19 +2161,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
|
||||
/// `vector.broadcast` through other patterns.
|
||||
struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
|
||||
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
|
||||
adaptor.getInput());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::vector::populateVectorRankReducingFMAPattern(
|
||||
@ -2212,7 +2199,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
||||
VectorInsertOpConversion, VectorPrintOpConversion,
|
||||
VectorTypeCastOpConversion, VectorScaleOpConversion,
|
||||
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
|
||||
VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
|
||||
VectorBroadcastScalarToLowRankLowering,
|
||||
VectorBroadcastScalarToNdLowering,
|
||||
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
|
||||
MaskedReductionOpConversion, VectorInterleaveOpLowering,
|
||||
|
||||
@ -22,7 +22,6 @@
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
@ -79,20 +78,6 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Convert `vector.splat` to `vector.broadcast`. There is a path from
|
||||
// `vector.broadcast` to SPIRV via other patterns.
|
||||
struct VectorSplatToBroadcast final
|
||||
: public OpConversionPattern<vector::SplatOp> {
|
||||
using Base::Base;
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
|
||||
adaptor.getInput());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct VectorBitcastConvert final
|
||||
: public OpConversionPattern<vector::BitCastOp> {
|
||||
using Base::Base;
|
||||
@ -1092,10 +1077,10 @@ void mlir::populateVectorToSPIRVPatterns(
|
||||
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
|
||||
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
|
||||
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
|
||||
VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert,
|
||||
VectorShuffleOpConvert, VectorInterleaveOpConvert,
|
||||
VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern,
|
||||
VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>(
|
||||
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
|
||||
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
|
||||
VectorScalarBroadcastPattern, VectorLoadOpConverter,
|
||||
VectorStoreOpConverter, VectorStepOpConvert>(
|
||||
typeConverter, patterns.getContext(), PatternBenefit(1));
|
||||
|
||||
// Make sure that the more specialized dot product pattern has higher benefit
|
||||
|
||||
@ -123,8 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
|
||||
vector::OuterProductOp, vector::ScanOp>(
|
||||
[&](Operation *op) { return converter.isLegal(op); });
|
||||
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
|
||||
arith::ConstantOp, arith::SelectOp, vector::SplatOp,
|
||||
vector::BroadcastOp>();
|
||||
arith::ConstantOp, arith::SelectOp, vector::BroadcastOp>();
|
||||
}
|
||||
|
||||
void EmulateUnsupportedFloatsPass::runOnOperation() {
|
||||
|
||||
@ -1665,10 +1665,10 @@ static bool hasZeroDimVectors(Operation *op) {
|
||||
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
|
||||
}
|
||||
|
||||
/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
|
||||
/// 1s, are considered to be 'broadcastlike'.
|
||||
/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
|
||||
/// considered to be 'broadcastlike'.
|
||||
static bool isBroadcastLike(Operation *op) {
|
||||
if (isa<BroadcastOp, SplatOp>(op))
|
||||
if (isa<BroadcastOp>(op))
|
||||
return true;
|
||||
|
||||
auto shapeCast = dyn_cast<ShapeCastOp>(op);
|
||||
@ -3249,12 +3249,11 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
|
||||
};
|
||||
|
||||
/// Consider the defining operation `defOp` of `value`. If `defOp` is a
|
||||
/// vector.splat or a vector.broadcast with a scalar operand, return the scalar
|
||||
/// value that is splatted. Otherwise return null.
|
||||
/// vector.broadcast with a scalar operand, return the scalar value that is
|
||||
/// splatted. Otherwise return null.
|
||||
///
|
||||
/// Examples:
|
||||
/// Example:
|
||||
///
|
||||
/// scalar_source --> vector.splat --> value - return scalar_source
|
||||
/// scalar_source --> vector.broadcast --> value - return scalar_source
|
||||
static Value getScalarSplatSource(Value value) {
|
||||
// Block argument:
|
||||
@ -3262,10 +3261,6 @@ static Value getScalarSplatSource(Value value) {
|
||||
if (!defOp)
|
||||
return {};
|
||||
|
||||
// Splat:
|
||||
if (auto splat = dyn_cast<vector::SplatOp>(defOp))
|
||||
return splat.getInput();
|
||||
|
||||
auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
|
||||
|
||||
// Not broadcast (and not splat):
|
||||
@ -7511,41 +7506,6 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
|
||||
patterns.getContext(), benefit);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SplatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
|
||||
auto constOperand = adaptor.getInput();
|
||||
if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
|
||||
return {};
|
||||
|
||||
// SplatElementsAttr::get treats single value for second arg as being a splat.
|
||||
return SplatElementsAttr::get(getType(), {constOperand});
|
||||
}
|
||||
|
||||
// Canonicalizer for vector.splat. It always gets canonicalized to a
|
||||
// vector.broadcast.
|
||||
class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
|
||||
public:
|
||||
using Base::Base;
|
||||
LogicalResult matchAndRewrite(SplatOp splatOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
|
||||
splatOp.getOperand());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<SplatToBroadcastPattern>(context);
|
||||
}
|
||||
|
||||
void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
|
||||
SetIntRangeFn setResultRanges) {
|
||||
setResultRanges(getResult(), argRanges.front());
|
||||
}
|
||||
|
||||
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
|
||||
CombiningKind kind, Value v1, Value acc,
|
||||
arith::FastMathFlagsAttr fastmath,
|
||||
|
||||
@ -90,7 +90,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
|
||||
|
||||
Operation *maskOp = mask.getDefiningOp();
|
||||
SmallVector<vector::ExtractOp, 2> extractOps;
|
||||
// TODO: add support to `vector.splat`.
|
||||
// TODO: add support to `vector.broadcast`.
|
||||
// Finding the mask creation operation.
|
||||
while (maskOp &&
|
||||
!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
|
||||
|
||||
@ -590,32 +590,6 @@ struct LinearizeVectorBitCast final
|
||||
}
|
||||
};
|
||||
|
||||
/// This pattern converts the SplatOp to work on a linearized vector.
|
||||
/// Following,
|
||||
/// vector.splat %value : vector<4x4xf32>
|
||||
/// is converted to:
|
||||
/// %out_1d = vector.splat %value : vector<16xf32>
|
||||
/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
|
||||
struct LinearizeVectorSplat final
|
||||
: public OpConversionPattern<vector::SplatOp> {
|
||||
using Base::Base;
|
||||
|
||||
LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
|
||||
PatternBenefit benefit = 1)
|
||||
: OpConversionPattern(typeConverter, context, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto dstTy = getTypeConverter()->convertType(splatOp.getType());
|
||||
if (!dstTy)
|
||||
return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
|
||||
rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
|
||||
dstTy);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// This pattern converts the CreateMaskOp to work on a linearized vector.
|
||||
/// It currently supports only 2D masks with a unit outer dimension.
|
||||
/// Following,
|
||||
@ -934,9 +908,9 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns
|
||||
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
|
||||
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
|
||||
LinearizeVectorStore, LinearizeVectorFromElements,
|
||||
LinearizeVectorToElements>(typeConverter, patterns.getContext());
|
||||
LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore,
|
||||
LinearizeVectorFromElements, LinearizeVectorToElements>(
|
||||
typeConverter, patterns.getContext());
|
||||
}
|
||||
|
||||
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
|
||||
|
||||
@ -878,7 +878,7 @@ struct BubbleUpBitCastForStridedSliceInsert
|
||||
// This transforms IR like:
|
||||
// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
|
||||
// Into:
|
||||
// %cst = vector.splat %c0_f32 : vector<4xf32>
|
||||
// %cst = vector.broadcast %c0_f32 : f32 to vector<4xf32>
|
||||
// %1 = vector.extract_strided_slice %0 {
|
||||
// offsets = [0], sizes = [4], strides = [1]
|
||||
// } : vector<8xf16> to vector<4xf16>
|
||||
@ -987,8 +987,8 @@ static Type cloneOrReplace(Type type, Type newElementType) {
|
||||
return newElementType;
|
||||
}
|
||||
|
||||
/// If `value` is the result of a splat or broadcast operation, return the input
|
||||
/// of the splat/broadcast operation.
|
||||
/// If `value` is the result of a broadcast operation, return the input
|
||||
/// of the broadcast operation.
|
||||
static Value getBroadcastLikeSource(Value value) {
|
||||
|
||||
Operation *op = value.getDefiningOp();
|
||||
@ -998,13 +998,10 @@ static Value getBroadcastLikeSource(Value value) {
|
||||
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
|
||||
return broadcast.getSource();
|
||||
|
||||
if (auto splat = dyn_cast<vector::SplatOp>(op))
|
||||
return splat.getInput();
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
|
||||
/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
@ -1017,9 +1014,6 @@ static Value getBroadcastLikeSource(Value value) {
|
||||
/// %r = arith.addi %arg0, %arg1 : index
|
||||
/// %b = vector.broadcast %r : index to vector<1x4xindex>
|
||||
/// ```
|
||||
///
|
||||
/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
|
||||
/// ops.
|
||||
struct ReorderElementwiseOpsOnBroadcast final
|
||||
: public OpTraitRewritePattern<OpTrait::Elementwise> {
|
||||
using OpTraitRewritePattern::OpTraitRewritePattern;
|
||||
@ -1045,29 +1039,29 @@ struct ReorderElementwiseOpsOnBroadcast final
|
||||
Type resultElemType = resultType.getElementType();
|
||||
|
||||
// Get the type of the first non-constant operand
|
||||
Value splatSource;
|
||||
Value broadcastSource;
|
||||
for (Value operand : op->getOperands()) {
|
||||
Operation *definingOp = operand.getDefiningOp();
|
||||
if (!definingOp)
|
||||
return failure();
|
||||
if (definingOp->hasTrait<OpTrait::ConstantLike>())
|
||||
continue;
|
||||
splatSource = getBroadcastLikeSource(operand);
|
||||
broadcastSource = getBroadcastLikeSource(operand);
|
||||
break;
|
||||
}
|
||||
if (!splatSource)
|
||||
if (!broadcastSource)
|
||||
return failure();
|
||||
Type unbroadcastResultType =
|
||||
cloneOrReplace(splatSource.getType(), resultElemType);
|
||||
cloneOrReplace(broadcastSource.getType(), resultElemType);
|
||||
|
||||
// Make sure that all operands are broadcast from identically-shaped types:
|
||||
// * scalar (`vector.broadcast` + `vector.splat`), or
|
||||
// * scalar (`vector.broadcast`), or
|
||||
// * vector (`vector.broadcast`).
|
||||
// Otherwise the re-ordering wouldn't be safe.
|
||||
if (!llvm::all_of(op->getOperands(), [splatSource](Value val) {
|
||||
if (!llvm::all_of(op->getOperands(), [broadcastSource](Value val) {
|
||||
if (auto source = getBroadcastLikeSource(val))
|
||||
return haveSameShapeAndScaling(source.getType(),
|
||||
splatSource.getType());
|
||||
broadcastSource.getType());
|
||||
SplatElementsAttr splatConst;
|
||||
return matchPattern(val, m_Constant(&splatConst));
|
||||
})) {
|
||||
@ -1271,19 +1265,18 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
|
||||
/// Pattern to rewrite vector.store(vector.broadcast) -> vector/memref.store.
|
||||
///
|
||||
/// Example:
|
||||
/// ```
|
||||
/// %0 = vector.splat %arg2 : vector<1xf32>
|
||||
/// %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
|
||||
/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
|
||||
/// ```
|
||||
/// Gets converted to:
|
||||
/// ```
|
||||
/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
|
||||
/// ```
|
||||
class StoreOpFromSplatOrBroadcast final
|
||||
: public OpRewritePattern<vector::StoreOp> {
|
||||
class StoreOpFromBroadcast final : public OpRewritePattern<vector::StoreOp> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
@ -1308,9 +1301,9 @@ public:
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "value to store is not from a broadcast");
|
||||
|
||||
// Checking for single use so we can remove splat.
|
||||
Operation *splat = toStore.getDefiningOp();
|
||||
if (!splat->hasOneUse())
|
||||
// Checking for single use so we can remove broadcast.
|
||||
Operation *broadcast = toStore.getDefiningOp();
|
||||
if (!broadcast->hasOneUse())
|
||||
return rewriter.notifyMatchFailure(op, "expected single op use");
|
||||
|
||||
Value base = op.getBase();
|
||||
@ -1321,7 +1314,7 @@ public:
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
|
||||
}
|
||||
rewriter.eraseOp(splat);
|
||||
rewriter.eraseOp(broadcast);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -2391,8 +2384,8 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
|
||||
void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
|
||||
PatternBenefit benefit) {
|
||||
// TODO: Consider converting these patterns to canonicalizations.
|
||||
patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
|
||||
patterns.getContext(), benefit);
|
||||
patterns.add<ExtractOpFromLoad, StoreOpFromBroadcast>(patterns.getContext(),
|
||||
benefit);
|
||||
}
|
||||
|
||||
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
|
||||
|
||||
@ -86,7 +86,7 @@ func.func @fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<1xf
|
||||
// CHECK: %[[VAL:.+]] = spirv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]]
|
||||
// CHECK: spirv.ReturnValue %[[VAL]] : vector<4xf32>
|
||||
func.func @splat(%f : f32) -> vector<4xf32> {
|
||||
%splat = vector.splat %f : vector<4xf32>
|
||||
%splat = vector.broadcast %f : f32 to vector<4xf32>
|
||||
return %splat : vector<4xf32>
|
||||
}
|
||||
|
||||
|
||||
@ -429,38 +429,6 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
|
||||
return
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// vector.splat
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @splat_vec2d_from_i32(
|
||||
// CHECK-SAME: %[[SRC:.*]]: i32) {
|
||||
// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
|
||||
// CHECK: arm_sme.get_tile : vector<[4]x[4]xi32>
|
||||
// CHECK: %[[VSCALE:.*]] = vector.vscale
|
||||
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
|
||||
// CHECK: scf.for {{.*}} to %[[NUM_TILE_SLICES]] {{.*}} {
|
||||
// CHECK: arm_sme.insert_tile_slice %[[BCST]], {{.*}} : vector<[4]xi32> into vector<[4]x[4]xi32>
|
||||
func.func @splat_vec2d_from_i32(%arg0: i32) {
|
||||
%0 = vector.splat %arg0 : vector<[4]x[4]xi32>
|
||||
"prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @splat_vec2d_from_f16(
|
||||
// CHECK-SAME: %[[SRC:.*]]: f16) {
|
||||
// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<[8]xf16>
|
||||
// CHECK: scf.for
|
||||
// CHECK: arm_sme.insert_tile_slice %[[BCST]], {{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16>
|
||||
func.func @splat_vec2d_from_f16(%arg0: f16) {
|
||||
%0 = vector.splat %arg0 : vector<[8]x[8]xf16>
|
||||
"prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// vector.transpose
|
||||
|
||||
@ -2216,23 +2216,6 @@ func.func @compress_store_op_with_alignment(%arg0: memref<?xindex>, %arg1: vecto
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// vector.splat
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// vector.splat is converted to vector.broadcast. Then, vector.broadcast is converted to LLVM.
|
||||
// CHECK-LABEL: @splat_0d
|
||||
// CHECK-NOT: splat
|
||||
// CHECK: return
|
||||
func.func @splat_0d(%elt: f32) -> (vector<f32>, vector<4xf32>, vector<[4]xf32>) {
|
||||
%a = vector.splat %elt : vector<f32>
|
||||
%b = vector.splat %elt : vector<4xf32>
|
||||
%c = vector.splat %elt : vector<[4]xf32>
|
||||
return %a, %b, %c : vector<f32>, vector<4xf32>, vector<[4]xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// vector.scalable_insert
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -105,9 +105,9 @@ func.func @ipowi32_fold(%result : memref<?xi32>) {
|
||||
|
||||
// --- Test vector folding ---
|
||||
%arg11_base = arith.constant 2 : i32
|
||||
%arg11_base_vec = vector.splat %arg11_base : vector<2x2xi32>
|
||||
%arg11_base_vec = vector.broadcast %arg11_base : i32 to vector<2x2xi32>
|
||||
%arg11_power = arith.constant 30 : i32
|
||||
%arg11_power_vec = vector.splat %arg11_power : vector<2x2xi32>
|
||||
%arg11_power_vec = vector.broadcast %arg11_power : i32 to vector<2x2xi32>
|
||||
%res11_vec = math.ipowi %arg11_base_vec, %arg11_power_vec : vector<2x2xi32>
|
||||
%i11 = arith.constant 11 : index
|
||||
%res11 = vector.extract %res11_vec[1, 1] : i32 from vector<2x2xi32>
|
||||
|
||||
@ -837,7 +837,7 @@ func.func @fold_extract_splatlike(%a : f32, %idx0 : index, %idx1 : index, %idx2
|
||||
// CHECK-LABEL: fold_extract_vector_from_splat
|
||||
// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
|
||||
func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
|
||||
%b = vector.splat %a : vector<1x2x4xf32>
|
||||
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
|
||||
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
|
||||
return %r : vector<4xf32>
|
||||
}
|
||||
|
||||
@ -1,126 +0,0 @@
|
||||
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
|
||||
|
||||
// This file should be removed when vector.splat is removed.
|
||||
// This file tests canonicalization/folding with vector.splat.
|
||||
// These tests all have equivalent tests using vector.broadcast in canonicalize.mlir
|
||||
|
||||
|
||||
// CHECK-LABEL: fold_extract_splat
|
||||
// CHECK-SAME: %[[A:.*]]: f32
|
||||
// CHECK: return %[[A]] : f32
|
||||
func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
|
||||
%b = vector.splat %a : vector<1x2x4xf32>
|
||||
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
|
||||
return %r : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: extract_strided_splat
|
||||
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16>
|
||||
// CHECK-NEXT: return %[[B]] : vector<2x4xf16>
|
||||
func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
|
||||
%0 = vector.splat %arg0 : vector<16x4xf16>
|
||||
%1 = vector.extract_strided_slice %0
|
||||
{offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} :
|
||||
vector<16x4xf16> to vector<2x4xf16>
|
||||
return %1 : vector<2x4xf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @splat_fold
|
||||
// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
|
||||
// CHECK-NEXT: return [[V]] : vector<4xf32>
|
||||
func.func @splat_fold() -> vector<4xf32> {
|
||||
%c = arith.constant 1.0 : f32
|
||||
%v = vector.splat %c : vector<4xf32>
|
||||
return %v : vector<4xf32>
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @transpose_splat2(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
|
||||
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32>
|
||||
// CHECK: return %[[VAL_1]] : vector<3x4xf32>
|
||||
func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
|
||||
%splat = vector.splat %arg : vector<4x3xf32>
|
||||
%0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
|
||||
return %0 : vector<3x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @insert_strided_slice_splat
|
||||
// CHECK-SAME: (%[[ARG:.*]]: f32)
|
||||
// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32>
|
||||
// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32>
|
||||
func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
|
||||
%splat0 = vector.splat %x : vector<4x4xf32>
|
||||
%splat1 = vector.splat %x : vector<8x16xf32>
|
||||
%0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
|
||||
: vector<4x4xf32> into vector<8x16xf32>
|
||||
return %0 : vector<8x16xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @shuffle_splat
|
||||
// CHECK-SAME: (%[[ARG:.*]]: i32)
|
||||
// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32>
|
||||
// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32>
|
||||
func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
|
||||
%v0 = vector.splat %x : vector<4xi32>
|
||||
%v1 = vector.splat %x : vector<2xi32>
|
||||
%shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32>
|
||||
return %shuffle : vector<4xi32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @insert_splat
|
||||
// CHECK-SAME: (%[[ARG:.*]]: i32)
|
||||
// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32>
|
||||
// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32>
|
||||
func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
|
||||
%v0 = vector.splat %x : vector<4x3xi32>
|
||||
%v1 = vector.splat %x : vector<2x4x3xi32>
|
||||
%insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32>
|
||||
return %insert : vector<2x4x3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression
|
||||
// CHECK-SAME: (%[[A:.*]]: f32, %[[C:.*]]: vector<2xf32>)
|
||||
func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %c: vector<2xf32>) -> (f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
|
||||
// Splat scalar to 0D and extract scalar.
|
||||
%0 = vector.splat %a : vector<f32>
|
||||
%1 = vector.extract %0[] : f32 from vector<f32>
|
||||
|
||||
// Broadcast scalar to 0D and extract scalar.
|
||||
%2 = vector.splat %a : vector<f32>
|
||||
%3 = vector.extract %2[] : f32 from vector<f32>
|
||||
|
||||
// Splat scalar to 2D and extract scalar.
|
||||
%6 = vector.splat %a : vector<2x3xf32>
|
||||
%7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
|
||||
|
||||
// Broadcast scalar to 3D and extract scalar.
|
||||
%8 = vector.splat %a : vector<5x6x7xf32>
|
||||
%9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32>
|
||||
|
||||
// Extract 2D from 3D that was broadcasted from a scalar.
|
||||
// CHECK: %[[EXTRACT2:.*]] = vector.broadcast %[[A]] : f32 to vector<6x7xf32>
|
||||
%10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32>
|
||||
|
||||
// Extract 1D from 2D that was splat'ed from a scalar.
|
||||
// CHECK: %[[EXTRACT3:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32>
|
||||
%11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32>
|
||||
|
||||
// CHECK: return %[[A]], %[[A]], %[[A]], %[[A]], %[[EXTRACT2]], %[[EXTRACT3]]
|
||||
return %1, %3, %7, %9, %10, %11 : f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
|
||||
}
|
||||
@ -28,7 +28,7 @@ func.func @float_constant_splat() -> vector<8xf32> {
|
||||
// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
|
||||
func.func @vector_splat() -> vector<4xindex> {
|
||||
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
|
||||
%1 = vector.splat %0 : vector<4xindex>
|
||||
%1 = vector.broadcast %0 : index to vector<4xindex>
|
||||
%2 = test.reflect_bounds %1 : vector<4xindex>
|
||||
func.return %2 : vector<4xindex>
|
||||
}
|
||||
|
||||
@ -320,7 +320,7 @@ func.func @test_vector.transfer_write(%m: memref<1xi32>, %2: vector<1x32xi32>)
|
||||
func.func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
|
||||
%c3 = arith.constant 3 : index
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
%vf0 = vector.splat %f0 : vector<4x3xf32>
|
||||
%vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
|
||||
// expected-error@+1 {{ requires memref or ranked tensor type}}
|
||||
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : vector<4x3xf32>, vector<1x1x2x3xf32>
|
||||
}
|
||||
@ -330,7 +330,7 @@ func.func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
|
||||
func.func @test_vector.transfer_read(%arg0: memref<4x3xf32>) {
|
||||
%c3 = arith.constant 3 : index
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
%vf0 = vector.splat %f0 : vector<4x3xf32>
|
||||
%vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
|
||||
// expected-error@+1 {{ requires vector type}}
|
||||
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<4x3xf32>, f32
|
||||
}
|
||||
@ -414,7 +414,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
|
||||
%c3 = arith.constant 3 : index
|
||||
%cst = arith.constant 3.0 : f32
|
||||
// expected-note@+1 {{prior use here}}
|
||||
%mask = vector.splat %c1 : vector<3x8x7xi1>
|
||||
%mask = vector.broadcast %c1 : i1 to vector<3x8x7xi1>
|
||||
// expected-error@+1 {{expects different type than prior uses: 'vector<3x7xi1>' vs 'vector<3x8x7xi1>'}}
|
||||
%0 = vector.transfer_read %arg0[%c3, %c3, %c3], %cst, %mask {permutation_map = affine_map<(d0, d1, d2)->(d0, 0, d2)>} : memref<?x?x?xf32>, vector<3x8x7xf32>
|
||||
}
|
||||
@ -424,7 +424,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
|
||||
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
|
||||
%c3 = arith.constant 3 : index
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
%vf0 = vector.splat %f0 : vector<4x3xf32>
|
||||
%vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
|
||||
// expected-error@+1 {{requires source vector element and vector result ranks to match}}
|
||||
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<3xf32>
|
||||
}
|
||||
@ -434,7 +434,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
|
||||
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
|
||||
%c3 = arith.constant 3 : index
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
%vf0 = vector.splat %f0 : vector<6xf32>
|
||||
%vf0 = vector.broadcast %f0 : f32 to vector<6xf32>
|
||||
// expected-error@+1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the source}}
|
||||
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<?x?xvector<6xf32>>, vector<3xf32>
|
||||
}
|
||||
@ -444,7 +444,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
|
||||
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
|
||||
%c3 = arith.constant 3 : index
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
%vf0 = vector.splat %f0 : vector<2x3xf32>
|
||||
%vf0 = vector.broadcast %f0 : f32 to vector<2x3xf32>
|
||||
// expected-error@+1 {{ expects the in_bounds attr of same rank as permutation_map results: affine_map<(d0, d1) -> (d0, d1)>}}
|
||||
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {in_bounds = [true], permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
|
||||
}
|
||||
@ -454,8 +454,8 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
|
||||
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
|
||||
%c3 = arith.constant 3 : index
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
%vf0 = vector.splat %f0 : vector<2x3xf32>
|
||||
%mask = vector.splat %c1 : vector<2x3xi1>
|
||||
%vf0 = vector.broadcast %f0 : f32 to vector<2x3xf32>
|
||||
%mask = vector.broadcast %c1 : f32 to vector<2x3xi1>
|
||||
// expected-error@+1 {{does not support masks with vector element type}}
|
||||
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0, %mask {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
|
||||
}
|
||||
@ -492,7 +492,7 @@ func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
|
||||
func.func @test_vector.transfer_write(%arg0: memref<vector<4x3xf32>>) {
|
||||
%c3 = arith.constant 3 : index
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
%vf0 = vector.splat %f0 : vector<4x3xf32>
|
||||
%vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
|
||||
// expected-error@+1 {{ requires vector type}}
|
||||
vector.transfer_write %arg0, %arg0[%c3, %c3] : memref<vector<4x3xf32>>, vector<4x3xf32>
|
||||
}
|
||||
@ -502,7 +502,7 @@ func.func @test_vector.transfer_write(%arg0: memref<vector<4x3xf32>>) {
|
||||
func.func @test_vector.transfer_write(%arg0: vector<4x3xf32>) {
|
||||
%c3 = arith.constant 3 : index
|
||||
%f0 = arith.constant 0.0 : f32
|
||||
%vf0 = vector.splat %f0 : vector<4x3xf32>
|
||||
%vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
|
||||
// expected-error@+1 {{ requires memref or ranked tensor type}}
|
||||
vector.transfer_write %arg0, %arg0[%c3, %c3] : vector<4x3xf32>, f32
|
||||
}
|
||||
@ -1980,29 +1980,6 @@ func.func @invalid_step_2d() {
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// vector.splat
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// -----
|
||||
|
||||
func.func @vector_splat_invalid_result(%v : f32) {
|
||||
// expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'memref<8xf32>'}}
|
||||
vector.splat %v : memref<8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-note @+1 {{prior use here}}
|
||||
func.func @vector_splat_type_mismatch(%a: f32) {
|
||||
// expected-error @+1 {{expects different type than prior uses: 'i32' vs 'f32'}}
|
||||
%0 = vector.splat %a : vector<1xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// vector.load
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -428,33 +428,6 @@ func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: linearize_vector_splat
|
||||
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
|
||||
func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
|
||||
|
||||
// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
|
||||
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
|
||||
// CHECK: return %[[CAST]] : vector<4x2xi32>
|
||||
%0 = vector.splat %arg0 : vector<4x2xi32>
|
||||
return %0 : vector<4x2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: linearize_scalable_vector_splat
|
||||
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32>
|
||||
func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
|
||||
|
||||
// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<[8]xi32>
|
||||
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<[8]xi32> to vector<4x[2]xi32>
|
||||
// CHECK: return %[[CAST]] : vector<4x[2]xi32>
|
||||
%0 = vector.splat %arg0 : vector<4x[2]xi32>
|
||||
return %0 : vector<4x[2]xi32>
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: linearize_create_mask
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1>
|
||||
func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {
|
||||
|
||||
@ -45,11 +45,11 @@ func.func @vector_transfer_ops(%arg0: memref<?x?xf32>,
|
||||
%i0 = arith.constant 0 : index
|
||||
%i1 = arith.constant 1 : i1
|
||||
|
||||
%vf0 = vector.splat %f0 : vector<4x3xf32>
|
||||
%v0 = vector.splat %c0 : vector<4x3xi32>
|
||||
%vi0 = vector.splat %i0 : vector<4x3xindex>
|
||||
%vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
|
||||
%v0 = vector.broadcast %c0 : i32 to vector<4x3xi32>
|
||||
%vi0 = vector.broadcast %i0 : index to vector<4x3xindex>
|
||||
%m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
|
||||
%m2 = vector.splat %i1 : vector<4x5xi1>
|
||||
%m2 = vector.broadcast %i1 : i1 to vector<4x5xi1>
|
||||
//
|
||||
// CHECK: vector.transfer_read
|
||||
%0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : memref<?x?xf32>, vector<128xf32>
|
||||
@ -106,9 +106,9 @@ func.func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
|
||||
%c0 = arith.constant 0 : i32
|
||||
%i0 = arith.constant 0 : index
|
||||
|
||||
%vf0 = vector.splat %f0 : vector<4x3xf32>
|
||||
%v0 = vector.splat %c0 : vector<4x3xi32>
|
||||
%vi0 = vector.splat %i0 : vector<4x3xindex>
|
||||
%vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
|
||||
%v0 = vector.broadcast %c0 : i32 to vector<4x3xi32>
|
||||
%vi0 = vector.broadcast %i0 : index to vector<4x3xindex>
|
||||
|
||||
//
|
||||
// CHECK: vector.transfer_read
|
||||
@ -922,28 +922,6 @@ func.func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> {
|
||||
return %2#0 : vector<4x8x16x32xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_splat_op
|
||||
// CHECK-SAME: %[[s:.*]]: f32, %[[s2:.*]]: !llvm.ptr<1>
|
||||
func.func @test_splat_op(%s : f32, %s2 : !llvm.ptr<1>) {
|
||||
// CHECK: vector.splat %[[s]] : vector<8xf32>
|
||||
%v = vector.splat %s : vector<8xf32>
|
||||
|
||||
// CHECK: vector.splat %[[s]] : vector<4xf32>
|
||||
%u = "vector.splat"(%s) : (f32) -> vector<4xf32>
|
||||
|
||||
// CHECK: vector.splat %[[s2]] : vector<16x!llvm.ptr<1>>
|
||||
%w = vector.splat %s2 : vector<16x!llvm.ptr<1>>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @vector_splat_0d(
|
||||
func.func @vector_splat_0d(%a: f32) -> vector<f32> {
|
||||
// CHECK: vector.splat %{{.*}} : vector<f32>
|
||||
%0 = vector.splat %a : vector<f32>
|
||||
return %0 : vector<f32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @vector_mask
|
||||
func.func @vector_mask(%a: vector<8xi32>, %m0: vector<8xi1>) -> i32 {
|
||||
// CHECK-NEXT: %{{.*}} = vector.mask %{{.*}} { vector.reduction <add>, %{{.*}} : vector<8xi32> into i32 } : vector<8xi1> -> i32
|
||||
|
||||
@ -49,7 +49,7 @@ func.func @vector_maskedload(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
|
||||
%idx_4 = arith.constant 4 : index
|
||||
%mask = vector.create_mask %idx_1 : vector<4xi1>
|
||||
%s = arith.constant 0.0 : f32
|
||||
%pass_thru = vector.splat %s : vector<4xf32>
|
||||
%pass_thru = vector.broadcast %s : f32 to vector<4xf32>
|
||||
%0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
|
||||
return %0: vector<4xf32>
|
||||
}
|
||||
@ -65,7 +65,7 @@ func.func @vector_maskedload_with_alignment(%arg0 : memref<4x5xf32>) -> vector<4
|
||||
%idx_4 = arith.constant 4 : index
|
||||
%mask = vector.create_mask %idx_1 : vector<4xi1>
|
||||
%s = arith.constant 0.0 : f32
|
||||
%pass_thru = vector.splat %s : vector<4xf32>
|
||||
%pass_thru = vector.broadcast %s : f32 to vector<4xf32>
|
||||
%0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru {alignment = 8}: memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
|
||||
return %0: vector<4xf32>
|
||||
}
|
||||
|
||||
@ -107,7 +107,7 @@ func.func @return_not_in_function() {
|
||||
// -----
|
||||
|
||||
func.func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
|
||||
vector.splat %v : vector<8xf64>
|
||||
vector.broadcast %v : f64 to vector<8xf64>
|
||||
// expected-error@-1 {{expects different type than prior uses}}
|
||||
return
|
||||
}
|
||||
|
||||
@ -21,13 +21,6 @@ func.func @print_vector_0d(%a: vector<f32>) {
|
||||
return
|
||||
}
|
||||
|
||||
func.func @splat_0d(%a: f32) {
|
||||
%1 = vector.splat %a : vector<f32>
|
||||
// CHECK: ( 42 )
|
||||
vector.print %1: vector<f32>
|
||||
return
|
||||
}
|
||||
|
||||
func.func @broadcast_0d(%a: f32) {
|
||||
%1 = vector.broadcast %a : f32 to vector<f32>
|
||||
// CHECK: ( 42 )
|
||||
|
||||
@ -56,7 +56,7 @@ func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interf
|
||||
func.func @vector_splat_2d() {
|
||||
%c0 = arith.constant 0 : index
|
||||
%f10 = arith.constant 10.0 : f32
|
||||
%vf10 = vector.splat %f10: !vector_type_C
|
||||
%vf10 = vector.broadcast %f10: f32 to !vector_type_C
|
||||
%C = memref.alloc() : !matrix_type_CC
|
||||
memref.store %vf10, %C[%c0, %c0]: !matrix_type_CC
|
||||
|
||||
|
||||
@ -181,7 +181,6 @@
|
||||
"vector.insert_strided_slice"
|
||||
"vector.matrix_multiply"
|
||||
"vector.print"
|
||||
"vector.splat"
|
||||
"vector.transfer_read"
|
||||
"vector.transfer_write"
|
||||
"vector.yield"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user