[mlir][vector] Folder: shape_cast(extract) -> extract (#146368)

In a later PR more shape_cast ops will appear. Specifically, broadcasts that 
just prepend ones become shape_cast ops (i.e. volume preserving broadcasts 
are canonicalized to shape_casts). This PR ensures that broadcast-like 
shape_cast ops fold at least as well as broadcast ops.

This is done by modifying patterns that target broadcast ops, to target
'broadcast-like' ops. No new patterns are added, the patterns that exist
are just made to match on shape_casts where appropriate.

This PR also includes minor code simplifications: use
`isBroadcastableTo` to simplify `ExtractOpFromBroadcast` and simplify
how broadcast dims are detected in `foldExtractFromBroadcast`. These are
NFC.

---------

Co-authored-by: Andrzej Warzyński <andrzej.warzynski@gmail.com>
This commit is contained in:
James Newling 2025-07-21 14:12:50 -04:00 committed by GitHub
parent 881b3fdfad
commit abce4e9ad0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 146 additions and 69 deletions

View File

@ -1707,59 +1707,99 @@ static bool hasZeroDimVectors(Operation *op) {
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
}
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
/// 1s, are considered to be 'broadcastlike'.
static bool isBroadcastLike(Operation *op) {
if (isa<BroadcastOp, SplatOp>(op))
return true;
auto shapeCast = dyn_cast<ShapeCastOp>(op);
if (!shapeCast)
return false;
// Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
// Checking that the destination shape has a prefix of 1s is not sufficient,
// for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
// is that the source shape is a suffix of the destination shape.
VectorType srcType = shapeCast.getSourceVectorType();
ArrayRef<int64_t> srcShape = srcType.getShape();
uint64_t srcRank = srcType.getRank();
ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
}
/// Fold extract(broadcast(X)) to either extract(X) or just X.
///
/// Example:
///
/// broadcast extract [1][2]
/// (3, 4) --------> (2, 3, 4) ----------------> (4)
///
/// becomes
/// extract [1]
/// (3,4) -------------------------------------> (4)
///
///
/// The variable names used in this implementation correspond to the above
/// shapes as,
///
/// - (3, 4) is `input` shape.
/// - (2, 3, 4) is `broadcast` shape.
/// - (4) is `extract` shape.
///
/// This folding is possible when the suffix of `input` shape is the same as
/// `extract` shape.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
Operation *defOp = extractOp.getVector().getDefiningOp();
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
if (!defOp || !isBroadcastLike(defOp))
return Value();
Value source = defOp->getOperand(0);
if (extractOp.getType() == source.getType())
return source;
auto getRank = [](Type type) {
return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
: 0;
};
Value input = defOp->getOperand(0);
// If splat or broadcast from a scalar, just return the source scalar.
unsigned broadcastSrcRank = getRank(source.getType());
if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
return source;
// Replace extract(broadcast(X)) with X
if (extractOp.getType() == input.getType())
return input;
unsigned extractResultRank = getRank(extractOp.getType());
if (extractResultRank > broadcastSrcRank)
return Value();
// Check that the dimension of the result haven't been broadcasted.
auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
if (extractVecType && broadcastVecType &&
extractVecType.getShape() !=
broadcastVecType.getShape().take_back(extractResultRank))
// Get required types and ranks in the chain
// input -> broadcast -> extract
// (scalars are treated as rank-0).
auto inputType = llvm::dyn_cast<VectorType>(input.getType());
auto extractType = llvm::dyn_cast<VectorType>(extractOp.getType());
unsigned inputRank = inputType ? inputType.getRank() : 0;
unsigned broadcastRank = extractOp.getSourceVectorType().getRank();
unsigned extractRank = extractType ? extractType.getRank() : 0;
// Cannot do without the broadcast if overall the rank increases.
if (extractRank > inputRank)
return Value();
auto broadcastOp = cast<vector::BroadcastOp>(defOp);
int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
// The above condition guarantees that input is a vector.
assert(inputType && "input must be a vector type because of previous checks");
ArrayRef<int64_t> inputShape = inputType.getShape();
// Detect all the positions that come from "dim-1" broadcasting.
// These dimensions correspond to "dim-1" broadcasted dims; set the mathching
// extract position to `0` when extracting from the source operand.
llvm::SetVector<int64_t> broadcastedUnitDims =
broadcastOp.computeBroadcastedUnitDims();
SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
OpBuilder b(extractOp.getContext());
int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i)
if (broadcastedUnitDims.contains(i))
extractPos[i] = b.getIndexAttr(0);
// `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
// matching extract position when extracting from the source operand.
int64_t rankDiff = broadcastSrcRank - extractResultRank;
extractPos.erase(extractPos.begin(),
std::next(extractPos.begin(), extractPos.size() - rankDiff));
// OpBuilder is only used as a helper to build an I64ArrayAttr.
auto [staticPos, dynPos] = decomposeMixedValues(extractPos);
// In the case where there is a broadcast dimension in the suffix, it is not
// possible to replace extract(broadcast(X)) with extract(X). Example:
//
// broadcast extract
// (1) --------> (3,4) ------> (4)
if (extractType &&
extractType.getShape() != inputShape.take_back(extractRank))
return Value();
// Replace extract(broadcast(X)) with extract(X).
// First, determine the new extraction position.
unsigned deltaOverall = inputRank - extractRank;
unsigned deltaBroadcast = broadcastRank - inputRank;
SmallVector<OpFoldResult> oldPositions = extractOp.getMixedPosition();
SmallVector<OpFoldResult> newPositions(deltaOverall);
IntegerAttr zero = OpBuilder(extractOp.getContext()).getIndexAttr(0);
for (auto [i, size] : llvm::enumerate(inputShape.take_front(deltaOverall))) {
newPositions[i] = size == 1 ? zero : oldPositions[i + deltaBroadcast];
}
auto [staticPos, dynPos] = decomposeMixedValues(newPositions);
extractOp->setOperands(
llvm::to_vector(llvm::concat<Value>(ValueRange(source), dynPos)));
llvm::to_vector(llvm::concat<Value>(ValueRange(input), dynPos)));
extractOp.setStaticPosition(staticPos);
return extractOp.getResult();
}
@ -2204,32 +2244,18 @@ public:
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
Operation *defOp = extractOp.getVector().getDefiningOp();
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
if (!defOp || !isBroadcastLike(defOp) || !outType)
return failure();
Value source = defOp->getOperand(0);
if (extractOp.getType() == source.getType())
return failure();
auto getRank = [](Type type) {
return llvm::isa<VectorType>(type)
? llvm::cast<VectorType>(type).getRank()
: 0;
};
unsigned broadcastSrcRank = getRank(source.getType());
unsigned extractResultRank = getRank(extractOp.getType());
// We only consider the case where the rank of the source is less than or
// equal to the rank of the extract dst. The other cases are handled in the
// folding patterns.
if (extractResultRank < broadcastSrcRank)
return failure();
// For scalar result, the input can only be a rank-0 vector, which will
// be handled by the folder.
if (extractResultRank == 0)
if (isBroadcastableTo(source.getType(), outType) !=
BroadcastableToResult::Success)
return failure();
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
extractOp, extractOp.getType(), source);
rewriter.replaceOpWithNewOp<BroadcastOp>(extractOp, outType, source);
return success();
}
};

