[mlir][IR] Deprecate match and rewrite functions (#130031)

Deprecate the `match` and `rewrite` functions. They mainly exist for
historic reasons. This PR also updates all remaining uses of in the MLIR
codebase.

This is addressing a
[comment](https://github.com/llvm/llvm-project/pull/129861#pullrequestreview-2662696084)
on an earlier PR.

Note for LLVM integration: `SplitMatchAndRewrite` will be deleted soon,
update your patterns to use `matchAndRewrite` instead of separate
`match` / `rewrite`.

---------

Co-authored-by: Jakub Kuderski <jakub@nod-labs.com>
This commit is contained in:
Matthias Springer 2025-03-07 08:43:01 +01:00 committed by GitHub
parent 6b094020c2
commit a21cfca320
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 158 additions and 171 deletions

View File

@ -179,13 +179,13 @@ updated/remapped operands of an operation, such as when the types of results
defined by an operation have changed. The general Rewrite Patterns can no longer
be used in these situations, as the types of the operands of the operation being
matched will not correspond with those expected by the user. This pattern
provides, as an additional argument to the `matchAndRewrite` and `rewrite`
methods, the list of operands that the operation should use after conversion. If
an operand was the result of a non-converted operation, for example if it was
already legal, the original operand is used. This means that the operands
provided always have a 1-1 non-null correspondence with the operands on the
operation. The original operands of the operation are still intact and may be
inspected as normal. These patterns also utilize a special `PatternRewriter`,
provides, as an additional argument to the `matchAndRewrite` method, the list
of operands that the operation should use after conversion. If an operand was
the result of a non-converted operation, for example if it was already legal,
the original operand is used. This means that the operands provided always have
a 1-1 non-null correspondence with the operands on the operation. The original
operands of the operation are still intact and may be inspected as normal.
These patterns also utilize a special `PatternRewriter`,
`ConversionPatternRewriter`, that provides special hooks for use with the
conversion infrastructure.

View File

@ -48,13 +48,9 @@ operation type, a special tag must be provided to make the intent explicit:
### `matchAndRewrite` implementation
This is the chunk of code that matches a given root `Operation` and performs a
rewrite of the IR. A `RewritePattern` can specify this implementation either via
the `matchAndRewrite` method or via separate `match` and `rewrite` methods when
deriving from `RewritePattern::SplitMatchAndRewrite`. When using the combined
`matchAndRewrite` method, no IR mutation should take place before the match is
deemed successful. The combined `matchAndRewrite` is useful when non-trivially
recomputable information is required by the matching and rewriting phase. See
below for examples:
rewrite of the IR. A `RewritePattern` can specify this implementation via the
`matchAndRewrite` method. No IR mutation should take place before the match is
deemed successful. See below for examples:
```c++
class MyPattern : public RewritePattern {
@ -67,21 +63,6 @@ public:
MyPattern(PatternBenefit benefit)
: RewritePattern(benefit, MatchAnyOpTypeTag()) {}
/// In this section, the `match` and `rewrite` implementation is specified
/// using the separate hooks.
LogicalResult match(Operation *op) const override {
// The `match` method returns `success()` if the pattern is a match, failure
// otherwise.
// ...
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
// The `rewrite` method performs mutations on the IR rooted at `op` using
// the provided rewriter. All mutations must go through the provided
// rewriter.
}
/// In this section, the `match` and `rewrite` implementation is specified
/// using a single hook.
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
// The `matchAndRewrite` method performs both the matching and the mutation.
// Note that the match must reach a successful point before IR mutation may
@ -92,12 +73,6 @@ public:
#### Restrictions
Within the `match` section of a pattern, the following constraints apply:
* No mutation of the IR is allowed.
Within the `rewrite` section of a pattern, the following constraints apply:
* All IR mutations, including creation, *must* be performed by the given
`PatternRewriter`. This class provides hooks for performing all of the
possible mutations that may take place within a pattern. For example, this
@ -107,8 +82,6 @@ Within the `rewrite` section of a pattern, the following constraints apply:
* The root operation is required to either be: updated in-place, replaced, or
erased.
* `matchAndRewrite` must return "success" if and only if the IR was modified.
`match` must return "success" if and only if the IR is going to be modified
during `rewrite`.
### Application Recursion

View File

@ -216,25 +216,6 @@ In case ODS patterns and `matchAndRewrite`-style functions are not sufficient
you can also specify rewrites as a general set of `RewritePattern`s:
```c++
/// Multi-step rewrite using "match" and "rewrite". This allows for separating
/// the concerns of matching and rewriting.
struct ConvertTFLeakyRelu : public RewritePattern {
ConvertTFLeakyRelu(MLIRContext *context)
: RewritePattern("tf.LeakyRelu", 1, context) {}
LogicalResult match(Operation *op) const override {
return success();
}
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
op, op->getResult(0).getType(), op->getOperand(0),
/*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
}
};
/// Single-step rewrite with "matchAndRewrite". This allows for performing the
/// rewrite immediately upon a successful match.
struct ConvertTFLeakyRelu : public RewritePattern {
ConvertTFLeakyRelu(MLIRContext *context)
: RewritePattern("tf.LeakyRelu", 1, context) {}

View File

@ -40,6 +40,8 @@ LogicalResult oneToOneRewrite(
/// during the entire pattern lifetime.
class ConvertToLLVMPattern : public ConversionPattern {
public:
/// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
/// separate `match` and `rewrite`.
using SplitMatchAndRewrite =
detail::ConversionSplitMatchAndRewriteImpl<ConvertToLLVMPattern>;
@ -149,6 +151,9 @@ public:
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
/// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
/// separate `match` and `rewrite`.
using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl<
ConvertOpToLLVMPattern<SourceOp>>;

View File

@ -237,6 +237,9 @@ private:
namespace detail {
/// Helper class that derives from a RewritePattern class and provides separate
/// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
///
/// This class is deprecated. Use `matchAndRewrite` instead of separate `match`
/// and `rewrite`.
template <typename PatternT>
class SplitMatchAndRewriteImpl : public PatternT {
using PatternT::PatternT;
@ -268,6 +271,9 @@ class SplitMatchAndRewriteImpl : public PatternT {
class RewritePattern : public Pattern {
public:
using OperationT = Operation *;
/// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
/// separate `match` and `rewrite`.
using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;
virtual ~RewritePattern() = default;
@ -350,6 +356,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
template <typename SourceOp>
struct OpRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
/// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
/// separate `match` and `rewrite`.
using SplitMatchAndRewrite =
detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;
@ -368,6 +377,9 @@ struct OpRewritePattern
template <typename SourceOp>
struct OpInterfaceRewritePattern
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
/// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
/// separate `match` and `rewrite`.
using SplitMatchAndRewrite =
detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;

View File

@ -598,6 +598,9 @@ public:
using OperationT = Operation *;
using OpAdaptor = ArrayRef<Value>;
using OneToNOpAdaptor = ArrayRef<ValueRange>;
/// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
/// separate `match` and `rewrite`.
using SplitMatchAndRewrite =
detail::ConversionSplitMatchAndRewriteImpl<ConversionPattern>;
@ -669,6 +672,9 @@ public:
using OpAdaptor = typename SourceOp::Adaptor;
using OneToNOpAdaptor =
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
/// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
/// separate `match` and `rewrite`.
using SplitMatchAndRewrite =
detail::ConversionSplitMatchAndRewriteImpl<OpConversionPattern<SourceOp>>;

View File

@ -41,48 +41,46 @@ struct ArithToAMDGPUConversionPass final
void runOnOperation() override;
};
struct ExtFOnFloat8RewritePattern final
: OpRewritePattern<arith::ExtFOp>::SplitMatchAndRewrite {
using SplitMatchAndRewrite::SplitMatchAndRewrite;
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;
Chipset chipset;
ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
: SplitMatchAndRewrite::SplitMatchAndRewrite(ctx), chipset(chipset) {}
: OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
LogicalResult match(arith::ExtFOp op) const override;
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const override;
};
struct TruncFToFloat8RewritePattern final
: OpRewritePattern<arith::TruncFOp>::SplitMatchAndRewrite {
struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
bool saturateFP8 = false;
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
Chipset chipset)
: SplitMatchAndRewrite::SplitMatchAndRewrite(ctx),
saturateFP8(saturateFP8), chipset(chipset) {}
: OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
chipset(chipset) {}
Chipset chipset;
LogicalResult match(arith::TruncFOp op) const override;
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(arith::TruncFOp op,
PatternRewriter &rewriter) const override;
};
struct TruncfToFloat16RewritePattern final
: public OpRewritePattern<arith::TruncFOp>::SplitMatchAndRewrite {
: public OpRewritePattern<arith::TruncFOp> {
using SplitMatchAndRewrite::SplitMatchAndRewrite;
using OpRewritePattern::OpRewritePattern;
LogicalResult match(arith::TruncFOp op) const override;
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
LogicalResult matchAndRewrite(arith::TruncFOp op,
PatternRewriter &rewriter) const override;
};
} // end namespace
static LogicalResult isSupportedF8(Type elementType, Chipset chipset) {
static bool isSupportedF8(Type elementType, Chipset chipset) {
if (chipset == kGfx942)
return success(isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType));
return isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType);
if (hasOcpFp8(chipset))
return success(isa<Float8E4M3FNType, Float8E5M2Type>(elementType));
return failure();
return isa<Float8E4M3FNType, Float8E5M2Type>(elementType);
return false;
}
static Value castF32To(Type elementType, Value f32, Location loc,
@ -96,35 +94,36 @@ static Value castF32To(Type elementType, Value f32, Location loc,
llvm_unreachable("The only 32-bit float type is f32");
}
LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
LogicalResult
ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const {
Type inType = op.getIn().getType();
if (auto inVecType = dyn_cast<VectorType>(inType)) {
auto inVecType = dyn_cast<VectorType>(inType);
if (inVecType) {
if (inVecType.isScalable())
return failure();
inType = inVecType.getElementType();
}
return isSupportedF8(inType, chipset);
}
if (!isSupportedF8(inType, chipset))
return failure();
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
auto inType = dyn_cast<VectorType>(in.getType());
if (!inType) {
if (!inVecType) {
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
loc, rewriter.getF32Type(), in, 0);
Value result = castF32To(outElemType, asFloat, loc, rewriter);
return rewriter.replaceOp(op, result);
rewriter.replaceOp(op, result);
return success();
}
int64_t numElements = inType.getNumElements();
int64_t numElements = inVecType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
VectorType outType = cast<VectorType>(op.getOut().getType());
if (inType.getShape().empty()) {
if (inVecType.getShape().empty()) {
Value zerodSplat =
rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
Value scalarIn =
@ -133,17 +132,18 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat,
ArrayRef<int64_t>{});
return rewriter.replaceOp(op, result);
rewriter.replaceOp(op, result);
return success();
}
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outType.getElementType());
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
if (inType.getRank() > 1) {
inType = VectorType::get(SmallVector<int64_t>{numElements},
inType.getElementType());
in = rewriter.create<vector::ShapeCastOp>(loc, inType, in);
if (inVecType.getRank() > 1) {
inVecType = VectorType::get(SmallVector<int64_t>{numElements},
inVecType.getElementType());
in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in);
}
for (int64_t i = 0; i < numElements; i += 4) {
@ -158,11 +158,12 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
}
}
if (inType.getRank() != outType.getRank()) {
if (inVecType.getRank() != outType.getRank()) {
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
}
rewriter.replaceOp(op, result);
return success();
}
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
@ -222,12 +223,15 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
return res;
}
LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
LogicalResult
TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
PatternRewriter &rewriter) const {
// Only supporting default rounding mode as of now.
if (op.getRoundingmodeAttr())
return failure();
Type outType = op.getOut().getType();
if (auto outVecType = dyn_cast<VectorType>(outType)) {
auto outVecType = dyn_cast<VectorType>(outType);
if (outVecType) {
if (outVecType.isScalable())
return failure();
outType = outVecType.getElementType();
@ -237,11 +241,9 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
// Conversion between 8-bit floats is not supported with truncation enabled.
return failure();
return isSupportedF8(outType, chipset);
}
if (!isSupportedF8(outType, chipset))
return failure();
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
@ -255,13 +257,14 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
/*existing=*/nullptr);
Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
return rewriter.replaceOp(op, result);
rewriter.replaceOp(op, result);
return success();
}
VectorType outType = cast<VectorType>(op.getOut().getType());
int64_t numElements = outType.getNumElements();
int64_t numElements = outVecType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
if (outType.getShape().empty()) {
if (outVecType.getShape().empty()) {
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
@ -269,11 +272,12 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
ArrayRef<int64_t>{});
return rewriter.replaceOp(op, result);
rewriter.replaceOp(op, result);
return success();
}
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outType.getElementType());
outVecType.getElementType());
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
if (inVectorTy.getRank() > 1) {
@ -303,26 +307,27 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
result, i, 1);
}
if (inVectorTy.getRank() != outType.getRank()) {
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
if (inVectorTy.getRank() != outVecType.getRank()) {
result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
}
rewriter.replaceOp(op, result);
return success();
}
LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const {
LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
arith::TruncFOp op, PatternRewriter &rewriter) const {
Type outType = op.getOut().getType();
Type inputType = getElementTypeOrSelf(op.getIn());
if (auto outVecType = dyn_cast<VectorType>(outType)) {
auto outVecType = dyn_cast<VectorType>(outType);
if (outVecType) {
if (outVecType.isScalable())
return failure();
outType = outVecType.getElementType();
}
return success(outType.isF16() && inputType.isF32());
}
if (!(outType.isF16() && inputType.isF32()))
return failure();
void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
@ -335,13 +340,13 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
Value asF16s =
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
return rewriter.replaceOp(op, result);
rewriter.replaceOp(op, result);
return success();
}
VectorType outType = cast<VectorType>(op.getOut().getType());
int64_t numElements = outType.getNumElements();
int64_t numElements = outVecType.getNumElements();
Value zero = rewriter.createOrFold<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
@ -371,11 +376,12 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
result, i, 1);
}
if (inVectorTy.getRank() != outType.getRank()) {
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
if (inVectorTy.getRank() != outVecType.getRank()) {
result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
}
rewriter.replaceOp(op, result);
return success();
}
void mlir::arith::populateArithToAMDGPUConversionPatterns(

View File

@ -657,11 +657,12 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
}
};
struct MemRefCastOpLowering
: public ConvertOpToLLVMPattern<memref::CastOp>::SplitMatchAndRewrite {
using SplitMatchAndRewrite::SplitMatchAndRewrite;
struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult match(memref::CastOp memRefCastOp) const override {
LogicalResult
matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type srcType = memRefCastOp.getOperand().getType();
Type dstType = memRefCastOp.getType();
@ -671,30 +672,22 @@ struct MemRefCastOpLowering
// perform a sanity check that the underlying structs are the same. Once op
// semantics are relaxed we can revisit.
if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
return success(typeConverter->convertType(srcType) ==
typeConverter->convertType(dstType));
// At least one of the operands is unranked type
assert(isa<UnrankedMemRefType>(srcType) ||
isa<UnrankedMemRefType>(dstType));
if (typeConverter->convertType(srcType) !=
typeConverter->convertType(dstType))
return failure();
// Unranked to unranked cast is disallowed
return !(isa<UnrankedMemRefType>(srcType) &&
isa<UnrankedMemRefType>(dstType))
? success()
: failure();
}
if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
return failure();
void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = memRefCastOp.getOperand().getType();
auto dstType = memRefCastOp.getType();
auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
auto loc = memRefCastOp.getLoc();
// For ranked/ranked case, just keep the original descriptor.
if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
return success();
}
if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
// Casting ranked to unranked memref type
@ -733,6 +726,8 @@ struct MemRefCastOpLowering
} else {
llvm_unreachable("Unsupported unranked memref to unranked memref cast");
}
return success();
}
};

