[mlir][linalg] Data layout propagation test schedule (#184151)
Replaces data layout propagation test pass with equivalent transform schedule. The two required patterns are wrapped in new linalg transform ops improving reusability.
This commit is contained in:
parent
6fae863eba
commit
6b59ad6e8d
@ -143,6 +143,28 @@ def ApplyFoldPackUnpackIntoEmptyPatternsOp : Op<Transform_Dialect,
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def ApplyDataLayoutPropagationPatternsOp : Op<Transform_Dialect,
|
||||
"apply_patterns.linalg.data_layout_propagation",
|
||||
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
||||
let description = [{
|
||||
Collection of patterns to bubble up or down data layout ops across other
|
||||
operations.
|
||||
}];
|
||||
|
||||
let arguments = (ins DefaultValuedAttr<BoolAttr, "false">:$poison_padding);
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def ApplyExtractSliceSinkingPatternsOp : Op<Transform_Dialect,
|
||||
"apply_patterns.linalg.extract_slice_sinking",
|
||||
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
||||
let description = [{
|
||||
Patterns to sink extract slice across other operations.
|
||||
}];
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferizeToAllocationOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -272,6 +272,26 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
|
||||
linalg::populateFoldPackUnpackIntoTensorEmptyPatterns(patterns);
|
||||
}
|
||||
|
||||
void transform::ApplyDataLayoutPropagationPatternsOp::populatePatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
linalg::ControlPropagationFn defaultControlFn = [](OpOperand *operand) {
|
||||
return true;
|
||||
};
|
||||
linalg::populateDataLayoutPropagationPatterns(patterns, defaultControlFn,
|
||||
getPoisonPadding());
|
||||
}
|
||||
|
||||
void transform::ApplyExtractSliceSinkingPatternsOp::populatePatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
linalg::ControlPropagationFn defaultControlFn =
|
||||
[](OpOperand *opOperand) -> bool {
|
||||
Operation *producer = opOperand->get().getDefiningOp();
|
||||
Operation *consumer = opOperand->getOwner();
|
||||
return consumer->getBlock() == producer->getBlock();
|
||||
};
|
||||
linalg::populateExtractSliceSinkingPatterns(patterns, defaultControlFn);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferizeToAllocationOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -split-input-file \
|
||||
// RUN: -transform-preload-library='transform-library-paths=%p/td/propagate-data-layout.mlir' \
|
||||
// RUN: -transform-interpreter=entry-point=propagate_data_layout \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
func.func @dynamic_elem_pack(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>
|
||||
|
||||
12
mlir/test/Dialect/Linalg/td/propagate-data-layout.mlir
Normal file
12
mlir/test/Dialect/Linalg/td/propagate-data-layout.mlir
Normal file
@ -0,0 +1,12 @@
|
||||
module @transforms attributes { transform.with_named_sequence } {
|
||||
transform.named_sequence @propagate_data_layout(%module: !transform.any_op {transform.readonly}) {
|
||||
%func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
|
||||
|
||||
transform.apply_patterns to %func {
|
||||
transform.apply_patterns.linalg.data_layout_propagation {poison_padding = true}
|
||||
transform.apply_patterns.linalg.extract_slice_sinking
|
||||
} : !transform.any_op
|
||||
|
||||
transform.yield
|
||||
}
|
||||
}
|
||||
@ -1,6 +1,5 @@
|
||||
# Exclude tests from libMLIR.so
|
||||
add_mlir_library(MLIRLinalgTestPasses
|
||||
TestDataLayoutPropagation.cpp
|
||||
TestLinalgDecomposeOps.cpp
|
||||
TestLinalgDropUnitDims.cpp
|
||||
TestLinalgElementwiseFusion.cpp
|
||||
|
||||
@ -1,57 +0,0 @@
|
||||
//===- TestDataLayoutPropagation.cpp --------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct TestDataLayoutPropagationPass
|
||||
: public PassWrapper<TestDataLayoutPropagationPass, OperationPass<>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDataLayoutPropagationPass)
|
||||
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<affine::AffineDialect, linalg::LinalgDialect,
|
||||
tensor::TensorDialect>();
|
||||
}
|
||||
|
||||
StringRef getArgument() const final {
|
||||
return "test-linalg-data-layout-propagation";
|
||||
}
|
||||
StringRef getDescription() const final {
|
||||
return "Test data layout propagation";
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
linalg::populateDataLayoutPropagationPatterns(
|
||||
patterns, [](OpOperand *opOperand) { return true; },
|
||||
/*poisonPaddingOk=*/true);
|
||||
linalg::ControlPropagationFn controlExtract =
|
||||
[](OpOperand *opOperand) -> bool {
|
||||
Operation *producer = opOperand->get().getDefiningOp();
|
||||
Operation *consumer = opOperand->getOwner();
|
||||
return consumer->getBlock() == producer->getBlock();
|
||||
};
|
||||
linalg::populateExtractSliceSinkingPatterns(patterns, controlExtract);
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace test {
|
||||
void registerTestDataLayoutPropagation() {
|
||||
PassRegistration<TestDataLayoutPropagationPass>();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace mlir
|
||||
@ -89,7 +89,6 @@ void registerTestComposeSubView();
|
||||
void registerTestCompositePass();
|
||||
void registerTestControlFlowSink();
|
||||
void registerTestConvertToSPIRVPass();
|
||||
void registerTestDataLayoutPropagation();
|
||||
void registerTestDataLayoutQuery();
|
||||
void registerTestDeadCodeAnalysisPass();
|
||||
void registerTestDecomposeCallGraphTypes();
|
||||
@ -238,7 +237,6 @@ static void registerTestPasses() {
|
||||
mlir::test::registerTestCompositePass();
|
||||
mlir::test::registerTestControlFlowSink();
|
||||
mlir::test::registerTestConvertToSPIRVPass();
|
||||
mlir::test::registerTestDataLayoutPropagation();
|
||||
mlir::test::registerTestDataLayoutQuery();
|
||||
mlir::test::registerTestDeadCodeAnalysisPass();
|
||||
mlir::test::registerTestDecomposeCallGraphTypes();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user