In the past, it was hard to set padding values because we did not have ub.poison. It is not always correct if we set zeros as padding values. Now we can use `ub.poison` in this case. The revision adds the support for setting padding value using `ub.poison` when padding is required in the propagation. Otherwise, it creates an invalid pack op. Additionally the revision adds a control option for allowing padding in the pattern which is false by default. To correctly do this, a new `requirePaddingValueStrict` method is added which assumes dynamic dims would mean padding is required. The revision also removes trailing white space in the lit test file. Co-authored-by : Nirvedh Meshram <nirvedh@gmail.com> --------- Signed-off-by: hanhanW <hanhan0912@gmail.com> Signed-off-by: Nirvedh Meshram <nirvedh@gmail.com> Co-authored-by: Nirvedh Meshram <nirvedh@gmail.com>
58 lines
2.0 KiB
C++
58 lines
2.0 KiB
C++
//===- TestDataLayoutPropagation.cpp --------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
struct TestDataLayoutPropagationPass
|
|
: public PassWrapper<TestDataLayoutPropagationPass, OperationPass<>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDataLayoutPropagationPass)
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<affine::AffineDialect, linalg::LinalgDialect,
|
|
tensor::TensorDialect>();
|
|
}
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-linalg-data-layout-propagation";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test data layout propagation";
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
MLIRContext *context = &getContext();
|
|
RewritePatternSet patterns(context);
|
|
linalg::populateDataLayoutPropagationPatterns(
|
|
patterns, [](OpOperand *opOperand) { return true; },
|
|
/*poisonPaddingOk=*/true);
|
|
linalg::ControlPropagationFn controlExtract =
|
|
[](OpOperand *opOperand) -> bool {
|
|
Operation *producer = opOperand->get().getDefiningOp();
|
|
Operation *consumer = opOperand->getOwner();
|
|
return consumer->getBlock() == producer->getBlock();
|
|
};
|
|
linalg::populateExtractSliceSinkingPatterns(patterns, controlExtract);
|
|
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestDataLayoutPropagation() {
|
|
PassRegistration<TestDataLayoutPropagationPass>();
|
|
}
|
|
} // namespace test
|
|
} // namespace mlir
|