View File

@ -558,10 +558,9 @@ func.func @vector_print_vector_0d(%arg0: vector<f32>) {
// CHECK-SAME: %[[VEC:.*]]: vector<f32>) {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[FLAT_VEC:.*]] = vector.shape_cast %[[VEC]] : vector<f32> to vector<1xf32>
// CHECK: vector.print punctuation <open>
// CHECK: scf.for %[[IDX:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
// CHECK: %[[EL:.*]] = vector.extract %[[FLAT_VEC]][%[[IDX]]] : f32 from vector<1xf32>
// CHECK: %[[EL:.*]] = vector.extract %[[VEC]][] : f32 from vector<f32>
// CHECK: vector.print %[[EL]] : f32 punctuation <no_punctuation>
// CHECK: %[[IS_NOT_LAST:.*]] = arith.cmpi ult, %[[IDX]], %[[C0]] : index
// CHECK: scf.if %[[IS_NOT_LAST]] {

View File

@ -823,10 +823,10 @@ func.func @negative_fold_extract_broadcast(%a : vector<1x1xf32>) -> vector<4xf32
// -----
// CHECK-LABEL: fold_extract_splat
// CHECK-LABEL: fold_extract_scalar_from_splat
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
func.func @fold_extract_scalar_from_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
%b = vector.splat %a : vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
return %r : f32
@ -834,6 +834,16 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
// -----
// CHECK-LABEL: fold_extract_vector_from_splat
// CHECK: vector.broadcast {{.*}} f32 to vector<4xf32>
func.func @fold_extract_vector_from_splat(%a : f32, %idx0 : index, %idx1 : index) -> vector<4xf32> {
%b = vector.splat %a : vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
}
// -----
// CHECK-LABEL: fold_extract_broadcast_dim1_broadcasting
// CHECK-SAME: %[[A:.*]]: vector<2x1xf32>
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
@ -863,6 +873,35 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
// -----
// Test where the shape_cast is broadcast-like.
// CHECK-LABEL: fold_extract_shape_cast_to_lower_rank
// CHECK-SAME: %[[A:.*]]: vector<2x4xf32>
// CHECK-SAME: %[[IDX0:.*]]: index, %[[IDX1:.*]]: index
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
// CHECK: return %[[B]] : vector<4xf32>
func.func @fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
%idx0 : index, %idx1 : index) -> vector<4xf32> {
%b = vector.shape_cast %a : vector<2x4xf32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
return %r : vector<4xf32>
}
// -----
// Test where the shape_cast is not broadcast-like, even though it prepends 1s.
// CHECK-LABEL: negative_fold_extract_shape_cast_to_lower_rank
// CHECK-NEXT: vector.shape_cast
// CHECK-NEXT: vector.extract
// CHECK-NEXT: return
func.func @negative_fold_extract_shape_cast_to_lower_rank(%a : vector<2x4xf32>,
%idx0 : index, %idx1 : index) -> vector<2xf32> {
%b = vector.shape_cast %a : vector<2x4xf32> to vector<1x4x2xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<2xf32> from vector<1x4x2xf32>
return %r : vector<2xf32>
}
// -----
// CHECK-LABEL: fold_extract_broadcast_to_higher_rank
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
// CHECK: return %[[B]] : vector<4xf32>
@ -890,6 +929,19 @@ func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : inde
// -----
// CHECK-LABEL: fold_extract_broadcastlike_shape_cast
// CHECK-SAME: %[[A:.*]]: vector<1xf32>
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<1x1xf32>
// CHECK: return %[[R]] : vector<1x1xf32>
func.func @fold_extract_broadcastlike_shape_cast(%a : vector<1xf32>, %idx0 : index)
-> vector<1x1xf32> {
%s = vector.shape_cast %a : vector<1xf32> to vector<1x1x1xf32>
%r = vector.extract %s[%idx0] : vector<1x1xf32> from vector<1x1x1xf32>
return %r : vector<1x1xf32>
}
// -----
// CHECK-LABEL: @fold_extract_shuffle
// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
// CHECK-NOT: vector.shuffle