View File

@ -40,36 +40,33 @@ struct EmulateUnsupportedFloatsPass
void runOnOperation() override;
};
struct EmulateFloatPattern final : ConversionPattern::SplitMatchAndRewrite {
struct EmulateFloatPattern final : ConversionPattern {
EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx)
: ConversionPattern::SplitMatchAndRewrite(
: ConversionPattern::ConversionPattern(
converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
LogicalResult match(Operation *op) const override;
void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // end namespace
LogicalResult EmulateFloatPattern::match(Operation *op) const {
LogicalResult EmulateFloatPattern::matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
if (getTypeConverter()->isLegal(op))
return failure();
// The rewrite doesn't handle cloning regions.
if (op->getNumRegions() != 0)
return failure();
return success();
}
void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
const TypeConverter *converter = getTypeConverter();
SmallVector<Type> resultTypes;
if (failed(converter->convertTypes(op->getResultTypes(), resultTypes))) {
// Note to anyone looking for this error message: this is a "can't happen".
// If you're seeing it, there's a bug.
op->emitOpError("type conversion failed in float emulation");
return;
return op->emitOpError("type conversion failed in float emulation");
}
Operation *expandedOp =
rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
@ -84,6 +81,7 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
}
}
rewriter.replaceOp(op, newResults);
return success();
}
void mlir::arith::populateEmulateUnsupportedFloatsConversions(

View File

@ -115,14 +115,14 @@ protected:
/// and replace their uses with that constant. Return success() if all results
/// where thus replaced and the operation is erased. Also replace any block
/// arguments with their constant values.
struct MaterializeKnownConstantValues
: public RewritePattern::SplitMatchAndRewrite {
struct MaterializeKnownConstantValues : public RewritePattern {
MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
: RewritePattern::SplitMatchAndRewrite(Pattern::MatchAnyOpTypeTag(),
/*benefit=*/1, context),
: RewritePattern::RewritePattern(Pattern::MatchAnyOpTypeTag(),
/*benefit=*/1, context),
solver(s) {}
LogicalResult match(Operation *op) const override {
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
if (matchPattern(op, m_Constant()))
return failure();
@ -131,7 +131,8 @@ struct MaterializeKnownConstantValues
};
bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
if (op->getNumRegions() == 0)
return success(hasConstantResults);
if (!hasConstantResults)
return failure();
bool hasConstantRegionArgs = false;
for (Region &region : op->getRegions()) {
for (Block &block : region.getBlocks()) {
@ -139,10 +140,9 @@ struct MaterializeKnownConstantValues
llvm::any_of(block.getArguments(), needsReplacing);
}
}
return success(hasConstantResults || hasConstantRegionArgs);
}
if (!hasConstantResults && !hasConstantRegionArgs)
return failure();
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
bool replacedAll = (op->getNumResults() != 0);
for (Value v : op->getResults())
replacedAll &=
@ -150,7 +150,7 @@ struct MaterializeKnownConstantValues
v.use_empty());
if (replacedAll && isOpTriviallyDead(op)) {
rewriter.eraseOp(op);
return;
return success();
}
PatternRewriter::InsertionGuard guard(rewriter);
@ -162,6 +162,8 @@ struct MaterializeKnownConstantValues
}
}
}
return success();
}
private:

