[mlir][vector] Standardize base
Naming Across Vector Ops (NFC) (#137859)
[mlir][vector] Standardize base Naming Across Vector Ops (NFC) This change standardizes the naming convention for the argument representing the value to read from or write to in Vector ops that interface with Tensors or MemRefs. Specifically, it ensures that all such ops use the name `base` (i.e., the base address or location to which offsets are applied). Updated operations: * `vector.transfer_read`, * `vector.transfer_write`. For reference, these ops already use `base`: * `vector.load`, `vector.store`, `vector.scatter`, `vector.gather`, `vector.expandload`, `vector.compressstore`, `vector.maskedstore`, `vector.maskedload`. This is a non-functional change (NFC) and does not alter the semantics of these operations. However, it does require users of the XFer ops to switch from `op.getSource()` to `op.getBase()`. To ease the transition, this PR temporarily adds a `getSource()` interface method for compatibility. This is intended for downstream use only and should not be relied on upstream. The method will be removed prior to the LLVM 21 release. Implements #131602
This commit is contained in:
parent
5a1edf0f51
commit
c45cc3e420
@ -1273,7 +1273,7 @@ def Vector_TransferReadOp :
|
|||||||
AttrSizedOperandSegments,
|
AttrSizedOperandSegments,
|
||||||
DestinationStyleOpInterface
|
DestinationStyleOpInterface
|
||||||
]>,
|
]>,
|
||||||
Arguments<(ins AnyShaped:$source,
|
Arguments<(ins AnyShaped:$base,
|
||||||
Variadic<Index>:$indices,
|
Variadic<Index>:$indices,
|
||||||
AffineMapAttr:$permutation_map,
|
AffineMapAttr:$permutation_map,
|
||||||
AnyType:$padding,
|
AnyType:$padding,
|
||||||
@ -1522,7 +1522,7 @@ def Vector_TransferWriteOp :
|
|||||||
DestinationStyleOpInterface
|
DestinationStyleOpInterface
|
||||||
]>,
|
]>,
|
||||||
Arguments<(ins AnyVectorOfAnyRank:$valueToStore,
|
Arguments<(ins AnyVectorOfAnyRank:$valueToStore,
|
||||||
AnyShaped:$source,
|
AnyShaped:$base,
|
||||||
Variadic<Index>:$indices,
|
Variadic<Index>:$indices,
|
||||||
AffineMapAttr:$permutation_map,
|
AffineMapAttr:$permutation_map,
|
||||||
Optional<VectorOfNonZeroRankOf<[I1]>>:$mask,
|
Optional<VectorOfNonZeroRankOf<[I1]>>:$mask,
|
||||||
@ -1663,7 +1663,7 @@ def Vector_TransferWriteOp :
|
|||||||
/// ops of other dialects.
|
/// ops of other dialects.
|
||||||
Value getValue() { return getVector(); }
|
Value getValue() { return getVector(); }
|
||||||
|
|
||||||
MutableOperandRange getDpsInitsMutable() { return getSourceMutable(); }
|
MutableOperandRange getDpsInitsMutable() { return getBaseMutable(); }
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
@ -111,7 +111,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
|
|||||||
TODO: Change name of operand, which is not accurate for xfer_write.
|
TODO: Change name of operand, which is not accurate for xfer_write.
|
||||||
}],
|
}],
|
||||||
/*retTy=*/"::mlir::Value",
|
/*retTy=*/"::mlir::Value",
|
||||||
/*methodName=*/"getSource",
|
/*methodName=*/"getBase",
|
||||||
/*args=*/(ins)
|
/*args=*/(ins)
|
||||||
>,
|
>,
|
||||||
InterfaceMethod<
|
InterfaceMethod<
|
||||||
@ -187,6 +187,12 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
|
|||||||
return inBounds;
|
return inBounds;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Wrapper for getBase, which replaced getSource.
|
||||||
|
[[deprecated("Use getBase instead!")]]
|
||||||
|
::mlir::Value getSource() {
|
||||||
|
return $_op.getBase();
|
||||||
|
}
|
||||||
|
|
||||||
/// Return the number of leading shaped dimensions (of the "source" operand)
|
/// Return the number of leading shaped dimensions (of the "source" operand)
|
||||||
/// that do not participate in the permutation map.
|
/// that do not participate in the permutation map.
|
||||||
unsigned getLeadingShapedRank() {
|
unsigned getLeadingShapedRank() {
|
||||||
@ -203,7 +209,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
|
|||||||
|
|
||||||
/// Return the shaped type of the "source" operand value.
|
/// Return the shaped type of the "source" operand value.
|
||||||
::mlir::ShapedType getShapedType() {
|
::mlir::ShapedType getShapedType() {
|
||||||
return ::llvm::cast<::mlir::ShapedType>($_op.getSource().getType());
|
return ::llvm::cast<::mlir::ShapedType>($_op.getBase().getType());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the number of dimensions that participate in the permutation map.
|
/// Return the number of dimensions that participate in the permutation map.
|
||||||
|
@ -58,7 +58,7 @@ struct TransferReadToArmSMELowering
|
|||||||
return rewriter.notifyMatchFailure(transferReadOp,
|
return rewriter.notifyMatchFailure(transferReadOp,
|
||||||
"not a valid vector type for SME");
|
"not a valid vector type for SME");
|
||||||
|
|
||||||
if (!llvm::isa<MemRefType>(transferReadOp.getSource().getType()))
|
if (!llvm::isa<MemRefType>(transferReadOp.getBase().getType()))
|
||||||
return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
|
return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");
|
||||||
|
|
||||||
// Out-of-bounds dims are not supported.
|
// Out-of-bounds dims are not supported.
|
||||||
@ -84,7 +84,7 @@ struct TransferReadToArmSMELowering
|
|||||||
auto mask = transferReadOp.getMask();
|
auto mask = transferReadOp.getMask();
|
||||||
auto padding = mask ? transferReadOp.getPadding() : nullptr;
|
auto padding = mask ? transferReadOp.getPadding() : nullptr;
|
||||||
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
|
rewriter.replaceOpWithNewOp<arm_sme::TileLoadOp>(
|
||||||
transferReadOp, vectorType, transferReadOp.getSource(),
|
transferReadOp, vectorType, transferReadOp.getBase(),
|
||||||
transferReadOp.getIndices(), padding, mask, layout);
|
transferReadOp.getIndices(), padding, mask, layout);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
@ -128,7 +128,7 @@ struct TransferWriteToArmSMELowering
|
|||||||
if (!arm_sme::isValidSMETileVectorType(vType))
|
if (!arm_sme::isValidSMETileVectorType(vType))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
if (!llvm::isa<MemRefType>(writeOp.getSource().getType()))
|
if (!llvm::isa<MemRefType>(writeOp.getBase().getType()))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
// Out-of-bounds dims are not supported.
|
// Out-of-bounds dims are not supported.
|
||||||
@ -149,7 +149,7 @@ struct TransferWriteToArmSMELowering
|
|||||||
: arm_sme::TileSliceLayout::Horizontal;
|
: arm_sme::TileSliceLayout::Horizontal;
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
|
rewriter.replaceOpWithNewOp<arm_sme::TileStoreOp>(
|
||||||
writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices(),
|
writeOp, writeOp.getVector(), writeOp.getBase(), writeOp.getIndices(),
|
||||||
writeOp.getMask(), layout);
|
writeOp.getMask(), layout);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@ -686,7 +686,7 @@ struct FoldTransferWriteOfExtractTileSlice
|
|||||||
|
|
||||||
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
|
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
|
||||||
PatternRewriter &rewriter) const final {
|
PatternRewriter &rewriter) const final {
|
||||||
if (!isa<MemRefType>(writeOp.getSource().getType()))
|
if (!isa<MemRefType>(writeOp.getBase().getType()))
|
||||||
return rewriter.notifyMatchFailure(writeOp, "destination not a memref");
|
return rewriter.notifyMatchFailure(writeOp, "destination not a memref");
|
||||||
|
|
||||||
if (writeOp.hasOutOfBoundsDim())
|
if (writeOp.hasOutOfBoundsDim())
|
||||||
@ -713,7 +713,7 @@ struct FoldTransferWriteOfExtractTileSlice
|
|||||||
|
|
||||||
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
|
rewriter.replaceOpWithNewOp<arm_sme::StoreTileSliceOp>(
|
||||||
writeOp, extractTileSlice.getTile(),
|
writeOp, extractTileSlice.getTile(),
|
||||||
extractTileSlice.getTileSliceIndex(), mask, writeOp.getSource(),
|
extractTileSlice.getTileSliceIndex(), mask, writeOp.getBase(),
|
||||||
writeOp.getIndices(), extractTileSlice.getLayout());
|
writeOp.getIndices(), extractTileSlice.getLayout());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -486,7 +486,7 @@ struct CombineTransferReadOpTranspose final
|
|||||||
Value result =
|
Value result =
|
||||||
rewriter
|
rewriter
|
||||||
.create<vector::TransferReadOp>(
|
.create<vector::TransferReadOp>(
|
||||||
loc, resultType, transferReadOp.getSource(),
|
loc, resultType, transferReadOp.getBase(),
|
||||||
transferReadOp.getIndices(), AffineMapAttr::get(newMap),
|
transferReadOp.getIndices(), AffineMapAttr::get(newMap),
|
||||||
transferReadOp.getPadding(), transferReadOp.getMask(),
|
transferReadOp.getPadding(), transferReadOp.getMask(),
|
||||||
transferReadOp.getInBoundsAttr())
|
transferReadOp.getInBoundsAttr())
|
||||||
@ -581,7 +581,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
|
|||||||
gpu::MMAMatrixType type =
|
gpu::MMAMatrixType type =
|
||||||
gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
|
gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
|
||||||
Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
|
Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
|
||||||
op.getLoc(), type, op.getSource(), op.getIndices(),
|
op.getLoc(), type, op.getBase(), op.getIndices(),
|
||||||
rewriter.getIndexAttr(*stride),
|
rewriter.getIndexAttr(*stride),
|
||||||
isTranspose ? rewriter.getUnitAttr() : UnitAttr());
|
isTranspose ? rewriter.getUnitAttr() : UnitAttr());
|
||||||
valueMapping[mappingResult] = load;
|
valueMapping[mappingResult] = load;
|
||||||
@ -612,7 +612,7 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
|
|||||||
|
|
||||||
Value matrix = it->second;
|
Value matrix = it->second;
|
||||||
auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
|
auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
|
||||||
op.getLoc(), matrix, op.getSource(), op.getIndices(),
|
op.getLoc(), matrix, op.getBase(), op.getIndices(),
|
||||||
rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
|
rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
|
||||||
(void)store;
|
(void)store;
|
||||||
|
|
||||||
@ -759,7 +759,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
|
|||||||
indices);
|
indices);
|
||||||
|
|
||||||
nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
|
nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
|
||||||
loc, vectorType, op.getSource(), indices, *transpose, params->numTiles);
|
loc, vectorType, op.getBase(), indices, *transpose, params->numTiles);
|
||||||
valueMapping[op] = newOp->getResult(0);
|
valueMapping[op] = newOp->getResult(0);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@ -818,7 +818,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
|
|||||||
rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
|
rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
|
||||||
|
|
||||||
Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
|
Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
|
||||||
op.getSource(), newIndices);
|
op.getBase(), newIndices);
|
||||||
result = rewriter.create<vector::InsertOp>(loc, el, result, i);
|
result = rewriter.create<vector::InsertOp>(loc, el, result, i);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -841,7 +841,7 @@ createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
|
|||||||
getXferIndices<vector::TransferReadOp>(
|
getXferIndices<vector::TransferReadOp>(
|
||||||
rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
|
rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
|
||||||
Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
|
Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
|
||||||
op.getSource(), newIndices);
|
op.getBase(), newIndices);
|
||||||
result = rewriter.create<vector::InsertOp>(
|
result = rewriter.create<vector::InsertOp>(
|
||||||
op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
|
op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
|
||||||
}
|
}
|
||||||
@ -875,7 +875,7 @@ convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
|
|||||||
return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
|
return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
|
||||||
|
|
||||||
bool isLdMatrixCompatible =
|
bool isLdMatrixCompatible =
|
||||||
isSharedMemory(cast<MemRefType>(op.getSource().getType())) &&
|
isSharedMemory(cast<MemRefType>(op.getBase().getType())) &&
|
||||||
nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
|
nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
|
||||||
|
|
||||||
VectorType vecTy = op.getVectorType();
|
VectorType vecTy = op.getVectorType();
|
||||||
@ -933,7 +933,7 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
|
|||||||
SmallVector<Value, 4> newIndices;
|
SmallVector<Value, 4> newIndices;
|
||||||
getXferIndices<vector::TransferWriteOp>(
|
getXferIndices<vector::TransferWriteOp>(
|
||||||
rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
|
rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
|
||||||
rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
|
rewriter.create<vector::StoreOp>(loc, el, op.getBase(), newIndices);
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
|
LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
|
||||||
|
@ -198,8 +198,7 @@ static Value generateInBoundsCheck(
|
|||||||
Location loc = xferOp.getLoc();
|
Location loc = xferOp.getLoc();
|
||||||
ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
|
ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
|
||||||
if (!xferOp.isDimInBounds(0) && !isBroadcast) {
|
if (!xferOp.isDimInBounds(0) && !isBroadcast) {
|
||||||
Value memrefDim =
|
Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.getBase(), *dim);
|
||||||
vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim);
|
|
||||||
AffineExpr d0, d1;
|
AffineExpr d0, d1;
|
||||||
bindDims(xferOp.getContext(), d0, d1);
|
bindDims(xferOp.getContext(), d0, d1);
|
||||||
Value base = xferOp.getIndices()[*dim];
|
Value base = xferOp.getIndices()[*dim];
|
||||||
@ -426,7 +425,7 @@ struct Strategy<TransferReadOp> {
|
|||||||
auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
|
auto vecType = dyn_cast<VectorType>(bufferType.getElementType());
|
||||||
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
|
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
|
||||||
auto newXferOp = b.create<vector::TransferReadOp>(
|
auto newXferOp = b.create<vector::TransferReadOp>(
|
||||||
loc, vecType, xferOp.getSource(), xferIndices,
|
loc, vecType, xferOp.getBase(), xferIndices,
|
||||||
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
|
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
|
||||||
xferOp.getPadding(), Value(), inBoundsAttr);
|
xferOp.getPadding(), Value(), inBoundsAttr);
|
||||||
|
|
||||||
@ -512,7 +511,7 @@ struct Strategy<TransferWriteOp> {
|
|||||||
Location loc = xferOp.getLoc();
|
Location loc = xferOp.getLoc();
|
||||||
auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
|
auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
|
||||||
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
|
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
|
||||||
auto source = loopState.empty() ? xferOp.getSource() : loopState[0];
|
auto source = loopState.empty() ? xferOp.getBase() : loopState[0];
|
||||||
Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
|
Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
|
||||||
auto newXferOp = b.create<vector::TransferWriteOp>(
|
auto newXferOp = b.create<vector::TransferWriteOp>(
|
||||||
loc, type, vec, source, xferIndices,
|
loc, type, vec, source, xferIndices,
|
||||||
@ -544,7 +543,7 @@ struct Strategy<TransferWriteOp> {
|
|||||||
|
|
||||||
/// Return the initial loop state for the generated scf.for loop.
|
/// Return the initial loop state for the generated scf.for loop.
|
||||||
static Value initialLoopState(TransferWriteOp xferOp) {
|
static Value initialLoopState(TransferWriteOp xferOp) {
|
||||||
return isTensorOp(xferOp) ? xferOp.getSource() : Value();
|
return isTensorOp(xferOp) ? xferOp.getBase() : Value();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1145,7 +1144,7 @@ struct ScalableTransposeTransferWriteConversion
|
|||||||
ArrayRef<OpFoldResult>(*maskDims).drop_front());
|
ArrayRef<OpFoldResult>(*maskDims).drop_front());
|
||||||
}
|
}
|
||||||
|
|
||||||
Value initDest = isTensorOp(writeOp) ? writeOp.getSource() : Value{};
|
Value initDest = isTensorOp(writeOp) ? writeOp.getBase() : Value{};
|
||||||
ValueRange initLoopArgs = initDest ? initDest : ValueRange{};
|
ValueRange initLoopArgs = initDest ? initDest : ValueRange{};
|
||||||
auto result = rewriter.create<scf::ForOp>(
|
auto result = rewriter.create<scf::ForOp>(
|
||||||
loc, lb, ub, step, initLoopArgs,
|
loc, lb, ub, step, initLoopArgs,
|
||||||
@ -1165,7 +1164,7 @@ struct ScalableTransposeTransferWriteConversion
|
|||||||
|
|
||||||
// Create the transfer_write for the slice.
|
// Create the transfer_write for the slice.
|
||||||
Value dest =
|
Value dest =
|
||||||
loopIterArgs.empty() ? writeOp.getSource() : loopIterArgs.front();
|
loopIterArgs.empty() ? writeOp.getBase() : loopIterArgs.front();
|
||||||
auto newWriteOp = b.create<vector::TransferWriteOp>(
|
auto newWriteOp = b.create<vector::TransferWriteOp>(
|
||||||
loc, sliceVec, dest, xferIndices,
|
loc, sliceVec, dest, xferIndices,
|
||||||
ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
|
ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front());
|
||||||
@ -1340,7 +1339,7 @@ struct UnrollTransferReadConversion
|
|||||||
|
|
||||||
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
|
auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
|
||||||
auto newXferOp = b.create<vector::TransferReadOp>(
|
auto newXferOp = b.create<vector::TransferReadOp>(
|
||||||
loc, newXferVecType, xferOp.getSource(), xferIndices,
|
loc, newXferVecType, xferOp.getBase(), xferIndices,
|
||||||
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
|
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
|
||||||
xferOp.getPadding(), Value(), inBoundsAttr);
|
xferOp.getPadding(), Value(), inBoundsAttr);
|
||||||
maybeAssignMask(b, xferOp, newXferOp, i);
|
maybeAssignMask(b, xferOp, newXferOp, i);
|
||||||
@ -1449,7 +1448,7 @@ struct UnrollTransferWriteConversion
|
|||||||
}
|
}
|
||||||
|
|
||||||
int64_t dimSize = inputVectorTy.getShape()[0];
|
int64_t dimSize = inputVectorTy.getShape()[0];
|
||||||
Value source = xferOp.getSource(); // memref or tensor to be written to.
|
Value source = xferOp.getBase(); // memref or tensor to be written to.
|
||||||
auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
|
auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
|
||||||
|
|
||||||
// Generate fully unrolled loop of transfer ops.
|
// Generate fully unrolled loop of transfer ops.
|
||||||
@ -1567,8 +1566,7 @@ struct Strategy1d<TransferReadOp> {
|
|||||||
b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
|
b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
|
||||||
/*inBoundsCase=*/
|
/*inBoundsCase=*/
|
||||||
[&](OpBuilder &b, Location loc) {
|
[&](OpBuilder &b, Location loc) {
|
||||||
Value val =
|
Value val = b.create<memref::LoadOp>(loc, xferOp.getBase(), indices);
|
||||||
b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
|
|
||||||
return b.create<vector::InsertElementOp>(loc, val, vec, iv);
|
return b.create<vector::InsertElementOp>(loc, val, vec, iv);
|
||||||
},
|
},
|
||||||
/*outOfBoundsCase=*/
|
/*outOfBoundsCase=*/
|
||||||
@ -1599,7 +1597,7 @@ struct Strategy1d<TransferWriteOp> {
|
|||||||
/*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
|
/*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
|
||||||
auto val =
|
auto val =
|
||||||
b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
|
b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
|
||||||
b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
|
b.create<memref::StoreOp>(loc, val, xferOp.getBase(), indices);
|
||||||
});
|
});
|
||||||
b.create<scf::YieldOp>(loc);
|
b.create<scf::YieldOp>(loc);
|
||||||
}
|
}
|
||||||
|
@ -192,7 +192,7 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
|
|||||||
|
|
||||||
xegpu::CreateNdDescOp ndDesc =
|
xegpu::CreateNdDescOp ndDesc =
|
||||||
createNdDescriptor(rewriter, loc, descType,
|
createNdDescriptor(rewriter, loc, descType,
|
||||||
dyn_cast<TypedValue<MemRefType>>(readOp.getSource()),
|
dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
|
||||||
readOp.getIndices());
|
readOp.getIndices());
|
||||||
|
|
||||||
DenseI64ArrayAttr transposeAttr =
|
DenseI64ArrayAttr transposeAttr =
|
||||||
@ -231,10 +231,10 @@ struct TransferWriteLowering
|
|||||||
vecTy.getShape(), vecTy.getElementType(),
|
vecTy.getShape(), vecTy.getElementType(),
|
||||||
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
|
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
|
||||||
xegpu::MemorySpace::Global);
|
xegpu::MemorySpace::Global);
|
||||||
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
|
xegpu::CreateNdDescOp ndDesc =
|
||||||
rewriter, loc, descType,
|
createNdDescriptor(rewriter, loc, descType,
|
||||||
dyn_cast<TypedValue<MemRefType>>(writeOp.getSource()),
|
dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
|
||||||
writeOp.getIndices());
|
writeOp.getIndices());
|
||||||
|
|
||||||
// By default, no specific caching policy is assigned.
|
// By default, no specific caching policy is assigned.
|
||||||
xegpu::CachePolicyAttr hint = nullptr;
|
xegpu::CachePolicyAttr hint = nullptr;
|
||||||
|
@ -118,7 +118,7 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
|
|||||||
Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
|
Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
|
||||||
readOp.getPadding());
|
readOp.getPadding());
|
||||||
Value load = builder.create<vector::LoadOp>(
|
Value load = builder.create<vector::LoadOp>(
|
||||||
loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
|
loc, unbroadcastedVectorType, readOp.getBase(), readOp.getIndices());
|
||||||
Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
|
Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
|
||||||
readOp.getMask(), load, fill);
|
readOp.getMask(), load, fill);
|
||||||
// Insert a broadcasting op if required.
|
// Insert a broadcasting op if required.
|
||||||
@ -149,7 +149,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Location loc = readOp.getLoc();
|
Location loc = readOp.getLoc();
|
||||||
Value src = readOp.getSource();
|
Value src = readOp.getBase();
|
||||||
|
|
||||||
VectorType vectorType = readOp.getVectorType();
|
VectorType vectorType = readOp.getVectorType();
|
||||||
int64_t vectorSize = vectorType.getNumElements();
|
int64_t vectorSize = vectorType.getNumElements();
|
||||||
|
@ -315,7 +315,7 @@ struct LegalizeTransferReadOpsByDecomposition
|
|||||||
decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
|
decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
|
||||||
auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
|
auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
|
||||||
auto smeRead = rewriter.create<vector::TransferReadOp>(
|
auto smeRead = rewriter.create<vector::TransferReadOp>(
|
||||||
loc, smeTileType, readOp.getSource(),
|
loc, smeTileType, readOp.getBase(),
|
||||||
getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
|
getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
|
||||||
readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
|
readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
|
||||||
readOp.getInBoundsAttr());
|
readOp.getInBoundsAttr());
|
||||||
@ -359,7 +359,7 @@ struct LegalizeTransferWriteOpsByDecomposition
|
|||||||
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
|
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
|
||||||
auto inputSMETiles = adaptor.getValueToStore();
|
auto inputSMETiles = adaptor.getValueToStore();
|
||||||
|
|
||||||
Value destTensorOrMemref = writeOp.getSource();
|
Value destTensorOrMemref = writeOp.getBase();
|
||||||
for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
|
for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
|
||||||
rewriter, vectorType, smeTileType, transposed))) {
|
rewriter, vectorType, smeTileType, transposed))) {
|
||||||
auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
|
auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
|
||||||
@ -497,7 +497,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
|
|||||||
auto slice =
|
auto slice =
|
||||||
rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex);
|
rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex);
|
||||||
rewriter.create<vector::TransferWriteOp>(
|
rewriter.create<vector::TransferWriteOp>(
|
||||||
loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
|
loc, slice, writeOp.getBase(), ValueRange{storeRow, storeCol},
|
||||||
AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
|
AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
|
||||||
sliceMask,
|
sliceMask,
|
||||||
rewriter.getBoolArrayAttr(
|
rewriter.getBoolArrayAttr(
|
||||||
@ -677,7 +677,7 @@ struct LiftIllegalVectorTransposeToMemory
|
|||||||
});
|
});
|
||||||
SmallVector<Value> strides(readType.getRank(), Value(one));
|
SmallVector<Value> strides(readType.getRank(), Value(one));
|
||||||
auto readSubview = rewriter.create<memref::SubViewOp>(
|
auto readSubview = rewriter.create<memref::SubViewOp>(
|
||||||
loc, illegalRead.getSource(), illegalRead.getIndices(), readSizes,
|
loc, illegalRead.getBase(), illegalRead.getIndices(), readSizes,
|
||||||
strides);
|
strides);
|
||||||
|
|
||||||
// Apply the transpose to all values/attributes of the transfer_read:
|
// Apply the transpose to all values/attributes of the transfer_read:
|
||||||
@ -851,7 +851,7 @@ struct LowerIllegalTransposeStoreViaZA
|
|||||||
|
|
||||||
// Note: We need to use `get_tile` as there's no vector-level `undef`.
|
// Note: We need to use `get_tile` as there's no vector-level `undef`.
|
||||||
Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType);
|
Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType);
|
||||||
Value destTensorOrMemref = writeOp.getSource();
|
Value destTensorOrMemref = writeOp.getBase();
|
||||||
auto numSlicesPerTile =
|
auto numSlicesPerTile =
|
||||||
std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
|
std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
|
||||||
auto numSlices =
|
auto numSlices =
|
||||||
|
@ -171,7 +171,7 @@ void mlir::linalg::hoistRedundantVectorBroadcasts(RewriterBase &rewriter,
|
|||||||
|
|
||||||
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
|
static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
|
||||||
LoopLikeOpInterface loop) {
|
LoopLikeOpInterface loop) {
|
||||||
Value source = transferRead.getSource();
|
Value source = transferRead.getBase();
|
||||||
|
|
||||||
// Skip view-like Ops and retrive the actual soruce Operation
|
// Skip view-like Ops and retrive the actual soruce Operation
|
||||||
while (auto srcOp =
|
while (auto srcOp =
|
||||||
@ -276,7 +276,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
|
|||||||
for (auto *sliceOp : llvm::reverse(forwardSlice)) {
|
for (auto *sliceOp : llvm::reverse(forwardSlice)) {
|
||||||
auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
|
auto candidateWrite = dyn_cast<vector::TransferWriteOp>(sliceOp);
|
||||||
if (!candidateWrite ||
|
if (!candidateWrite ||
|
||||||
candidateWrite.getSource() != transferRead.getSource())
|
candidateWrite.getBase() != transferRead.getBase())
|
||||||
continue;
|
continue;
|
||||||
transferWrite = candidateWrite;
|
transferWrite = candidateWrite;
|
||||||
}
|
}
|
||||||
@ -312,11 +312,11 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
|
|||||||
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
|
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
|
|
||||||
auto *source = transferRead.getSource().getDefiningOp();
|
auto *source = transferRead.getBase().getDefiningOp();
|
||||||
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
|
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
|
|
||||||
source = transferWrite.getSource().getDefiningOp();
|
source = transferWrite.getBase().getDefiningOp();
|
||||||
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
|
if (source && isa_and_nonnull<ViewLikeOpInterface>(source))
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
|
|
||||||
@ -325,7 +325,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
|
|||||||
DominanceInfo dom(loop);
|
DominanceInfo dom(loop);
|
||||||
if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
|
if (!dom.properlyDominates(transferRead.getOperation(), transferWrite))
|
||||||
return WalkResult::advance();
|
return WalkResult::advance();
|
||||||
for (auto &use : transferRead.getSource().getUses()) {
|
for (auto &use : transferRead.getBase().getUses()) {
|
||||||
if (!loop->isAncestor(use.getOwner()))
|
if (!loop->isAncestor(use.getOwner()))
|
||||||
continue;
|
continue;
|
||||||
if (use.getOwner() == transferRead.getOperation() ||
|
if (use.getOwner() == transferRead.getOperation() ||
|
||||||
|
@ -2627,7 +2627,7 @@ struct PadOpVectorizationWithTransferReadPattern
|
|||||||
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
|
SmallVector<bool> inBounds(xferOp.getVectorType().getRank(), false);
|
||||||
xferOp->setAttr(xferOp.getInBoundsAttrName(),
|
xferOp->setAttr(xferOp.getInBoundsAttrName(),
|
||||||
rewriter.getBoolArrayAttr(inBounds));
|
rewriter.getBoolArrayAttr(inBounds));
|
||||||
xferOp.getSourceMutable().assign(padOp.getSource());
|
xferOp.getBaseMutable().assign(padOp.getSource());
|
||||||
xferOp.getPaddingMutable().assign(padValue);
|
xferOp.getPaddingMutable().assign(padValue);
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -3114,7 +3114,7 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
|
|||||||
return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
|
return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
|
||||||
|
|
||||||
// Transfer into `view`.
|
// Transfer into `view`.
|
||||||
Value viewOrAlloc = xferOp.getSource();
|
Value viewOrAlloc = xferOp.getBase();
|
||||||
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
|
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
|
||||||
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
|
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
|
||||||
return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
|
return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
|
||||||
@ -3191,7 +3191,7 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
|
|||||||
return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
|
return rewriter.notifyMatchFailure(xferOp, "unsupported mask");
|
||||||
|
|
||||||
// Transfer into `viewOrAlloc`.
|
// Transfer into `viewOrAlloc`.
|
||||||
Value viewOrAlloc = xferOp.getSource();
|
Value viewOrAlloc = xferOp.getBase();
|
||||||
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
|
if (!viewOrAlloc.getDefiningOp<memref::ViewOp>() &&
|
||||||
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
|
!viewOrAlloc.getDefiningOp<memref::AllocOp>())
|
||||||
return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
|
return rewriter.notifyMatchFailure(xferOp, "source not a view or alloc");
|
||||||
|
@ -119,7 +119,7 @@ static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter,
|
|||||||
template <typename TransferLikeOp>
|
template <typename TransferLikeOp>
|
||||||
static FailureOr<Value>
|
static FailureOr<Value>
|
||||||
getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
|
getTransferLikeOpSrcMemRef(TransferLikeOp transferLikeOp) {
|
||||||
Value src = transferLikeOp.getSource();
|
Value src = transferLikeOp.getBase();
|
||||||
if (isa<MemRefType>(src.getType()))
|
if (isa<MemRefType>(src.getType()))
|
||||||
return src;
|
return src;
|
||||||
return failure();
|
return failure();
|
||||||
|
@ -224,7 +224,7 @@ static Value getMemRefOperand(LoadOrStoreOpTy op) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static Value getMemRefOperand(vector::TransferReadOp op) {
|
static Value getMemRefOperand(vector::TransferReadOp op) {
|
||||||
return op.getSource();
|
return op.getBase();
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
|
static Value getMemRefOperand(nvgpu::LdMatrixOp op) {
|
||||||
@ -240,7 +240,7 @@ static Value getMemRefOperand(vector::MaskedLoadOp op) { return op.getBase(); }
|
|||||||
static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
|
static Value getMemRefOperand(vector::MaskedStoreOp op) { return op.getBase(); }
|
||||||
|
|
||||||
static Value getMemRefOperand(vector::TransferWriteOp op) {
|
static Value getMemRefOperand(vector::TransferWriteOp op) {
|
||||||
return op.getSource();
|
return op.getBase();
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
|
static Value getMemRefOperand(gpu::SubgroupMmaLoadMatrixOp op) {
|
||||||
|
@ -172,7 +172,7 @@ static Value getValueLoadedFromGlobal(Operation *op) {
|
|||||||
if (!load)
|
if (!load)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
||||||
auto loadType = dyn_cast<MemRefType>(load.getSource().getType());
|
auto loadType = dyn_cast<MemRefType>(load.getBase().getType());
|
||||||
if (!loadType || !hasDefaultMemorySpace(loadType))
|
if (!loadType || !hasDefaultMemorySpace(loadType))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return load;
|
return load;
|
||||||
@ -185,7 +185,7 @@ static bool isStoreToShared(Operation *op, Value v) {
|
|||||||
if (!store || store.getVector() != v)
|
if (!store || store.getVector() != v)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
auto storeType = dyn_cast<MemRefType>(store.getSource().getType());
|
auto storeType = dyn_cast<MemRefType>(store.getBase().getType());
|
||||||
return storeType || hasSharedMemorySpace(storeType);
|
return storeType || hasSharedMemorySpace(storeType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -71,9 +71,9 @@ Value nvgpu::getMemrefOperand(Operation *op) {
|
|||||||
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
|
if (auto storeOp = dyn_cast<memref::StoreOp>(op))
|
||||||
return storeOp.getMemref();
|
return storeOp.getMemref();
|
||||||
if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
|
if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
|
||||||
return transferWrite.getSource();
|
return transferWrite.getBase();
|
||||||
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
|
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
|
||||||
return transferRead.getSource();
|
return transferRead.getBase();
|
||||||
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
|
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
|
||||||
return storeOp.getBase();
|
return storeOp.getBase();
|
||||||
if (auto loadOp = dyn_cast<vector::LoadOp>(op))
|
if (auto loadOp = dyn_cast<vector::LoadOp>(op))
|
||||||
|
@ -285,7 +285,7 @@ bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferReadOp op) {
|
|||||||
// information to ensure correctness of downstream assumptions. It is possible
|
// information to ensure correctness of downstream assumptions. It is possible
|
||||||
// to enable this if caller can assert that tensor will be lowered in a
|
// to enable this if caller can assert that tensor will be lowered in a
|
||||||
// particular manner.
|
// particular manner.
|
||||||
auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
|
auto sourceType = dyn_cast<MemRefType>(op.getBase().getType());
|
||||||
if (!sourceType)
|
if (!sourceType)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
@ -309,7 +309,7 @@ bool nvgpu::canLowerToWarpMatrixOperation(vector::TransferWriteOp op) {
|
|||||||
return false;
|
return false;
|
||||||
// Currently we can't support reads on tensor types because we need stride
|
// Currently we can't support reads on tensor types because we need stride
|
||||||
// information to ensure correctness of downstream assumptions.
|
// information to ensure correctness of downstream assumptions.
|
||||||
auto sourceType = dyn_cast<MemRefType>(op.getSource().getType());
|
auto sourceType = dyn_cast<MemRefType>(op.getBase().getType());
|
||||||
if (!sourceType)
|
if (!sourceType)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ namespace tensor {
|
|||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
|
||||||
static Value getTensorOperand(vector::TransferReadOp op) {
|
static Value getTensorOperand(vector::TransferReadOp op) {
|
||||||
return op.getSource();
|
return op.getBase();
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value getTensorOperand(tensor::InsertSliceOp op) {
|
static Value getTensorOperand(tensor::InsertSliceOp op) {
|
||||||
|
@ -314,7 +314,7 @@ bool mlir::vector::isDisjointTransferIndices(
|
|||||||
bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
|
bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
|
||||||
VectorTransferOpInterface transferB,
|
VectorTransferOpInterface transferB,
|
||||||
bool testDynamicValueUsingBounds) {
|
bool testDynamicValueUsingBounds) {
|
||||||
if (transferA.getSource() != transferB.getSource())
|
if (transferA.getBase() != transferB.getBase())
|
||||||
return false;
|
return false;
|
||||||
return isDisjointTransferIndices(transferA, transferB,
|
return isDisjointTransferIndices(transferA, transferB,
|
||||||
testDynamicValueUsingBounds);
|
testDynamicValueUsingBounds);
|
||||||
@ -4205,7 +4205,7 @@ static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TransferReadOp::print(OpAsmPrinter &p) {
|
void TransferReadOp::print(OpAsmPrinter &p) {
|
||||||
p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
|
p << " " << getBase() << "[" << getIndices() << "], " << getPadding();
|
||||||
if (getMask())
|
if (getMask())
|
||||||
p << ", " << getMask();
|
p << ", " << getMask();
|
||||||
printTransferAttrs(p, *this);
|
printTransferAttrs(p, *this);
|
||||||
@ -4464,7 +4464,7 @@ static LogicalResult foldTransferFullMask(TransferOp op) {
|
|||||||
static Value foldRAW(TransferReadOp readOp) {
|
static Value foldRAW(TransferReadOp readOp) {
|
||||||
if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
|
if (!llvm::isa<RankedTensorType>(readOp.getShapedType()))
|
||||||
return {};
|
return {};
|
||||||
auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
|
auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
|
||||||
while (defWrite) {
|
while (defWrite) {
|
||||||
if (checkSameValueRAW(defWrite, readOp))
|
if (checkSameValueRAW(defWrite, readOp))
|
||||||
return defWrite.getVector();
|
return defWrite.getVector();
|
||||||
@ -4472,7 +4472,7 @@ static Value foldRAW(TransferReadOp readOp) {
|
|||||||
cast<VectorTransferOpInterface>(defWrite.getOperation()),
|
cast<VectorTransferOpInterface>(defWrite.getOperation()),
|
||||||
cast<VectorTransferOpInterface>(readOp.getOperation())))
|
cast<VectorTransferOpInterface>(readOp.getOperation())))
|
||||||
break;
|
break;
|
||||||
defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
|
defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
|
||||||
}
|
}
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
@ -4500,7 +4500,7 @@ void TransferReadOp::getEffects(
|
|||||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||||
&effects) {
|
&effects) {
|
||||||
if (llvm::isa<MemRefType>(getShapedType()))
|
if (llvm::isa<MemRefType>(getShapedType()))
|
||||||
effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
|
effects.emplace_back(MemoryEffects::Read::get(), &getBaseMutable(),
|
||||||
SideEffects::DefaultResource::get());
|
SideEffects::DefaultResource::get());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4542,7 +4542,7 @@ struct TransferReadAfterWriteToBroadcast
|
|||||||
if (readOp.hasOutOfBoundsDim() ||
|
if (readOp.hasOutOfBoundsDim() ||
|
||||||
!llvm::isa<RankedTensorType>(readOp.getShapedType()))
|
!llvm::isa<RankedTensorType>(readOp.getShapedType()))
|
||||||
return failure();
|
return failure();
|
||||||
auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
|
auto defWrite = readOp.getBase().getDefiningOp<vector::TransferWriteOp>();
|
||||||
if (!defWrite)
|
if (!defWrite)
|
||||||
return failure();
|
return failure();
|
||||||
// TODO: If the written transfer chunk is a superset of the read transfer
|
// TODO: If the written transfer chunk is a superset of the read transfer
|
||||||
@ -4727,7 +4727,7 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void TransferWriteOp::print(OpAsmPrinter &p) {
|
void TransferWriteOp::print(OpAsmPrinter &p) {
|
||||||
p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
|
p << " " << getVector() << ", " << getBase() << "[" << getIndices() << "]";
|
||||||
if (getMask())
|
if (getMask())
|
||||||
p << ", " << getMask();
|
p << ", " << getMask();
|
||||||
printTransferAttrs(p, *this);
|
printTransferAttrs(p, *this);
|
||||||
@ -4806,7 +4806,7 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write,
|
|||||||
if (write.getTransferRank() == 0)
|
if (write.getTransferRank() == 0)
|
||||||
return failure();
|
return failure();
|
||||||
auto rankedTensorType =
|
auto rankedTensorType =
|
||||||
llvm::dyn_cast<RankedTensorType>(write.getSource().getType());
|
llvm::dyn_cast<RankedTensorType>(write.getBase().getType());
|
||||||
// If not operating on tensors, bail.
|
// If not operating on tensors, bail.
|
||||||
if (!rankedTensorType)
|
if (!rankedTensorType)
|
||||||
return failure();
|
return failure();
|
||||||
@ -4828,7 +4828,7 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write,
|
|||||||
if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
|
if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
|
||||||
return failure();
|
return failure();
|
||||||
// Tensor types must be the same.
|
// Tensor types must be the same.
|
||||||
if (read.getSource().getType() != rankedTensorType)
|
if (read.getBase().getType() != rankedTensorType)
|
||||||
return failure();
|
return failure();
|
||||||
// Vector types must be the same.
|
// Vector types must be the same.
|
||||||
if (read.getVectorType() != write.getVectorType())
|
if (read.getVectorType() != write.getVectorType())
|
||||||
@ -4845,13 +4845,13 @@ static LogicalResult foldReadInitWrite(TransferWriteOp write,
|
|||||||
llvm::any_of(write.getIndices(), isNotConstantZero))
|
llvm::any_of(write.getIndices(), isNotConstantZero))
|
||||||
return failure();
|
return failure();
|
||||||
// Success.
|
// Success.
|
||||||
results.push_back(read.getSource());
|
results.push_back(read.getBase());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool checkSameValueWAR(vector::TransferReadOp read,
|
static bool checkSameValueWAR(vector::TransferReadOp read,
|
||||||
vector::TransferWriteOp write) {
|
vector::TransferWriteOp write) {
|
||||||
return read.getSource() == write.getSource() &&
|
return read.getBase() == write.getBase() &&
|
||||||
read.getIndices() == write.getIndices() &&
|
read.getIndices() == write.getIndices() &&
|
||||||
read.getPermutationMap() == write.getPermutationMap() &&
|
read.getPermutationMap() == write.getPermutationMap() &&
|
||||||
read.getVectorType() == write.getVectorType() && !read.getMask() &&
|
read.getVectorType() == write.getVectorType() && !read.getMask() &&
|
||||||
@ -4873,7 +4873,7 @@ static bool checkSameValueWAR(vector::TransferReadOp read,
|
|||||||
/// ```
|
/// ```
|
||||||
static LogicalResult foldWAR(TransferWriteOp write,
|
static LogicalResult foldWAR(TransferWriteOp write,
|
||||||
SmallVectorImpl<OpFoldResult> &results) {
|
SmallVectorImpl<OpFoldResult> &results) {
|
||||||
if (!llvm::isa<RankedTensorType>(write.getSource().getType()))
|
if (!llvm::isa<RankedTensorType>(write.getBase().getType()))
|
||||||
return failure();
|
return failure();
|
||||||
auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
|
auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
|
||||||
if (!read)
|
if (!read)
|
||||||
@ -4881,7 +4881,7 @@ static LogicalResult foldWAR(TransferWriteOp write,
|
|||||||
|
|
||||||
if (!checkSameValueWAR(read, write))
|
if (!checkSameValueWAR(read, write))
|
||||||
return failure();
|
return failure();
|
||||||
results.push_back(read.getSource());
|
results.push_back(read.getBase());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4953,12 +4953,11 @@ public:
|
|||||||
return failure();
|
return failure();
|
||||||
vector::TransferWriteOp writeToModify = writeOp;
|
vector::TransferWriteOp writeToModify = writeOp;
|
||||||
|
|
||||||
auto defWrite =
|
auto defWrite = writeOp.getBase().getDefiningOp<vector::TransferWriteOp>();
|
||||||
writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
|
|
||||||
while (defWrite) {
|
while (defWrite) {
|
||||||
if (checkSameValueWAW(writeOp, defWrite)) {
|
if (checkSameValueWAW(writeOp, defWrite)) {
|
||||||
rewriter.modifyOpInPlace(writeToModify, [&]() {
|
rewriter.modifyOpInPlace(writeToModify, [&]() {
|
||||||
writeToModify.getSourceMutable().assign(defWrite.getSource());
|
writeToModify.getBaseMutable().assign(defWrite.getBase());
|
||||||
});
|
});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
@ -4971,7 +4970,7 @@ public:
|
|||||||
if (!defWrite->hasOneUse())
|
if (!defWrite->hasOneUse())
|
||||||
break;
|
break;
|
||||||
writeToModify = defWrite;
|
writeToModify = defWrite;
|
||||||
defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
|
defWrite = defWrite.getBase().getDefiningOp<vector::TransferWriteOp>();
|
||||||
}
|
}
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
@ -52,7 +52,7 @@ struct TransferReadOpInterface
|
|||||||
auto readOp = cast<vector::TransferReadOp>(op);
|
auto readOp = cast<vector::TransferReadOp>(op);
|
||||||
assert(isa<TensorType>(readOp.getShapedType()) &&
|
assert(isa<TensorType>(readOp.getShapedType()) &&
|
||||||
"only tensor types expected");
|
"only tensor types expected");
|
||||||
FailureOr<Value> buffer = getBuffer(rewriter, readOp.getSource(), options);
|
FailureOr<Value> buffer = getBuffer(rewriter, readOp.getBase(), options);
|
||||||
if (failed(buffer))
|
if (failed(buffer))
|
||||||
return failure();
|
return failure();
|
||||||
replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
|
replaceOpWithNewBufferizedOp<vector::TransferReadOp>(
|
||||||
@ -110,7 +110,7 @@ struct TransferWriteOpInterface
|
|||||||
|
|
||||||
// Create a new transfer_write on buffer that doesn't have a return value.
|
// Create a new transfer_write on buffer that doesn't have a return value.
|
||||||
FailureOr<Value> resultBuffer =
|
FailureOr<Value> resultBuffer =
|
||||||
getBuffer(rewriter, writeOp.getSource(), options);
|
getBuffer(rewriter, writeOp.getBase(), options);
|
||||||
if (failed(resultBuffer))
|
if (failed(resultBuffer))
|
||||||
return failure();
|
return failure();
|
||||||
rewriter.create<vector::TransferWriteOp>(
|
rewriter.create<vector::TransferWriteOp>(
|
||||||
|
@ -222,7 +222,7 @@ public:
|
|||||||
|
|
||||||
// Replace the `vector.mask` operation.
|
// Replace the `vector.mask` operation.
|
||||||
rewriter.replaceOpWithNewOp<TransferReadOp>(
|
rewriter.replaceOpWithNewOp<TransferReadOp>(
|
||||||
maskingOp.getOperation(), readOp.getVectorType(), readOp.getSource(),
|
maskingOp.getOperation(), readOp.getVectorType(), readOp.getBase(),
|
||||||
readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
|
readOp.getIndices(), readOp.getPermutationMap(), readOp.getPadding(),
|
||||||
maskingOp.getMask(), readOp.getInBounds());
|
maskingOp.getMask(), readOp.getInBounds());
|
||||||
return success();
|
return success();
|
||||||
@ -245,7 +245,7 @@ public:
|
|||||||
// Replace the `vector.mask` operation.
|
// Replace the `vector.mask` operation.
|
||||||
rewriter.replaceOpWithNewOp<TransferWriteOp>(
|
rewriter.replaceOpWithNewOp<TransferWriteOp>(
|
||||||
maskingOp.getOperation(), resultType, writeOp.getVector(),
|
maskingOp.getOperation(), resultType, writeOp.getVector(),
|
||||||
writeOp.getSource(), writeOp.getIndices(), writeOp.getPermutationMap(),
|
writeOp.getBase(), writeOp.getIndices(), writeOp.getPermutationMap(),
|
||||||
maskingOp.getMask(), writeOp.getInBounds());
|
maskingOp.getMask(), writeOp.getInBounds());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -139,7 +139,7 @@ struct TransferReadPermutationLowering
|
|||||||
VectorType newReadType = VectorType::get(
|
VectorType newReadType = VectorType::get(
|
||||||
newVectorShape, op.getVectorType().getElementType(), newScalableDims);
|
newVectorShape, op.getVectorType().getElementType(), newScalableDims);
|
||||||
Value newRead = rewriter.create<vector::TransferReadOp>(
|
Value newRead = rewriter.create<vector::TransferReadOp>(
|
||||||
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
|
op.getLoc(), newReadType, op.getBase(), op.getIndices(),
|
||||||
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
|
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
|
||||||
newInBoundsAttr);
|
newInBoundsAttr);
|
||||||
|
|
||||||
@ -214,7 +214,7 @@ struct TransferWritePermutationLowering
|
|||||||
auto newMap = AffineMap::getMinorIdentityMap(
|
auto newMap = AffineMap::getMinorIdentityMap(
|
||||||
map.getNumDims(), map.getNumResults(), rewriter.getContext());
|
map.getNumDims(), map.getNumResults(), rewriter.getContext());
|
||||||
auto newWrite = rewriter.create<vector::TransferWriteOp>(
|
auto newWrite = rewriter.create<vector::TransferWriteOp>(
|
||||||
op.getLoc(), newVec, op.getSource(), op.getIndices(),
|
op.getLoc(), newVec, op.getBase(), op.getIndices(),
|
||||||
AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
|
AffineMapAttr::get(newMap), op.getMask(), newInBoundsAttr);
|
||||||
if (newWrite.hasPureTensorSemantics())
|
if (newWrite.hasPureTensorSemantics())
|
||||||
return newWrite.getResult();
|
return newWrite.getResult();
|
||||||
@ -300,7 +300,7 @@ struct TransferWriteNonPermutationLowering
|
|||||||
}
|
}
|
||||||
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
|
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
|
||||||
auto newWrite = rewriter.create<vector::TransferWriteOp>(
|
auto newWrite = rewriter.create<vector::TransferWriteOp>(
|
||||||
op.getLoc(), newVec, op.getSource(), op.getIndices(),
|
op.getLoc(), newVec, op.getBase(), op.getIndices(),
|
||||||
AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
|
AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
|
||||||
if (newWrite.hasPureTensorSemantics())
|
if (newWrite.hasPureTensorSemantics())
|
||||||
return newWrite.getResult();
|
return newWrite.getResult();
|
||||||
@ -371,7 +371,7 @@ struct TransferOpReduceRank
|
|||||||
op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
|
op.getInBoundsAttr().getValue().take_back(reducedShapeRank))
|
||||||
: ArrayAttr();
|
: ArrayAttr();
|
||||||
Value newRead = rewriter.create<vector::TransferReadOp>(
|
Value newRead = rewriter.create<vector::TransferReadOp>(
|
||||||
op.getLoc(), newReadType, op.getSource(), op.getIndices(),
|
op.getLoc(), newReadType, op.getBase(), op.getIndices(),
|
||||||
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
|
AffineMapAttr::get(newMap), op.getPadding(), op.getMask(),
|
||||||
newInBoundsAttr);
|
newInBoundsAttr);
|
||||||
return rewriter
|
return rewriter
|
||||||
@ -474,12 +474,12 @@ struct TransferReadToVectorLoadLowering
|
|||||||
Value fill = rewriter.create<vector::SplatOp>(
|
Value fill = rewriter.create<vector::SplatOp>(
|
||||||
read.getLoc(), unbroadcastedVectorType, read.getPadding());
|
read.getLoc(), unbroadcastedVectorType, read.getPadding());
|
||||||
res = rewriter.create<vector::MaskedLoadOp>(
|
res = rewriter.create<vector::MaskedLoadOp>(
|
||||||
read.getLoc(), unbroadcastedVectorType, read.getSource(),
|
read.getLoc(), unbroadcastedVectorType, read.getBase(),
|
||||||
read.getIndices(), read.getMask(), fill);
|
read.getIndices(), read.getMask(), fill);
|
||||||
} else {
|
} else {
|
||||||
res = rewriter.create<vector::LoadOp>(
|
res = rewriter.create<vector::LoadOp>(read.getLoc(),
|
||||||
read.getLoc(), unbroadcastedVectorType, read.getSource(),
|
unbroadcastedVectorType,
|
||||||
read.getIndices());
|
read.getBase(), read.getIndices());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Insert a broadcasting op if required.
|
// Insert a broadcasting op if required.
|
||||||
@ -570,11 +570,11 @@ struct TransferWriteToVectorStoreLowering
|
|||||||
});
|
});
|
||||||
|
|
||||||
rewriter.create<vector::MaskedStoreOp>(
|
rewriter.create<vector::MaskedStoreOp>(
|
||||||
write.getLoc(), write.getSource(), write.getIndices(),
|
write.getLoc(), write.getBase(), write.getIndices(), write.getMask(),
|
||||||
write.getMask(), write.getVector());
|
write.getVector());
|
||||||
} else {
|
} else {
|
||||||
rewriter.create<vector::StoreOp>(write.getLoc(), write.getVector(),
|
rewriter.create<vector::StoreOp>(write.getLoc(), write.getVector(),
|
||||||
write.getSource(), write.getIndices());
|
write.getBase(), write.getIndices());
|
||||||
}
|
}
|
||||||
// There's no return value for StoreOps. Use Value() to signal success to
|
// There's no return value for StoreOps. Use Value() to signal success to
|
||||||
// matchAndRewrite.
|
// matchAndRewrite.
|
||||||
|
@ -37,7 +37,7 @@ struct TransferReadOpSubsetExtractionOpInterface
|
|||||||
: public SubsetExtractionOpInterface::ExternalModel<
|
: public SubsetExtractionOpInterface::ExternalModel<
|
||||||
TransferReadOpSubsetExtractionOpInterface, vector::TransferReadOp> {
|
TransferReadOpSubsetExtractionOpInterface, vector::TransferReadOp> {
|
||||||
OpOperand &getSourceOperand(Operation *op) const {
|
OpOperand &getSourceOperand(Operation *op) const {
|
||||||
return cast<vector::TransferReadOp>(op).getSourceMutable();
|
return cast<vector::TransferReadOp>(op).getBaseMutable();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -49,7 +49,7 @@ struct TransferWriteOpSubsetInsertionOpInterface
|
|||||||
}
|
}
|
||||||
|
|
||||||
OpOperand &getDestinationOperand(Operation *op) const {
|
OpOperand &getDestinationOperand(Operation *op) const {
|
||||||
return cast<vector::TransferWriteOp>(op).getSourceMutable();
|
return cast<vector::TransferWriteOp>(op).getBaseMutable();
|
||||||
}
|
}
|
||||||
|
|
||||||
Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
|
Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
|
||||||
|
@ -718,7 +718,7 @@ struct WarpOpTransferRead : public WarpDistributionPattern {
|
|||||||
auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
|
auto read = operand->get().getDefiningOp<vector::TransferReadOp>();
|
||||||
|
|
||||||
// Source must be defined outside of the region.
|
// Source must be defined outside of the region.
|
||||||
if (!warpOp.isDefinedOutsideOfRegion(read.getSource()))
|
if (!warpOp.isDefinedOutsideOfRegion(read.getBase()))
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
read, "source must be defined outside of the region");
|
read, "source must be defined outside of the region");
|
||||||
|
|
||||||
@ -802,7 +802,7 @@ struct WarpOpTransferRead : public WarpDistributionPattern {
|
|||||||
hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
|
hasMask ? newWarpOp.getResult(newRetIndices[newRetIndices.size() - 1])
|
||||||
: Value();
|
: Value();
|
||||||
auto newRead = rewriter.create<vector::TransferReadOp>(
|
auto newRead = rewriter.create<vector::TransferReadOp>(
|
||||||
read.getLoc(), distributedVal.getType(), read.getSource(), newIndices,
|
read.getLoc(), distributedVal.getType(), read.getBase(), newIndices,
|
||||||
read.getPermutationMapAttr(), newPadding, newMask,
|
read.getPermutationMapAttr(), newPadding, newMask,
|
||||||
read.getInBoundsAttr());
|
read.getInBoundsAttr());
|
||||||
|
|
||||||
|
@ -230,7 +230,7 @@ struct CastAwayTransferReadLeadingOneDim
|
|||||||
if (read.getTransferRank() == 0)
|
if (read.getTransferRank() == 0)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto shapedType = cast<ShapedType>(read.getSource().getType());
|
auto shapedType = cast<ShapedType>(read.getBase().getType());
|
||||||
if (shapedType.getElementType() != read.getVectorType().getElementType())
|
if (shapedType.getElementType() != read.getVectorType().getElementType())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@ -260,7 +260,7 @@ struct CastAwayTransferReadLeadingOneDim
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto newRead = rewriter.create<vector::TransferReadOp>(
|
auto newRead = rewriter.create<vector::TransferReadOp>(
|
||||||
read.getLoc(), newType, read.getSource(), read.getIndices(),
|
read.getLoc(), newType, read.getBase(), read.getIndices(),
|
||||||
AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
|
AffineMapAttr::get(newMap), read.getPadding(), mask, inBoundsAttr);
|
||||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
|
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
|
||||||
|
|
||||||
@ -284,7 +284,7 @@ struct CastAwayTransferWriteLeadingOneDim
|
|||||||
if (write.getTransferRank() == 0)
|
if (write.getTransferRank() == 0)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto shapedType = dyn_cast<ShapedType>(write.getSource().getType());
|
auto shapedType = dyn_cast<ShapedType>(write.getBase().getType());
|
||||||
if (shapedType.getElementType() != write.getVectorType().getElementType())
|
if (shapedType.getElementType() != write.getVectorType().getElementType())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@ -314,13 +314,13 @@ struct CastAwayTransferWriteLeadingOneDim
|
|||||||
Value newMask = dropUnitDimsFromMask(
|
Value newMask = dropUnitDimsFromMask(
|
||||||
rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
|
rewriter, write.getLoc(), write.getMask(), newType, newMap, maskType);
|
||||||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
||||||
write, newVector, write.getSource(), write.getIndices(),
|
write, newVector, write.getBase(), write.getIndices(),
|
||||||
AffineMapAttr::get(newMap), newMask, inBoundsAttr);
|
AffineMapAttr::get(newMap), newMask, inBoundsAttr);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
||||||
write, newVector, write.getSource(), write.getIndices(),
|
write, newVector, write.getBase(), write.getIndices(),
|
||||||
AffineMapAttr::get(newMap), inBoundsAttr);
|
AffineMapAttr::get(newMap), inBoundsAttr);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1249,7 +1249,7 @@ struct ConvertVectorTransferRead final
|
|||||||
|
|
||||||
auto loc = op.getLoc();
|
auto loc = op.getLoc();
|
||||||
auto containerElemTy =
|
auto containerElemTy =
|
||||||
cast<MemRefType>(adaptor.getSource().getType()).getElementType();
|
cast<MemRefType>(adaptor.getBase().getType()).getElementType();
|
||||||
Type emulatedElemTy = op.getType().getElementType();
|
Type emulatedElemTy = op.getType().getElementType();
|
||||||
int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
|
int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
|
||||||
int containerBits = containerElemTy.getIntOrFloatBitWidth();
|
int containerBits = containerElemTy.getIntOrFloatBitWidth();
|
||||||
@ -1272,7 +1272,7 @@ struct ConvertVectorTransferRead final
|
|||||||
adaptor.getPadding());
|
adaptor.getPadding());
|
||||||
|
|
||||||
auto stridedMetadata =
|
auto stridedMetadata =
|
||||||
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getSource());
|
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
|
||||||
|
|
||||||
OpFoldResult linearizedIndices;
|
OpFoldResult linearizedIndices;
|
||||||
memref::LinearizedMemRefInfo linearizedInfo;
|
memref::LinearizedMemRefInfo linearizedInfo;
|
||||||
@ -1294,7 +1294,7 @@ struct ConvertVectorTransferRead final
|
|||||||
emulatedPerContainerElem);
|
emulatedPerContainerElem);
|
||||||
|
|
||||||
auto newRead = rewriter.create<vector::TransferReadOp>(
|
auto newRead = rewriter.create<vector::TransferReadOp>(
|
||||||
loc, VectorType::get(numElements, containerElemTy), adaptor.getSource(),
|
loc, VectorType::get(numElements, containerElemTy), adaptor.getBase(),
|
||||||
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
|
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
|
||||||
newPadding);
|
newPadding);
|
||||||
|
|
||||||
|
@ -92,7 +92,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
|
|||||||
<< "\n");
|
<< "\n");
|
||||||
llvm::SmallVector<Operation *, 8> blockingAccesses;
|
llvm::SmallVector<Operation *, 8> blockingAccesses;
|
||||||
Operation *firstOverwriteCandidate = nullptr;
|
Operation *firstOverwriteCandidate = nullptr;
|
||||||
Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getSource()));
|
Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getBase()));
|
||||||
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
|
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
|
||||||
source.getUsers().end());
|
source.getUsers().end());
|
||||||
llvm::SmallDenseSet<Operation *, 32> processed;
|
llvm::SmallDenseSet<Operation *, 32> processed;
|
||||||
@ -112,8 +112,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
|
|||||||
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
|
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
|
||||||
// Check candidate that can override the store.
|
// Check candidate that can override the store.
|
||||||
if (memref::isSameViewOrTrivialAlias(
|
if (memref::isSameViewOrTrivialAlias(
|
||||||
cast<MemrefValue>(nextWrite.getSource()),
|
cast<MemrefValue>(nextWrite.getBase()),
|
||||||
cast<MemrefValue>(write.getSource())) &&
|
cast<MemrefValue>(write.getBase())) &&
|
||||||
checkSameValueWAW(nextWrite, write) &&
|
checkSameValueWAW(nextWrite, write) &&
|
||||||
postDominators.postDominates(nextWrite, write)) {
|
postDominators.postDominates(nextWrite, write)) {
|
||||||
if (firstOverwriteCandidate == nullptr ||
|
if (firstOverwriteCandidate == nullptr ||
|
||||||
@ -178,7 +178,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
|
|||||||
<< "\n");
|
<< "\n");
|
||||||
SmallVector<Operation *, 8> blockingWrites;
|
SmallVector<Operation *, 8> blockingWrites;
|
||||||
vector::TransferWriteOp lastwrite = nullptr;
|
vector::TransferWriteOp lastwrite = nullptr;
|
||||||
Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getSource()));
|
Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getBase()));
|
||||||
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
|
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
|
||||||
source.getUsers().end());
|
source.getUsers().end());
|
||||||
llvm::SmallDenseSet<Operation *, 32> processed;
|
llvm::SmallDenseSet<Operation *, 32> processed;
|
||||||
@ -202,8 +202,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
|
|||||||
/*testDynamicValueUsingBounds=*/true))
|
/*testDynamicValueUsingBounds=*/true))
|
||||||
continue;
|
continue;
|
||||||
if (memref::isSameViewOrTrivialAlias(
|
if (memref::isSameViewOrTrivialAlias(
|
||||||
cast<MemrefValue>(read.getSource()),
|
cast<MemrefValue>(read.getBase()),
|
||||||
cast<MemrefValue>(write.getSource())) &&
|
cast<MemrefValue>(write.getBase())) &&
|
||||||
dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
|
dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
|
||||||
if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
|
if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
|
||||||
lastwrite = write;
|
lastwrite = write;
|
||||||
@ -351,7 +351,7 @@ class TransferReadDropUnitDimsPattern
|
|||||||
auto loc = transferReadOp.getLoc();
|
auto loc = transferReadOp.getLoc();
|
||||||
Value vector = transferReadOp.getVector();
|
Value vector = transferReadOp.getVector();
|
||||||
VectorType vectorType = cast<VectorType>(vector.getType());
|
VectorType vectorType = cast<VectorType>(vector.getType());
|
||||||
Value source = transferReadOp.getSource();
|
Value source = transferReadOp.getBase();
|
||||||
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
|
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
|
||||||
// TODO: support tensor types.
|
// TODO: support tensor types.
|
||||||
if (!sourceType)
|
if (!sourceType)
|
||||||
@ -433,7 +433,7 @@ class TransferWriteDropUnitDimsPattern
|
|||||||
auto loc = transferWriteOp.getLoc();
|
auto loc = transferWriteOp.getLoc();
|
||||||
Value vector = transferWriteOp.getVector();
|
Value vector = transferWriteOp.getVector();
|
||||||
VectorType vectorType = cast<VectorType>(vector.getType());
|
VectorType vectorType = cast<VectorType>(vector.getType());
|
||||||
Value source = transferWriteOp.getSource();
|
Value source = transferWriteOp.getBase();
|
||||||
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
|
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
|
||||||
// TODO: support tensor type.
|
// TODO: support tensor type.
|
||||||
if (!sourceType)
|
if (!sourceType)
|
||||||
@ -604,7 +604,7 @@ public:
|
|||||||
auto loc = transferReadOp.getLoc();
|
auto loc = transferReadOp.getLoc();
|
||||||
Value vector = transferReadOp.getVector();
|
Value vector = transferReadOp.getVector();
|
||||||
VectorType vectorType = cast<VectorType>(vector.getType());
|
VectorType vectorType = cast<VectorType>(vector.getType());
|
||||||
auto source = transferReadOp.getSource();
|
auto source = transferReadOp.getBase();
|
||||||
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
|
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
|
||||||
|
|
||||||
// 0. Check pre-conditions
|
// 0. Check pre-conditions
|
||||||
@ -695,7 +695,7 @@ public:
|
|||||||
auto loc = transferWriteOp.getLoc();
|
auto loc = transferWriteOp.getLoc();
|
||||||
Value vector = transferWriteOp.getVector();
|
Value vector = transferWriteOp.getVector();
|
||||||
VectorType vectorType = cast<VectorType>(vector.getType());
|
VectorType vectorType = cast<VectorType>(vector.getType());
|
||||||
Value source = transferWriteOp.getSource();
|
Value source = transferWriteOp.getBase();
|
||||||
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
|
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
|
||||||
|
|
||||||
// 0. Check pre-conditions
|
// 0. Check pre-conditions
|
||||||
@ -851,12 +851,12 @@ class RewriteScalarExtractElementOfTransferRead
|
|||||||
*getConstantIntValue(ofr));
|
*getConstantIntValue(ofr));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (isa<MemRefType>(xferOp.getSource().getType())) {
|
if (isa<MemRefType>(xferOp.getBase().getType())) {
|
||||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
|
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
|
||||||
newIndices);
|
newIndices);
|
||||||
} else {
|
} else {
|
||||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
|
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
|
||||||
extractOp, xferOp.getSource(), newIndices);
|
extractOp, xferOp.getBase(), newIndices);
|
||||||
}
|
}
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
@ -899,12 +899,12 @@ class RewriteScalarExtractOfTransferRead
|
|||||||
extractOp.getLoc(), *getConstantIntValue(ofr));
|
extractOp.getLoc(), *getConstantIntValue(ofr));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (isa<MemRefType>(xferOp.getSource().getType())) {
|
if (isa<MemRefType>(xferOp.getBase().getType())) {
|
||||||
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getSource(),
|
rewriter.replaceOpWithNewOp<memref::LoadOp>(extractOp, xferOp.getBase(),
|
||||||
newIndices);
|
newIndices);
|
||||||
} else {
|
} else {
|
||||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
|
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
|
||||||
extractOp, xferOp.getSource(), newIndices);
|
extractOp, xferOp.getBase(), newIndices);
|
||||||
}
|
}
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
@ -932,12 +932,12 @@ class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
|
|||||||
Value scalar =
|
Value scalar =
|
||||||
rewriter.create<vector::ExtractOp>(xferOp.getLoc(), xferOp.getVector());
|
rewriter.create<vector::ExtractOp>(xferOp.getLoc(), xferOp.getVector());
|
||||||
// Construct a scalar store.
|
// Construct a scalar store.
|
||||||
if (isa<MemRefType>(xferOp.getSource().getType())) {
|
if (isa<MemRefType>(xferOp.getBase().getType())) {
|
||||||
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
||||||
xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
|
xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
|
||||||
} else {
|
} else {
|
||||||
rewriter.replaceOpWithNewOp<tensor::InsertOp>(
|
rewriter.replaceOpWithNewOp<tensor::InsertOp>(
|
||||||
xferOp, scalar, xferOp.getSource(), xferOp.getIndices());
|
xferOp, scalar, xferOp.getBase(), xferOp.getIndices());
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -58,7 +58,7 @@ static Value createInBoundsCond(RewriterBase &b,
|
|||||||
b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize),
|
b, loc, b.getAffineDimExpr(0) + b.getAffineConstantExpr(vectorSize),
|
||||||
{xferOp.getIndices()[indicesIdx]});
|
{xferOp.getIndices()[indicesIdx]});
|
||||||
OpFoldResult dimSz =
|
OpFoldResult dimSz =
|
||||||
memref::getMixedSize(b, loc, xferOp.getSource(), indicesIdx);
|
memref::getMixedSize(b, loc, xferOp.getBase(), indicesIdx);
|
||||||
auto maybeCstSum = getConstantIntValue(sum);
|
auto maybeCstSum = getConstantIntValue(sum);
|
||||||
auto maybeCstDimSz = getConstantIntValue(dimSz);
|
auto maybeCstDimSz = getConstantIntValue(dimSz);
|
||||||
if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
|
if (maybeCstSum && maybeCstDimSz && *maybeCstSum <= *maybeCstDimSz)
|
||||||
@ -185,7 +185,7 @@ static Value castToCompatibleMemRefType(OpBuilder &b, Value memref,
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Operates under a scoped context to build the intersection between the
|
/// Operates under a scoped context to build the intersection between the
|
||||||
/// view `xferOp.getSource()` @ `xferOp.getIndices()` and the view `alloc`.
|
/// view `xferOp.getbase()` @ `xferOp.getIndices()` and the view `alloc`.
|
||||||
// TODO: view intersection/union/differences should be a proper std op.
|
// TODO: view intersection/union/differences should be a proper std op.
|
||||||
static std::pair<Value, Value>
|
static std::pair<Value, Value>
|
||||||
createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
|
createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
|
||||||
@ -202,8 +202,8 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
|
|||||||
auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
|
auto isaWrite = isa<vector::TransferWriteOp>(xferOp);
|
||||||
xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
|
xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) {
|
||||||
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
|
using MapList = ArrayRef<ArrayRef<AffineExpr>>;
|
||||||
Value dimMemRef = b.create<memref::DimOp>(xferOp.getLoc(),
|
Value dimMemRef =
|
||||||
xferOp.getSource(), indicesIdx);
|
b.create<memref::DimOp>(xferOp.getLoc(), xferOp.getBase(), indicesIdx);
|
||||||
Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
|
Value dimAlloc = b.create<memref::DimOp>(loc, alloc, resultIdx);
|
||||||
Value index = xferOp.getIndices()[indicesIdx];
|
Value index = xferOp.getIndices()[indicesIdx];
|
||||||
AffineExpr i, j, k;
|
AffineExpr i, j, k;
|
||||||
@ -221,9 +221,9 @@ createSubViewIntersection(RewriterBase &b, VectorTransferOpInterface xferOp,
|
|||||||
SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
|
SmallVector<OpFoldResult> destIndices(memrefRank, b.getIndexAttr(0));
|
||||||
SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
|
SmallVector<OpFoldResult> strides(memrefRank, b.getIndexAttr(1));
|
||||||
auto copySrc = b.create<memref::SubViewOp>(
|
auto copySrc = b.create<memref::SubViewOp>(
|
||||||
loc, isaWrite ? alloc : xferOp.getSource(), srcIndices, sizes, strides);
|
loc, isaWrite ? alloc : xferOp.getBase(), srcIndices, sizes, strides);
|
||||||
auto copyDest = b.create<memref::SubViewOp>(
|
auto copyDest = b.create<memref::SubViewOp>(
|
||||||
loc, isaWrite ? xferOp.getSource() : alloc, destIndices, sizes, strides);
|
loc, isaWrite ? xferOp.getBase() : alloc, destIndices, sizes, strides);
|
||||||
return std::make_pair(copySrc, copyDest);
|
return std::make_pair(copySrc, copyDest);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -252,7 +252,7 @@ createFullPartialLinalgCopy(RewriterBase &b, vector::TransferReadOp xferOp,
|
|||||||
MemRefType compatibleMemRefType, Value alloc) {
|
MemRefType compatibleMemRefType, Value alloc) {
|
||||||
Location loc = xferOp.getLoc();
|
Location loc = xferOp.getLoc();
|
||||||
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
Value memref = xferOp.getSource();
|
Value memref = xferOp.getBase();
|
||||||
return b.create<scf::IfOp>(
|
return b.create<scf::IfOp>(
|
||||||
loc, inBoundsCond,
|
loc, inBoundsCond,
|
||||||
[&](OpBuilder &b, Location loc) {
|
[&](OpBuilder &b, Location loc) {
|
||||||
@ -305,7 +305,7 @@ static scf::IfOp createFullPartialVectorTransferRead(
|
|||||||
Location loc = xferOp.getLoc();
|
Location loc = xferOp.getLoc();
|
||||||
scf::IfOp fullPartialIfOp;
|
scf::IfOp fullPartialIfOp;
|
||||||
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
Value memref = xferOp.getSource();
|
Value memref = xferOp.getBase();
|
||||||
return b.create<scf::IfOp>(
|
return b.create<scf::IfOp>(
|
||||||
loc, inBoundsCond,
|
loc, inBoundsCond,
|
||||||
[&](OpBuilder &b, Location loc) {
|
[&](OpBuilder &b, Location loc) {
|
||||||
@ -352,7 +352,7 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp,
|
|||||||
MemRefType compatibleMemRefType, Value alloc) {
|
MemRefType compatibleMemRefType, Value alloc) {
|
||||||
Location loc = xferOp.getLoc();
|
Location loc = xferOp.getLoc();
|
||||||
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
|
||||||
Value memref = xferOp.getSource();
|
Value memref = xferOp.getBase();
|
||||||
return b
|
return b
|
||||||
.create<scf::IfOp>(
|
.create<scf::IfOp>(
|
||||||
loc, inBoundsCond,
|
loc, inBoundsCond,
|
||||||
@ -509,7 +509,7 @@ static Operation *getAutomaticAllocationScope(Operation *op) {
|
|||||||
///
|
///
|
||||||
/// Preconditions:
|
/// Preconditions:
|
||||||
/// 1. `xferOp.getPermutationMap()` must be a minor identity map
|
/// 1. `xferOp.getPermutationMap()` must be a minor identity map
|
||||||
/// 2. the rank of the `xferOp.getSource()` and the rank of the
|
/// 2. the rank of the `xferOp.getBase()` and the rank of the
|
||||||
/// `xferOp.getVector()` must be equal. This will be relaxed in the future
|
/// `xferOp.getVector()` must be equal. This will be relaxed in the future
|
||||||
/// but requires rank-reducing subviews.
|
/// but requires rank-reducing subviews.
|
||||||
LogicalResult mlir::vector::splitFullAndPartialTransfer(
|
LogicalResult mlir::vector::splitFullAndPartialTransfer(
|
||||||
@ -611,7 +611,7 @@ LogicalResult mlir::vector::splitFullAndPartialTransfer(
|
|||||||
// The operation is cloned to prevent deleting information needed for the
|
// The operation is cloned to prevent deleting information needed for the
|
||||||
// later IR creation.
|
// later IR creation.
|
||||||
IRMapping mapping;
|
IRMapping mapping;
|
||||||
mapping.map(xferWriteOp.getSource(), memrefAndIndices.front());
|
mapping.map(xferWriteOp.getBase(), memrefAndIndices.front());
|
||||||
mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
|
mapping.map(xferWriteOp.getIndices(), memrefAndIndices.drop_front());
|
||||||
auto *clone = b.clone(*xferWriteOp, mapping);
|
auto *clone = b.clone(*xferWriteOp, mapping);
|
||||||
clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
|
clone->setAttr(xferWriteOp.getInBoundsAttrName(), inBoundsAttr);
|
||||||
|
@ -1265,7 +1265,7 @@ public:
|
|||||||
unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
|
unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1;
|
||||||
Value off = xferOp.getIndices()[lastIndex];
|
Value off = xferOp.getIndices()[lastIndex];
|
||||||
Value dim =
|
Value dim =
|
||||||
vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex);
|
vector::createOrFoldDimOp(rewriter, loc, xferOp.getBase(), lastIndex);
|
||||||
Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
|
Value b = rewriter.create<arith::SubIOp>(loc, dim.getType(), dim, off);
|
||||||
Value mask = rewriter.create<vector::CreateMaskOp>(
|
Value mask = rewriter.create<vector::CreateMaskOp>(
|
||||||
loc,
|
loc,
|
||||||
@ -1437,7 +1437,7 @@ class DropInnerMostUnitDimsTransferRead
|
|||||||
if (readOp.getMask())
|
if (readOp.getMask())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto srcType = dyn_cast<MemRefType>(readOp.getSource().getType());
|
auto srcType = dyn_cast<MemRefType>(readOp.getBase().getType());
|
||||||
if (!srcType)
|
if (!srcType)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@ -1469,7 +1469,7 @@ class DropInnerMostUnitDimsTransferRead
|
|||||||
|
|
||||||
auto loc = readOp.getLoc();
|
auto loc = readOp.getLoc();
|
||||||
SmallVector<OpFoldResult> sizes =
|
SmallVector<OpFoldResult> sizes =
|
||||||
memref::getMixedSizes(rewriter, loc, readOp.getSource());
|
memref::getMixedSizes(rewriter, loc, readOp.getBase());
|
||||||
SmallVector<OpFoldResult> offsets(srcType.getRank(),
|
SmallVector<OpFoldResult> offsets(srcType.getRank(),
|
||||||
rewriter.getIndexAttr(0));
|
rewriter.getIndexAttr(0));
|
||||||
SmallVector<OpFoldResult> strides(srcType.getRank(),
|
SmallVector<OpFoldResult> strides(srcType.getRank(),
|
||||||
@ -1480,7 +1480,7 @@ class DropInnerMostUnitDimsTransferRead
|
|||||||
ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
|
ArrayAttr inBoundsAttr = rewriter.getArrayAttr(
|
||||||
readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
|
readOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
|
||||||
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
|
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
|
||||||
loc, resultMemrefType, readOp.getSource(), offsets, sizes, strides);
|
loc, resultMemrefType, readOp.getBase(), offsets, sizes, strides);
|
||||||
auto permMap = getTransferMinorIdentityMap(
|
auto permMap = getTransferMinorIdentityMap(
|
||||||
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
|
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
|
||||||
Value result = rewriter.create<vector::TransferReadOp>(
|
Value result = rewriter.create<vector::TransferReadOp>(
|
||||||
@ -1527,7 +1527,7 @@ class DropInnerMostUnitDimsTransferWrite
|
|||||||
if (writeOp.getMask())
|
if (writeOp.getMask())
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
|
auto srcType = dyn_cast<MemRefType>(writeOp.getBase().getType());
|
||||||
if (!srcType)
|
if (!srcType)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
@ -1559,7 +1559,7 @@ class DropInnerMostUnitDimsTransferWrite
|
|||||||
|
|
||||||
Location loc = writeOp.getLoc();
|
Location loc = writeOp.getLoc();
|
||||||
SmallVector<OpFoldResult> sizes =
|
SmallVector<OpFoldResult> sizes =
|
||||||
memref::getMixedSizes(rewriter, loc, writeOp.getSource());
|
memref::getMixedSizes(rewriter, loc, writeOp.getBase());
|
||||||
SmallVector<OpFoldResult> offsets(srcType.getRank(),
|
SmallVector<OpFoldResult> offsets(srcType.getRank(),
|
||||||
rewriter.getIndexAttr(0));
|
rewriter.getIndexAttr(0));
|
||||||
SmallVector<OpFoldResult> strides(srcType.getRank(),
|
SmallVector<OpFoldResult> strides(srcType.getRank(),
|
||||||
@ -1571,7 +1571,7 @@ class DropInnerMostUnitDimsTransferWrite
|
|||||||
writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
|
writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop));
|
||||||
|
|
||||||
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
|
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
|
||||||
loc, resultMemrefType, writeOp.getSource(), offsets, sizes, strides);
|
loc, resultMemrefType, writeOp.getBase(), offsets, sizes, strides);
|
||||||
auto permMap = getTransferMinorIdentityMap(
|
auto permMap = getTransferMinorIdentityMap(
|
||||||
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
|
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);
|
||||||
|
|
||||||
|
@ -164,7 +164,7 @@ struct UnrollTransferReadPattern
|
|||||||
sliceTransferIndices(elementOffsets, originalIndices,
|
sliceTransferIndices(elementOffsets, originalIndices,
|
||||||
readOp.getPermutationMap(), loc, rewriter);
|
readOp.getPermutationMap(), loc, rewriter);
|
||||||
auto slicedRead = rewriter.create<vector::TransferReadOp>(
|
auto slicedRead = rewriter.create<vector::TransferReadOp>(
|
||||||
loc, targetType, readOp.getSource(), indices,
|
loc, targetType, readOp.getBase(), indices,
|
||||||
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
|
readOp.getPermutationMapAttr(), readOp.getPadding(), readOp.getMask(),
|
||||||
readOp.getInBoundsAttr());
|
readOp.getInBoundsAttr());
|
||||||
|
|
||||||
@ -215,7 +215,7 @@ struct UnrollTransferWritePattern
|
|||||||
sliceTransferIndices(elementOffsets, originalIndices,
|
sliceTransferIndices(elementOffsets, originalIndices,
|
||||||
writeOp.getPermutationMap(), loc, rewriter);
|
writeOp.getPermutationMap(), loc, rewriter);
|
||||||
Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
|
Operation *slicedWrite = rewriter.create<vector::TransferWriteOp>(
|
||||||
loc, slicedVector, resultTensor ? resultTensor : writeOp.getSource(),
|
loc, slicedVector, resultTensor ? resultTensor : writeOp.getBase(),
|
||||||
indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
|
indices, writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
|
||||||
// For the tensor case update the destination for the next transfer write.
|
// For the tensor case update the destination for the next transfer write.
|
||||||
if (!slicedWrite->getResults().empty())
|
if (!slicedWrite->getResults().empty())
|
||||||
|
@ -312,7 +312,7 @@ SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
|
|||||||
|
|
||||||
Value base = TypeSwitch<Operation *, Value>(xfer)
|
Value base = TypeSwitch<Operation *, Value>(xfer)
|
||||||
.Case<vector::TransferReadOp>(
|
.Case<vector::TransferReadOp>(
|
||||||
[&](auto readOp) { return readOp.getSource(); })
|
[&](auto readOp) { return readOp.getBase(); })
|
||||||
.Case<vector::TransferWriteOp>(
|
.Case<vector::TransferWriteOp>(
|
||||||
[&](auto writeOp) { return writeOp.getOperand(1); });
|
[&](auto writeOp) { return writeOp.getOperand(1); });
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user