[mlir][Vector] Move vector.mask canonicalization to folder (#140324)

This MR moves the canonicalization that elides empty `vector.mask` ops
to folders.
This commit is contained in:
Diego Caballero 2025-05-21 17:25:01 -07:00 committed by GitHub
parent 12c62ebcb2
commit d6f394e141
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 36 additions and 39 deletions

View File

@ -2559,7 +2559,6 @@ def Vector_MaskOp : Vector_Op<"mask", [
Location loc); Location loc);
}]; }];
let hasCanonicalizer = 1;
let hasFolder = 1; let hasFolder = 1;
let hasCustomAssemblyFormat = 1; let hasCustomAssemblyFormat = 1;
let hasVerifier = 1; let hasVerifier = 1;

View File

@ -6650,13 +6650,40 @@ LogicalResult MaskOp::verify() {
return success(); return success();
} }
/// Folds vector.mask ops with an all-true mask. /// Folds empty `vector.mask` with no passthru operand and with or without
LogicalResult MaskOp::fold(FoldAdaptor adaptor, /// return values. For example:
SmallVectorImpl<OpFoldResult> &results) { ///
MaskFormat maskFormat = getMaskFormat(getMask()); /// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
if (isEmpty()) /// vector<8xi1> -> vector<8xf32>
/// %1 = user_op %0 : vector<8xf32>
///
/// becomes:
///
/// %0 = user_op %a : vector<8xf32>
///
static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
if (!maskOp.isEmpty() || maskOp.hasPassthru())
return failure(); return failure();
Block *block = maskOp.getMaskBlock();
auto terminator = cast<vector::YieldOp>(block->front());
if (terminator.getNumOperands() == 0) {
// `vector.mask` has no results, just remove the `vector.mask`.
return success();
}
// `vector.mask` has results, propagate the results.
llvm::append_range(results, terminator.getOperands());
return success();
}
LogicalResult MaskOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
return success();
MaskFormat maskFormat = getMaskFormat(getMask());
if (maskFormat != MaskFormat::AllTrue) if (maskFormat != MaskFormat::AllTrue)
return failure(); return failure();
@ -6669,37 +6696,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
return success(); return success();
} }
// Elides empty vector.mask operations with or without return values. Propagates
// the yielded values by the vector.yield terminator, if any, or erases the op,
// otherwise.
class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(MaskOp maskOp,
PatternRewriter &rewriter) const override {
auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
if (maskingOp.getMaskableOp())
return failure();
if (!maskOp.isEmpty())
return failure();
Block *block = maskOp.getMaskBlock();
auto terminator = cast<vector::YieldOp>(block->front());
if (terminator.getNumOperands() == 0)
rewriter.eraseOp(maskOp);
else
rewriter.replaceOp(maskOp, terminator.getOperands());
return success();
}
};
void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ElideEmptyMaskOp>(context);
}
// MaskingOpInterface definitions. // MaskingOpInterface definitions.
/// Returns the operation masked by this 'vector.mask'. /// Returns the operation masked by this 'vector.mask'.

View File

@ -1,10 +1,12 @@
// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s // RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
module { module {
// CHECK-LABEL: func @func
// CHECK-SAME: %[[IN:.*]]: vector<11xf32>
func.func @func(%arg: vector<11xf32>) -> vector<11xf32> { func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
%cst_41 = arith.constant dense<true> : vector<11xi1> %cst_41 = arith.constant dense<true> : vector<11xi1>
// CHECK: vector.mask // CHECK-NOT: vector.mask
// CHECK-SAME: vector.yield %arg0 // CHECK: return %[[IN]] : vector<11xf32>
%127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32> %127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
return %127 : vector<11xf32> return %127 : vector<11xf32>
} }