[mlir][linalg] Data layout propagation test schedule (#184151)

Replaces data layout propagation test pass with equivalent transform
schedule.

The two required patterns are wrapped in new linalg transform ops
improving reusability.
This commit is contained in:
Adam Siemieniuk 2026-03-04 08:36:29 +01:00 committed by GitHub
parent 6fae863eba
commit 6b59ad6e8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 58 additions and 61 deletions

View File

@ -143,6 +143,28 @@ def ApplyFoldPackUnpackIntoEmptyPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
def ApplyDataLayoutPropagationPatternsOp : Op<Transform_Dialect,
"apply_patterns.linalg.data_layout_propagation",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collection of patterns to bubble up or down data layout ops across other
operations.
}];
let arguments = (ins DefaultValuedAttr<BoolAttr, "false">:$poison_padding);
let assemblyFormat = "attr-dict";
}
def ApplyExtractSliceSinkingPatternsOp : Op<Transform_Dialect,
"apply_patterns.linalg.extract_slice_sinking",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Patterns to sink extract slice across other operations.
}];
let assemblyFormat = "attr-dict";
}
//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//

View File

@ -272,6 +272,26 @@ void transform::ApplyFoldPackUnpackIntoEmptyPatternsOp::populatePatterns(
linalg::populateFoldPackUnpackIntoTensorEmptyPatterns(patterns);
}
void transform::ApplyDataLayoutPropagationPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
linalg::ControlPropagationFn defaultControlFn = [](OpOperand *operand) {
return true;
};
linalg::populateDataLayoutPropagationPatterns(patterns, defaultControlFn,
getPoisonPadding());
}
void transform::ApplyExtractSliceSinkingPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
linalg::ControlPropagationFn defaultControlFn =
[](OpOperand *opOperand) -> bool {
Operation *producer = opOperand->get().getDefiningOp();
Operation *consumer = opOperand->getOwner();
return consumer->getBlock() == producer->getBlock();
};
linalg::populateExtractSliceSinkingPatterns(patterns, defaultControlFn);
}
//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//

View File

@ -1,4 +1,7 @@
// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file | FileCheck %s
// RUN: mlir-opt %s -split-input-file \
// RUN: -transform-preload-library='transform-library-paths=%p/td/propagate-data-layout.mlir' \
// RUN: -transform-interpreter=entry-point=propagate_data_layout \
// RUN: | FileCheck %s
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func.func @dynamic_elem_pack(%arg0: tensor<?x?xf32>, %dest: tensor<?x?x8x2xf32>) -> tensor<?x?x8x2xf32>

View File

@ -0,0 +1,12 @@
module @transforms attributes { transform.with_named_sequence } {
transform.named_sequence @propagate_data_layout(%module: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %func {
transform.apply_patterns.linalg.data_layout_propagation {poison_padding = true}
transform.apply_patterns.linalg.extract_slice_sinking
} : !transform.any_op
transform.yield
}
}

View File

@ -1,6 +1,5 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRLinalgTestPasses
TestDataLayoutPropagation.cpp
TestLinalgDecomposeOps.cpp
TestLinalgDropUnitDims.cpp
TestLinalgElementwiseFusion.cpp

View File

@ -1,57 +0,0 @@
//===- TestDataLayoutPropagation.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
namespace {
struct TestDataLayoutPropagationPass
: public PassWrapper<TestDataLayoutPropagationPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDataLayoutPropagationPass)
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<affine::AffineDialect, linalg::LinalgDialect,
tensor::TensorDialect>();
}
StringRef getArgument() const final {
return "test-linalg-data-layout-propagation";
}
StringRef getDescription() const final {
return "Test data layout propagation";
}
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
linalg::populateDataLayoutPropagationPatterns(
patterns, [](OpOperand *opOperand) { return true; },
/*poisonPaddingOk=*/true);
linalg::ControlPropagationFn controlExtract =
[](OpOperand *opOperand) -> bool {
Operation *producer = opOperand->get().getDefiningOp();
Operation *consumer = opOperand->getOwner();
return consumer->getBlock() == producer->getBlock();
};
linalg::populateExtractSliceSinkingPatterns(patterns, controlExtract);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
namespace mlir {
namespace test {
void registerTestDataLayoutPropagation() {
PassRegistration<TestDataLayoutPropagationPass>();
}
} // namespace test
} // namespace mlir

View File

@ -89,7 +89,6 @@ void registerTestComposeSubView();
void registerTestCompositePass();
void registerTestControlFlowSink();
void registerTestConvertToSPIRVPass();
void registerTestDataLayoutPropagation();
void registerTestDataLayoutQuery();
void registerTestDeadCodeAnalysisPass();
void registerTestDecomposeCallGraphTypes();
@ -238,7 +237,6 @@ static void registerTestPasses() {
mlir::test::registerTestCompositePass();
mlir::test::registerTestControlFlowSink();
mlir::test::registerTestConvertToSPIRVPass();
mlir::test::registerTestDataLayoutPropagation();
mlir::test::registerTestDataLayoutQuery();
mlir::test::registerTestDeadCodeAnalysisPass();
mlir::test::registerTestDecomposeCallGraphTypes();