
- Replaces dyn_cast<cir::ConstantOp>(v.getDefiningOp()) and similar with v.getDefiningOp<cir::ConstantOp>() - Adds `getValueAttr` method to ConstantOp
325 lines
11 KiB
C++
325 lines
11 KiB
C++
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "PassDetail.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/IR/Block.h"
|
|
#include "mlir/IR/Operation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/Region.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "clang/CIR/Dialect/IR/CIRDialect.h"
|
|
#include "clang/CIR/Dialect/Passes.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
using namespace mlir;
|
|
using namespace cir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Rewrite patterns
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
/// Simplify suitable ternary operations into select operations.
|
|
///
|
|
/// For now we only simplify those ternary operations whose true and false
|
|
/// branches directly yield a value or a constant. That is, both of the true and
|
|
/// the false branch must either contain a cir.yield operation as the only
|
|
/// operation in the branch, or contain a cir.const operation followed by a
|
|
/// cir.yield operation that yields the constant value.
|
|
///
|
|
/// For example, we will simplify the following ternary operation:
|
|
///
|
|
/// %0 = ...
|
|
/// %1 = cir.ternary (%condition, true {
|
|
/// %2 = cir.const ...
|
|
/// cir.yield %2
|
|
/// } false {
|
|
/// cir.yield %0
|
|
///
|
|
/// into the following sequence of operations:
|
|
///
|
|
/// %1 = cir.const ...
|
|
/// %0 = cir.select if %condition then %1 else %2
|
|
struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
|
|
using OpRewritePattern<TernaryOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(TernaryOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (op->getNumResults() != 1)
|
|
return mlir::failure();
|
|
|
|
if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
|
|
!isSimpleTernaryBranch(op.getFalseRegion()))
|
|
return mlir::failure();
|
|
|
|
cir::YieldOp trueBranchYieldOp =
|
|
mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
|
|
cir::YieldOp falseBranchYieldOp =
|
|
mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
|
|
mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
|
|
mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
|
|
|
|
rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
|
|
rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
|
|
rewriter.eraseOp(trueBranchYieldOp);
|
|
rewriter.eraseOp(falseBranchYieldOp);
|
|
rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
|
|
falseValue);
|
|
|
|
return mlir::success();
|
|
}
|
|
|
|
private:
|
|
bool isSimpleTernaryBranch(mlir::Region ®ion) const {
|
|
if (!region.hasOneBlock())
|
|
return false;
|
|
|
|
mlir::Block &onlyBlock = region.front();
|
|
mlir::Block::OpListType &ops = onlyBlock.getOperations();
|
|
|
|
// The region/block could only contain at most 2 operations.
|
|
if (ops.size() > 2)
|
|
return false;
|
|
|
|
if (ops.size() == 1) {
|
|
// The region/block only contain a cir.yield operation.
|
|
return true;
|
|
}
|
|
|
|
// Check whether the region/block contains a cir.const followed by a
|
|
// cir.yield that yields the value.
|
|
auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
|
|
auto yieldValueDefOp =
|
|
yieldOp.getArgs()[0].getDefiningOp<cir::ConstantOp>();
|
|
return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
|
|
}
|
|
};
|
|
|
|
/// Simplify select operations with boolean constants into simpler forms.
|
|
///
|
|
/// This pattern simplifies select operations where both true and false values
|
|
/// are boolean constants. Two specific cases are handled:
|
|
///
|
|
/// 1. When selecting between true and false based on a condition,
|
|
/// the operation simplifies to just the condition itself:
|
|
///
|
|
/// %0 = cir.select if %condition then true else false
|
|
/// ->
|
|
/// (replaced with %condition directly)
|
|
///
|
|
/// 2. When selecting between false and true based on a condition,
|
|
/// the operation simplifies to the logical negation of the condition:
|
|
///
|
|
/// %0 = cir.select if %condition then false else true
|
|
/// ->
|
|
/// %0 = cir.unary not %condition
|
|
struct SimplifySelect : public OpRewritePattern<SelectOp> {
|
|
using OpRewritePattern<SelectOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(SelectOp op,
|
|
PatternRewriter &rewriter) const final {
|
|
auto trueValueOp = op.getTrueValue().getDefiningOp<cir::ConstantOp>();
|
|
auto falseValueOp = op.getFalseValue().getDefiningOp<cir::ConstantOp>();
|
|
if (!trueValueOp || !falseValueOp)
|
|
return mlir::failure();
|
|
|
|
auto trueValue = trueValueOp.getValueAttr<cir::BoolAttr>();
|
|
auto falseValue = falseValueOp.getValueAttr<cir::BoolAttr>();
|
|
if (!trueValue || !falseValue)
|
|
return mlir::failure();
|
|
|
|
// cir.select if %0 then #true else #false -> %0
|
|
if (trueValue.getValue() && !falseValue.getValue()) {
|
|
rewriter.replaceAllUsesWith(op, op.getCondition());
|
|
rewriter.eraseOp(op);
|
|
return mlir::success();
|
|
}
|
|
|
|
// cir.select if %0 then #false else #true -> cir.unary not %0
|
|
if (!trueValue.getValue() && falseValue.getValue()) {
|
|
rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
|
|
op.getCondition());
|
|
return mlir::success();
|
|
}
|
|
|
|
return mlir::failure();
|
|
}
|
|
};
|
|
|
|
/// Simplify `cir.switch` operations by folding cascading cases
|
|
/// into a single `cir.case` with the `anyof` kind.
|
|
///
|
|
/// This pattern identifies cascading cases within a `cir.switch` operation.
|
|
/// Cascading cases are defined as consecutive `cir.case` operations of kind
|
|
/// `equal`, each containing a single `cir.yield` operation in their body.
|
|
///
|
|
/// The pattern merges these cascading cases into a single `cir.case` operation
|
|
/// with kind `anyof`, aggregating all the case values.
|
|
///
|
|
/// The merging process continues until a `cir.case` with a different body
|
|
/// (e.g., containing `cir.break` or compound stmt) is encountered, which
|
|
/// breaks the chain.
|
|
///
|
|
/// Example:
|
|
///
|
|
/// Before:
|
|
/// cir.case equal, [#cir.int<0> : !s32i] {
|
|
/// cir.yield
|
|
/// }
|
|
/// cir.case equal, [#cir.int<1> : !s32i] {
|
|
/// cir.yield
|
|
/// }
|
|
/// cir.case equal, [#cir.int<2> : !s32i] {
|
|
/// cir.break
|
|
/// }
|
|
///
|
|
/// After applying SimplifySwitch:
|
|
/// cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
|
|
/// !s32i] {
|
|
/// cir.break
|
|
/// }
|
|
struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
|
|
using OpRewritePattern<SwitchOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(SwitchOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
|
|
LogicalResult changed = mlir::failure();
|
|
SmallVector<CaseOp, 8> cases;
|
|
SmallVector<CaseOp, 4> cascadingCases;
|
|
SmallVector<mlir::Attribute, 4> cascadingCaseValues;
|
|
|
|
op.collectCases(cases);
|
|
if (cases.empty())
|
|
return mlir::failure();
|
|
|
|
auto flushMergedOps = [&]() {
|
|
for (CaseOp &c : cascadingCases)
|
|
rewriter.eraseOp(c);
|
|
cascadingCases.clear();
|
|
cascadingCaseValues.clear();
|
|
};
|
|
|
|
auto mergeCascadingInto = [&](CaseOp &target) {
|
|
rewriter.modifyOpInPlace(target, [&]() {
|
|
target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
|
|
target.setKind(CaseOpKind::Anyof);
|
|
});
|
|
changed = mlir::success();
|
|
};
|
|
|
|
for (CaseOp c : cases) {
|
|
cir::CaseOpKind kind = c.getKind();
|
|
if (kind == cir::CaseOpKind::Equal &&
|
|
isa<YieldOp>(c.getCaseRegion().front().front())) {
|
|
// If the case contains only a YieldOp, collect it for cascading merge
|
|
cascadingCases.push_back(c);
|
|
cascadingCaseValues.push_back(c.getValue()[0]);
|
|
} else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
|
|
// merge previously collected cascading cases
|
|
cascadingCaseValues.push_back(c.getValue()[0]);
|
|
mergeCascadingInto(c);
|
|
flushMergedOps();
|
|
} else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
|
|
// If a Default, Anyof or Range case is found and there are previous
|
|
// cascading cases, merge all of them into the last cascading case.
|
|
// We don't currently fold case range statements with other case
|
|
// statements.
|
|
assert(!cir::MissingFeatures::foldRangeCase());
|
|
CaseOp lastCascadingCase = cascadingCases.back();
|
|
mergeCascadingInto(lastCascadingCase);
|
|
cascadingCases.pop_back();
|
|
flushMergedOps();
|
|
} else {
|
|
cascadingCases.clear();
|
|
cascadingCaseValues.clear();
|
|
}
|
|
}
|
|
|
|
// Edge case: all cases are simple cascading cases
|
|
if (cascadingCases.size() == cases.size()) {
|
|
CaseOp lastCascadingCase = cascadingCases.back();
|
|
mergeCascadingInto(lastCascadingCase);
|
|
cascadingCases.pop_back();
|
|
flushMergedOps();
|
|
}
|
|
|
|
return changed;
|
|
}
|
|
};
|
|
|
|
struct SimplifyVecSplat : public OpRewritePattern<VecSplatOp> {
|
|
using OpRewritePattern<VecSplatOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(VecSplatOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
mlir::Value splatValue = op.getValue();
|
|
auto constant = splatValue.getDefiningOp<cir::ConstantOp>();
|
|
if (!constant)
|
|
return mlir::failure();
|
|
|
|
auto value = constant.getValue();
|
|
if (!mlir::isa_and_nonnull<cir::IntAttr>(value) &&
|
|
!mlir::isa_and_nonnull<cir::FPAttr>(value))
|
|
return mlir::failure();
|
|
|
|
cir::VectorType resultType = op.getResult().getType();
|
|
SmallVector<mlir::Attribute, 16> elements(resultType.getSize(), value);
|
|
auto constVecAttr = cir::ConstVectorAttr::get(
|
|
resultType, mlir::ArrayAttr::get(getContext(), elements));
|
|
|
|
rewriter.replaceOpWithNewOp<cir::ConstantOp>(op, constVecAttr);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CIRSimplifyPass
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {
|
|
using CIRSimplifyBase::CIRSimplifyBase;
|
|
|
|
void runOnOperation() override;
|
|
};
|
|
|
|
void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
|
|
// clang-format off
|
|
patterns.add<
|
|
SimplifyTernary,
|
|
SimplifySelect,
|
|
SimplifySwitch,
|
|
SimplifyVecSplat
|
|
>(patterns.getContext());
|
|
// clang-format on
|
|
}
|
|
|
|
void CIRSimplifyPass::runOnOperation() {
|
|
// Collect rewrite patterns.
|
|
RewritePatternSet patterns(&getContext());
|
|
populateMergeCleanupPatterns(patterns);
|
|
|
|
// Collect operations to apply patterns.
|
|
llvm::SmallVector<Operation *, 16> ops;
|
|
getOperation()->walk([&](Operation *op) {
|
|
if (isa<TernaryOp, SelectOp, SwitchOp, VecSplatOp>(op))
|
|
ops.push_back(op);
|
|
});
|
|
|
|
// Apply patterns.
|
|
if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
|
|
signalPassFailure();
|
|
}
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
|
|
return std::make_unique<CIRSimplifyPass>();
|
|
}
|