[mlir][IR] Deprecate match
and rewrite
functions (#130031)
Deprecate the `match` and `rewrite` functions. They mainly exist for historic reasons. This PR also updates all remaining uses of in the MLIR codebase. This is addressing a [comment](https://github.com/llvm/llvm-project/pull/129861#pullrequestreview-2662696084) on an earlier PR. Note for LLVM integration: `SplitMatchAndRewrite` will be deleted soon, update your patterns to use `matchAndRewrite` instead of separate `match` / `rewrite`. --------- Co-authored-by: Jakub Kuderski <jakub@nod-labs.com>
This commit is contained in:
parent
6b094020c2
commit
a21cfca320
@ -179,13 +179,13 @@ updated/remapped operands of an operation, such as when the types of results
|
||||
defined by an operation have changed. The general Rewrite Patterns can no longer
|
||||
be used in these situations, as the types of the operands of the operation being
|
||||
matched will not correspond with those expected by the user. This pattern
|
||||
provides, as an additional argument to the `matchAndRewrite` and `rewrite`
|
||||
methods, the list of operands that the operation should use after conversion. If
|
||||
an operand was the result of a non-converted operation, for example if it was
|
||||
already legal, the original operand is used. This means that the operands
|
||||
provided always have a 1-1 non-null correspondence with the operands on the
|
||||
operation. The original operands of the operation are still intact and may be
|
||||
inspected as normal. These patterns also utilize a special `PatternRewriter`,
|
||||
provides, as an additional argument to the `matchAndRewrite` method, the list
|
||||
of operands that the operation should use after conversion. If an operand was
|
||||
the result of a non-converted operation, for example if it was already legal,
|
||||
the original operand is used. This means that the operands provided always have
|
||||
a 1-1 non-null correspondence with the operands on the operation. The original
|
||||
operands of the operation are still intact and may be inspected as normal.
|
||||
These patterns also utilize a special `PatternRewriter`,
|
||||
`ConversionPatternRewriter`, that provides special hooks for use with the
|
||||
conversion infrastructure.
|
||||
|
||||
|
@ -48,13 +48,9 @@ operation type, a special tag must be provided to make the intent explicit:
|
||||
### `matchAndRewrite` implementation
|
||||
|
||||
This is the chunk of code that matches a given root `Operation` and performs a
|
||||
rewrite of the IR. A `RewritePattern` can specify this implementation either via
|
||||
the `matchAndRewrite` method or via separate `match` and `rewrite` methods when
|
||||
deriving from `RewritePattern::SplitMatchAndRewrite`. When using the combined
|
||||
`matchAndRewrite` method, no IR mutation should take place before the match is
|
||||
deemed successful. The combined `matchAndRewrite` is useful when non-trivially
|
||||
recomputable information is required by the matching and rewriting phase. See
|
||||
below for examples:
|
||||
rewrite of the IR. A `RewritePattern` can specify this implementation via the
|
||||
`matchAndRewrite` method. No IR mutation should take place before the match is
|
||||
deemed successful. See below for examples:
|
||||
|
||||
```c++
|
||||
class MyPattern : public RewritePattern {
|
||||
@ -67,21 +63,6 @@ public:
|
||||
MyPattern(PatternBenefit benefit)
|
||||
: RewritePattern(benefit, MatchAnyOpTypeTag()) {}
|
||||
|
||||
/// In this section, the `match` and `rewrite` implementation is specified
|
||||
/// using the separate hooks.
|
||||
LogicalResult match(Operation *op) const override {
|
||||
// The `match` method returns `success()` if the pattern is a match, failure
|
||||
// otherwise.
|
||||
// ...
|
||||
}
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
// The `rewrite` method performs mutations on the IR rooted at `op` using
|
||||
// the provided rewriter. All mutations must go through the provided
|
||||
// rewriter.
|
||||
}
|
||||
|
||||
/// In this section, the `match` and `rewrite` implementation is specified
|
||||
/// using a single hook.
|
||||
LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
// The `matchAndRewrite` method performs both the matching and the mutation.
|
||||
// Note that the match must reach a successful point before IR mutation may
|
||||
@ -92,12 +73,6 @@ public:
|
||||
|
||||
#### Restrictions
|
||||
|
||||
Within the `match` section of a pattern, the following constraints apply:
|
||||
|
||||
* No mutation of the IR is allowed.
|
||||
|
||||
Within the `rewrite` section of a pattern, the following constraints apply:
|
||||
|
||||
* All IR mutations, including creation, *must* be performed by the given
|
||||
`PatternRewriter`. This class provides hooks for performing all of the
|
||||
possible mutations that may take place within a pattern. For example, this
|
||||
@ -107,8 +82,6 @@ Within the `rewrite` section of a pattern, the following constraints apply:
|
||||
* The root operation is required to either be: updated in-place, replaced, or
|
||||
erased.
|
||||
* `matchAndRewrite` must return "success" if and only if the IR was modified.
|
||||
`match` must return "success" if and only if the IR is going to be modified
|
||||
during `rewrite`.
|
||||
|
||||
|
||||
### Application Recursion
|
||||
|
@ -216,25 +216,6 @@ In case ODS patterns and `matchAndRewrite`-style functions are not sufficient
|
||||
you can also specify rewrites as a general set of `RewritePattern`s:
|
||||
|
||||
```c++
|
||||
/// Multi-step rewrite using "match" and "rewrite". This allows for separating
|
||||
/// the concerns of matching and rewriting.
|
||||
struct ConvertTFLeakyRelu : public RewritePattern {
|
||||
ConvertTFLeakyRelu(MLIRContext *context)
|
||||
: RewritePattern("tf.LeakyRelu", 1, context) {}
|
||||
|
||||
LogicalResult match(Operation *op) const override {
|
||||
return success();
|
||||
}
|
||||
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<TFL::LeakyReluOp>(
|
||||
op, op->getResult(0).getType(), op->getOperand(0),
|
||||
/*alpha=*/op->getAttrOfType<FloatAttr>("alpha"));
|
||||
}
|
||||
};
|
||||
|
||||
/// Single-step rewrite with "matchAndRewrite". This allows for performing the
|
||||
/// rewrite immediately upon a successful match.
|
||||
struct ConvertTFLeakyRelu : public RewritePattern {
|
||||
ConvertTFLeakyRelu(MLIRContext *context)
|
||||
: RewritePattern("tf.LeakyRelu", 1, context) {}
|
||||
|
@ -40,6 +40,8 @@ LogicalResult oneToOneRewrite(
|
||||
/// during the entire pattern lifetime.
|
||||
class ConvertToLLVMPattern : public ConversionPattern {
|
||||
public:
|
||||
/// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
|
||||
/// separate `match` and `rewrite`.
|
||||
using SplitMatchAndRewrite =
|
||||
detail::ConversionSplitMatchAndRewriteImpl<ConvertToLLVMPattern>;
|
||||
|
||||
@ -149,6 +151,9 @@ public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
using OneToNOpAdaptor =
|
||||
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
|
||||
|
||||
/// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
|
||||
/// separate `match` and `rewrite`.
|
||||
using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl<
|
||||
ConvertOpToLLVMPattern<SourceOp>>;
|
||||
|
||||
|
@ -237,6 +237,9 @@ private:
|
||||
namespace detail {
|
||||
/// Helper class that derives from a RewritePattern class and provides separate
|
||||
/// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
|
||||
///
|
||||
/// This class is deprecated. Use `matchAndRewrite` instead of separate `match`
|
||||
/// and `rewrite`.
|
||||
template <typename PatternT>
|
||||
class SplitMatchAndRewriteImpl : public PatternT {
|
||||
using PatternT::PatternT;
|
||||
@ -268,6 +271,9 @@ class SplitMatchAndRewriteImpl : public PatternT {
|
||||
class RewritePattern : public Pattern {
|
||||
public:
|
||||
using OperationT = Operation *;
|
||||
|
||||
/// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
|
||||
/// separate `match` and `rewrite`.
|
||||
using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;
|
||||
|
||||
virtual ~RewritePattern() = default;
|
||||
@ -350,6 +356,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
|
||||
template <typename SourceOp>
|
||||
struct OpRewritePattern
|
||||
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
|
||||
|
||||
/// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
|
||||
/// separate `match` and `rewrite`.
|
||||
using SplitMatchAndRewrite =
|
||||
detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;
|
||||
|
||||
@ -368,6 +377,9 @@ struct OpRewritePattern
|
||||
template <typename SourceOp>
|
||||
struct OpInterfaceRewritePattern
|
||||
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
|
||||
|
||||
/// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
|
||||
/// separate `match` and `rewrite`.
|
||||
using SplitMatchAndRewrite =
|
||||
detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;
|
||||
|
||||
|
@ -598,6 +598,9 @@ public:
|
||||
using OperationT = Operation *;
|
||||
using OpAdaptor = ArrayRef<Value>;
|
||||
using OneToNOpAdaptor = ArrayRef<ValueRange>;
|
||||
|
||||
/// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
|
||||
/// separate `match` and `rewrite`.
|
||||
using SplitMatchAndRewrite =
|
||||
detail::ConversionSplitMatchAndRewriteImpl<ConversionPattern>;
|
||||
|
||||
@ -669,6 +672,9 @@ public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
using OneToNOpAdaptor =
|
||||
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
|
||||
|
||||
/// `SplitMatchAndRewrite` is deprecated. Use `matchAndRewrite` instead of
|
||||
/// separate `match` and `rewrite`.
|
||||
using SplitMatchAndRewrite =
|
||||
detail::ConversionSplitMatchAndRewriteImpl<OpConversionPattern<SourceOp>>;
|
||||
|
||||
|
@ -41,48 +41,46 @@ struct ArithToAMDGPUConversionPass final
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
struct ExtFOnFloat8RewritePattern final
|
||||
: OpRewritePattern<arith::ExtFOp>::SplitMatchAndRewrite {
|
||||
using SplitMatchAndRewrite::SplitMatchAndRewrite;
|
||||
struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
Chipset chipset;
|
||||
ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset)
|
||||
: SplitMatchAndRewrite::SplitMatchAndRewrite(ctx), chipset(chipset) {}
|
||||
: OpRewritePattern::OpRewritePattern(ctx), chipset(chipset) {}
|
||||
|
||||
LogicalResult match(arith::ExtFOp op) const override;
|
||||
void rewrite(arith::ExtFOp op, PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(arith::ExtFOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct TruncFToFloat8RewritePattern final
|
||||
: OpRewritePattern<arith::TruncFOp>::SplitMatchAndRewrite {
|
||||
struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
|
||||
bool saturateFP8 = false;
|
||||
TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
|
||||
Chipset chipset)
|
||||
: SplitMatchAndRewrite::SplitMatchAndRewrite(ctx),
|
||||
saturateFP8(saturateFP8), chipset(chipset) {}
|
||||
: OpRewritePattern::OpRewritePattern(ctx), saturateFP8(saturateFP8),
|
||||
chipset(chipset) {}
|
||||
Chipset chipset;
|
||||
|
||||
LogicalResult match(arith::TruncFOp op) const override;
|
||||
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(arith::TruncFOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
struct TruncfToFloat16RewritePattern final
|
||||
: public OpRewritePattern<arith::TruncFOp>::SplitMatchAndRewrite {
|
||||
: public OpRewritePattern<arith::TruncFOp> {
|
||||
|
||||
using SplitMatchAndRewrite::SplitMatchAndRewrite;
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult match(arith::TruncFOp op) const override;
|
||||
void rewrite(arith::TruncFOp op, PatternRewriter &rewriter) const override;
|
||||
LogicalResult matchAndRewrite(arith::TruncFOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
} // end namespace
|
||||
|
||||
static LogicalResult isSupportedF8(Type elementType, Chipset chipset) {
|
||||
static bool isSupportedF8(Type elementType, Chipset chipset) {
|
||||
if (chipset == kGfx942)
|
||||
return success(isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType));
|
||||
return isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType);
|
||||
if (hasOcpFp8(chipset))
|
||||
return success(isa<Float8E4M3FNType, Float8E5M2Type>(elementType));
|
||||
return failure();
|
||||
return isa<Float8E4M3FNType, Float8E5M2Type>(elementType);
|
||||
return false;
|
||||
}
|
||||
|
||||
static Value castF32To(Type elementType, Value f32, Location loc,
|
||||
@ -96,35 +94,36 @@ static Value castF32To(Type elementType, Value f32, Location loc,
|
||||
llvm_unreachable("The only 32-bit float type is f32");
|
||||
}
|
||||
|
||||
LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
|
||||
LogicalResult
|
||||
ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
Type inType = op.getIn().getType();
|
||||
if (auto inVecType = dyn_cast<VectorType>(inType)) {
|
||||
auto inVecType = dyn_cast<VectorType>(inType);
|
||||
if (inVecType) {
|
||||
if (inVecType.isScalable())
|
||||
return failure();
|
||||
inType = inVecType.getElementType();
|
||||
}
|
||||
return isSupportedF8(inType, chipset);
|
||||
}
|
||||
if (!isSupportedF8(inType, chipset))
|
||||
return failure();
|
||||
|
||||
void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
Value in = op.getIn();
|
||||
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
|
||||
auto inType = dyn_cast<VectorType>(in.getType());
|
||||
if (!inType) {
|
||||
if (!inVecType) {
|
||||
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
|
||||
loc, rewriter.getF32Type(), in, 0);
|
||||
Value result = castF32To(outElemType, asFloat, loc, rewriter);
|
||||
return rewriter.replaceOp(op, result);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
int64_t numElements = inType.getNumElements();
|
||||
int64_t numElements = inVecType.getNumElements();
|
||||
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
|
||||
VectorType outType = cast<VectorType>(op.getOut().getType());
|
||||
|
||||
if (inType.getShape().empty()) {
|
||||
if (inVecType.getShape().empty()) {
|
||||
Value zerodSplat =
|
||||
rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
|
||||
Value scalarIn =
|
||||
@ -133,17 +132,18 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
|
||||
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
|
||||
Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat,
|
||||
ArrayRef<int64_t>{});
|
||||
return rewriter.replaceOp(op, result);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
|
||||
outType.getElementType());
|
||||
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
|
||||
|
||||
if (inType.getRank() > 1) {
|
||||
inType = VectorType::get(SmallVector<int64_t>{numElements},
|
||||
inType.getElementType());
|
||||
in = rewriter.create<vector::ShapeCastOp>(loc, inType, in);
|
||||
if (inVecType.getRank() > 1) {
|
||||
inVecType = VectorType::get(SmallVector<int64_t>{numElements},
|
||||
inVecType.getElementType());
|
||||
in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in);
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < numElements; i += 4) {
|
||||
@ -158,11 +158,12 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
|
||||
}
|
||||
}
|
||||
|
||||
if (inType.getRank() != outType.getRank()) {
|
||||
if (inVecType.getRank() != outType.getRank()) {
|
||||
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
|
||||
@ -222,12 +223,15 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
|
||||
return res;
|
||||
}
|
||||
|
||||
LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
|
||||
LogicalResult
|
||||
TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
// Only supporting default rounding mode as of now.
|
||||
if (op.getRoundingmodeAttr())
|
||||
return failure();
|
||||
Type outType = op.getOut().getType();
|
||||
if (auto outVecType = dyn_cast<VectorType>(outType)) {
|
||||
auto outVecType = dyn_cast<VectorType>(outType);
|
||||
if (outVecType) {
|
||||
if (outVecType.isScalable())
|
||||
return failure();
|
||||
outType = outVecType.getElementType();
|
||||
@ -237,11 +241,9 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
|
||||
// Conversion between 8-bit floats is not supported with truncation enabled.
|
||||
return failure();
|
||||
|
||||
return isSupportedF8(outType, chipset);
|
||||
}
|
||||
if (!isSupportedF8(outType, chipset))
|
||||
return failure();
|
||||
|
||||
void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
Value in = op.getIn();
|
||||
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
|
||||
@ -255,13 +257,14 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
|
||||
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
|
||||
/*existing=*/nullptr);
|
||||
Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
|
||||
return rewriter.replaceOp(op, result);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
VectorType outType = cast<VectorType>(op.getOut().getType());
|
||||
int64_t numElements = outType.getNumElements();
|
||||
|
||||
int64_t numElements = outVecType.getNumElements();
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
|
||||
if (outType.getShape().empty()) {
|
||||
if (outVecType.getShape().empty()) {
|
||||
Value scalarIn =
|
||||
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
|
||||
// Recurse to send the 0-D vector case to the 1-D vector case
|
||||
@ -269,11 +272,12 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
|
||||
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
|
||||
Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
|
||||
ArrayRef<int64_t>{});
|
||||
return rewriter.replaceOp(op, result);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
|
||||
outType.getElementType());
|
||||
outVecType.getElementType());
|
||||
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
|
||||
|
||||
if (inVectorTy.getRank() > 1) {
|
||||
@ -303,26 +307,27 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
|
||||
result, i, 1);
|
||||
}
|
||||
|
||||
if (inVectorTy.getRank() != outType.getRank()) {
|
||||
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
|
||||
if (inVectorTy.getRank() != outVecType.getRank()) {
|
||||
result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult TruncfToFloat16RewritePattern::match(arith::TruncFOp op) const {
|
||||
LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
|
||||
arith::TruncFOp op, PatternRewriter &rewriter) const {
|
||||
Type outType = op.getOut().getType();
|
||||
Type inputType = getElementTypeOrSelf(op.getIn());
|
||||
if (auto outVecType = dyn_cast<VectorType>(outType)) {
|
||||
auto outVecType = dyn_cast<VectorType>(outType);
|
||||
if (outVecType) {
|
||||
if (outVecType.isScalable())
|
||||
return failure();
|
||||
outType = outVecType.getElementType();
|
||||
}
|
||||
return success(outType.isF16() && inputType.isF32());
|
||||
}
|
||||
if (!(outType.isF16() && inputType.isF32()))
|
||||
return failure();
|
||||
|
||||
void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
|
||||
PatternRewriter &rewriter) const {
|
||||
Location loc = op.getLoc();
|
||||
Value in = op.getIn();
|
||||
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
|
||||
@ -335,13 +340,13 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
|
||||
Value asF16s =
|
||||
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
|
||||
Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
|
||||
return rewriter.replaceOp(op, result);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
VectorType outType = cast<VectorType>(op.getOut().getType());
|
||||
int64_t numElements = outType.getNumElements();
|
||||
int64_t numElements = outVecType.getNumElements();
|
||||
Value zero = rewriter.createOrFold<arith::ConstantOp>(
|
||||
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
|
||||
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
|
||||
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);
|
||||
|
||||
if (inVectorTy.getRank() > 1) {
|
||||
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
|
||||
@ -371,11 +376,12 @@ void TruncfToFloat16RewritePattern::rewrite(arith::TruncFOp op,
|
||||
result, i, 1);
|
||||
}
|
||||
|
||||
if (inVectorTy.getRank() != outType.getRank()) {
|
||||
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
|
||||
if (inVectorTy.getRank() != outVecType.getRank()) {
|
||||
result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
|
||||
}
|
||||
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::arith::populateArithToAMDGPUConversionPatterns(
|
||||
|
@ -657,11 +657,12 @@ struct RankOpLowering : public ConvertOpToLLVMPattern<memref::RankOp> {
|
||||
}
|
||||
};
|
||||
|
||||
struct MemRefCastOpLowering
|
||||
: public ConvertOpToLLVMPattern<memref::CastOp>::SplitMatchAndRewrite {
|
||||
using SplitMatchAndRewrite::SplitMatchAndRewrite;
|
||||
struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
|
||||
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult match(memref::CastOp memRefCastOp) const override {
|
||||
LogicalResult
|
||||
matchAndRewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type srcType = memRefCastOp.getOperand().getType();
|
||||
Type dstType = memRefCastOp.getType();
|
||||
|
||||
@ -671,30 +672,22 @@ struct MemRefCastOpLowering
|
||||
// perform a sanity check that the underlying structs are the same. Once op
|
||||
// semantics are relaxed we can revisit.
|
||||
if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
|
||||
return success(typeConverter->convertType(srcType) ==
|
||||
typeConverter->convertType(dstType));
|
||||
|
||||
// At least one of the operands is unranked type
|
||||
assert(isa<UnrankedMemRefType>(srcType) ||
|
||||
isa<UnrankedMemRefType>(dstType));
|
||||
if (typeConverter->convertType(srcType) !=
|
||||
typeConverter->convertType(dstType))
|
||||
return failure();
|
||||
|
||||
// Unranked to unranked cast is disallowed
|
||||
return !(isa<UnrankedMemRefType>(srcType) &&
|
||||
isa<UnrankedMemRefType>(dstType))
|
||||
? success()
|
||||
: failure();
|
||||
}
|
||||
if (isa<UnrankedMemRefType>(srcType) && isa<UnrankedMemRefType>(dstType))
|
||||
return failure();
|
||||
|
||||
void rewrite(memref::CastOp memRefCastOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcType = memRefCastOp.getOperand().getType();
|
||||
auto dstType = memRefCastOp.getType();
|
||||
auto targetStructType = typeConverter->convertType(memRefCastOp.getType());
|
||||
auto loc = memRefCastOp.getLoc();
|
||||
|
||||
// For ranked/ranked case, just keep the original descriptor.
|
||||
if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType))
|
||||
return rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
|
||||
if (isa<MemRefType>(srcType) && isa<MemRefType>(dstType)) {
|
||||
rewriter.replaceOp(memRefCastOp, {adaptor.getSource()});
|
||||
return success();
|
||||
}
|
||||
|
||||
if (isa<MemRefType>(srcType) && isa<UnrankedMemRefType>(dstType)) {
|
||||
// Casting ranked to unranked memref type
|
||||
@ -733,6 +726,8 @@ struct MemRefCastOpLowering
|
||||
} else {
|
||||
llvm_unreachable("Unsupported unranked memref to unranked memref cast");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -40,36 +40,33 @@ struct EmulateUnsupportedFloatsPass
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
struct EmulateFloatPattern final : ConversionPattern::SplitMatchAndRewrite {
|
||||
struct EmulateFloatPattern final : ConversionPattern {
|
||||
EmulateFloatPattern(const TypeConverter &converter, MLIRContext *ctx)
|
||||
: ConversionPattern::SplitMatchAndRewrite(
|
||||
: ConversionPattern::ConversionPattern(
|
||||
converter, Pattern::MatchAnyOpTypeTag(), 1, ctx) {}
|
||||
|
||||
LogicalResult match(Operation *op) const override;
|
||||
void rewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
} // end namespace
|
||||
|
||||
LogicalResult EmulateFloatPattern::match(Operation *op) const {
|
||||
LogicalResult EmulateFloatPattern::matchAndRewrite(
|
||||
Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (getTypeConverter()->isLegal(op))
|
||||
return failure();
|
||||
// The rewrite doesn't handle cloning regions.
|
||||
if (op->getNumRegions() != 0)
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
Location loc = op->getLoc();
|
||||
const TypeConverter *converter = getTypeConverter();
|
||||
SmallVector<Type> resultTypes;
|
||||
if (failed(converter->convertTypes(op->getResultTypes(), resultTypes))) {
|
||||
// Note to anyone looking for this error message: this is a "can't happen".
|
||||
// If you're seeing it, there's a bug.
|
||||
op->emitOpError("type conversion failed in float emulation");
|
||||
return;
|
||||
return op->emitOpError("type conversion failed in float emulation");
|
||||
}
|
||||
Operation *expandedOp =
|
||||
rewriter.create(loc, op->getName().getIdentifier(), operands, resultTypes,
|
||||
@ -84,6 +81,7 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, newResults);
|
||||
return success();
|
||||
}
|
||||
|
||||
void mlir::arith::populateEmulateUnsupportedFloatsConversions(
|
||||
|
@ -115,14 +115,14 @@ protected:
|
||||
/// and replace their uses with that constant. Return success() if all results
|
||||
/// where thus replaced and the operation is erased. Also replace any block
|
||||
/// arguments with their constant values.
|
||||
struct MaterializeKnownConstantValues
|
||||
: public RewritePattern::SplitMatchAndRewrite {
|
||||
struct MaterializeKnownConstantValues : public RewritePattern {
|
||||
MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
|
||||
: RewritePattern::SplitMatchAndRewrite(Pattern::MatchAnyOpTypeTag(),
|
||||
/*benefit=*/1, context),
|
||||
: RewritePattern::RewritePattern(Pattern::MatchAnyOpTypeTag(),
|
||||
/*benefit=*/1, context),
|
||||
solver(s) {}
|
||||
|
||||
LogicalResult match(Operation *op) const override {
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (matchPattern(op, m_Constant()))
|
||||
return failure();
|
||||
|
||||
@ -131,7 +131,8 @@ struct MaterializeKnownConstantValues
|
||||
};
|
||||
bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
|
||||
if (op->getNumRegions() == 0)
|
||||
return success(hasConstantResults);
|
||||
if (!hasConstantResults)
|
||||
return failure();
|
||||
bool hasConstantRegionArgs = false;
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
for (Block &block : region.getBlocks()) {
|
||||
@ -139,10 +140,9 @@ struct MaterializeKnownConstantValues
|
||||
llvm::any_of(block.getArguments(), needsReplacing);
|
||||
}
|
||||
}
|
||||
return success(hasConstantResults || hasConstantRegionArgs);
|
||||
}
|
||||
if (!hasConstantResults && !hasConstantRegionArgs)
|
||||
return failure();
|
||||
|
||||
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
|
||||
bool replacedAll = (op->getNumResults() != 0);
|
||||
for (Value v : op->getResults())
|
||||
replacedAll &=
|
||||
@ -150,7 +150,7 @@ struct MaterializeKnownConstantValues
|
||||
v.use_empty());
|
||||
if (replacedAll && isOpTriviallyDead(op)) {
|
||||
rewriter.eraseOp(op);
|
||||
return;
|
||||
return success();
|
||||
}
|
||||
|
||||
PatternRewriter::InsertionGuard guard(rewriter);
|
||||
@ -162,6 +162,8 @@ struct MaterializeKnownConstantValues
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -772,17 +772,16 @@ private:
|
||||
/// `vector.extract` and `vector.extract_element`.
|
||||
template <class VectorExtractOp>
|
||||
class RewriteScalarExtractOfTransferReadBase
|
||||
: public OpRewritePattern<VectorExtractOp>::SplitMatchAndRewrite {
|
||||
using Base = typename OpRewritePattern<VectorExtractOp>::SplitMatchAndRewrite;
|
||||
: public OpRewritePattern<VectorExtractOp> {
|
||||
using Base = OpRewritePattern<VectorExtractOp>;
|
||||
|
||||
public:
|
||||
RewriteScalarExtractOfTransferReadBase(MLIRContext *context,
|
||||
PatternBenefit benefit,
|
||||
bool allowMultipleUses)
|
||||
: Base::SplitMatchAndRewrite(context, benefit),
|
||||
allowMultipleUses(allowMultipleUses) {}
|
||||
: Base(context, benefit), allowMultipleUses(allowMultipleUses) {}
|
||||
|
||||
LogicalResult match(VectorExtractOp extractOp) const override {
|
||||
LogicalResult match(VectorExtractOp extractOp) const {
|
||||
auto xferOp =
|
||||
extractOp.getVector().template getDefiningOp<vector::TransferReadOp>();
|
||||
if (!xferOp)
|
||||
@ -828,8 +827,11 @@ class RewriteScalarExtractElementOfTransferRead
|
||||
using RewriteScalarExtractOfTransferReadBase::
|
||||
RewriteScalarExtractOfTransferReadBase;
|
||||
|
||||
void rewrite(vector::ExtractElementOp extractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(vector::ExtractElementOp extractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (failed(match(extractOp)))
|
||||
return failure();
|
||||
|
||||
// Construct scalar load.
|
||||
auto loc = extractOp.getLoc();
|
||||
auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
|
||||
@ -856,6 +858,8 @@ class RewriteScalarExtractElementOfTransferRead
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
|
||||
extractOp, xferOp.getSource(), newIndices);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@ -872,8 +876,11 @@ class RewriteScalarExtractOfTransferRead
|
||||
using RewriteScalarExtractOfTransferReadBase::
|
||||
RewriteScalarExtractOfTransferReadBase;
|
||||
|
||||
void rewrite(vector::ExtractOp extractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (failed(match(extractOp)))
|
||||
return failure();
|
||||
|
||||
// Construct scalar load.
|
||||
auto xferOp = extractOp.getVector().getDefiningOp<vector::TransferReadOp>();
|
||||
SmallVector<Value> newIndices(xferOp.getIndices().begin(),
|
||||
@ -899,6 +906,8 @@ class RewriteScalarExtractOfTransferRead
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
|
||||
extractOp, xferOp.getSource(), newIndices);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user