[mlir][vector] Use source as the source argument name (#158258)
This patch updates the following ops to use `source` (instead of `vector`) as the name for their source argument: * `vector.extract` * `vector.scalable.extract` * `vector.extract_strided_slice` This change ensures naming consistency with the "builders" for these Ops that already use the name `source` rather than `vector`. It also addresses part of: * https://github.com/llvm/llvm-project/issues/131602 Specifically, it ensures that we use `source` and `dest` for read and write operations, respectively (as opposed to `vector` and `dest`).
This commit is contained in:
parent
04cd39ae28
commit
1287ed1fa2
@ -675,7 +675,7 @@ def Vector_ExtractOp :
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyVectorOfAnyRank:$vector,
|
||||
AnyVectorOfAnyRank:$source,
|
||||
Variadic<Index>:$dynamic_position,
|
||||
DenseI64ArrayAttr:$static_position
|
||||
);
|
||||
@ -692,7 +692,7 @@ def Vector_ExtractOp :
|
||||
|
||||
let extraClassDeclaration = extraPoisonClassDeclaration # [{
|
||||
VectorType getSourceVectorType() {
|
||||
return ::llvm::cast<VectorType>(getVector().getType());
|
||||
return ::llvm::cast<VectorType>(getSource().getType());
|
||||
}
|
||||
|
||||
/// Return a vector with all the static and dynamic position indices.
|
||||
@ -709,12 +709,17 @@ def Vector_ExtractOp :
|
||||
bool hasDynamicPosition() {
|
||||
return !getDynamicPosition().empty();
|
||||
}
|
||||
|
||||
/// Wrapper for getSource, which replaced getVector.
|
||||
[[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
|
||||
return getSource();
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$vector ``
|
||||
$source ``
|
||||
custom<DynamicIndexList>($dynamic_position, $static_position)
|
||||
attr-dict `:` type($result) `from` type($vector)
|
||||
attr-dict `:` type($result) `from` type($source)
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
@ -1023,6 +1028,10 @@ def Vector_ScalableExtractOp :
|
||||
VectorType getResultVectorType() {
|
||||
return ::llvm::cast<VectorType>(getResult().getType());
|
||||
}
|
||||
/// Wrapper for getSource, which replaced getVector.
|
||||
[[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
|
||||
return getSource();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
@ -1174,7 +1183,7 @@ def Vector_ExtractStridedSliceOp :
|
||||
Vector_Op<"extract_strided_slice", [Pure,
|
||||
PredOpTrait<"operand and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>]>,
|
||||
Arguments<(ins AnyVectorOfNonZeroRank:$vector, I64ArrayAttr:$offsets,
|
||||
Arguments<(ins AnyVectorOfNonZeroRank:$source, I64ArrayAttr:$offsets,
|
||||
I64ArrayAttr:$sizes, I64ArrayAttr:$strides)>,
|
||||
Results<(outs AnyVectorOfNonZeroRank)> {
|
||||
let summary = "extract_strided_slice operation";
|
||||
@ -1209,7 +1218,7 @@ def Vector_ExtractStridedSliceOp :
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getSourceVectorType() {
|
||||
return ::llvm::cast<VectorType>(getVector().getType());
|
||||
return ::llvm::cast<VectorType>(getSource().getType());
|
||||
}
|
||||
void getOffsets(SmallVectorImpl<int64_t> &results);
|
||||
bool hasNonUnitStrides() {
|
||||
@ -1217,11 +1226,15 @@ def Vector_ExtractStridedSliceOp :
|
||||
return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
|
||||
});
|
||||
}
|
||||
/// Wrapper for getSource, which replaced getVector.
|
||||
[[deprecated("Use getSource instead!")]] ::mlir::Value getVector() {
|
||||
return getSource();
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
|
||||
let assemblyFormat = "$source attr-dict `:` type($source) `to` type(results)";
|
||||
}
|
||||
|
||||
// TODO: Tighten semantics so that masks and inbounds can't be used
|
||||
|
||||
@ -462,7 +462,7 @@ struct VectorExtractToArmSMELowering
|
||||
auto loc = extractOp.getLoc();
|
||||
auto position = extractOp.getMixedPosition();
|
||||
|
||||
Value sourceVector = extractOp.getVector();
|
||||
Value sourceVector = extractOp.getSource();
|
||||
|
||||
// Extract entire vector. Should be handled by folder, but just to be safe.
|
||||
if (position.empty()) {
|
||||
@ -692,7 +692,7 @@ struct ExtractFromCreateMaskToPselLowering
|
||||
return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
|
||||
|
||||
auto createMaskOp =
|
||||
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
|
||||
extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
|
||||
if (!createMaskOp)
|
||||
return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
|
||||
|
||||
|
||||
@ -962,7 +962,7 @@ convertExtractStridedSlice(RewriterBase &rewriter,
|
||||
return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
|
||||
|
||||
// Find the vector.transer_read whose result vector is being sliced.
|
||||
auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
|
||||
auto transferReadOp = op.getSource().getDefiningOp<vector::TransferReadOp>();
|
||||
if (!transferReadOp)
|
||||
return rewriter.notifyMatchFailure(op, "no transfer read");
|
||||
|
||||
|
||||
@ -1131,7 +1131,7 @@ public:
|
||||
positionVec.push_back(rewriter.getZeroAttr(idxType));
|
||||
}
|
||||
|
||||
Value extracted = adaptor.getVector();
|
||||
Value extracted = adaptor.getSource();
|
||||
if (extractsAggregate) {
|
||||
ArrayRef<OpFoldResult> position(positionVec);
|
||||
if (extractsScalar) {
|
||||
|
||||
@ -1414,7 +1414,7 @@ struct UnrollTransferWriteConversion
|
||||
/// Return the vector from which newly generated ExtracOps will extract.
|
||||
Value getDataVector(TransferWriteOp xferOp) const {
|
||||
if (auto extractOp = getExtractOp(xferOp))
|
||||
return extractOp.getVector();
|
||||
return extractOp.getSource();
|
||||
return xferOp.getVector();
|
||||
}
|
||||
|
||||
|
||||
@ -189,8 +189,8 @@ struct VectorExtractOpConvert final
|
||||
if (!dstType)
|
||||
return failure();
|
||||
|
||||
if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
|
||||
rewriter.replaceOp(extractOp, adaptor.getVector());
|
||||
if (isa<spirv::ScalarType>(adaptor.getSource().getType())) {
|
||||
rewriter.replaceOp(extractOp, adaptor.getSource());
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -201,7 +201,7 @@ struct VectorExtractOpConvert final
|
||||
extractOp,
|
||||
"Static use of poison index handled elsewhere (folded to poison)");
|
||||
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
|
||||
extractOp, dstType, adaptor.getVector(),
|
||||
extractOp, dstType, adaptor.getSource(),
|
||||
rewriter.getI32ArrayAttr(id.value()));
|
||||
} else {
|
||||
Value sanitizedIndex = sanitizeDynamicIndex(
|
||||
@ -209,7 +209,7 @@ struct VectorExtractOpConvert final
|
||||
vector::ExtractOp::kPoisonIndex,
|
||||
extractOp.getSourceVectorType().getNumElements());
|
||||
rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
|
||||
extractOp, dstType, adaptor.getVector(), sanitizedIndex);
|
||||
extractOp, dstType, adaptor.getSource(), sanitizedIndex);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -445,7 +445,7 @@ struct SwapVectorExtractOfArithExtend
|
||||
return rewriter.notifyMatchFailure(
|
||||
extractOp, "extracted type is not a 1-D scalable vector type");
|
||||
|
||||
auto *extendOp = extractOp.getVector().getDefiningOp();
|
||||
auto *extendOp = extractOp.getSource().getDefiningOp();
|
||||
if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
|
||||
extendOp))
|
||||
return rewriter.notifyMatchFailure(extractOp,
|
||||
|
||||
@ -542,7 +542,7 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = extractOp.getLoc();
|
||||
auto createMaskOp =
|
||||
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
|
||||
extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
|
||||
if (!createMaskOp)
|
||||
return rewriter.notifyMatchFailure(
|
||||
extractOp, "extract not from vector.create_mask op");
|
||||
|
||||
@ -105,7 +105,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
|
||||
return WalkResult::advance();
|
||||
|
||||
// Check that the vector to extract from is a BlockArgument.
|
||||
auto blockArg = dyn_cast<BlockArgument>(extractOp.getVector());
|
||||
auto blockArg = dyn_cast<BlockArgument>(extractOp.getSource());
|
||||
if (!blockArg)
|
||||
return WalkResult::advance();
|
||||
|
||||
@ -141,7 +141,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
|
||||
return WalkResult::advance();
|
||||
|
||||
rewriter.modifyOpInPlace(broadcast, [&] {
|
||||
extractOp.getVectorMutable().assign(initArg->get());
|
||||
extractOp.getSourceMutable().assign(initArg->get());
|
||||
});
|
||||
loop.moveOutOfLoop(extractOp);
|
||||
rewriter.moveOpAfter(broadcast, loop);
|
||||
|
||||
@ -71,7 +71,7 @@ static FailureOr<TransferMask> getMaskOp(Operation *loadOp) {
|
||||
if (auto extractOp =
|
||||
transferRead.getMask().getDefiningOp<vector::ExtractOp>())
|
||||
if (auto maskOp =
|
||||
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>())
|
||||
extractOp.getSource().getDefiningOp<vector::CreateMaskOp>())
|
||||
return TransferMask{maskOp,
|
||||
SmallVector<int64_t>(extractOp.getStaticPosition())};
|
||||
|
||||
|
||||
@ -1309,7 +1309,7 @@ LogicalResult
|
||||
ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
|
||||
ExtractOp::Adaptor adaptor,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
auto vectorType = llvm::cast<VectorType>(adaptor.getVector().getType());
|
||||
auto vectorType = llvm::cast<VectorType>(adaptor.getSource().getType());
|
||||
if (static_cast<int64_t>(adaptor.getStaticPosition().size()) ==
|
||||
vectorType.getRank()) {
|
||||
inferredReturnTypes.push_back(vectorType.getElementType());
|
||||
@ -1379,7 +1379,7 @@ static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
|
||||
/// Fold the result of chains of ExtractOp in place by simply concatenating the
|
||||
/// positions.
|
||||
static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
|
||||
if (!extractOp.getVector().getDefiningOp<ExtractOp>())
|
||||
if (!extractOp.getSource().getDefiningOp<ExtractOp>())
|
||||
return failure();
|
||||
|
||||
// TODO: Canonicalization for dynamic position not implemented yet.
|
||||
@ -1390,7 +1390,7 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
|
||||
ExtractOp currentOp = extractOp;
|
||||
ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
|
||||
globalPosition.append(extrPos.rbegin(), extrPos.rend());
|
||||
while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
|
||||
while (ExtractOp nextOp = currentOp.getSource().getDefiningOp<ExtractOp>()) {
|
||||
currentOp = nextOp;
|
||||
// TODO: Canonicalization for dynamic position not implemented yet.
|
||||
if (currentOp.hasDynamicPosition())
|
||||
@ -1398,7 +1398,7 @@ static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
|
||||
ArrayRef<int64_t> extrPos = currentOp.getStaticPosition();
|
||||
globalPosition.append(extrPos.rbegin(), extrPos.rend());
|
||||
}
|
||||
extractOp.setOperand(0, currentOp.getVector());
|
||||
extractOp.setOperand(0, currentOp.getSource());
|
||||
// OpBuilder is only used as a helper to build an I64ArrayAttr.
|
||||
OpBuilder b(extractOp.getContext());
|
||||
std::reverse(globalPosition.begin(), globalPosition.end());
|
||||
@ -1584,7 +1584,7 @@ Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
|
||||
return Value();
|
||||
|
||||
// If we can't fold (either internal transposition, or nothing to fold), bail.
|
||||
bool nothingToFold = (source == extractOp.getVector());
|
||||
bool nothingToFold = (source == extractOp.getSource());
|
||||
if (nothingToFold || !canFold())
|
||||
return Value();
|
||||
|
||||
@ -1592,7 +1592,7 @@ Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
|
||||
OpBuilder b(extractOp.getContext());
|
||||
extractOp.setStaticPosition(
|
||||
ArrayRef(extractPosition).take_front(extractedRank));
|
||||
extractOp.getVectorMutable().assign(source);
|
||||
extractOp.getSourceMutable().assign(source);
|
||||
return extractOp.getResult();
|
||||
}
|
||||
|
||||
@ -1602,7 +1602,7 @@ Value ExtractFromInsertTransposeChainState::fold() {
|
||||
if (extractOp.hasDynamicPosition())
|
||||
return Value();
|
||||
|
||||
Value valueToExtractFrom = extractOp.getVector();
|
||||
Value valueToExtractFrom = extractOp.getSource();
|
||||
updateStateForNextIteration(valueToExtractFrom);
|
||||
while (nextInsertOp || nextTransposeOp) {
|
||||
// Case 1. If we hit a transpose, just compose the map and iterate.
|
||||
@ -1693,7 +1693,7 @@ static bool isBroadcastLike(Operation *op) {
|
||||
/// `extract` shape.
|
||||
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
|
||||
|
||||
Operation *defOp = extractOp.getVector().getDefiningOp();
|
||||
Operation *defOp = extractOp.getSource().getDefiningOp();
|
||||
if (!defOp || !isBroadcastLike(defOp))
|
||||
return Value();
|
||||
|
||||
@ -1762,7 +1762,7 @@ static Value foldExtractFromShuffle(ExtractOp extractOp) {
|
||||
if (extractOp.hasDynamicPosition())
|
||||
return Value();
|
||||
|
||||
auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
|
||||
auto shuffleOp = extractOp.getSource().getDefiningOp<ShuffleOp>();
|
||||
if (!shuffleOp)
|
||||
return Value();
|
||||
|
||||
@ -1793,7 +1793,7 @@ static Value foldExtractFromShapeCast(ExtractOp extractOp) {
|
||||
if (extractOp.hasDynamicPosition())
|
||||
return Value();
|
||||
|
||||
auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
|
||||
auto shapeCastOp = extractOp.getSource().getDefiningOp<vector::ShapeCastOp>();
|
||||
if (!shapeCastOp)
|
||||
return Value();
|
||||
|
||||
@ -1859,7 +1859,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
|
||||
return Value();
|
||||
|
||||
auto extractStridedSliceOp =
|
||||
extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
|
||||
extractOp.getSource().getDefiningOp<vector::ExtractStridedSliceOp>();
|
||||
if (!extractStridedSliceOp)
|
||||
return Value();
|
||||
|
||||
@ -1896,7 +1896,7 @@ static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
|
||||
assert(extractedPos.size() >= sliceOffsets.size());
|
||||
for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
|
||||
extractedPos[i] = extractedPos[i] + sliceOffsets[i];
|
||||
extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
|
||||
extractOp.getSourceMutable().assign(extractStridedSliceOp.getSource());
|
||||
|
||||
// OpBuilder is only used as a helper to build an I64ArrayAttr.
|
||||
OpBuilder b(extractOp.getContext());
|
||||
@ -1914,7 +1914,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
|
||||
llvm::isa<VectorType>(extractOp.getType())
|
||||
? llvm::cast<VectorType>(extractOp.getType()).getRank()
|
||||
: 0;
|
||||
auto insertOp = extractOp.getVector().getDefiningOp<InsertStridedSliceOp>();
|
||||
auto insertOp = extractOp.getSource().getDefiningOp<InsertStridedSliceOp>();
|
||||
if (!insertOp)
|
||||
return Value();
|
||||
|
||||
@ -1966,7 +1966,7 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) {
|
||||
insertRankDiff))
|
||||
return Value();
|
||||
}
|
||||
extractOp.getVectorMutable().assign(insertOp.getValueToStore());
|
||||
extractOp.getSourceMutable().assign(insertOp.getValueToStore());
|
||||
// OpBuilder is only used as a helper to build an I64ArrayAttr.
|
||||
OpBuilder b(extractOp.getContext());
|
||||
extractOp.setStaticPosition(offsetDiffs);
|
||||
@ -1991,7 +1991,7 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
|
||||
return {};
|
||||
|
||||
// Look for extract(from_elements).
|
||||
auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
|
||||
auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
|
||||
if (!fromElementsOp)
|
||||
return {};
|
||||
|
||||
@ -2142,20 +2142,20 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
|
||||
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
|
||||
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
|
||||
// mismatch).
|
||||
if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
|
||||
return getVector();
|
||||
if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
|
||||
if (getNumIndices() == 0 && getSource().getType() == getResult().getType())
|
||||
return getSource();
|
||||
if (auto res = foldPoisonSrcExtractOp(adaptor.getSource()))
|
||||
return res;
|
||||
// Fold `arith.constant` indices into the `vector.extract` operation.
|
||||
// Do not stop here as this fold may enable subsequent folds that require
|
||||
// constant indices.
|
||||
SmallVector<Value> operands = {getVector()};
|
||||
SmallVector<Value> operands = {getSource()};
|
||||
auto inplaceFolded = extractInsertFoldConstantOp(*this, adaptor, operands);
|
||||
|
||||
if (auto res = foldPoisonIndexInsertExtractOp(
|
||||
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
|
||||
return res;
|
||||
if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
|
||||
if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getSource()))
|
||||
return res;
|
||||
if (succeeded(foldExtractOpFromExtractChain(*this)))
|
||||
return getResult();
|
||||
@ -2187,7 +2187,7 @@ public:
|
||||
LogicalResult matchAndRewrite(ExtractOp extractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Operation *defOp = extractOp.getVector().getDefiningOp();
|
||||
Operation *defOp = extractOp.getSource().getDefiningOp();
|
||||
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
|
||||
if (!defOp || !isBroadcastLike(defOp) || !outType)
|
||||
return failure();
|
||||
@ -2210,7 +2210,7 @@ public:
|
||||
LogicalResult matchAndRewrite(ExtractOp extractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto createMaskOp =
|
||||
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
|
||||
extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
|
||||
if (!createMaskOp)
|
||||
return failure();
|
||||
|
||||
@ -2271,7 +2271,7 @@ public:
|
||||
// does not change.
|
||||
LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp,
|
||||
PatternRewriter &rewriter) {
|
||||
auto castOp = extractOp.getVector().getDefiningOp<ShapeCastOp>();
|
||||
auto castOp = extractOp.getSource().getDefiningOp<ShapeCastOp>();
|
||||
if (!castOp)
|
||||
return failure();
|
||||
|
||||
@ -2306,7 +2306,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
|
||||
return failure();
|
||||
|
||||
// Look for extracts from a from_elements op.
|
||||
auto fromElementsOp = extractOp.getVector().getDefiningOp<FromElementsOp>();
|
||||
auto fromElementsOp = extractOp.getSource().getDefiningOp<FromElementsOp>();
|
||||
if (!fromElementsOp)
|
||||
return failure();
|
||||
VectorType inputType = fromElementsOp.getType();
|
||||
@ -2558,8 +2558,8 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
|
||||
// Check condition (i) by checking that all elements have the same source
|
||||
// as the first element.
|
||||
if (insertIndex == 0) {
|
||||
source = extractOp.getVector();
|
||||
} else if (extractOp.getVector() != source) {
|
||||
source = extractOp.getSource();
|
||||
} else if (extractOp.getSource() != source) {
|
||||
return rewriter.notifyMatchFailure(fromElements,
|
||||
"element from different vector");
|
||||
}
|
||||
@ -4095,7 +4095,7 @@ foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
|
||||
ArrayAttr extractOffsets = op.getOffsets();
|
||||
ArrayAttr extractStrides = op.getStrides();
|
||||
ArrayAttr extractSizes = op.getSizes();
|
||||
auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
|
||||
auto insertOp = op.getSource().getDefiningOp<InsertStridedSliceOp>();
|
||||
while (insertOp) {
|
||||
if (op.getSourceVectorType().getRank() !=
|
||||
insertOp.getSourceVectorType().getRank())
|
||||
@ -4199,17 +4199,17 @@ foldExtractStridedSliceNonSplatConstant(ExtractStridedSliceOp op,
|
||||
|
||||
OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) {
|
||||
if (getSourceVectorType() == getResult().getType())
|
||||
return getVector();
|
||||
return getSource();
|
||||
if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
|
||||
return getResult();
|
||||
|
||||
// ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
|
||||
if (auto splat =
|
||||
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector()))
|
||||
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
|
||||
DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>());
|
||||
|
||||
// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp.
|
||||
return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getVector());
|
||||
return foldExtractStridedSliceNonSplatConstant(*this, adaptor.getSource());
|
||||
}
|
||||
|
||||
void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
|
||||
@ -4241,7 +4241,7 @@ public:
|
||||
// Return if 'extractStridedSliceOp' operand is not defined by a
|
||||
// CreateMaskOp.
|
||||
auto createMaskOp =
|
||||
extractStridedSliceOp.getVector().getDefiningOp<CreateMaskOp>();
|
||||
extractStridedSliceOp.getSource().getDefiningOp<CreateMaskOp>();
|
||||
if (!createMaskOp)
|
||||
return failure();
|
||||
// Return if 'extractStridedSliceOp' has non-unit strides.
|
||||
@ -4298,7 +4298,7 @@ public:
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Return if 'extractStridedSliceOp' operand is not defined by a
|
||||
// ConstantMaskOp.
|
||||
auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
|
||||
auto *defOp = extractStridedSliceOp.getSource().getDefiningOp();
|
||||
auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
|
||||
if (!constantMaskOp)
|
||||
return failure();
|
||||
@ -4351,7 +4351,7 @@ public:
|
||||
|
||||
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
|
||||
auto broadcast = op.getSource().getDefiningOp<BroadcastOp>();
|
||||
if (!broadcast)
|
||||
return failure();
|
||||
auto srcVecType =
|
||||
@ -4403,7 +4403,7 @@ public:
|
||||
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Value splat = getScalarSplatSource(op.getVector());
|
||||
Value splat = getScalarSplatSource(op.getSource());
|
||||
if (!splat)
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
|
||||
|
||||
@ -1341,7 +1341,7 @@ struct WarpOpExtractStridedSlice : public WarpDistributionPattern {
|
||||
VectorType::get(newDistributedShape, distributedType.getElementType());
|
||||
SmallVector<size_t> newRetIndices;
|
||||
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
|
||||
rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
|
||||
rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
|
||||
newRetIndices);
|
||||
rewriter.setInsertionPointAfter(newWarpOp);
|
||||
SmallVector<Attribute> distributedSizes = llvm::map_to_vector(
|
||||
@ -1395,7 +1395,7 @@ struct WarpOpExtract : public WarpDistributionPattern {
|
||||
// the 1d case).
|
||||
SmallVector<size_t> newRetIndices;
|
||||
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
|
||||
rewriter, warpOp, {extractOp.getVector()},
|
||||
rewriter, warpOp, {extractOp.getSource()},
|
||||
{extractOp.getSourceVectorType()}, newRetIndices);
|
||||
rewriter.setInsertionPointAfter(newWarpOp);
|
||||
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
|
||||
@ -1424,7 +1424,7 @@ struct WarpOpExtract : public WarpDistributionPattern {
|
||||
VectorType::get(newDistributedShape, distributedType.getElementType());
|
||||
SmallVector<size_t> newRetIndices;
|
||||
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
|
||||
rewriter, warpOp, {extractOp.getVector()}, {newDistributedType},
|
||||
rewriter, warpOp, {extractOp.getSource()}, {newDistributedType},
|
||||
newRetIndices);
|
||||
rewriter.setInsertionPointAfter(newWarpOp);
|
||||
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
|
||||
@ -1478,7 +1478,7 @@ struct WarpOpExtractScalar : public WarpDistributionPattern {
|
||||
distributedVecType = extractSrcType;
|
||||
}
|
||||
// Yield source vector and position (if present) from warp op.
|
||||
SmallVector<Value> additionalResults{extractOp.getVector()};
|
||||
SmallVector<Value> additionalResults{extractOp.getSource()};
|
||||
SmallVector<Type> additionalResultTypes{distributedVecType};
|
||||
additionalResults.append(
|
||||
SmallVector<Value>(extractOp.getDynamicPosition()));
|
||||
|
||||
@ -78,7 +78,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
|
||||
Location loc = extractOp.getLoc();
|
||||
|
||||
Value newSrcVector = vector::ExtractOp::create(
|
||||
rewriter, loc, extractOp.getVector(), splatZero(dropCount));
|
||||
rewriter, loc, extractOp.getSource(), splatZero(dropCount));
|
||||
|
||||
// The offsets/sizes/strides attribute can have a less number of elements
|
||||
// than the input vector's rank: it is meant for the leading dimensions.
|
||||
|
||||
@ -94,7 +94,7 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
|
||||
!isa<arith::ConstantOp, vector::CreateMaskOp, vector::ConstantMaskOp>(
|
||||
maskOp)) {
|
||||
if (auto extractOp = dyn_cast<vector::ExtractOp>(maskOp)) {
|
||||
maskOp = extractOp.getVector().getDefiningOp();
|
||||
maskOp = extractOp.getSource().getDefiningOp();
|
||||
extractOps.push_back(extractOp);
|
||||
}
|
||||
}
|
||||
|
||||
@ -213,8 +213,8 @@ public:
|
||||
for (int64_t off = offset, e = offset + size * stride; off < e;
|
||||
off += stride)
|
||||
offsets.push_back(off);
|
||||
rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
|
||||
op.getVector(), offsets);
|
||||
rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getSource(),
|
||||
op.getSource(), offsets);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -250,7 +250,7 @@ public:
|
||||
SmallVector<Value> elements;
|
||||
elements.reserve(size);
|
||||
for (int64_t i = offset, e = offset + size * stride; i < e; i += stride)
|
||||
elements.push_back(ExtractOp::create(rewriter, loc, op.getVector(), i));
|
||||
elements.push_back(ExtractOp::create(rewriter, loc, op.getSource(), i));
|
||||
|
||||
Value result = arith::ConstantOp::create(
|
||||
rewriter, loc, rewriter.getZeroAttr(op.getType()));
|
||||
@ -306,7 +306,7 @@ public:
|
||||
Value res = BroadcastOp::create(rewriter, loc, dstType, zero);
|
||||
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
|
||||
off += stride, ++idx) {
|
||||
Value one = ExtractOp::create(rewriter, loc, op.getVector(), off);
|
||||
Value one = ExtractOp::create(rewriter, loc, op.getSource(), off);
|
||||
Value extracted = ExtractStridedSliceOp::create(
|
||||
rewriter, loc, one, getI64SubArray(op.getOffsets(), /* dropFront=*/1),
|
||||
getI64SubArray(op.getSizes(), /* dropFront=*/1),
|
||||
|
||||
@ -252,7 +252,7 @@ struct LinearizeVectorExtractStridedSlice final
|
||||
SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
|
||||
outputShape, inputShape, offsets.value());
|
||||
|
||||
Value srcVector = adaptor.getVector();
|
||||
Value srcVector = adaptor.getSource();
|
||||
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
|
||||
extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
|
||||
return success();
|
||||
@ -438,8 +438,8 @@ struct LinearizeVectorExtract final
|
||||
return rewriter.notifyMatchFailure(extractOp,
|
||||
"dynamic position is not supported.");
|
||||
|
||||
llvm::ArrayRef<int64_t> shape = extractOp.getVector().getType().getShape();
|
||||
int64_t size = extractOp.getVector().getType().getNumElements();
|
||||
llvm::ArrayRef<int64_t> shape = extractOp.getSource().getType().getShape();
|
||||
int64_t size = extractOp.getSource().getType().getNumElements();
|
||||
|
||||
// Compute linearized offset.
|
||||
int64_t linearizedOffset = 0;
|
||||
@ -449,7 +449,7 @@ struct LinearizeVectorExtract final
|
||||
linearizedOffset += offsets[i] * size;
|
||||
}
|
||||
|
||||
Value srcVector = adaptor.getVector();
|
||||
Value srcVector = adaptor.getSource();
|
||||
if (!isa<VectorType>(extractOp.getType())) {
|
||||
// Scalar case: generate a 1-D extract.
|
||||
Value result = rewriter.createOrFold<vector::ExtractOp>(
|
||||
|
||||
@ -1007,7 +1007,7 @@ public:
|
||||
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Match phase.
|
||||
auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
|
||||
auto xferOp = extractOp.getSource().getDefiningOp<vector::TransferReadOp>();
|
||||
if (!xferOp)
|
||||
return failure();
|
||||
// Check that we are extracting a scalar and not a sub-vector.
|
||||
|
||||
@ -576,7 +576,7 @@ struct BubbleDownVectorBitCastForExtract
|
||||
if (extractOp.getSourceVectorType().getRank() != 1)
|
||||
return failure();
|
||||
|
||||
auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
|
||||
auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
|
||||
if (!castOp)
|
||||
return failure();
|
||||
|
||||
@ -647,7 +647,7 @@ struct BubbleDownBitCastForStridedSliceExtract
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto castOp = extractOp.getVector().getDefiningOp<vector::BitCastOp>();
|
||||
auto castOp = extractOp.getSource().getDefiningOp<vector::BitCastOp>();
|
||||
if (!castOp)
|
||||
return failure();
|
||||
|
||||
@ -1135,7 +1135,7 @@ public:
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ExtractOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Operation *eltwise = op.getVector().getDefiningOp();
|
||||
Operation *eltwise = op.getSource().getDefiningOp();
|
||||
|
||||
// TODO: vector::FMAOp is not an ElemetwiseMappable even if it claims to be,
|
||||
// as it doesn't support scalars.
|
||||
@ -1210,7 +1210,7 @@ public:
|
||||
|
||||
LogicalResult matchAndRewrite(vector::ExtractOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
|
||||
auto loadOp = op.getSource().getDefiningOp<vector::LoadOp>();
|
||||
if (!loadOp)
|
||||
return rewriter.notifyMatchFailure(op, "expected a load op");
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user