[MLIR][Vector] Remove vector.splat (#162167)

vector.splat has been deprecated (user: please use the very similar vector.broadcast instead) 
with the last PR landing about 6 weeks ago.

The discourse discussion is at
https://discourse.llvm.org/t/rfc-mlir-vector-deprecate-then-remove-vector-splat/87143/1
The last PR was #152230

This PR completely removes vector.splat. In addition to removing vector.splat from VectorOps.td, it

- Updates the few remaining places where vector::SplatOp is created (now vector::BroadcastOp is created)
- Removes temporary patterns where vector.splat is replaced by vector.broadcast

The only place 'vector.splat' appears is now the files

https://github.com/llvm/llvm-project/blob/main/mlir/utils/tree-sitter-mlir/test/corpus/op.txt
 and

https://github.com/llvm/llvm-project/blob/main/mlir/utils/tree-sitter-mlir/dialect/vector.js

---------

Signed-off-by: James Newling <james.newling@gmail.com>
This commit is contained in:
James Newling 2025-10-10 09:58:18 -07:00 committed by GitHub
parent 6ed18d8525
commit ea291d0e8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 78 additions and 501 deletions

View File

@ -2383,7 +2383,7 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
auto context{builder.getContext()};
auto argBases{getBasesForArgs(args)};
mlir::vector::SplatOp splatOp{nullptr};
mlir::vector::BroadcastOp splatOp{nullptr};
mlir::Type retTy{nullptr};
switch (vop) {
case VecOp::Splat: {
@ -2391,9 +2391,9 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
auto vecTyInfo{getVecTypeFromFir(argBases[0])};
auto extractOp{genVecExtract(resultType, args)};
splatOp =
mlir::vector::SplatOp::create(builder, loc, *(extractOp.getUnboxed()),
vecTyInfo.toMlirVectorType(context));
splatOp = mlir::vector::BroadcastOp::create(
builder, loc, vecTyInfo.toMlirVectorType(context),
*(extractOp.getUnboxed()));
retTy = vecTyInfo.toFirVectorType();
break;
}
@ -2401,8 +2401,8 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
assert(args.size() == 1);
auto vecTyInfo{getVecTypeFromEle(argBases[0])};
splatOp = mlir::vector::SplatOp::create(
builder, loc, argBases[0], vecTyInfo.toMlirVectorType(context));
splatOp = mlir::vector::BroadcastOp::create(
builder, loc, vecTyInfo.toMlirVectorType(context), argBases[0]);
retTy = vecTyInfo.toFirVectorType();
break;
}
@ -2412,8 +2412,8 @@ PPCIntrinsicLibrary::genVecSplat(mlir::Type resultType,
auto intOp{builder.createConvert(loc, eleTy, argBases[0])};
// the intrinsic always returns vector(integer(4))
splatOp = mlir::vector::SplatOp::create(builder, loc, intOp,
mlir::VectorType::get(4, eleTy));
splatOp = mlir::vector::BroadcastOp::create(
builder, loc, mlir::VectorType::get(4, eleTy), intOp);
retTy = fir::VectorType::get(4, eleTy);
break;
}
@ -2444,7 +2444,8 @@ PPCIntrinsicLibrary::genVecXlds(mlir::Type resultType,
auto addrConv{fir::ConvertOp::create(builder, loc, i64RefTy, addr)};
auto addrVal{fir::LoadOp::create(builder, loc, addrConv)};
auto splatRes{mlir::vector::SplatOp::create(builder, loc, addrVal, i64VecTy)};
auto splatRes{
mlir::vector::BroadcastOp::create(builder, loc, i64VecTy, addrVal)};
mlir::Value result{nullptr};
if (mlirTy != splatRes.getType()) {

View File

@ -125,7 +125,7 @@ Some existing Arith and Vector Dialect on `n-D` `vector` types comprise:
// Produces a vector<3x7x8xf32>
%b = arith.mulf %0, %1 : vector<3x7x8xf32>
// Produces a vector<3x7x8xf32>
%c = vector.splat %1 : vector<3x7x8xf32>
%c = vector.broadcast %1 : f32 to vector<3x7x8xf32>
%d = vector.extract %0[1]: vector<7x8xf32> from vector<3x7x8xf32>
%e = vector.extract %0[1, 5]: vector<8xf32> from vector<3x7x8xf32>
@ -176,8 +176,6 @@ infrastructure can apply iteratively.
### Virtual Vector to Hardware Vector Lowering
For now, `VV -> HWV` are specified in C++ (see for instance the
[SplatOpLowering for n-D vectors](https://github.com/tensorflow/mlir/commit/0a0c4867c6a6fcb0a2f17ef26a791c1d551fe33d)
or the
[VectorOuterProductOp lowering](https://github.com/tensorflow/mlir/commit/957b1ca9680b4aacabb3a480fbc4ebd2506334b8)).
Simple

View File

@ -2881,53 +2881,6 @@ def Vector_PrintOp :
}];
}
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//
def Vector_SplatOp : Vector_Op<"splat", [
Pure,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
TypesMatchWith<"operand type matches element type of result",
"aggregate", "input",
"::llvm::cast<VectorType>($_self).getElementType()">
]> {
let summary = "vector splat or broadcast operation";
let description = [{
Note: This operation is deprecated. Please use vector.broadcast.
Broadcast the operand to all elements of the result vector. The type of the
operand must match the element type of the vector type.
Example:
```mlir
%s = arith.constant 10.1 : f32
%t = vector.splat %s : vector<8x16xf32>
```
This operation is deprecated, the preferred representation of the above is:
```mlir
%s = arith.constant 10.1 : f32
%t = vector.broadcast %s : f32 to vector<8x16xf32>
```
}];
let arguments = (ins AnyType:$input);
let results = (outs AnyVectorOfAnyRank:$aggregate);
let builders = [
OpBuilder<(ins "Value":$element, "Type":$aggregateType),
[{ build($_builder, $_state, aggregateType, element); }]>];
let assemblyFormat = "$input attr-dict `:` type($aggregate)";
let hasFolder = 1;
// vector.splat is deprecated, and vector.broadcast should be used instead.
// Canonicalize vector.splat to vector.broadcast.
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// VectorScaleOp

View File

@ -432,10 +432,6 @@ static Value getOriginalVectorValue(Value value) {
current = op.getSource();
return false;
})
.Case<vector::SplatOp>([&current](auto op) {
current = op.getInput();
return false;
})
.Default([](Operation *) { return false; });
if (!skipOp) {

View File

@ -236,7 +236,7 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
/// AFTER:
/// ```mlir
/// ...
/// %pad_1d = vector.splat %pad : vector<[4]xi32>
/// %pad_1d = vector.broadcast %pad : i32 to vector<[4]xi32>
/// %tile = scf.for %tile_slice_idx = %c0 to %svl_s step %c1
/// iter_args(%iter_tile = %init_tile) -> (vector<[4]x[4]xi32>) {
/// ...

View File

@ -731,28 +731,14 @@ struct ExtractFromCreateMaskToPselLowering
}
};
// Convert all `vector.splat` to `vector.broadcast`. There is a path from
// `vector.broadcast` to ArmSME via another pattern.
struct ConvertSplatToBroadcast : public OpRewritePattern<vector::SplatOp> {
using Base::Base;
LogicalResult matchAndRewrite(vector::SplatOp splatOp,
PatternRewriter &rewriter) const final {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
splatOp.getInput());
return success();
}
};
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
patterns.add<BroadcastOpToArmSMELowering, ConvertSplatToBroadcast,
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
patterns.add<BroadcastOpToArmSMELowering, TransferReadToArmSMELowering,
TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
VectorOuterProductToArmSMELowering,
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice,
ExtractFromCreateMaskToPselLowering>(&ctx);

View File

@ -2161,19 +2161,6 @@ public:
}
};
/// Convert `vector.splat` to `vector.broadcast`. There is a path to LLVM from
/// `vector.broadcast` through other patterns.
struct VectorSplatToBroadcast : public ConvertOpToLLVMPattern<vector::SplatOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult
matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
adaptor.getInput());
return success();
}
};
} // namespace
void mlir::vector::populateVectorRankReducingFMAPattern(
@ -2212,7 +2199,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
VectorSplatToBroadcast, VectorBroadcastScalarToLowRankLowering,
VectorBroadcastScalarToLowRankLowering,
VectorBroadcastScalarToNdLowering,
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,

View File

@ -22,7 +22,6 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
@ -79,20 +78,6 @@ struct VectorShapeCast final : public OpConversionPattern<vector::ShapeCastOp> {
}
};
// Convert `vector.splat` to `vector.broadcast`. There is a path from
// `vector.broadcast` to SPIRV via other patterns.
struct VectorSplatToBroadcast final
: public OpConversionPattern<vector::SplatOp> {
using Base::Base;
LogicalResult
matchAndRewrite(vector::SplatOp splat, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splat, splat.getType(),
adaptor.getInput());
return success();
}
};
struct VectorBitcastConvert final
: public OpConversionPattern<vector::BitCastOp> {
using Base::Base;
@ -1092,10 +1077,10 @@ void mlir::populateVectorToSPIRVPatterns(
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorSplatToBroadcast, VectorInsertStridedSliceOpConvert,
VectorShuffleOpConvert, VectorInterleaveOpConvert,
VectorDeinterleaveOpConvert, VectorScalarBroadcastPattern,
VectorLoadOpConverter, VectorStoreOpConverter, VectorStepOpConvert>(
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
VectorScalarBroadcastPattern, VectorLoadOpConverter,
VectorStoreOpConverter, VectorStepOpConvert>(
typeConverter, patterns.getContext(), PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit

View File

@ -123,8 +123,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
vector::OuterProductOp, vector::ScanOp>(
[&](Operation *op) { return converter.isLegal(op); });
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
arith::ConstantOp, arith::SelectOp, vector::SplatOp,
vector::BroadcastOp>();
arith::ConstantOp, arith::SelectOp, vector::BroadcastOp>();
}
void EmulateUnsupportedFloatsPass::runOnOperation() {

View File

@ -1665,10 +1665,10 @@ static bool hasZeroDimVectors(Operation *op) {
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
}
/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
/// 1s, are considered to be 'broadcastlike'.
/// All BroadcastOps, as well as ShapeCastOps that only prepend 1s, are
/// considered to be 'broadcastlike'.
static bool isBroadcastLike(Operation *op) {
if (isa<BroadcastOp, SplatOp>(op))
if (isa<BroadcastOp>(op))
return true;
auto shapeCast = dyn_cast<ShapeCastOp>(op);
@ -3249,12 +3249,11 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
};
/// Consider the defining operation `defOp` of `value`. If `defOp` is a
/// vector.splat or a vector.broadcast with a scalar operand, return the scalar
/// value that is splatted. Otherwise return null.
/// vector.broadcast with a scalar operand, return the scalar value that is
/// splatted. Otherwise return null.
///
/// Examples:
/// Example:
///
/// scalar_source --> vector.splat --> value - return scalar_source
/// scalar_source --> vector.broadcast --> value - return scalar_source
static Value getScalarSplatSource(Value value) {
// Block argument:
@ -3262,10 +3261,6 @@ static Value getScalarSplatSource(Value value) {
if (!defOp)
return {};
// Splat:
if (auto splat = dyn_cast<vector::SplatOp>(defOp))
return splat.getInput();
auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
// Not broadcast (and not splat):
@ -7511,41 +7506,6 @@ void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
patterns.getContext(), benefit);
}
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//
OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
auto constOperand = adaptor.getInput();
if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
return {};
// SplatElementsAttr::get treats single value for second arg as being a splat.
return SplatElementsAttr::get(getType(), {constOperand});
}
// Canonicalizer for vector.splat. It always gets canonicalized to a
// vector.broadcast.
class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
public:
using Base::Base;
LogicalResult matchAndRewrite(SplatOp splatOp,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
splatOp.getOperand());
return success();
}
};
void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SplatToBroadcastPattern>(context);
}
void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges.front());
}
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
CombiningKind kind, Value v1, Value acc,
arith::FastMathFlagsAttr fastmath,

