
These are identified by misc-include-cleaner. I've filtered out those that break builds. Also, I'm staying away from llvm-config.h, config.h, and Compiler.h, which likely cause platform- or compiler-specific build failures.
88 lines
3.0 KiB
C++
88 lines
3.0 KiB
C++
//===- FoldIntoElementwise.cpp - Fold Ops into elementwise if possible ---===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements folding ops such as transpose and broadcast into the
|
|
// affine maps of the elementwise op.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_LINALGFOLDINTOELEMENTWISEPASS
|
|
#include "mlir/Dialect/Linalg/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
#define DEBUG_TYPE "linalg-fold-into-elementwise"
|
|
|
|
namespace {
|
|
struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
|
|
using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(ElementwiseOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
bool changed = false;
|
|
SmallVector<Value> newIns;
|
|
SmallVector<AffineMap> newMaps;
|
|
for (OpOperand *operand : op.getDpsInputOperands()) {
|
|
AffineMap map = op.getMatchingIndexingMap(operand);
|
|
auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
|
|
|
|
if (!map.isIdentity() || !transposeOp) {
|
|
// push in original operand and its map.
|
|
newIns.push_back(operand->get());
|
|
newMaps.push_back(map);
|
|
continue;
|
|
}
|
|
newIns.push_back(transposeOp.getInput());
|
|
// push in transposeOp's inverse permutation map.
|
|
newMaps.push_back(transposeOp.getMatchingIndexingMap(
|
|
transposeOp.getDpsInputOperand(0)));
|
|
changed = true;
|
|
}
|
|
if (!changed)
|
|
return failure();
|
|
newMaps.push_back(op.getIndexingMapsArray().back());
|
|
|
|
rewriter.replaceOpWithNewOp<ElementwiseOp>(
|
|
op, newIns, op.getDpsInits()[0], op.getKindAttr(),
|
|
rewriter.getAffineMapArrayAttr(newMaps));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct LinalgFoldIntoElementwisePass
|
|
: public impl::LinalgFoldIntoElementwisePassBase<
|
|
LinalgFoldIntoElementwisePass> {
|
|
using impl::LinalgFoldIntoElementwisePassBase<
|
|
LinalgFoldIntoElementwisePass>::LinalgFoldIntoElementwisePassBase;
|
|
|
|
void runOnOperation() override {
|
|
Operation *op = getOperation();
|
|
RewritePatternSet patterns(op->getContext());
|
|
populateLinalgFoldIntoElementwisePatterns(patterns);
|
|
|
|
if (failed(applyPatternsGreedily(op, std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
|
|
RewritePatternSet &patterns) {
|
|
patterns.add<FoldTransposePattern>(patterns.getContext());
|
|
}
|