llvm-project/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp
River Riddle 3655069234 [mlir] Move the Builtin FuncOp to the Func dialect
This commit moves FuncOp out of the builtin dialect, and into the Func
dialect. This move has been planned in some capacity from the moment
we made FuncOp an operation (years ago). This commit handles the
functional aspects of the move, but various aspects are left untouched
to ease migration: func::FuncOp is re-exported into mlir to reduce
the actual API churn, the assembly format still accepts the unqualified
`func`. These temporary measures will remain for a little while to
simplify migration before being removed.

Differential Revision: https://reviews.llvm.org/D121266
2022-03-16 17:07:03 -07:00

178 lines
6.5 KiB
C++

//===- TestLinalgElementwiseFusion.cpp - Test Linalg elementwise fusion ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass for testing fusion of elementwise operations in
// Linalg, mainly linalg options.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
static void addOperands(Operation *op, SetVector<Value> &operandSet) {
if (!op)
return;
TypeSwitch<Operation *, void>(op)
.Case<linalg::LinalgOp>([&](linalg::LinalgOp linalgOp) {
SmallVector<Value> inputOperands = linalgOp.getInputOperands();
operandSet.insert(inputOperands.begin(), inputOperands.end());
})
.Default([&](Operation *operation) {
operandSet.insert(operation->operand_begin(), operation->operand_end());
});
}
template <int limit = 3>
static bool setFusedOpOperandLimit(const OpResult &producer,
const OpOperand &consumer) {
SetVector<Value> fusedOpOperands;
if (producer.getOwner()->getNumResults() != 1)
return false;
addOperands(consumer.getOwner(), fusedOpOperands);
fusedOpOperands.remove(producer);
addOperands(producer.getOwner(), fusedOpOperands);
return fusedOpOperands.size() <= limit;
}
namespace {
struct TestLinalgElementwiseFusion
: public PassWrapper<TestLinalgElementwiseFusion, OperationPass<FuncOp>> {
TestLinalgElementwiseFusion() = default;
TestLinalgElementwiseFusion(const TestLinalgElementwiseFusion &pass)
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
tensor::TensorDialect>();
}
StringRef getArgument() const final {
return "test-linalg-elementwise-fusion-patterns";
}
StringRef getDescription() const final {
return "Test Linalg element wise operation fusion patterns";
}
Option<bool> fuseGenericOps{
*this, "fuse-generic-ops",
llvm::cl::desc("Test fusion of generic operations."),
llvm::cl::init(false)};
Option<bool> controlFuseByExpansion{
*this, "control-fusion-by-expansion",
llvm::cl::desc(
"Test controlling fusion of reshape with generic op by expansion"),
llvm::cl::init(false)};
Option<bool> pushExpandingReshape{
*this, "push-expanding-reshape",
llvm::cl::desc("Test linalg expand_shape -> generic "
"to generic -> expand_shape pattern"),
llvm::cl::init(false)};
Option<bool> fuseWithReshapeByCollapsing{
*this, "fuse-with-reshape-by-collapsing",
llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that "
"collapse the iteration space of the consumer"),
llvm::cl::init(false)};
Option<bool> fuseWithReshapeByCollapsingWithControlFn{
*this, "fuse-with-reshape-by-collapsing-control",
llvm::cl::desc("Test controlling the linalg expand_shape -> generic "
"fusion patterns that "
"collapse the iteration space of the consumer"),
llvm::cl::init(false)};
void runOnOperation() override {
MLIRContext *context = &this->getContext();
FuncOp funcOp = this->getOperation();
if (fuseGenericOps) {
RewritePatternSet fusionPatterns(context);
linalg::populateElementwiseOpsFusionPatterns(
fusionPatterns,
linalg::LinalgElementwiseFusionOptions()
.setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
return;
}
if (controlFuseByExpansion) {
RewritePatternSet fusionPatterns(context);
linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
[](const OpResult &producer, OpOperand &consumer) {
if (auto collapseOp =
producer.getDefiningOp<tensor::CollapseShapeOp>()) {
if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
return false;
}
}
if (auto expandOp =
dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
if (expandOp->hasOneUse()) {
OpOperand &use = *expandOp->getUses().begin();
auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
if (linalgOp && linalgOp.isOutputTensor(&use))
return true;
}
}
return linalg::skipUnitDimReshape(producer, consumer);
};
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
controlReshapeFusionFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
return;
}
if (pushExpandingReshape) {
RewritePatternSet patterns(context);
linalg::populatePushReshapeOpsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
if (fuseWithReshapeByCollapsing) {
RewritePatternSet patterns(context);
linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
if (fuseWithReshapeByCollapsingWithControlFn) {
RewritePatternSet patterns(context);
linalg::ControlElementwiseOpsFusionFn controlFn =
[](const OpResult &producer, OpOperand &consumer) -> bool {
if (isa<tensor::ExpandShapeOp>(producer.getDefiningOp())) {
// Skip fusing the first operand.
return consumer.getOperandNumber();
}
return true;
};
linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
}
};
} // namespace
namespace test {
void registerTestLinalgElementwiseFusion() {
PassRegistration<TestLinalgElementwiseFusion>();
}
} // namespace test
} // namespace mlir