llvm-project/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp
Han-Chung Wang 6f58c16c49
[mlir][linalg] Use ub.poison in data layout propagation if a packed operand requires padding. (#159467)
In the past, it was hard to set padding values because we did not have
ub.poison. It is not always correct if we set zeros as padding values.
Now we can use `ub.poison` in this case. The revision adds the support
for setting padding value using `ub.poison` when padding is required in
the propagation. Otherwise, it creates an invalid pack op.

Additionally the revision adds a control option for allowing padding in
the pattern which is false by default. To correctly do this, a new
`requirePaddingValueStrict` method is added which assumes dynamic dims
would mean padding is required.

The revision also removes trailing white space in the lit test file.

Co-authored-by : Nirvedh Meshram <nirvedh@gmail.com>

---------

Signed-off-by: hanhanW <hanhan0912@gmail.com>
Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com>
Co-authored-by: Nirvedh Meshram <nirvedh@gmail.com>
2025-09-24 16:27:30 -05:00

58 lines
2.0 KiB
C++

//===- TestDataLayoutPropagation.cpp --------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
namespace {
struct TestDataLayoutPropagationPass
: public PassWrapper<TestDataLayoutPropagationPass, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDataLayoutPropagationPass)
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<affine::AffineDialect, linalg::LinalgDialect,
tensor::TensorDialect>();
}
StringRef getArgument() const final {
return "test-linalg-data-layout-propagation";
}
StringRef getDescription() const final {
return "Test data layout propagation";
}
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
linalg::populateDataLayoutPropagationPatterns(
patterns, [](OpOperand *opOperand) { return true; },
/*poisonPaddingOk=*/true);
linalg::ControlPropagationFn controlExtract =
[](OpOperand *opOperand) -> bool {
Operation *producer = opOperand->get().getDefiningOp();
Operation *consumer = opOperand->getOwner();
return consumer->getBlock() == producer->getBlock();
};
linalg::populateExtractSliceSinkingPatterns(patterns, controlExtract);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
namespace mlir {
namespace test {
void registerTestDataLayoutPropagation() {
PassRegistration<TestDataLayoutPropagationPass>();
}
} // namespace test
} // namespace mlir