[MLIR][Linalg] Rename convolution pass (#154400)
Rename the pass `LinalgNamedOpConversionPass` to `SimplifyDepthwiseConvPass` to avoid conflating it with the new morphisms we are creating between the norms.
This commit is contained in:
parent
a53e73e6ef
commit
32a5adbd42
@ -86,13 +86,13 @@ def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops">,
|
|||||||
let dependentDialects = ["linalg::LinalgDialect"];
|
let dependentDialects = ["linalg::LinalgDialect"];
|
||||||
}
|
}
|
||||||
|
|
||||||
def LinalgNamedOpConversionPass: Pass<"linalg-named-op-conversion"> {
|
// ------------------ End of "form" conversions
|
||||||
let summary = "Convert from one named linalg op to another.";
|
|
||||||
|
def SimplifyDepthwiseConvPass: Pass<"simplify-depthwise-conv"> {
|
||||||
|
let summary = "Simplify depthwise convolution.";
|
||||||
let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
|
let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
|
||||||
}
|
}
|
||||||
|
|
||||||
// ------------------ End of "form" conversions
|
|
||||||
|
|
||||||
def ConvertElementwiseToLinalgPass : Pass<"convert-elementwise-to-linalg", ""> {
|
def ConvertElementwiseToLinalgPass : Pass<"convert-elementwise-to-linalg", ""> {
|
||||||
let summary = "Convert ElementwiseMappable ops to linalg";
|
let summary = "Convert ElementwiseMappable ops to linalg";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -1962,9 +1962,8 @@ void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns);
|
|||||||
void populateFuseTensorPadWithProducerLinalgOpPatterns(
|
void populateFuseTensorPadWithProducerLinalgOpPatterns(
|
||||||
RewritePatternSet &patterns);
|
RewritePatternSet &patterns);
|
||||||
|
|
||||||
/// Patterns to convert from one named op to another. These can be seen as
|
/// Patterns to simplify depthwise convolutions.
|
||||||
/// canonicalizations of named ops into another named op.
|
void populateSimplifyDepthwiseConvPatterns(RewritePatternSet &patterns);
|
||||||
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
|
|
||||||
|
|
||||||
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
|
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
|
||||||
/// tensors via reassociative reshape ops.
|
/// tensors via reassociative reshape ops.
|
||||||
|
@ -26,7 +26,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
|||||||
MorphOps.cpp
|
MorphOps.cpp
|
||||||
TransposeMatmul.cpp
|
TransposeMatmul.cpp
|
||||||
ShardingInterfaceImpl.cpp
|
ShardingInterfaceImpl.cpp
|
||||||
NamedOpConversions.cpp
|
SimplifyDepthwiseConv.cpp
|
||||||
NamedToElementwise.cpp
|
NamedToElementwise.cpp
|
||||||
BlockPackMatmul.cpp
|
BlockPackMatmul.cpp
|
||||||
PackAndUnpackPatterns.cpp
|
PackAndUnpackPatterns.cpp
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
#define GEN_PASS_DEF_LINALGNAMEDOPCONVERSIONPASS
|
#define GEN_PASS_DEF_SIMPLIFYDEPTHWISECONVPASS
|
||||||
#include "mlir/Dialect/Linalg/Passes.h.inc"
|
#include "mlir/Dialect/Linalg/Passes.h.inc"
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
@ -143,23 +143,22 @@ struct SimplifyDepthwiseConvQOp
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LinalgNamedOpConversionPass
|
struct SimplifyDepthwiseConvPass
|
||||||
: public impl::LinalgNamedOpConversionPassBase<
|
: public impl::SimplifyDepthwiseConvPassBase<SimplifyDepthwiseConvPass> {
|
||||||
LinalgNamedOpConversionPass> {
|
using impl::SimplifyDepthwiseConvPassBase<
|
||||||
using impl::LinalgNamedOpConversionPassBase<
|
SimplifyDepthwiseConvPass>::SimplifyDepthwiseConvPassBase;
|
||||||
LinalgNamedOpConversionPass>::LinalgNamedOpConversionPassBase;
|
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
Operation *op = getOperation();
|
Operation *op = getOperation();
|
||||||
RewritePatternSet patterns(op->getContext());
|
RewritePatternSet patterns(op->getContext());
|
||||||
populateLinalgNamedOpConversionPatterns(patterns);
|
populateSimplifyDepthwiseConvPatterns(patterns);
|
||||||
if (failed(applyPatternsGreedily(op, std::move(patterns))))
|
if (failed(applyPatternsGreedily(op, std::move(patterns))))
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void mlir::linalg::populateLinalgNamedOpConversionPatterns(
|
void mlir::linalg::populateSimplifyDepthwiseConvPatterns(
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
|
patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
|
||||||
patterns.getContext());
|
patterns.getContext());
|
@ -1,4 +1,4 @@
|
|||||||
// RUN: mlir-opt %s -linalg-named-op-conversion -split-input-file | FileCheck %s
|
// RUN: mlir-opt %s --simplify-depthwise-conv -split-input-file | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: @depthwise_conv
|
// CHECK-LABEL: @depthwise_conv
|
||||||
func.func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>, %arg2: tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32> {
|
func.func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>, %arg2: tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32> {
|
Loading…
x
Reference in New Issue
Block a user