View File

@ -772,17 +772,16 @@ private:
/// `vector.extract` and `vector.extract_element`.
template <class VectorExtractOp>
class RewriteScalarExtractOfTransferReadBase
: public OpRewritePattern<VectorExtractOp>::SplitMatchAndRewrite {
using Base = typename OpRewritePattern<VectorExtractOp>::SplitMatchAndRewrite;
: public OpRewritePattern<VectorExtractOp> {
using Base = OpRewritePattern<VectorExtractOp>;
public:
RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
PatternBenefit benefit,
bool allowMultipleUses)
: Base::SplitMatchAndRewrite(context, benefit),
allowMultipleUses(allowMultipleUses) {}
: Base(context, benefit), allowMultipleUses(allowMultipleUses) {}
LogicalResult match(VectorExtractOp extractOp) const override {
LogicalResult match(VectorExtractOp extractOp) const {
auto xferOp =
extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
if (!xferOp)
@ -828,8 +827,11 @@ class RewriteScalarExtractElementOfTransferRead
using RewriteScalarExtractOfTransferReadBase::
RewriteScalarExtractOfTransferReadBase;
void rewrite(vector::ExtractElementOp extractOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
PatternRewriter &rewriter) const override {
if (failed(match(extractOp)))
return failure();
// Construct scalar load.
auto loc = extractOp.getLoc();
auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
@ -856,6 +858,8 @@ class RewriteScalarExtractElementOfTransferRead
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
extractOp, xferOp.getSource(), newIndices);
}
return success();
}
};
@ -872,8 +876,11 @@ class RewriteScalarExtractOfTransferRead
using RewriteScalarExtractOfTransferReadBase::
RewriteScalarExtractOfTransferReadBase;
void rewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
if (failed(match(extractOp)))
return failure();
// Construct scalar load.
auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
@ -899,6 +906,8 @@ class RewriteScalarExtractOfTransferRead
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
extractOp, xferOp.getSource(), newIndices);
}
return success();
}
};