[mlir][vector] Folder: shape_cast(extract) -> extract (#146368)
In a later PR more shape_cast ops will appear. Specifically, broadcasts that just prepend ones become shape_cast ops (i.e. volume preserving broadcasts are canonicalized to shape_casts). This PR ensures that broadcast-like shape_cast ops fold at least as well as broadcast ops. This is done by modifying patterns that target broadcast ops, to target 'broadcast-like' ops. No new patterns are added, the patterns that exist are just made to match on shape_casts where appropriate. This PR also includes minor code simplifications: use `isBroadcastableTo` to simplify `ExtractOpFromBroadcast` and simplify how broadcast dims are detected in `foldExtractFromBroadcast`. These are NFC. --------- Co-authored-by: Andrzej Warzyński <andrzej.warzynski@gmail.com>
This commit is contained in:
parent
881b3fdfad
commit
abce4e9ad0
@ -1707,59 +1707,99 @@ static bool hasZeroDimVectors(Operation *op) {
|
||||
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
|
||||
}
|
||||
|
||||
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
|
||||
/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
|
||||
/// 1s, are considered to be 'broadcastlike'.
|
||||
static bool isBroadcastLike(Operation *op) {
|
||||
if (isa<BroadcastOp, SplatOp>(op))
|
||||
return true;
|
||||
|
||||
auto shapeCast = dyn_cast<ShapeCastOp>(op);
|
||||
if (!shapeCast)
|
||||
return false;
|
||||
|
||||
// Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
|
||||
// Checking that the destination shape has a prefix of 1s is not sufficient,
|
||||
// for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
|
||||
// is that the source shape is a suffix of the destination shape.
|
||||
VectorType srcType = shapeCast.getSourceVectorType();
|
||||
ArrayRef<int64_t> srcShape = srcType.getShape();
|
||||
uint64_t srcRank = srcType.getRank();
|
||||
ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
|
||||
return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
|
||||
}
|
||||
|
||||
/// Fold extract(broadcast(X)) to either extract(X) or just X.
|
||||
///
|
||||
/// Example:
|
||||
///
|
||||
/// broadcast extract [1][2]
|
||||
/// (3, 4) --------> (2, 3, 4) ----------------> (4)
|
||||
///
|
||||
/// becomes
|
||||
/// extract [1]
|
||||
/// (3,4) -------------------------------------> (4)
|
||||
///
|
||||
///
|
||||
/// The variable names used in this implementation correspond to the above
|
||||
/// shapes as,
|
||||
///
|
||||
/// - (3, 4) is `input` shape.
|
||||
/// - (2, 3, 4) is `broadcast` shape.
|
||||
/// - (4) is `extract` shape.
|
||||
///
|
||||
/// This folding is possible when the suffix of `input` shape is the same as
|
||||
/// `extract` shape.
|
||||
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
|
||||
|
||||
Operation *defOp = extractOp.getVector().getDefiningOp();
|
||||
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
|
||||
if (!defOp || !isBroadcastLike(defOp))
|
||||
return Value();
|
||||
|
||||
Value source = defOp->getOperand(0);
|
||||
if (extractOp.getType() == source.getType())
|
||||
return source;
|
||||
auto getRank = [](Type type) {
|
||||
return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
|
||||
: 0;
|
||||
};
|
||||
Value input = defOp->getOperand(0);
|
||||
|
||||
// If splat or broadcast from a scalar, just return the source scalar.
|
||||
unsigned broadcastSrcRank = getRank(source.getType());
|
||||
if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
|
||||
return source;
|
||||
// Replace extract(broadcast(X)) with X
|
||||
if (extractOp.getType() == input.getType())
|
||||
return input;
|
||||
|
||||
unsigned extractResultRank = getRank(extractOp.getType());
|
||||
if (extractResultRank > broadcastSrcRank)
|
||||
return Value();
|
||||
// Check that the dimension of the result haven't been broadcasted.
|
||||
auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
|
||||
auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
|
||||
if (extractVecType && broadcastVecType &&
|
||||
extractVecType.getShape() !=
|
||||
broadcastVecType.getShape().take_back(extractResultRank))
|
||||
// Get required types and ranks in the chain
|
||||
// input -> broadcast -> extract
|
||||
// (scalars are treated as rank-0).
|
||||
auto inputType = llvm::dyn_cast<VectorType>(input.getType());
|
||||
auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
|
||||
unsigned inputRank = inputType ? inputType.getRank() : 0;
|
||||
unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
|
||||
unsigned extractRank = extractType ? extractType.getRank() : 0;
|
||||
|
||||
// Cannot do without the broadcast if overall the rank increases.
|
||||
if (extractRank > inputRank)
|
||||
return Value();
|
||||
|
||||
auto broadcastOp = cast<vector::BroadcastOp>(defOp);
|
||||
int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
|
||||
// The above condition guarantees that input is a vector.
|
||||
assert(inputType && "input must be a vector type because of previous checks");
|
||||
ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
|
||||
// Detect all the positions that come from "dim-1" broadcasting.
|
||||
// These dimensions correspond to "dim-1" broadcasted dims; set the mathching
|
||||
// extract position to `0` when extracting from the source operand.
|
||||
llvm::SetVector<int64_t> broadcastedUnitDims =
|
||||
broadcastOp.computeBroadcastedUnitDims();
|
||||
SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
|
||||
OpBuilder b(extractOp.getContext());
|
||||
int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
|
||||
for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
|
||||
if (broadcastedUnitDims.contains(i))
|
||||
extractPos[i] = b.getIndexAttr(0);
|
||||
// `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
|
||||
// matching extract position when extracting from the source operand.
|
||||
int64_t rankDiff = broadcastSrcRank - extractResultRank;
|
||||
extractPos.erase(extractPos.begin(),
|
||||
std::next(extractPos.begin(), extractPos.size() - rankDiff));
|
||||
// OpBuilder is only used as a helper to build an I64ArrayAttr.
|
||||
auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
|
||||
// In the case where there is a broadcast dimension in the suffix, it is not
|
||||
// possible to replace extract(broadcast(X)) with extract(X). Example:
|
||||
//
|
||||
// broadcast extract
|
||||
// (1) --------> (3,4) ------> (4)
|
||||
if (extractType &&
|
||||
extractType.getShape() != inputShape.take_back(extractRank))
|
||||
return Value();
|
||||
|
||||
// Replace extract(broadcast(X)) with extract(X).
|
||||
// First, determine the new extraction position.
|
||||
unsigned deltaOverall = inputRank - extractRank;
|
||||
unsigned deltaBroadcast = broadcastRank - inputRank;
|
||||
SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
|
||||
SmallVector<OpFoldResult> newPositions(deltaOverall);
|
||||
IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
|
||||
for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
|
||||
newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
|
||||
}
|
||||
auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
|
||||
extractOp->setOperands(
|
||||
llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
|
||||
llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
|
||||
extractOp.setStaticPosition(staticPos);
|
||||
return extractOp.getResult();
|
||||
}
|
||||
@ -2204,32 +2244,18 @@ public:
|
||||
|
||||
LogicalResult matchAndRewrite(ExtractOp extractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
Operation *defOp = extractOp.getVector().getDefiningOp();
|
||||
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
|
||||
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
|
||||
if (!defOp || !isBroadcastLike(defOp) || !outType)
|
||||
return failure();
|
||||
|
||||
Value source = defOp->getOperand(0);
|
||||
if (extractOp.getType() == source.getType())
|
||||
return failure();
|
||||
auto getRank = [](Type type) {
|
||||
return llvm::isa<VectorType>(type)
|
||||
? llvm::cast<VectorType>(type).getRank()
|
||||
: 0;
|
||||
};
|
||||
unsigned broadcastSrcRank = getRank(source.getType());
|
||||
unsigned extractResultRank = getRank(extractOp.getType());
|
||||
// We only consider the case where the rank of the source is less than or
|
||||
// equal to the rank of the extract dst. The other cases are handled in the
|
||||
// folding patterns.
|
||||
if (extractResultRank < broadcastSrcRank)
|
||||
return failure();
|
||||
// For scalar result, the input can only be a rank-0 vector, which will
|
||||
// be handled by the folder.
|
||||
if (extractResultRank == 0)
|
||||
if (isBroadcastableTo(source.getType(), outType) !=
|
||||
BroadcastableToResult::Success)
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
|
||||
extractOp, extractOp.getType(), source);
|
||||
rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -558,10 +558,9 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
|
||||
// CHECK-SAME: %[[VEC:.*]]: vector<f32>) {
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
|
||||
// CHECK: vector.print punctuation <open>
|
||||
// CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
|
||||
// CHECK: %[[EL:.*]] = vector.extract %[[FLAT_VEC]][%[[IDX]]] : f32 from vector<1xf32>
|
||||
// CHECK: %[[EL:.*]] = vector.extract %[[VEC]][] : f32 from vector<f32>
|
||||
// CHECK: vector.print %[[EL]] : f32 punctuation <no_punctuation>
|
||||
// CHECK: %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[C0]] : index
|
||||
// CHECK: scf.if %[[IS_NOT_LAST]] {
|
||||
|
@ -823,10 +823,10 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: fold_extract_splat
|
||||
// CHECK-LABEL: fold_extract_scalar_from_splat
|
||||
// CHECK-SAME: %[[A:.*]]: f32
|
||||
// CHECK: return %[[A]] : f32
|
||||
func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
|
||||
func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
|
||||
%b = vector.splat %a : vector<1x2x4xf32>
|
||||
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
|
||||
return %r : f32
|
||||
@ -834,6 +834,16 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: fold_extract_vector_from_splat
|
||||
// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
|
||||
func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
|
||||
%b = vector.splat %a : vector<1x2x4xf32>
|
||||
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
|
||||
return %r : vector<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
|
||||
// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
|
||||
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
|
||||
@ -863,6 +873,35 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
|
||||
|
||||
// -----
|
||||
|
||||
// Test where the shape_cast is broadcast-like.
|
||||
// CHECK-LABEL: fold_extract_shape_cast_to_lower_rank
|
||||
// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
|
||||
// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
|
||||
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
|
||||
// CHECK: return %[[B]] : vector<4xf32>
|
||||
func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
|
||||
%idx0 : index, %idx1 : index) -> vector<4xf32> {
|
||||
%b = vector.shape_cast %a : vector<2x4xf32> to vector<1x2x4xf32>
|
||||
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
|
||||
return %r : vector<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test where the shape_cast is not broadcast-like, even though it prepends 1s.
|
||||
// CHECK-LABEL: negative_fold_extract_shape_cast_to_lower_rank
|
||||
// CHECK-NEXT: vector.shape_cast
|
||||
// CHECK-NEXT: vector.extract
|
||||
// CHECK-NEXT: return
|
||||
func.func @negative_fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
|
||||
%idx0 : index, %idx1 : index) -> vector<2xf32> {
|
||||
%b = vector.shape_cast %a : vector<2x4xf32> to vector<1x4x2xf32>
|
||||
%r = vector.extract %b[%idx0, %idx1] : vector<2xf32> from vector<1x4x2xf32>
|
||||
return %r : vector<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
|
||||
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
|
||||
// CHECK: return %[[B]] : vector<4xf32>
|
||||
@ -890,6 +929,19 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
|
||||
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
|
||||
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
|
||||
// CHECK: return %[[R]] : vector<1x1xf32>
|
||||
func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
|
||||
-> vector<1x1xf32> {
|
||||
%s = vector.shape_cast %a : vector<1xf32> to vector<1x1x1xf32>
|
||||
%r = vector.extract %s[%idx0] : vector<1x1xf32> from vector<1x1x1xf32>
|
||||
return %r : vector<1x1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @fold_extract_shuffle
|
||||
// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
|
||||
// CHECK-NOT: vector.shuffle
|
||||
@ -1623,7 +1675,7 @@ func.func @negative_store_to_load_tensor_memref(
|
||||
%arg0 : tensor<?x?xf32>,
|
||||
%arg1 : memref<?x?xf32>,
|
||||
%v0 : vector<4x2xf32>
|
||||
) -> vector<4x2xf32>
|
||||
) -> vector<4x2xf32>
|
||||
{
|
||||
%c0 = arith.constant 0 : index
|
||||
%cf0 = arith.constant 0.0 : f32
|
||||
@ -1680,7 +1732,7 @@ func.func @negative_store_to_load_tensor_broadcast_out_of_bounds(%arg0 : tensor<
|
||||
// CHECK: vector.transfer_read
|
||||
func.func @negative_store_to_load_tensor_broadcast_masked(
|
||||
%arg0 : tensor<?x?xf32>, %v0 : vector<4x2xf32>, %mask : vector<4x2xi1>)
|
||||
-> vector<4x2x6xf32>
|
||||
-> vector<4x2x6xf32>
|
||||
{
|
||||
%c0 = arith.constant 0 : index
|
||||
%cf0 = arith.constant 0.0 : f32
|
||||
|
Loading…
x
Reference in New Issue
Block a user