[mlir][linalg][elementwise] Fold transpose into new elementwise (#130207)
Fold transpose into new elementwise Op which has affine-map attached. Will add broadcast folding in next diff.
This commit is contained in:
parent
be0215d745
commit
ecf4d995f6
@ -601,6 +601,17 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
|
|||||||
[{
|
[{
|
||||||
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
|
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
|
||||||
attributes, ElementwiseOp::getRegionBuilder());
|
attributes, ElementwiseOp::getRegionBuilder());
|
||||||
|
}]>,
|
||||||
|
|
||||||
|
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
|
||||||
|
"ElementwiseKindAttr":$kind,
|
||||||
|
"ArrayAttr":$indexingMaps,
|
||||||
|
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
|
||||||
|
[{
|
||||||
|
$_state.addAttribute("kind", kind);
|
||||||
|
$_state.addAttribute("indexing_maps", indexingMaps);
|
||||||
|
buildStructuredOp($_builder, $_state, std::nullopt, inputs, outputs,
|
||||||
|
attributes, ElementwiseOp::getRegionBuilder());
|
||||||
}]>
|
}]>
|
||||||
];
|
];
|
||||||
|
|
||||||
|
@ -99,6 +99,11 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
|
|||||||
let dependentDialects = ["linalg::LinalgDialect"];
|
let dependentDialects = ["linalg::LinalgDialect"];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
|
||||||
|
let summary = "Fold transform, broadcast and other ops into elementwise";
|
||||||
|
let dependentDialects = ["linalg::LinalgDialect"];
|
||||||
|
}
|
||||||
|
|
||||||
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
|
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
|
||||||
let summary = "Detensorize linalg ops";
|
let summary = "Detensorize linalg ops";
|
||||||
let dependentDialects = [];
|
let dependentDialects = [];
|
||||||
|
@ -1710,6 +1710,10 @@ void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);
|
|||||||
void populateLinalgGenericOpsSpecializationPatterns(
|
void populateLinalgGenericOpsSpecializationPatterns(
|
||||||
RewritePatternSet &patterns);
|
RewritePatternSet &patterns);
|
||||||
|
|
||||||
|
/// Populates `patterns` with patterns that fold operations like
|
||||||
|
/// `linalg.transform` into elementwise op map.
|
||||||
|
void populateLinalgFoldIntoElementwisePatterns(RewritePatternSet &patterns);
|
||||||
|
|
||||||
/// Linalg decompose convolutions patterns
|
/// Linalg decompose convolutions patterns
|
||||||
|
|
||||||
/// Populates patterns to decompose high-D convolution ops into low-D ones.
|
/// Populates patterns to decompose high-D convolution ops into low-D ones.
|
||||||
|
@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
|||||||
EliminateEmptyTensors.cpp
|
EliminateEmptyTensors.cpp
|
||||||
EraseUnusedOperandsAndResults.cpp
|
EraseUnusedOperandsAndResults.cpp
|
||||||
FoldAddIntoDest.cpp
|
FoldAddIntoDest.cpp
|
||||||
|
FoldIntoElementwise.cpp
|
||||||
FusePadOpWithLinalgProducer.cpp
|
FusePadOpWithLinalgProducer.cpp
|
||||||
Fusion.cpp
|
Fusion.cpp
|
||||||
Generalization.cpp
|
Generalization.cpp
|
||||||
|
89
mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
Normal file
89
mlir/lib/Dialect/Linalg/Transforms/FoldIntoElementwise.cpp
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
//===- FoldIntoElementwise.cpp - Fold Ops into elementwise if possible ---===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file implements folding ops such as transpose and broadcast into the
|
||||||
|
// affine maps of the elementwise op.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||||
|
#include "mlir/Dialect/Linalg/Passes.h"
|
||||||
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
#define GEN_PASS_DEF_LINALGFOLDINTOELEMENTWISEPASS
|
||||||
|
#include "mlir/Dialect/Linalg/Passes.h.inc"
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::linalg;
|
||||||
|
|
||||||
|
#define DEBUG_TYPE "linalg-fold-into-elementwise"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
struct FoldTransposePattern : public OpRewritePattern<ElementwiseOp> {
|
||||||
|
using OpRewritePattern<ElementwiseOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(ElementwiseOp op,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
bool changed = false;
|
||||||
|
SmallVector<Value> newIns;
|
||||||
|
SmallVector<AffineMap> newMaps;
|
||||||
|
for (OpOperand *operand : op.getDpsInputOperands()) {
|
||||||
|
AffineMap map = op.getMatchingIndexingMap(operand);
|
||||||
|
auto transposeOp = operand->get().getDefiningOp<TransposeOp>();
|
||||||
|
|
||||||
|
if (!map.isIdentity() || !transposeOp) {
|
||||||
|
// push in original operand and its map.
|
||||||
|
newIns.push_back(operand->get());
|
||||||
|
newMaps.push_back(map);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
newIns.push_back(transposeOp.getInput());
|
||||||
|
// push in transposeOp's inverse permutation map.
|
||||||
|
newMaps.push_back(transposeOp.getMatchingIndexingMap(
|
||||||
|
transposeOp.getDpsInputOperand(0)));
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
if (!changed)
|
||||||
|
return failure();
|
||||||
|
newMaps.push_back(op.getIndexingMapsArray().back());
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<ElementwiseOp>(
|
||||||
|
op, newIns, op.getDpsInits()[0], op.getKindAttr(),
|
||||||
|
rewriter.getAffineMapArrayAttr(newMaps));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct LinalgFoldIntoElementwisePass
|
||||||
|
: public impl::LinalgFoldIntoElementwisePassBase<
|
||||||
|
LinalgFoldIntoElementwisePass> {
|
||||||
|
using impl::LinalgFoldIntoElementwisePassBase<
|
||||||
|
LinalgFoldIntoElementwisePass>::LinalgFoldIntoElementwisePassBase;
|
||||||
|
|
||||||
|
void runOnOperation() override {
|
||||||
|
llvm::outs() << "Hellow from fold into elemenwise \n";
|
||||||
|
Operation *op = getOperation();
|
||||||
|
RewritePatternSet patterns(op->getContext());
|
||||||
|
populateLinalgFoldIntoElementwisePatterns(patterns);
|
||||||
|
|
||||||
|
if (failed(applyPatternsGreedily(op, std::move(patterns))))
|
||||||
|
return signalPassFailure();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
void mlir::linalg::populateLinalgFoldIntoElementwisePatterns(
|
||||||
|
RewritePatternSet &patterns) {
|
||||||
|
patterns.add<FoldTransposePattern>(patterns.getContext());
|
||||||
|
}
|
43
mlir/test/Dialect/Linalg/elementwise/fold.mlir
Normal file
43
mlir/test/Dialect/Linalg/elementwise/fold.mlir
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
// RUN: mlir-opt %s -linalg-fold-into-elementwise -split-input-file | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||||
|
// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
|
||||||
|
//
|
||||||
|
// CHECK: func.func @unary_transpose(%[[A:.+]]: tensor<16x8x32xf32>, %[[B:.+]]: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
|
||||||
|
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<exp>
|
||||||
|
// CHECK-SAME: indexing_maps = [#[[TRANSPOSED]], #[[IDENTITY]]]
|
||||||
|
// CHECK-SAME: ins(%[[A]] : tensor<16x8x32xf32>) outs(%[[B]] : tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
|
||||||
|
// CHECK-NEXT: return %[[RES]] : tensor<8x16x32xf32>
|
||||||
|
//
|
||||||
|
func.func @unary_transpose(%A : tensor<16x8x32xf32>, %B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> {
|
||||||
|
%empty = tensor.empty() : tensor<8x16x32xf32>
|
||||||
|
%transposed_A = linalg.transpose ins(%A : tensor<16x8x32xf32>) outs(%empty : tensor<8x16x32xf32>) permutation = [1, 0, 2]
|
||||||
|
%result = linalg.elementwise kind=#linalg.elementwise_kind<exp>
|
||||||
|
ins(%transposed_A : tensor<8x16x32xf32>) outs(%B: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
|
||||||
|
return %result : tensor<8x16x32xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-DAG: #[[IDENTITY:.+]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||||
|
// CHECK-DAG: #[[TRANSPOSED:.+]] = affine_map<(d0, d1) -> (d1, d0)>
|
||||||
|
//
|
||||||
|
// CHECK: func.func @binary_transposed(%[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
|
// CHECK-NEXT: %[[RES:.+]] = linalg.elementwise kind=#linalg.elementwise_kind<add>
|
||||||
|
// CHECK-SAME: indexing_maps = [#[[IDENTITY]], #[[TRANSPOSED]], #[[IDENTITY]]]
|
||||||
|
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[C]] : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
// CHECK-NEXT: return %[[RES]] : tensor<?x?xf32>
|
||||||
|
//
|
||||||
|
func.func @binary_transposed(%A : tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%dim0 = tensor.dim %A, %c0 : tensor<?x?xf32>
|
||||||
|
%dim1 = tensor.dim %A, %c1 : tensor<?x?xf32>
|
||||||
|
|
||||||
|
%empty = tensor.empty(%dim1, %dim0) : tensor<?x?xf32>
|
||||||
|
%transposed_B = linalg.transpose ins(%B : tensor<?x?xf32>) outs(%empty : tensor<?x?xf32>) permutation = [1, 0]
|
||||||
|
%result = linalg.elementwise kind=#linalg.elementwise_kind<add>
|
||||||
|
ins(%A, %transposed_B : tensor<?x?xf32>, tensor<?x?xf32>)
|
||||||
|
outs(%C: tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||||
|
return %result : tensor<?x?xf32>
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user