[mlir][vector] Standardize base
Naming Across Vector Ops (NFC) (#137859)
[mlir][vector] Standardize base Naming Across Vector Ops (NFC) This change standardizes the naming convention for the argument representing the value to read from or write to in Vector ops that interface with Tensors or MemRefs. Specifically, it ensures that all such ops use the name `base` (i.e., the base address or location to which offsets are applied). Updated operations: * `vector.transfer_read`, * `vector.transfer_write`. For reference, these ops already use `base`: * `vector.load`, `vector.store`, `vector.scatter`, `vector.gather`, `vector.expandload`, `vector.compressstore`, `vector.maskedstore`, `vector.maskedload`. This is a non-functional change (NFC) and does not alter the semantics of these operations. However, it does require users of the XFer ops to switch from `op.getSource()` to `op.getBase()`. To ease the transition, this PR temporarily adds a `getSource()` interface method for compatibility. This is intended for downstream use only and should not be relied on upstream. The method will be removed prior to the LLVM 21 release. Implements #131602
This commit is contained in:
parent
5a1edf0f51
commit
c45cc3e420
@ -1273,7 +1273,7 @@ def Vector_TransferReadOp :
|
||||
AttrSizedOperandSegments,
|
||||
DestinationStyleOpInterface
|
||||
]>,
|
||||
Arguments<(ins AnyShaped:$source,
|
||||
Arguments<(ins AnyShaped:$base,
|
||||
Variadic<Index>:$indices,
|
||||
AffineMapAttr:$permutation_map,
|
||||
AnyType:$padding,
|
||||
@ -1522,7 +1522,7 @@ def Vector_TransferWriteOp :
|
||||
DestinationStyleOpInterface
|
||||
]>,
|
||||
Arguments<(ins AnyVectorOfAnyRank:$valueToStore,
|
||||
AnyShaped:$source,
|
||||
AnyShaped:$base,
|
||||
Variadic<Index>:$indices,
|
||||
AffineMapAttr:$permutation_map,
|
||||
Optional<VectorOfNonZeroRankOf<[I1]>>:$mask,
|
||||
@ -1663,7 +1663,7 @@ def Vector_TransferWriteOp :
|
||||
/// ops of other dialects.
|
||||
Value getValue() { return getVector(); }
|
||||
|
||||
MutableOperandRange getDpsInitsMutable() { return getSourceMutable(); }
|
||||
MutableOperandRange getDpsInitsMutable() { return getBaseMutable(); }
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
|
@ -111,7 +111,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
|
||||
TODO: Change name of operand, which is not accurate for xfer_write.
|
||||
}],
|
||||
/*retTy=*/"::mlir::Value",
|
||||
/*methodName=*/"getSource",
|
||||
/*methodName=*/"getBase",
|
||||
/*args=*/(ins)
|
||||
>,
|
||||
InterfaceMethod<
|
||||
@ -187,6 +187,12 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
|
||||
return inBounds;
|
||||
}
|
||||
|
||||
/// Wrapper for getBase, which replaced getSource.
|
||||
[[deprecated("Use getBase instead!")]]
|
||||
::mlir::Value getSource() {
|
||||
return $_op.getBase();
|
||||
}
|
||||
|
||||
/// Return the number of leading shaped dimensions (of the "source" operand)
|
||||
/// that do not participate in the permutation map.
|
||||
unsigned getLeadingShapedRank() {
|
||||
@ -203,7 +209,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
|
||||
|
||||
/// Return the shaped type of the "source" operand value.
|
||||
::mlir::ShapedType getShapedType() {
|
||||
return ::llvm::cast<::mlir::ShapedType>($_op.getSource().getType());
|
||||
return ::llvm::cast<::mlir::ShapedType>($_op.getBase().getType());
|
||||
}
|
||||
|
||||
/// Return the number of dimensions that participate in the permutation map.
|
||||
|
@ -58,7 +58,7 @@ struct TransferReadToArmSMELowering
|
||||
return rewriter.notifyMatchFailure(transferReadOp,
|
||||
"not a valid vector type for SME");
|
||||
|
||||
if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
|
||||
if (!llvm::isa<MemRefType>(transferReadOp.getBase().getType()))
|
||||
return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
|
||||
|
||||
// Out-of-bounds dims are not supported.
|
||||
@ -84,7 +84,7 @@ struct TransferReadToArmSMELowering
|
||||
auto mask = transferReadOp.getMask();
|
||||
auto padding = mask ? transferReadOp.getPadding() : nullptr;
|
||||
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
|
||||
transferReadOp, vectorType, transferReadOp.getSource(),
|
||||
transferReadOp, vectorType, transferReadOp.getBase(),
|
||||
transferReadOp.getIndices(), padding, mask, layout);
|
||||
|
||||
return success();
|
||||
@ -128,7 +128,7 @@ struct TransferWriteToArmSMELowering
|
||||
if (!arm_sme::isValidSMETileVectorType(vType))
|
||||
return failure();
|
||||
|
||||
if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
|
||||
if (!llvm::isa<MemRefType>(writeOp.getBase().getType()))
|
||||
return failure();
|
||||
|
||||
// Out-of-bounds dims are not supported.
|
||||
@ -149,7 +149,7 @@ struct TransferWriteToArmSMELowering
|
||||
: arm_sme::TileSliceLayout::Horizontal;
|
||||
|
||||
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
|
||||
writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
|
||||
writeOp, writeOp.getVector(), writeOp.getBase(), writeOp.getIndices(),
|
||||
writeOp.getMask(), layout);
|
||||
return success();
|
||||
}
|
||||
@ -686,7 +686,7 @@ struct FoldTransferWriteOfExtractTileSlice
|
||||
|
||||
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
|
||||
PatternRewriter &rewriter) const final {
|
||||
if (!isa<MemRefType>(writeOp.getSource().getType()))
|
||||
if (!isa<MemRefType>(writeOp.getBase().getType()))
|
||||
return rewriter.notifyMatchFailure(writeOp, "destination not a memref");
|
||||
|
||||
if (writeOp.hasOutOfBoundsDim())
|
||||
@ -713,7 +713,7 @@ struct FoldTransferWriteOfExtractTileSlice
|
||||
|
||||
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
|
||||
writeOp, extractTileSlice.getTile(),
|
||||
extractTileSlice.getTileSliceIndex(), mask, writeOp.getSource(),
|
||||
extractTileSlice.getTileSliceIndex(), mask, writeOp.getBase(),
|
||||
writeOp.getIndices(), extractTileSlice.getLayout());
|
||||
return success();
|
||||
}
|
||||
|
@ -486,7 +486,7 @@ struct CombineTransferReadOpTranspose final
|
||||
Value result =
|
||||
rewriter
|
||||
.create<vector::TransferReadOp>(
|
||||
loc, resultType, transferReadOp.getSource(),
|
||||
loc, resultType, transferReadOp.getBase(),
|
||||
transferReadOp.getIndices(), AffineMapAttr::get(newMap),
|
||||
transferReadOp.getPadding(), transferReadOp.getMask(),
|
||||
transferReadOp.getInBoundsAttr())
|
||||
@ -581,7 +581,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
|
||||
gpu::MMAMatrixType type =
|
||||
gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
|
||||
Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
|
||||
op.getLoc(), type, op.getSource(), op.getIndices(),
|
||||
op.getLoc(), type, op.getBase(), op.getIndices(),
|
||||
rewriter.getIndexAttr(*stride),
|
||||
isTranspose ? rewriter.getUnitAttr() : UnitAttr());
|
||||
valueMapping[mappingResult] = load;
|
||||
@ -612,7 +612,7 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
|
||||
|
||||
Value matrix = it->second;
|
||||
auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
|
||||
op.getLoc(), matrix, op.getSource(), op.getIndices(),
|
||||
op.getLoc(), matrix, op.getBase(), op.getIndices(),
|
||||
rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
|
||||
(void)store;
|
||||
|
||||
@ -759,7 +759,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
|
||||
indices);
|
||||
|
||||
nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
|
||||
loc, vectorType, op.getSource(), indices, *transpose, params->numTiles);
|
||||
loc, vectorType, op.getBase(), indices, *transpose, params->numTiles);
|
||||
valueMapping[op] = newOp->getResult(0);
|
||||
return success();
|
||||
}
|
||||
@ -818,7 +818,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
|
||||
rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
|
||||
|
||||
Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
|
||||
op.getSource(), newIndices);
|
||||
op.getBase(), newIndices);
|
||||
result = rewriter.create<vector::InsertOp>(loc, el, result, i);
|
||||
}
|
||||
} else {
|
||||
@ -841,7 +841,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
|
||||
getXferIndices<vector::TransferReadOp>(
|
||||
rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
|
||||
Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
|
||||
op.getSource(), newIndices);
|
||||
op.getBase(), newIndices);
|
||||
result = rewriter.create<vector::InsertOp>(
|
||||
op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
|
||||
}
|
||||
@ -875,7 +875,7 @@ convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
|
||||
return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
|
||||
|
||||
bool isLdMatrixCompatible =
|
||||
isSharedMemory(cast<MemRefType>(op.getSource().getType())) &&
|
||||
isSharedMemory(cast<MemRefType>(op.getBase().getType())) &&
|
||||
nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
|
||||
|
||||
VectorType vecTy = op.getVectorType();
|
||||
@ -933,7 +933,7 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
|
||||
SmallVector<Value, 4> newIndices;
|
||||
getXferIndices<vector::TransferWriteOp>(
|
||||
rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
|
||||
rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
|
||||
rewriter.create<vector::StoreOp>(loc, el, op.getBase(), newIndices);
|
||||
}
|
||||
|
||||
LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
|
||||
|
@ -198,8 +198,7 @@ static Value generateInBoundsCheck(
|
||||
Location loc = xferOp.getLoc();
|
||||
ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
|
||||
if (!xferOp.isDimInBounds(0) && !isBroadcast) {
|
||||
Value memrefDim =
|
||||
vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim);
|
||||
Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.getBase(), *dim);
|
||||
AffineExpr d0, d1;
|
||||
bindDims(xferOp.getContext(), d0, d1);
|
||||
Value base = xferOp.getIndices()[*dim];
|
||||
@ -426,7 +425,7 @@ struct Strategy<TransferReadOp> {
|
||||
auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
|
||||
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
|
||||
auto newXferOp = b.create<vector::TransferReadOp>(
|
||||
loc, vecType, xferOp.getSource(), xferIndices,
|
||||
loc, vecType, xferOp.getBase(), xferIndices,
|
||||
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
|
||||
xferOp.getPadding(), Value(), inBoundsAttr);
|
||||
|
||||
@ -512,7 +511,7 @@ struct Strategy<TransferWriteOp> {
|
||||
Location loc = xferOp.getLoc();
|
||||
auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
|
||||
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
|
||||
auto source = loopState.empty() ? xferOp.getSource() : loopState[0];
|
||||
auto source = loopState.empty() ? xferOp.getBase() : loopState[0];
|
||||
Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
|
||||
auto newXferOp = b.create<vector::TransferWriteOp>(
|
||||
loc, type, vec, source, xferIndices,
|
||||
@ -544,7 +543,7 @@ struct Strategy<TransferWriteOp> {
|
||||
|
||||
/// Return the initial loop state for the generated scf.for loop.
|
||||
static Value initialLoopState(TransferWriteOp xferOp) {
|
||||
return isTensorOp(xferOp) ? xferOp.getSource() : Value();
|
||||
return isTensorOp(xferOp) ? xferOp.getBase() : Value();
|
||||
}
|
||||
};
|
||||
|
||||
@ -1145,7 +1144,7 @@ struct ScalableTransposeTransferWriteConversion
|
||||
ArrayRef<OpFoldResult>(*maskDims).drop_front());
|
||||
}
|
||||
|
||||
Value initDest = isTensorOp(writeOp) ? writeOp.getSource() : Value{};
|
||||
Value initDest = isTensorOp(writeOp) ? writeOp.getBase() : Value{};
|
||||
ValueRange initLoopArgs = initDest ? initDest : ValueRange{};
|
||||
auto result = rewriter.create<scf::ForOp>(
|
||||
loc, lb, ub, step, initLoopArgs,
|
||||
@ -1165,7 +1164,7 @@ struct ScalableTransposeTransferWriteConversion
|
||||
|
||||
// Create the transfer_write for the slice.
|
||||
Value dest =
|
||||
loopIterArgs.empty() ? writeOp.getSource() : loopIterArgs.front();
|
||||
loopIterArgs.empty() ? writeOp.getBase() : loopIterArgs.front();
|
||||
auto newWriteOp = b.create<vector::TransferWriteOp>(
|
||||
loc, sliceVec, dest, xferIndices,
|
||||
ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
|
||||
@ -1340,7 +1339,7 @@ struct UnrollTransferReadConversion
|
||||
|
||||
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
|
||||
auto newXferOp = b.create<vector::TransferReadOp>(
|
||||
loc, newXferVecType, xferOp.getSource(), xferIndices,
|
||||
loc, newXferVecType, xferOp.getBase(), xferIndices,
|
||||
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
|
||||
xferOp.getPadding(), Value(), inBoundsAttr);
|
||||
maybeAssignMask(b, xferOp, newXferOp, i);
|
||||
@ -1449,7 +1448,7 @@ struct UnrollTransferWriteConversion
|
||||
}
|
||||
|
||||
int64_t dimSize = inputVectorTy.getShape()[0];
|
||||
Value source = xferOp.getSource(); // memref or tensor to be written to.
|
||||
Value source = xferOp.getBase(); // memref or tensor to be written to.
|
||||
auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
|
||||
|
||||
// Generate fully unrolled loop of transfer ops.
|
||||
@ -1567,8 +1566,7 @@ struct Strategy1d<TransferReadOp> {
|
||||
b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
|
||||
/*inBoundsCase=*/
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
Value val =
|
||||
b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
|
||||
Value val = b.create<memref::LoadOp>(loc, xferOp.getBase(), indices);
|
||||
return b.create<vector::InsertElementOp>(loc, val, vec, iv);
|
||||
},
|
||||
/*outOfBoundsCase=*/
|
||||
@ -1599,7 +1597,7 @@ struct Strategy1d<TransferWriteOp> {
|
||||
/*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
|
||||
auto val =
|
||||
b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
|
||||
b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
|
||||
b.create<memref::StoreOp>(loc, val, xferOp.getBase(), indices);
|
||||
});
|
||||
b.create<scf::YieldOp>(loc);
|
||||
}
|
||||
|
@ -192,7 +192,7 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
|
||||
|
||||
xegpu::CreateNdDescOp ndDesc =
|
||||
createNdDescriptor(rewriter, loc, descType,
|
||||
dyn_cast<TypedValue<MemRefType>>(readOp.getSource()),
|
||||
dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
|
||||
readOp.getIndices());
|
||||
|
||||
DenseI64ArrayAttr transposeAttr =
|
||||
@ -231,9 +231,9 @@ struct TransferWriteLowering
|
||||
vecTy.getShape(), vecTy.getElementType(),
|
||||
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
|
||||
xegpu::MemorySpace::Global);
|
||||
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
|
||||
rewriter, loc, descType,
|
||||
dyn_cast<TypedValue<MemRefType>>(writeOp.getSource()),
|
||||
xegpu::CreateNdDescOp ndDesc =
|
||||
createNdDescriptor(rewriter, loc, descType,
|
||||
dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
|
||||
writeOp.getIndices());
|
||||
|
||||
// By default, no specific caching policy is assigned.
|
||||
|
@ -118,7 +118,7 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
|
||||
Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
|
||||
readOp.getPadding());
|
||||
Value load = builder.create<vector::LoadOp>(
|
||||
loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
|
||||
loc, unbroadcastedVectorType, readOp.getBase(), readOp.getIndices());
|
||||
Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
|
||||
readOp.getMask(), load, fill);
|
||||
// Insert a broadcasting op if required.
|
||||
@ -149,7 +149,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
|
||||
}
|
||||
|
||||
Location loc = readOp.getLoc();
|
||||
Value src = readOp.getSource();
|
||||
Value src = readOp.getBase();
|
||||
|
||||
VectorType vectorType = readOp.getVectorType();
|
||||
int64_t vectorSize = vectorType.getNumElements();
|
||||
|
@ -315,7 +315,7 @@ struct LegalizeTransferReadOpsByDecomposition
|
||||
decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
|
||||
auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
|
||||
auto smeRead = rewriter.create<vector::TransferReadOp>(
|
||||
loc, smeTileType, readOp.getSource(),
|
||||
loc, smeTileType, readOp.getBase(),
|
||||
getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
|
||||
readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
|
||||
readOp.getInBoundsAttr());
|
||||
@ -359,7 +359,7 @@ struct LegalizeTransferWriteOpsByDecomposition
|
||||
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
|
||||
auto inputSMETiles = adaptor.getValueToStore();
|
||||
|
||||
Value destTensorOrMemref = writeOp.getSource();
|
||||
Value destTensorOrMemref = writeOp.getBase();
|
||||
for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
|
||||
rewriter, vectorType, smeTileType, transposed))) {
|
||||
auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
|
||||
@ -497,7 +497,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
|
||||
auto slice =
|
||||
rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex);
|
||||
rewriter.create<vector::TransferWriteOp>(
|
||||
loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
|
||||
loc, slice, writeOp.getBase(), ValueRange{storeRow, storeCol},
|
||||
AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
|
||||
sliceMask,
|
||||
rewriter.getBoolArrayAttr(
|
||||
@ -677,7 +677,7 @@ struct LiftIllegalVectorTransposeToMemory
|
||||
});
|
||||
SmallVector<Value> strides(readType.getRank(), Value(one));
|
||||
auto readSubview = rewriter.create<memref::SubViewOp>(
|
||||
loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
|
||||
loc, illegalRead.getBase(), illegalRead.getIndices(), readSizes,
|
||||
strides);
|
||||
|
||||
// Apply the transpose to all values/attributes of the transfer_read:
|
||||
@ -851,7 +851,7 @@ struct LowerIllegalTransposeStoreViaZA
|
||||
|
||||
// Note: We need to use `get_tile` as there's no vector-level `undef`.
|
||||
Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType);
|
||||
Value destTensorOrMemref = writeOp.getSource();
|
||||
Value destTensorOrMemref = writeOp.getBase();
|
||||
auto numSlicesPerTile =
|
||||
std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
|
||||
auto numSlices =
|
||||
|
@ -171,7 +171,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
|
||||
|
||||
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
|
||||
LoopLikeOpInterface loop) {
|
||||
Value source = transferRead.getSource();
|
||||
Value source = transferRead.getBase();
|
||||
|
||||
// Skip view-like Ops and retrive the actual soruce Operation
|
||||
while (auto srcOp =
|
||||
@ -276,7 +276,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
|
||||
for (auto *sliceOp : llvm::reverse(forwardSlice)) {
|
||||
auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
|
||||
if (!candidateWrite ||
|
||||
candidateWrite.getSource() != transferRead.getSource())
|
||||
candidateWrite.getBase() != transferRead.getBase())
|
||||
continue;
|
||||
transferWrite = candidateWrite;
|
||||
}
|
||||
@ -312,11 +312,11 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
|
||||
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
|
||||
return WalkResult::advance();
|
||||
|
||||
auto *source = transferRead.getSource().getDefiningOp();
|
||||
auto *source = transferRead.getBase().getDefiningOp();
|
||||
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
|
||||
return WalkResult::advance();
|
||||
|
||||
source = transferWrite.getSource().getDefiningOp();
|
||||
source = transferWrite.getBase().getDefiningOp();
|
||||
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
|
||||
return WalkResult::advance();
|
||||
|
||||
@ -325,7 +325,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
|
||||
DominanceInfo dom(loop);
|
||||
if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
|
||||
return WalkResult::advance();
|
||||
for (auto &use : transferRead.getSource().getUses()) {
|
||||
for (auto &use : transferRead.getBase().getUses()) {
|
||||
if (!loop->isAncestor(use.getOwner()))
|
||||
continue;
|
||||
if (use.getOwner() == transferRead.getOperation() ||
|
||||
|
@ -2627,7 +2627,7 @@ struct PadOpVectorizationWithTransferReadPattern
|
||||
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
|
||||
xferOp->setAttr(xferOp.getInBoundsAttrName(),
|
||||
rewriter.getBoolArrayAttr(inBounds));
|
||||
xferOp.getSourceMutable().assign(padOp.getSource());
|
||||
xferOp.getBaseMutable().assign(padOp.getSource());
|
||||
xferOp.getPaddingMutable().assign(padValue);
|
||||
});
|
||||
|
||||
@ -3114,7 +3114,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
||||
return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
|
||||
|
||||
// Transfer into `view`.
|
||||
Value viewOrAlloc = xferOp.getSource();
|
||||
Value viewOrAlloc = xferOp.getBase();
|
||||
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
|
||||
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
|
||||
return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
|
||||
@ -3191,7 +3191,7 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
|
||||
return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
|
||||
|
||||
// Transfer into `viewOrAlloc`.
|
||||
Value viewOrAlloc = xferOp.getSource();
|
||||
Value viewOrAlloc = xferOp.getBase();
|
||||
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
|
||||
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
|
||||
return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
|
||||
|
@ -119,7 +119,7 @@ static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter,
|
||||
template <typename TransferLikeOp>
|
||||
static FailureOr<Value>
|
||||
getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
|
||||
Value src = transferLikeOp.getSource();
|
||||
Value src = transferLikeOp.getBase();
|
||||
if (isa<MemRefType>(src.getType()))
|
||||
return src;
|
||||
return failure();
|
||||
|
@ -224,7 +224,7 @@ static Value getMemRefOperand(LoadOrStoreOpTy op) {
|
||||
}
|
||||
|
||||
static Value getMemRefOperand(vector::TransferReadOp op) {
|
||||
return op.getSource();
|
||||
return op.getBase();
|
||||
}
|
||||
|
||||
static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
|
||||
@ -240,7 +240,7 @@ static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
|
||||
static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
|
||||
|
||||
static Value getMemRefOperand(vector::TransferWriteOp op) {
|
||||
return op.getSource();
|
||||
return op.getBase();
|
||||
}
|
||||
|
||||
static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
|
||||
|
@ -172,7 +172,7 @@ static Value getValueLoadedFromGlobal(Operation *op) {
|
||||
if (!load)
|
||||
return nullptr;
|
||||
|
||||
auto loadType = dyn_cast<MemRefType>(load.getSource().getType());
|
||||
auto loadType = dyn_cast<MemRefType>(load.getBase().getType());
|
||||
if (!loadType || !hasDefaultMemorySpace(loadType))
|
||||
return nullptr;
|
||||
return load;
|
||||
@ -185,7 +185,7 @@ static bool isStoreToShared(Operation *op, Value v) {
|
||||
if (!store || store.getVector() != v)
|
||||
return false;
|
||||
|
||||
auto storeType = dyn_cast<MemRefType>(store.getSource().getType());
|
||||
auto storeType = dyn_cast<MemRefType>(store.getBase().getType());
|
||||
return storeType || hasSharedMemorySpace(storeType);
|
||||
}
|
||||
|
||||
|
@ -71,9 +71,9 @@ Value nvgpu::getMemrefOperand(Operation *op) {
|
||||
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
|
||||
return storeOp.getMemref();
|
||||
if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
|
||||
return transferWrite.getSource();
|
||||
return transferWrite.getBase();
|
||||
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
|
||||
return transferRead.getSource();
|
||||
return transferRead.getBase();
|
||||
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
|
||||
return storeOp.getBase();
|
||||
if (auto loadOp = dyn_cast<vector::LoadOp>(op))
|
||||
|
@ -285,7 +285,7 @@ bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) {
|
||||
// information to ensure correctness of downstream assumptions. It is possible
|
||||
// to enable this if caller can assert that tensor will be lowered in a
|
||||
// particular manner.
|
||||
auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
|
||||
auto sourceType = dyn_cast<MemRefType>(op.getBase().getType());
|
||||
if (!sourceType)
|
||||
return false;
|
||||
|
||||
@ -309,7 +309,7 @@ bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) {
|
||||
return false;
|
||||
// Currently we can't support reads on tensor types because we need stride
|
||||
// information to ensure correctness of downstream assumptions.
|
||||
auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
|
||||
auto sourceType = dyn_cast<MemRefType>(op.getBase().getType());
|
||||
if (!sourceType)
|
||||
return false;
|
||||
|
||||
|
@ -36,7 +36,7 @@ namespace tensor {
|
||||
using namespace mlir;
|
||||
|
||||
static Value getTensorOperand(vector::TransferReadOp op) {
|
||||
return op.getSource();
|
||||
return op.getBase();
|
||||
}
|
||||
|
||||
static Value getTensorOperand(tensor::InsertSliceOp op) {
|
||||
|
@ -314,7 +314,7 @@ bool mlir::vector::isDisjointTransferIndices(
|
||||
bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
|
||||
VectorTransferOpInterface transferB,
|
||||
bool testDynamicValueUsingBounds) {
|
||||
if (transferA.getSource() != transferB.getSource())
|
||||
if (transferA.getBase() != transferB.getBase())
|
||||
return false;
|
||||
return isDisjointTransferIndices(transferA, transferB,
|
||||
testDynamicValueUsingBounds);
|
||||
@ -4205,7 +4205,7 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
|
||||
}
|
||||
|
||||
void TransferReadOp::print(OpAsmPrinter &p) {
|
||||
p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
|
||||
p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
|
||||
if (getMask())
|
||||
p << ", " << getMask();
|
||||
printTransferAttrs(p, *this);
|
||||
@ -4464,7 +4464,7 @@ static LogicalResult foldTransferFullMask(TransferOp op) {
|
||||
static Value foldRAW(TransferReadOp readOp) {
|
||||
if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
|
||||
return {};
|
||||
auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
|
||||
auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
|
||||
while (defWrite) {
|
||||
if (checkSameValueRAW(defWrite, readOp))
|
||||
return defWrite.getVector();
|
||||
@ -4472,7 +4472,7 @@ static Value foldRAW(TransferReadOp readOp) {
|
||||
cast<VectorTransferOpInterface>(defWrite.getOperation()),
|
||||
cast<VectorTransferOpInterface>(readOp.getOperation())))
|
||||
break;
|
||||
defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
|
||||
defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
@ -4500,7 +4500,7 @@ void TransferReadOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
if (llvm::isa<MemRefType>(getShapedType()))
|
||||
effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
|
||||
effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
|
||||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
|
||||
@ -4542,7 +4542,7 @@ struct TransferReadAfterWriteToBroadcast
|
||||
if (readOp.hasOutOfBoundsDim() ||
|
||||
!llvm::isa<RankedTensorType>(readOp.getShapedType()))
|
||||
return failure();
|
||||
auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
|
||||
auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
|
||||
if (!defWrite)
|
||||
return failure();
|
||||
// TODO: If the written transfer chunk is a superset of the read transfer
|
||||
@ -4727,7 +4727,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
|
||||
}
|
||||
|
||||
void TransferWriteOp::print(OpAsmPrinter &p) {
|
||||
p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
|
||||
p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
|
||||
if (getMask())
|
||||
p << ", " << getMask();
|
||||
printTransferAttrs(p, *this);
|
||||
@ -4806,7 +4806,7 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write,
|
||||
if (write.getTransferRank() == 0)
|
||||
return failure();
|
||||
auto rankedTensorType =
|
||||
llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
|
||||
llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
|
||||
// If not operating on tensors, bail.
|
||||
if (!rankedTensorType)
|
||||
return failure();
|
||||
@ -4828,7 +4828,7 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write,
|
||||
if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
|
||||
return failure();
|
||||
// Tensor types must be the same.
|
||||
if (read.getSource().getType() != rankedTensorType)
|
||||
if (read.getBase().getType() != rankedTensorType)
|
||||
return failure();
|
||||
// Vector types must be the same.
|
||||
if (read.getVectorType() != write.getVectorType())
|
||||
@ -4845,13 +4845,13 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write,
|
||||
llvm::any_of(write.getIndices(), isNotConstantZero))
|
||||
return failure();
|
||||
// Success.
|
||||
results.push_back(read.getSource());
|
||||
results.push_back(read.getBase());
|
||||
return success();
|
||||
}
|
||||
|
||||
static bool checkSameValueWAR(vector::TransferReadOp read,
|
||||
vector::TransferWriteOp write) {
|
||||
return read.getSource() == write.getSource() &&
|
||||
return read.getBase() == write.getBase() &&
|
||||
read.getIndices() == write.getIndices() &&
|
||||
read.getPermutationMap() == write.getPermutationMap() &&
|
||||
read.getVectorType() == write.getVectorType() && !read.getMask() &&
|
||||
@ -4873,7 +4873,7 @@ static bool checkSameValueWAR(vector::TransferReadOp read,
|
||||
/// ```
|
||||
static LogicalResult foldWAR(TransferWriteOp write,
|
||||
SmallVectorImpl<OpFoldResult> &results) {
|
||||
if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
|
||||
if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
|
||||
return failure();
|
||||
auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
|
||||
if (!read)
|
||||
@ -4881,7 +4881,7 @@ static LogicalResult foldWAR(TransferWriteOp write,
|
||||
|
||||
if (!checkSameValueWAR(read, write))
|
||||
return failure();
|
||||
results.push_back(read.getSource());
|
||||
results.push_back(read.getBase());
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -4953,12 +4953,11 @@ public:
|
||||
return failure();
|
||||
vector::TransferWriteOp writeToModify = writeOp;
|
||||
|
||||
auto defWrite =
|
||||
writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
|
||||
auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
|
||||
while (defWrite) {
|
||||
if (checkSameValueWAW(writeOp, defWrite)) {
|
||||
rewriter.modifyOpInPlace(writeToModify, [&]() {
|
||||
writeToModify.getSourceMutable().assign(defWrite.getSource());
|
||||
writeToModify.getBaseMutable().assign(defWrite.getBase());
|
||||
});
|
||||
return success();
|
||||
}
|
||||
@ -4971,7 +4970,7 @@ public:
|
||||
if (!defWrite->hasOneUse())
|
||||
break;
|
||||
writeToModify = defWrite;
|
||||
defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
|
||||
defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
@ -52,7 +52,7 @@ struct TransferReadOpInterface
|
||||
auto readOp = cast<vector::TransferReadOp>(op);
|
||||
assert(isa<TensorType>(readOp.getShapedType()) &&
|
||||
"only tensor types expected");
|
||||
FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options);
|
||||
FailureOr<Value> buffer = getBuffer(rewriter, readOp.getBase(), options);
|
||||
if (failed(buffer))
|
||||
return failure();
|
||||
replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
|
||||
@ -110,7 +110,7 @@ struct TransferWriteOpInterface
|
||||
|
||||
// Create a new transfer_write on buffer that doesn't have a return value.
|
||||
FailureOr<Value> resultBuffer =
|
||||
getBuffer(rewriter, writeOp.getSource(), options);
|
||||
getBuffer(rewriter, writeOp.getBase(), options);
|
||||
if (failed(resultBuffer))
|
||||
return failure();
|
||||
rewriter.create<vector::TransferWriteOp>(
|
||||
|
@ -222,7 +222,7 @@ public:
|
||||
|
||||
// Replace the `vector.mask` operation.
|
||||
rewriter.replaceOpWithNewOp<TransferReadOp>(
|
||||
maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(),
|
||||
maskingOp.getOperation(), readOp.getVectorType(), readOp.getBase(),
|
||||
readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
|
||||
maskingOp.getMask(), readOp.getInBounds());
|
||||
return success();
|
||||
@ -245,7 +245,7 @@ public:
|
||||
// Replace the `vector.mask` operation.
|
||||
rewriter.replaceOpWithNewOp<TransferWriteOp>(
|
||||
maskingOp.getOperation(), resultType, writeOp.getVector(),
|
||||
writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(),
|
||||
writeOp.getBase(), writeOp.getIndices(), writeOp.getPermutationMap(),
|
||||
maskingOp.getMask(), writeOp.getInBounds());
|
||||
return success();
|
||||
}
|
||||
|
@ -139,7 +139,7 @@ struct TransferReadPermutationLowering
|
||||
VectorType newReadType = VectorType::get(
|
||||
newVectorShape, op.getVectorType().getElementType(), newScalableDims);
|
||||
Value newRead = rewriter.create<vector::TransferReadOp>(
|
||||
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
|
||||
op.getLoc(), newReadType, op.getBase(), op.getIndices(),
|
||||
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
|
||||
newInBoundsAttr);
|
||||
|
||||
@ -214,7 +214,7 @@ struct TransferWritePermutationLowering
|
||||
auto newMap = AffineMap::getMinorIdentityMap(
|
||||
map.getNumDims(), map.getNumResults(), rewriter.getContext());
|
||||
auto newWrite = rewriter.create<vector::TransferWriteOp>(
|
||||
op.getLoc(), newVec, op.getSource(), op.getIndices(),
|
||||
op.getLoc(), newVec, op.getBase(), op.getIndices(),
|
||||
AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
|
||||
if (newWrite.hasPureTensorSemantics())
|
||||
return newWrite.getResult();
|
||||
@ -300,7 +300,7 @@ struct TransferWriteNonPermutationLowering
|
||||
}
|
||||
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
|
||||
auto newWrite = rewriter.create<vector::TransferWriteOp>(
|
||||
op.getLoc(), newVec, op.getSource(), op.getIndices(),
|
||||
op.getLoc(), newVec, op.getBase(), op.getIndices(),
|
||||
AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
|
||||
if (newWrite.hasPureTensorSemantics())
|
||||
return newWrite.getResult();
|
||||
@ -371,7 +371,7 @@ struct TransferOpReduceRank
|
||||
op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
|
||||
: ArrayAttr();
|
||||
Value newRead = rewriter.create<vector::TransferReadOp>(
|
||||
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
|
||||
op.getLoc(), newReadType, op.getBase(), op.getIndices(),
|
||||
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
|
||||
newInBoundsAttr);
|
||||
return rewriter
|
||||
@ -474,12 +474,12 @@ struct TransferReadToVectorLoadLowering
|
||||
Value fill = rewriter.create<vector::SplatOp>(
|
||||
read.getLoc(), unbroadcastedVectorType, read.getPadding());
|
||||
res = rewriter.create<vector::MaskedLoadOp>(
|
||||
read.getLoc(), unbroadcastedVectorType, read.getSource(),
|
||||
read.getLoc(), unbroadcastedVectorType, read.getBase(),
|
||||
read.getIndices(), read.getMask(), fill);
|
||||
} else {
|
||||
res = rewriter.create<vector::LoadOp>(
|
||||
read.getLoc(), unbroadcastedVectorType, read.getSource(),
|
||||
read.getIndices());
|
||||
res = rewriter.create<vector::LoadOp>(read.getLoc(),
|
||||
unbroadcastedVectorType,
|
||||
read.getBase(), read.getIndices());
|
||||
}
|
||||
|
||||
// Insert a broadcasting op if required.
|
||||
@ -570,11 +570,11 @@ struct TransferWriteToVectorStoreLowering
|
||||
});
|
||||
|
||||
rewriter.create<vector::MaskedStoreOp>(
|
||||
write.getLoc(), write.getSource(), write.getIndices(),
|
||||
write.getMask(), write.getVector());
|
||||
write.getLoc(), write.getBase(), write.getIndices(), write.getMask(),
|
||||
write.getVector());
|
||||
} else {
|
||||
rewriter.create<vector::StoreOp>(write.getLoc(), write.getVector(),
|
||||
write.getSource(), write.getIndices());
|
||||
write.getBase(), write.getIndices());
|
||||
}
|
||||
// There's no return value for StoreOps. Use Value() to signal success to
|
||||
// matchAndRewrite.
|
||||
|
@ -37,7 +37,7 @@ struct TransferReadOpSubsetExtractionOpInterface
|
||||
: public SubsetExtractionOpInterface::ExternalModel<
|
||||
TransferReadOpSubsetExtractionOpInterface, vector::TransferReadOp> {
|
||||
OpOperand &getSourceOperand(Operation *op) const {
|
||||
return cast<vector::TransferReadOp>(op).getSourceMutable();
|
||||
return cast<vector::TransferReadOp>(op).getBaseMutable();
|
||||
}
|
||||
};
|
||||
|
||||
@ -49,7 +49,7 @@ struct TransferWriteOpSubsetInsertionOpInterface
|
||||
}
|
||||
|
||||
OpOperand &getDestinationOperand(Operation *op) const {
|
||||
return cast<vector::TransferWriteOp>(op).getSourceMutable();
|
||||
return cast<vector::TransferWriteOp>(op).getBaseMutable();
|
||||
}
|
||||
|
||||
Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
|
||||
|
@ -718,7 +718,7 @@ struct WarpOpTransferRead : public WarpDistributionPattern {
|
||||
auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
|
||||
|
||||
// Source must be defined outside of the region.
|
||||
if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
|
||||
if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
read, "source must be defined outside of the region");
|
||||
|
||||
@ -802,7 +802,7 @@ struct WarpOpTransferRead : public WarpDistributionPattern {
|
||||
hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
|
||||
: Value();
|
||||
auto newRead = rewriter.create<vector::TransferReadOp>(
|
||||
read.getLoc(), distributedVal.getType(), read.getSource(), newIndices,
|
||||
read.getLoc(), distributedVal.getType(), read.getBase(), newIndices,
|
||||
read.getPermutationMapAttr(), newPadding, newMask,
|
||||
read.getInBoundsAttr());
|
||||
|
||||
|
@ -230,7 +230,7 @@ struct CastAwayTransferReadLeadingOneDim
|
||||
if (read.getTransferRank() == 0)
|
||||
return failure();
|
||||
|
||||
auto shapedType = cast<ShapedType>(read.getSource().getType());
|
||||
auto shapedType = cast<ShapedType>(read.getBase().getType());
|
||||
if (shapedType.getElementType() != read.getVectorType().getElementType())
|
||||
return failure();
|
||||
|
||||
@ -260,7 +260,7 @@ struct CastAwayTransferReadLeadingOneDim
|
||||
}
|
||||
|
||||
auto newRead = rewriter.create<vector::TransferReadOp>(
|
||||
read.getLoc(), newType, read.getSource(), read.getIndices(),
|
||||
read.getLoc(), newType, read.getBase(), read.getIndices(),
|
||||
AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
|
||||
|
||||
@ -284,7 +284,7 @@ struct CastAwayTransferWriteLeadingOneDim
|
||||
if (write.getTransferRank() == 0)
|
||||
return failure();
|
||||
|
||||
auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
|
||||
auto shapedType = dyn_cast<ShapedType>(write.getBase().getType());
|
||||
if (shapedType.getElementType() != write.getVectorType().getElementType())
|
||||
return failure();
|
||||
|
||||
@ -314,13 +314,13 @@ struct CastAwayTransferWriteLeadingOneDim
|
||||
Value newMask = dropUnitDimsFromMask(
|
||||
rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
|
||||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
||||
write, newVector, write.getSource(), write.getIndices(),
|
||||
write, newVector, write.getBase(), write.getIndices(),
|
||||
AffineMapAttr::get(newMap), newMask, inBoundsAttr);
|
||||
return success();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
||||
write, newVector, write.getSource(), write.getIndices(),
|
||||
write, newVector, write.getBase(), write.getIndices(),
|
||||
AffineMapAttr::get(newMap), inBoundsAttr);
|
||||
return success();
|
||||
}
|
||||
|
@ -1249,7 +1249,7 @@ struct ConvertVectorTransferRead final
|
||||
|
||||
auto loc = op.getLoc();
|
||||
auto containerElemTy =
|
||||
cast<MemRefType>(adaptor.getSource().getType()).getElementType();
|
||||
cast<MemRefType>(adaptor.getBase().getType()).getElementType();
|
||||
Type emulatedElemTy = op.getType().getElementType();
|
||||
int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
|
||||
int containerBits = containerElemTy.getIntOrFloatBitWidth();
|
||||
@ -1272,7 +1272,7 @@ struct ConvertVectorTransferRead final
|
||||
adaptor.getPadding());
|
||||
|
||||
auto stridedMetadata =
|
||||
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
|
||||
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
|
||||
|
||||
OpFoldResult linearizedIndices;
|
||||
memref::LinearizedMemRefInfo linearizedInfo;
|
||||
@ -1294,7 +1294,7 @@ struct ConvertVectorTransferRead final
|
||||
emulatedPerContainerElem);
|
||||
|
||||
auto newRead = rewriter.create<vector::TransferReadOp>(
|
||||
loc, VectorType::get(numElements, containerElemTy), adaptor.getSource(),
|
||||
loc, VectorType::get(numElements, containerElemTy), adaptor.getBase(),
|
||||
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
|
||||
newPadding);
|
||||
|
||||
|
@ -92,7 +92,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
|
||||
<< "\n");
|
||||
llvm::SmallVector<Operation *, 8> blockingAccesses;
|
||||
Operation *firstOverwriteCandidate = nullptr;
|
||||
Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getSource()));
|
||||
Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getBase()));
|
||||
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
|
||||
source.getUsers().end());
|
||||
llvm::SmallDenseSet<Operation *, 32> processed;
|
||||
@ -112,8 +112,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
|
||||
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
|
||||
// Check candidate that can override the store.
|
||||
if (memref::isSameViewOrTrivialAlias(
|
||||
cast<MemrefValue>(nextWrite.getSource()),
|
||||
cast<MemrefValue>(write.getSource())) &&
|
||||
cast<MemrefValue>(nextWrite.getBase()),
|
||||
cast<MemrefValue>(write.getBase())) &&
|
||||
checkSameValueWAW(nextWrite, write) &&
|
||||
postDominators.postDominates(nextWrite, write)) {
|
||||
if (firstOverwriteCandidate == nullptr ||
|
||||
@ -178,7 +178,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
|
||||
<< "\n");
|
||||
SmallVector<Operation *, 8> blockingWrites;
|
||||
vector::TransferWriteOp lastwrite = nullptr;
|
||||
Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getSource()));
|
||||
Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getBase()));
|
||||
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
|
||||
source.getUsers().end());
|
||||
llvm::SmallDenseSet<Operation *, 32> processed;
|
||||
@ -202,8 +202,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
|
||||
/*testDynamicValueUsingBounds=*/true))
|
||||
continue;
|
||||
if (memref::isSameViewOrTrivialAlias(
|
||||
cast<MemrefValue>(read.getSource()),
|
||||
cast<MemrefValue>(write.getSource())) &&
|
||||
cast<MemrefValue>(read.getBase()),
|
||||
cast<MemrefValue>(write.getBase())) &&
|
||||
dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
|
||||
if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
|
||||
lastwrite = write;
|
||||
@ -351,7 +351,7 @@ class TransferReadDropUnitDimsPattern
|
||||
auto loc = transferReadOp.getLoc();
|
||||
Value vector = transferReadOp.getVector();
|
||||
VectorType vectorType = cast<VectorType>(vector.getType());
|
||||
Value source = transferReadOp.getSource();
|
||||
Value source = transferReadOp.getBase();
|
||||
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
|
||||
// TODO: support tensor types.
|
||||
if (!sourceType)
|
||||
@ -433,7 +433,7 @@ class TransferWriteDropUnitDimsPattern
|
||||
auto loc = transferWriteOp.getLoc();
|
||||
Value vector = transferWriteOp.getVector();
|
||||
VectorType vectorType = cast<VectorType>(vector.getType());
|
||||
Value source = transferWriteOp.getSource();
|
||||
Value source = transferWriteOp.getBase();
|
||||
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
|
||||
// TODO: support tensor type.
|
||||
if (!sourceType)
|
||||
@ -604,7 +604,7 @@ public:
|
||||
auto loc = transferReadOp.getLoc();
|
||||
Value vector = transferReadOp.getVector();
|
||||
VectorType vectorType = cast<VectorType>(vector.getType());
|
||||
auto source = transferReadOp.getSource();
|
||||
auto source = transferReadOp.getBase();
|
||||
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
|
||||
|
||||
// 0. Check pre-conditions
|
||||
@ -695,7 +695,7 @@ public:
|
||||
auto loc = transferWriteOp.getLoc();
|
||||
Value vector = transferWriteOp.getVector();
|
||||
VectorType vectorType = cast<VectorType>(vector.getType());
|
||||
Value source = transferWriteOp.getSource();
|
||||
Value source = transferWriteOp.getBase();
|
||||
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
|
||||
|
||||
// 0. Check pre-conditions
|
||||
@ -851,12 +851,12 @@ class RewriteScalarExtractElementOfTransferRead
|
||||
*getConstantIntValue(ofr));
|
||||
}
|
||||
}
|
||||
if (isa<MemRefType>(xferOp.getSource().getType())) {
|
||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
|
||||
if (isa<MemRefType>(xferOp.getBase().getType())) {
|
||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
|
||||
newIndices);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
|
||||
extractOp, xferOp.getSource(), newIndices);
|
||||
extractOp, xferOp.getBase(), newIndices);
|
||||
}
|
||||
|
||||
return success();
|
||||
@ -899,12 +899,12 @@ class RewriteScalarExtractOfTransferRead
|
||||
extractOp.getLoc(), *getConstantIntValue(ofr));
|
||||
}
|
||||
}
|
||||
if (isa<MemRefType>(xferOp.getSource().getType())) {
|
||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
|
||||
if (isa<MemRefType>(xferOp.getBase().getType())) {
|
||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
|
||||
newIndices);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
|
||||
extractOp, xferOp.getSource(), newIndices);
|
||||
extractOp, xferOp.getBase(), newIndices);
|
||||
}
|
||||
|
||||
return success();
|
||||
@ -932,12 +932,12 @@ class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
|
||||
Value scalar =
|
||||
rewriter.create<vector::ExtractOp>(xferOp.getLoc(), xferOp.getVector());
|
||||
// Construct a scalar store.
|
||||
if (isa<MemRefType>(xferOp.getSource().getType())) {
|
||||
if (isa<MemRefType>(xferOp.getBase().getType())) {
|
||||
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
||||
xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
|
||||
xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<tensor::InsertOp>(
|
||||
xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
|
||||
xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -58,7 +58,7 @@ static Value createInBoundsCond(RewriterBase &b,
|
||||
b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize),
|
||||
{xferOp.getIndices()[indicesIdx]});
|
||||
OpFoldResult dimSz =
|
||||
memref::getMixedSize(b, loc, xferOp.getSource(), indicesIdx);
|
||||
memref::getMixedSize(b, loc, xferOp.getBase(), indicesIdx);
|
||||
auto maybeCstSum = getConstantIntValue(sum);
|
||||
auto maybeCstDimSz = getConstantIntValue(dimSz);
|
||||
if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
|
||||
@ -185,7 +185,7 @@ static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
|
||||
}
|
||||
|
||||
/// Operates under a scoped context to build the intersection between the
|
||||
/// view `xferOp.getSource()` @ `xferOp.getIndices()` and the view `alloc`.
|
||||
/// view `xferOp.getbase()` @ `xferOp.getIndices()` and the view `alloc`.
|
||||
// TODO: view intersection/union/differences should be a proper std op.
|
||||
static std::pair<Value, Value>
|
||||
createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
|
||||
@ -202,8 +202,8 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
|
||||
auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
|
||||
xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
|
||||
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
|
||||
Value dimMemRef = b.create<memref::DimOp>(xferOp.getLoc(),
|
||||
xferOp.getSource(), indicesIdx);
|
||||
Value dimMemRef =
|
||||
b.create<memref::DimOp>(xferOp.getLoc(), xferOp.getBase(), indicesIdx);
|
||||
Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
|
||||
Value index = xferOp.getIndices()[indicesIdx];
|
||||
AffineExpr i, j, k;
|
||||
@ -221,9 +221,9 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
|
||||
SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
|
||||
auto copySrc = b.create<memref::SubViewOp>(
|
||||
loc, isaWrite ? alloc : xferOp.getSource(), srcIndices, sizes, strides);
|
||||
loc, isaWrite ? alloc : xferOp.getBase(), srcIndices, sizes, strides);
|
||||
auto copyDest = b.create<memref::SubViewOp>(
|
||||
loc, isaWrite ? xferOp.getSource() : alloc, destIndices, sizes, strides);
|
||||
loc, isaWrite ? xferOp.getBase() : alloc, destIndices, sizes, strides);
|
||||
return std::make_pair(copySrc, copyDest);
|
||||
}
|
||||
|
||||
@ -252,7 +252,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
|
||||
MemRefType compatibleMemRefType, Value alloc) {
|
||||
Location loc = xferOp.getLoc();
|
||||
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value memref = xferOp.getSource();
|
||||
Value memref = xferOp.getBase();
|
||||
return b.create<scf::IfOp>(
|
||||
loc, inBoundsCond,
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
@ -305,7 +305,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
|
||||
Location loc = xferOp.getLoc();
|
||||
scf::IfOp fullPartialIfOp;
|
||||
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value memref = xferOp.getSource();
|
||||
Value memref = xferOp.getBase();
|
||||
return b.create<scf::IfOp>(
|
||||
loc, inBoundsCond,
|
||||
[&](OpBuilder &b, Location loc) {
|
||||
@ -352,7 +352,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
|
||||
MemRefType compatibleMemRefType, Value alloc) {
|
||||
Location loc = xferOp.getLoc();
|
||||
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||
Value memref = xferOp.getSource();
|
||||
Value memref = xferOp.getBase();
|
||||
return b
|
||||
.create<scf::IfOp>(
|
||||
loc, inBoundsCond,
|
||||
@ -509,7 +509,7 @@ static Operation *getAutomaticAllocationScope(Operation *op) {
|
||||
///
|
||||
/// Preconditions:
|
||||
/// 1. `xferOp.getPermutationMap()` must be a minor identity map
|
||||
/// 2. the rank of the `xferOp.getSource()` and the rank of the
|
||||
/// 2. the rank of the `xferOp.getBase()` and the rank of the
|
||||
/// `xferOp.getVector()` must be equal. This will be relaxed in the future
|
||||
/// but requires rank-reducing subviews.
|
||||
LogicalResult mlir::vector::splitFullAndPartialTransfer(
|
||||
@ -611,7 +611,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
|
||||
// The operation is cloned to prevent deleting information needed for the
|
||||
// later IR creation.
|
||||
IRMapping mapping;
|
||||
mapping.map(xferWriteOp.getSource(), memrefAndIndices.front());
|
||||
mapping.map(xferWriteOp.getBase(), memrefAndIndices.front());
|
||||
mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
|
||||
auto *clone = b.clone(*xferWriteOp, mapping);
|
||||
clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
|
||||
|
@ -1265,7 +1265,7 @@ public:
|
||||
unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
|
||||
Value off = xferOp.getIndices()[lastIndex];
|
||||
Value dim =
|
||||
vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
|
||||
vector::createOrFoldDimOp(rewriter, loc, xferOp.getBase(), lastIndex);
|
||||
Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
|
||||
Value mask = rewriter.create<vector::CreateMaskOp>(
|
||||
loc,
|
||||
@ -1437,7 +1437,7 @@ class DropInnerMostUnitDimsTransferRead
|
||||
if (readOp.getMask())
|
||||
return failure();
|
||||
|
||||
auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
|
||||
auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
|
||||
if (!srcType)
|
||||
return failure();
|
||||
|
||||
@ -1469,7 +1469,7 @@ class DropInnerMostUnitDimsTransferRead
|
||||
|
||||
auto loc = readOp.getLoc();
|
||||
SmallVector<OpFoldResult> sizes =
|
||||
memref::getMixedSizes(rewriter, loc, readOp.getSource());
|
||||
memref::getMixedSizes(rewriter, loc, readOp.getBase());
|
||||
SmallVector<OpFoldResult> offsets(srcType.getRank(),
|
||||
rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> strides(srcType.getRank(),
|
||||
@ -1480,7 +1480,7 @@ class DropInnerMostUnitDimsTransferRead
|
||||
ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
|
||||
readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
|
||||
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
|
||||
loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
|
||||
loc, resultMemrefType, readOp.getBase(), offsets, sizes, strides);
|
||||
auto permMap = getTransferMinorIdentityMap(
|
||||
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
|
||||
Value result = rewriter.create<vector::TransferReadOp>(
|
||||
@ -1527,7 +1527,7 @@ class DropInnerMostUnitDimsTransferWrite
|
||||
if (writeOp.getMask())
|
||||
return failure();
|
||||
|
||||
auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
|
||||
auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
|
||||
if (!srcType)
|
||||
return failure();
|
||||
|
||||
@ -1559,7 +1559,7 @@ class DropInnerMostUnitDimsTransferWrite
|
||||
|
||||
Location loc = writeOp.getLoc();
|
||||
SmallVector<OpFoldResult> sizes =
|
||||
memref::getMixedSizes(rewriter, loc, writeOp.getSource());
|
||||
memref::getMixedSizes(rewriter, loc, writeOp.getBase());
|
||||
SmallVector<OpFoldResult> offsets(srcType.getRank(),
|
||||
rewriter.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> strides(srcType.getRank(),
|
||||
@ -1571,7 +1571,7 @@ class DropInnerMostUnitDimsTransferWrite
|
||||
writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
|
||||
|
||||
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
|
||||
loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
|
||||
loc, resultMemrefType, writeOp.getBase(), offsets, sizes, strides);
|
||||
auto permMap = getTransferMinorIdentityMap(
|
||||
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
|
||||
|
||||
|
@ -164,7 +164,7 @@ struct UnrollTransferReadPattern
|
||||
sliceTransferIndices(elementOffsets, originalIndices,
|
||||
readOp.getPermutationMap(), loc, rewriter);
|
||||
auto slicedRead = rewriter.create<vector::TransferReadOp>(
|
||||
loc, targetType, readOp.getSource(), indices,
|
||||
loc, targetType, readOp.getBase(), indices,
|
||||
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
|
||||
readOp.getInBoundsAttr());
|
||||
|
||||
@ -215,7 +215,7 @@ struct UnrollTransferWritePattern
|
||||
sliceTransferIndices(elementOffsets, originalIndices,
|
||||
writeOp.getPermutationMap(), loc, rewriter);
|
||||
Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
|
||||
loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
|
||||
loc, slicedVector, resultTensor ? resultTensor : writeOp.getBase(),
|
||||
indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
|
||||
// For the tensor case update the destination for the next transfer write.
|
||||
if (!slicedWrite->getResults().empty())
|
||||
|
@ -312,7 +312,7 @@ SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
|
||||
|
||||
Value base = TypeSwitch<Operation *, Value>(xfer)
|
||||
.Case<vector::TransferReadOp>(
|
||||
[&](auto readOp) { return readOp.getSource(); })
|
||||
[&](auto readOp) { return readOp.getBase(); })
|
||||
.Case<vector::TransferWriteOp>(
|
||||
[&](auto writeOp) { return writeOp.getOperand(1); });
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user