View File

@ -90,7 +90,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
Operation *maskOp = mask.getDefiningOp();
SmallVector<vector::ExtractOp, 2> extractOps;
// TODO: add support to `vector.splat`.
// TODO: add support to `vector.broadcast`.
// Finding the mask creation operation.
while (maskOp &&
!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(

View File

@ -590,32 +590,6 @@ struct LinearizeVectorBitCast final
}
};
/// This pattern converts the SplatOp to work on a linearized vector.
/// Following,
/// vector.splat %value : vector<4x4xf32>
/// is converted to:
/// %out_1d = vector.splat %value : vector<16xf32>
/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
struct LinearizeVectorSplat final
: public OpConversionPattern<vector::SplatOp> {
using Base::Base;
LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
LogicalResult
matchAndRewrite(vector::SplatOp splatOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto dstTy = getTypeConverter()->convertType(splatOp.getType());
if (!dstTy)
return rewriter.notifyMatchFailure(splatOp, "cannot convert type.");
rewriter.replaceOpWithNewOp<vector::SplatOp>(splatOp, adaptor.getInput(),
dstTy);
return success();
}
};
/// This pattern converts the CreateMaskOp to work on a linearized vector.
/// It currently supports only 2D masks with a unit outer dimension.
/// Following,
@ -934,9 +908,9 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
RewritePatternSet &patterns) {
patterns
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
LinearizeVectorStore, LinearizeVectorFromElements,
LinearizeVectorToElements>(typeConverter, patterns.getContext());
LinearizeVectorCreateMask, LinearizeVectorLoad, LinearizeVectorStore,
LinearizeVectorFromElements, LinearizeVectorToElements>(
typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(

View File

@ -878,7 +878,7 @@ struct BubbleUpBitCastForStridedSliceInsert
// This transforms IR like:
// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32>
// Into:
// %cst = vector.splat %c0_f32 : vector<4xf32>
// %cst = vector.broadcast %c0_f32 : f32 to vector<4xf32>
// %1 = vector.extract_strided_slice %0 {
// offsets = [0], sizes = [4], strides = [1]
// } : vector<8xf16> to vector<4xf16>
@ -987,8 +987,8 @@ static Type cloneOrReplace(Type type, Type newElementType) {
return newElementType;
}
/// If `value` is the result of a splat or broadcast operation, return the input
/// of the splat/broadcast operation.
/// If `value` is the result of a broadcast operation, return the input
/// of the broadcast operation.
static Value getBroadcastLikeSource(Value value) {
Operation *op = value.getDefiningOp();
@ -998,13 +998,10 @@ static Value getBroadcastLikeSource(Value value) {
if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
return broadcast.getSource();
if (auto splat = dyn_cast<vector::SplatOp>(op))
return splat.getInput();
return {};
}
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex:
///
/// Example:
/// ```
@ -1017,9 +1014,6 @@ static Value getBroadcastLikeSource(Value value) {
/// %r = arith.addi %arg0, %arg1 : index
/// %b = vector.broadcast %r : index to vector<1x4xindex>
/// ```
///
/// Both `vector.broadcast` and `vector.splat` are supported as broadcasting
/// ops.
struct ReorderElementwiseOpsOnBroadcast final
: public OpTraitRewritePattern<OpTrait::Elementwise> {
using OpTraitRewritePattern::OpTraitRewritePattern;
@ -1045,29 +1039,29 @@ struct ReorderElementwiseOpsOnBroadcast final
Type resultElemType = resultType.getElementType();
// Get the type of the first non-constant operand
Value splatSource;
Value broadcastSource;
for (Value operand : op->getOperands()) {
Operation *definingOp = operand.getDefiningOp();
if (!definingOp)
return failure();
if (definingOp->hasTrait<OpTrait::ConstantLike>())
continue;
splatSource = getBroadcastLikeSource(operand);
broadcastSource = getBroadcastLikeSource(operand);
break;
}
if (!splatSource)
if (!broadcastSource)
return failure();
Type unbroadcastResultType =
cloneOrReplace(splatSource.getType(), resultElemType);
cloneOrReplace(broadcastSource.getType(), resultElemType);
// Make sure that all operands are broadcast from identically-shaped types:
// * scalar (`vector.broadcast` + `vector.splat`), or
// * scalar (`vector.broadcast`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
if (!llvm::all_of(op->getOperands(), [splatSource](Value val) {
if (!llvm::all_of(op->getOperands(), [broadcastSource](Value val) {
if (auto source = getBroadcastLikeSource(val))
return haveSameShapeAndScaling(source.getType(),
splatSource.getType());
broadcastSource.getType());
SplatElementsAttr splatConst;
return matchPattern(val, m_Constant(&splatConst));
})) {
@ -1271,19 +1265,18 @@ public:
}
};
/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
/// Pattern to rewrite vector.store(vector.broadcast) -> vector/memref.store.
///
/// Example:
/// ```
/// %0 = vector.splat %arg2 : vector<1xf32>
/// %0 = vector.broadcast %arg2 : f32 to vector<1xf32>
/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
/// ```
/// Gets converted to:
/// ```
/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
/// ```
class StoreOpFromSplatOrBroadcast final
: public OpRewritePattern<vector::StoreOp> {
class StoreOpFromBroadcast final : public OpRewritePattern<vector::StoreOp> {
public:
using Base::Base;
@ -1308,9 +1301,9 @@ public:
return rewriter.notifyMatchFailure(
op, "value to store is not from a broadcast");
// Checking for single use so we can remove splat.
Operation *splat = toStore.getDefiningOp();
if (!splat->hasOneUse())
// Checking for single use so we can remove broadcast.
Operation *broadcast = toStore.getDefiningOp();
if (!broadcast->hasOneUse())
return rewriter.notifyMatchFailure(op, "expected single op use");
Value base = op.getBase();
@ -1321,7 +1314,7 @@ public:
} else {
rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
}
rewriter.eraseOp(splat);
rewriter.eraseOp(broadcast);
return success();
}
};
@ -2391,8 +2384,8 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
PatternBenefit benefit) {
// TODO: Consider converting these patterns to canonicalizations.
patterns.add<ExtractOpFromLoad, StoreOpFromSplatOrBroadcast>(
patterns.getContext(), benefit);
patterns.add<ExtractOpFromLoad, StoreOpFromBroadcast>(patterns.getContext(),
benefit);
}
void mlir::vector::populateChainedVectorReductionFoldingPatterns(

View File

@ -86,7 +86,7 @@ func.func @fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<1xf
// CHECK: %[[VAL:.+]] = spirv.CompositeConstruct %[[A]], %[[A]], %[[A]], %[[A]]
// CHECK: spirv.ReturnValue %[[VAL]] : vector<4xf32>
func.func @splat(%f : f32) -> vector<4xf32> {
%splat = vector.splat %f : vector<4xf32>
%splat = vector.broadcast %f : f32 to vector<4xf32>
return %splat : vector<4xf32>
}

View File

@ -429,38 +429,6 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<[8]xi16>) {
return
}
//===----------------------------------------------------------------------===//
// vector.splat
//===----------------------------------------------------------------------===//
// -----
// CHECK-LABEL: func.func @splat_vec2d_from_i32(
// CHECK-SAME: %[[SRC:.*]]: i32) {
// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : i32 to vector<[4]xi32>
// CHECK: arm_sme.get_tile : vector<[4]x[4]xi32>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %{{.*}} : index
// CHECK: scf.for {{.*}} to %[[NUM_TILE_SLICES]] {{.*}} {
// CHECK: arm_sme.insert_tile_slice %[[BCST]], {{.*}} : vector<[4]xi32> into vector<[4]x[4]xi32>
func.func @splat_vec2d_from_i32(%arg0: i32) {
%0 = vector.splat %arg0 : vector<[4]x[4]xi32>
"prevent.dce"(%0) : (vector<[4]x[4]xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func.func @splat_vec2d_from_f16(
// CHECK-SAME: %[[SRC:.*]]: f16) {
// CHECK: %[[BCST:.*]] = vector.broadcast %[[SRC]] : f16 to vector<[8]xf16>
// CHECK: scf.for
// CHECK: arm_sme.insert_tile_slice %[[BCST]], {{.*}} : vector<[8]xf16> into vector<[8]x[8]xf16>
func.func @splat_vec2d_from_f16(%arg0: f16) {
%0 = vector.splat %arg0 : vector<[8]x[8]xf16>
"prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> ()
return
}
//===----------------------------------------------------------------------===//
// vector.transpose

View File

@ -2216,23 +2216,6 @@ func.func @compress_store_op_with_alignment(%arg0: memref<?xindex>, %arg1: vecto
// -----
//===----------------------------------------------------------------------===//
// vector.splat
//===----------------------------------------------------------------------===//
// vector.splat is converted to vector.broadcast. Then, vector.broadcast is converted to LLVM.
// CHECK-LABEL: @splat_0d
// CHECK-NOT: splat
// CHECK: return
func.func @splat_0d(%elt: f32) -> (vector<f32>, vector<4xf32>, vector<[4]xf32>) {
%a = vector.splat %elt : vector<f32>
%b = vector.splat %elt : vector<4xf32>
%c = vector.splat %elt : vector<[4]xf32>
return %a, %b, %c : vector<f32>, vector<4xf32>, vector<[4]xf32>
}
// -----
//===----------------------------------------------------------------------===//
// vector.scalable_insert
//===----------------------------------------------------------------------===//

View File

@ -105,9 +105,9 @@ func.func @ipowi32_fold(%result : memref<?xi32>) {
// --- Test vector folding ---
%arg11_base = arith.constant 2 : i32
%arg11_base_vec = vector.splat %arg11_base : vector<2x2xi32>
%arg11_base_vec = vector.broadcast %arg11_base : i32 to vector<2x2xi32>
%arg11_power = arith.constant 30 : i32
%arg11_power_vec = vector.splat %arg11_power : vector<2x2xi32>
%arg11_power_vec = vector.broadcast %arg11_power : i32 to vector<2x2xi32>
%res11_vec = math.ipowi %arg11_base_vec, %arg11_power_vec : vector<2x2xi32>
%i11 = arith.constant 11 : index
%res11 = vector.extract %res11_vec[1, 1] : i32 from vector<2x2xi32>

View File

@ -837,7 +837,7 @@ func.func @fold_extract_splatlike(%a : f32, %idx0 : index, %idx1 : index, %idx2
// CHECK-LABEL: fold_extract_vector_from_splat
// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
%b = vector.splat %a : vector<1x2x4xf32>
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
}

View File

@ -1,126 +0,0 @@
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
// This file should be removed when vector.splat is removed.
// This file tests canonicalization/folding with vector.splat.
// These tests all have equivalent tests using vector.broadcast in canonicalize.mlir
// CHECK-LABEL: fold_extract_splat
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
%b = vector.splat %a : vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
return %r : f32
}
// -----
// CHECK-LABEL: extract_strided_splat
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} f16 to vector<2x4xf16>
// CHECK-NEXT: return %[[B]] : vector<2x4xf16>
func.func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> {
%0 = vector.splat %arg0 : vector<16x4xf16>
%1 = vector.extract_strided_slice %0
{offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} :
vector<16x4xf16> to vector<2x4xf16>
return %1 : vector<2x4xf16>
}
// -----
// CHECK-LABEL: func @splat_fold
// CHECK-NEXT: [[V:%.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
// CHECK-NEXT: return [[V]] : vector<4xf32>
func.func @splat_fold() -> vector<4xf32> {
%c = arith.constant 1.0 : f32
%v = vector.splat %c : vector<4xf32>
return %v : vector<4xf32>
}
// -----
// CHECK-LABEL: func @transpose_splat2(
// CHECK-SAME: %[[VAL_0:.*]]: f32) -> vector<3x4xf32> {
// CHECK: %[[VAL_1:.*]] = vector.broadcast %[[VAL_0]] : f32 to vector<3x4xf32>
// CHECK: return %[[VAL_1]] : vector<3x4xf32>
func.func @transpose_splat2(%arg : f32) -> vector<3x4xf32> {
%splat = vector.splat %arg : vector<4x3xf32>
%0 = vector.transpose %splat, [1, 0] : vector<4x3xf32> to vector<3x4xf32>
return %0 : vector<3x4xf32>
}
// -----
// CHECK-LABEL: @insert_strided_slice_splat
// CHECK-SAME: (%[[ARG:.*]]: f32)
// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : f32 to vector<8x16xf32>
// CHECK-NEXT: return %[[SPLAT]] : vector<8x16xf32>
func.func @insert_strided_slice_splat(%x: f32) -> (vector<8x16xf32>) {
%splat0 = vector.splat %x : vector<4x4xf32>
%splat1 = vector.splat %x : vector<8x16xf32>
%0 = vector.insert_strided_slice %splat0, %splat1 {offsets = [2, 2], strides = [1, 1]}
: vector<4x4xf32> into vector<8x16xf32>
return %0 : vector<8x16xf32>
}
// -----
// CHECK-LABEL: func @shuffle_splat
// CHECK-SAME: (%[[ARG:.*]]: i32)
// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<4xi32>
// CHECK-NEXT: return %[[SPLAT]] : vector<4xi32>
func.func @shuffle_splat(%x : i32) -> vector<4xi32> {
%v0 = vector.splat %x : vector<4xi32>
%v1 = vector.splat %x : vector<2xi32>
%shuffle = vector.shuffle %v0, %v1 [2, 3, 4, 5] : vector<4xi32>, vector<2xi32>
return %shuffle : vector<4xi32>
}
// -----
// CHECK-LABEL: func @insert_splat
// CHECK-SAME: (%[[ARG:.*]]: i32)
// CHECK-NEXT: %[[SPLAT:.*]] = vector.broadcast %[[ARG]] : i32 to vector<2x4x3xi32>
// CHECK-NEXT: return %[[SPLAT]] : vector<2x4x3xi32>
func.func @insert_splat(%x : i32) -> vector<2x4x3xi32> {
%v0 = vector.splat %x : vector<4x3xi32>
%v1 = vector.splat %x : vector<2x4x3xi32>
%insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32>
return %insert : vector<2x4x3xi32>
}
// -----
// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression
// CHECK-SAME: (%[[A:.*]]: f32, %[[C:.*]]: vector<2xf32>)
func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %c: vector<2xf32>) -> (f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {
// Splat scalar to 0D and extract scalar.
%0 = vector.splat %a : vector<f32>
%1 = vector.extract %0[] : f32 from vector<f32>
// Broadcast scalar to 0D and extract scalar.
%2 = vector.splat %a : vector<f32>
%3 = vector.extract %2[] : f32 from vector<f32>
// Splat scalar to 2D and extract scalar.
%6 = vector.splat %a : vector<2x3xf32>
%7 = vector.extract %6[0, 1] : f32 from vector<2x3xf32>
// Broadcast scalar to 3D and extract scalar.
%8 = vector.splat %a : vector<5x6x7xf32>
%9 = vector.extract %8[2, 1, 5] : f32 from vector<5x6x7xf32>
// Extract 2D from 3D that was broadcasted from a scalar.
// CHECK: %[[EXTRACT2:.*]] = vector.broadcast %[[A]] : f32 to vector<6x7xf32>
%10 = vector.extract %8[2] : vector<6x7xf32> from vector<5x6x7xf32>
// Extract 1D from 2D that was splat'ed from a scalar.
// CHECK: %[[EXTRACT3:.*]] = vector.broadcast %[[A]] : f32 to vector<3xf32>
%11 = vector.extract %6[1] : vector<3xf32> from vector<2x3xf32>
// CHECK: return %[[A]], %[[A]], %[[A]], %[[A]], %[[EXTRACT2]], %[[EXTRACT3]]
return %1, %3, %7, %9, %10, %11 : f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>
}

View File

@ -28,7 +28,7 @@ func.func @float_constant_splat() -> vector<8xf32> {
// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index}
func.func @vector_splat() -> vector<4xindex> {
%0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index
%1 = vector.splat %0 : vector<4xindex>
%1 = vector.broadcast %0 : index to vector<4xindex>
%2 = test.reflect_bounds %1 : vector<4xindex>
func.return %2 : vector<4xindex>
}

View File

@ -320,7 +320,7 @@ func.func @test_vector.transfer_write(%m: memref<1xi32>, %2: vector<1x32xi32>)
func.func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%vf0 = vector.splat %f0 : vector<4x3xf32>
%vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
// expected-error@+1 {{ requires memref or ranked tensor type}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : vector<4x3xf32>, vector<1x1x2x3xf32>
}
@ -330,7 +330,7 @@ func.func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
func.func @test_vector.transfer_read(%arg0: memref<4x3xf32>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%vf0 = vector.splat %f0 : vector<4x3xf32>
%vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
// expected-error@+1 {{ requires vector type}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<4x3xf32>, f32
}
@ -414,7 +414,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
%c3 = arith.constant 3 : index
%cst = arith.constant 3.0 : f32
// expected-note@+1 {{prior use here}}
%mask = vector.splat %c1 : vector<3x8x7xi1>
%mask = vector.broadcast %c1 : i1 to vector<3x8x7xi1>
// expected-error@+1 {{expects different type than prior uses: 'vector<3x7xi1>' vs 'vector<3x8x7xi1>'}}
%0 = vector.transfer_read %arg0[%c3, %c3, %c3], %cst, %mask {permutation_map = affine_map<(d0, d1, d2)->(d0, 0, d2)>} : memref<?x?x?xf32>, vector<3x8x7xf32>
}
@ -424,7 +424,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?x?xf32>) {
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%vf0 = vector.splat %f0 : vector<4x3xf32>
%vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
// expected-error@+1 {{requires source vector element and vector result ranks to match}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<4x3xf32>>, vector<3xf32>
}
@ -434,7 +434,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<4x3xf32>>) {
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%vf0 = vector.splat %f0 : vector<6xf32>
%vf0 = vector.broadcast %f0 : f32 to vector<6xf32>
// expected-error@+1 {{requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the source}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 : memref<?x?xvector<6xf32>>, vector<3xf32>
}
@ -444,7 +444,7 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<6xf32>>) {
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%vf0 = vector.splat %f0 : vector<2x3xf32>
%vf0 = vector.broadcast %f0 : f32 to vector<2x3xf32>
// expected-error@+1 {{ expects the in_bounds attr of same rank as permutation_map results: affine_map<(d0, d1) -> (d0, d1)>}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0 {in_bounds = [true], permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
}
@ -454,8 +454,8 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
func.func @test_vector.transfer_read(%arg0: memref<?x?xvector<2x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%vf0 = vector.splat %f0 : vector<2x3xf32>
%mask = vector.splat %c1 : vector<2x3xi1>
%vf0 = vector.broadcast %f0 : f32 to vector<2x3xf32>
%mask = vector.broadcast %c1 : f32 to vector<2x3xi1>
// expected-error@+1 {{does not support masks with vector element type}}
%0 = vector.transfer_read %arg0[%c3, %c3], %vf0, %mask {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : memref<?x?xvector<2x3xf32>>, vector<1x1x2x3xf32>
}
@ -492,7 +492,7 @@ func.func @test_vector.transfer_write(%arg0: memref<?x?xf32>) {
func.func @test_vector.transfer_write(%arg0: memref<vector<4x3xf32>>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%vf0 = vector.splat %f0 : vector<4x3xf32>
%vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
// expected-error@+1 {{ requires vector type}}
vector.transfer_write %arg0, %arg0[%c3, %c3] : memref<vector<4x3xf32>>, vector<4x3xf32>
}
@ -502,7 +502,7 @@ func.func @test_vector.transfer_write(%arg0: memref<vector<4x3xf32>>) {
func.func @test_vector.transfer_write(%arg0: vector<4x3xf32>) {
%c3 = arith.constant 3 : index
%f0 = arith.constant 0.0 : f32
%vf0 = vector.splat %f0 : vector<4x3xf32>
%vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
// expected-error@+1 {{ requires memref or ranked tensor type}}
vector.transfer_write %arg0, %arg0[%c3, %c3] : vector<4x3xf32>, f32
}
@ -1980,29 +1980,6 @@ func.func @invalid_step_2d() {
// -----
//===----------------------------------------------------------------------===//
// vector.splat
//===----------------------------------------------------------------------===//
// -----
func.func @vector_splat_invalid_result(%v : f32) {
// expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'memref<8xf32>'}}
vector.splat %v : memref<8xf32>
return
}
// -----
// expected-note @+1 {{prior use here}}
func.func @vector_splat_type_mismatch(%a: f32) {
// expected-error @+1 {{expects different type than prior uses: 'i32' vs 'f32'}}
%0 = vector.splat %a : vector<1xi32>
return
}
// -----
//===----------------------------------------------------------------------===//
// vector.load
//===----------------------------------------------------------------------===//

View File

@ -428,33 +428,6 @@ func.func @test_linearize_across_for(%arg0 : vector<4xi8>) -> vector<4xi8> {
// -----
// CHECK-LABEL: linearize_vector_splat
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x2xi32>
func.func @linearize_vector_splat(%arg0: i32) -> vector<4x2xi32> {
// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<8xi32>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<8xi32> to vector<4x2xi32>
// CHECK: return %[[CAST]] : vector<4x2xi32>
%0 = vector.splat %arg0 : vector<4x2xi32>
return %0 : vector<4x2xi32>
}
// -----
// CHECK-LABEL: linearize_scalable_vector_splat
// CHECK-SAME: (%[[ARG:.*]]: i32) -> vector<4x[2]xi32>
func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG]] : vector<[8]xi32>
// CHECK: %[[CAST:.*]] = vector.shape_cast %[[SPLAT]] : vector<[8]xi32> to vector<4x[2]xi32>
// CHECK: return %[[CAST]] : vector<4x[2]xi32>
%0 = vector.splat %arg0 : vector<4x[2]xi32>
return %0 : vector<4x[2]xi32>
}
// -----
// CHECK-LABEL: linearize_create_mask
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x16xi1>
func.func @linearize_create_mask(%arg0 : index, %arg1 : index) -> vector<1x16xi1> {

View File

@ -45,11 +45,11 @@ func.func @vector_transfer_ops(%arg0: memref<?x?xf32>,
%i0 = arith.constant 0 : index
%i1 = arith.constant 1 : i1
%vf0 = vector.splat %f0 : vector<4x3xf32>
%v0 = vector.splat %c0 : vector<4x3xi32>
%vi0 = vector.splat %i0 : vector<4x3xindex>
%vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
%v0 = vector.broadcast %c0 : i32 to vector<4x3xi32>
%vi0 = vector.broadcast %i0 : index to vector<4x3xindex>
%m = arith.constant dense<[0, 0, 1, 0, 1]> : vector<5xi1>
%m2 = vector.splat %i1 : vector<4x5xi1>
%m2 = vector.broadcast %i1 : i1 to vector<4x5xi1>
//
// CHECK: vector.transfer_read
%0 = vector.transfer_read %arg0[%c3, %c3], %f0 {permutation_map = affine_map<(d0, d1)->(d0)>} : memref<?x?xf32>, vector<128xf32>
@ -106,9 +106,9 @@ func.func @vector_transfer_ops_tensor(%arg0: tensor<?x?xf32>,
%c0 = arith.constant 0 : i32
%i0 = arith.constant 0 : index
%vf0 = vector.splat %f0 : vector<4x3xf32>
%v0 = vector.splat %c0 : vector<4x3xi32>
%vi0 = vector.splat %i0 : vector<4x3xindex>
%vf0 = vector.broadcast %f0 : f32 to vector<4x3xf32>
%v0 = vector.broadcast %c0 : i32 to vector<4x3xi32>
%vi0 = vector.broadcast %i0 : index to vector<4x3xindex>
//
// CHECK: vector.transfer_read
@ -922,28 +922,6 @@ func.func @vector_scan(%0: vector<4x8x16x32xf32>) -> vector<4x8x16x32xf32> {
return %2#0 : vector<4x8x16x32xf32>
}
// CHECK-LABEL: func @test_splat_op
// CHECK-SAME: %[[s:.*]]: f32, %[[s2:.*]]: !llvm.ptr<1>
func.func @test_splat_op(%s : f32, %s2 : !llvm.ptr<1>) {
// CHECK: vector.splat %[[s]] : vector<8xf32>
%v = vector.splat %s : vector<8xf32>
// CHECK: vector.splat %[[s]] : vector<4xf32>
%u = "vector.splat"(%s) : (f32) -> vector<4xf32>
// CHECK: vector.splat %[[s2]] : vector<16x!llvm.ptr<1>>
%w = vector.splat %s2 : vector<16x!llvm.ptr<1>>
return
}
// CHECK-LABEL: func @vector_splat_0d(
func.func @vector_splat_0d(%a: f32) -> vector<f32> {
// CHECK: vector.splat %{{.*}} : vector<f32>
%0 = vector.splat %a : vector<f32>
return %0 : vector<f32>
}
// CHECK-LABEL: func @vector_mask
func.func @vector_mask(%a: vector<8xi32>, %m0: vector<8xi1>) -> i32 {
// CHECK-NEXT: %{{.*}} = vector.mask %{{.*}} { vector.reduction <add>, %{{.*}} : vector<8xi32> into i32 } : vector<8xi1> -> i32

View File

@ -49,7 +49,7 @@ func.func @vector_maskedload(%arg0 : memref<4x5xf32>) -> vector<4xf32> {
%idx_4 = arith.constant 4 : index
%mask = vector.create_mask %idx_1 : vector<4xi1>
%s = arith.constant 0.0 : f32
%pass_thru = vector.splat %s : vector<4xf32>
%pass_thru = vector.broadcast %s : f32 to vector<4xf32>
%0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru : memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
return %0: vector<4xf32>
}
@ -65,7 +65,7 @@ func.func @vector_maskedload_with_alignment(%arg0 : memref<4x5xf32>) -> vector<4
%idx_4 = arith.constant 4 : index
%mask = vector.create_mask %idx_1 : vector<4xi1>
%s = arith.constant 0.0 : f32
%pass_thru = vector.splat %s : vector<4xf32>
%pass_thru = vector.broadcast %s : f32 to vector<4xf32>
%0 = vector.maskedload %arg0[%idx_0, %idx_4], %mask, %pass_thru {alignment = 8}: memref<4x5xf32>, vector<4xi1>, vector<4xf32> into vector<4xf32>
return %0: vector<4xf32>
}

View File

@ -107,7 +107,7 @@ func.func @return_not_in_function() {
// -----
func.func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
vector.splat %v : vector<8xf64>
vector.broadcast %v : f64 to vector<8xf64>
// expected-error@-1 {{expects different type than prior uses}}
return
}

View File

@ -21,13 +21,6 @@ func.func @print_vector_0d(%a: vector<f32>) {
return
}
func.func @splat_0d(%a: f32) {
%1 = vector.splat %a : vector<f32>
// CHECK: ( 42 )
vector.print %1: vector<f32>
return
}
func.func @broadcast_0d(%a: f32) {
%1 = vector.broadcast %a : f32 to vector<f32>
// CHECK: ( 42 )

View File

@ -56,7 +56,7 @@ func.func private @printMemrefF32(memref<*xf32>) attributes { llvm.emit_c_interf
func.func @vector_splat_2d() {
%c0 = arith.constant 0 : index
%f10 = arith.constant 10.0 : f32
%vf10 = vector.splat %f10: !vector_type_C
%vf10 = vector.broadcast %f10: f32 to !vector_type_C
%C = memref.alloc() : !matrix_type_CC
memref.store %vf10, %C[%c0, %c0]: !matrix_type_CC

View File

@ -181,7 +181,6 @@
"vector.insert_strided_slice"
"vector.matrix_multiply"
"vector.print"
"vector.splat"
"vector.transfer_read"
"vector.transfer_write"
"vector.yield"