This patch fixes a bug discovered in the
`affine::makeComposedFoldedAffineApply` function when `composeAffineMin
== true`. The bug happened because the simplification assumed the
symbols appearing in the `affine.apply` op corresponded to symbols in
the `affine.min` op, and that's not always the case. For example:
```mlir
#map = affine_map<()[s0, s1] -> (s1)>
#map1 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
module {
func.func @min_max_full_simplify() -> index {
%0 = test.value_with_bounds {max = 64 : index, min = 32 : index}
%1 = test.value_with_bounds {max = 64 : index, min = 32 : index}
%2 = affine.min #map()[%0, %1]
%3 = affine.apply #map1()[%2, %0]
return %3 : index
}
}
```
This patch also introduces the test `make_composed_folded_affine_apply`
transform operation to test this simplification. It also adds tests
ensuring we get correct behavior.
---------
Co-authored-by: Nicolas Vasilache <nico.vasilache@amd.com>
113 lines
4.1 KiB
C++
113 lines
4.1 KiB
C++
//===- TestTransformsOps.cpp - Test Transforms ----------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file defines transform dialect operations for testing MLIR
|
|
// transformations
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
|
|
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
|
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
|
#include "mlir/IR/OpDefinition.h"
|
|
#include "mlir/Transforms/RegionUtils.h"
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "TestTransformsOps.h.inc"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::transform;
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "TestTransformsOps.cpp.inc"
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
|
|
TransformResults &TransformResults,
|
|
TransformState &state) {
|
|
Operation *op = *state.getPayloadOps(getOp()).begin();
|
|
Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
|
|
if (failed(moveOperationDependencies(rewriter, op, moveBefore))) {
|
|
auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
|
|
std::string errorMsg = listener->getLatestMatchFailureMessage();
|
|
(void)emitRemark(errorMsg);
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
DiagnosedSilenceableFailure
|
|
transform::TestMoveValueDefns::apply(TransformRewriter &rewriter,
|
|
TransformResults &TransformResults,
|
|
TransformState &state) {
|
|
SmallVector<Value> values;
|
|
for (auto tdValue : getValues()) {
|
|
values.push_back(*state.getPayloadValues(tdValue).begin());
|
|
}
|
|
Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
|
|
if (failed(moveValueDefinitions(rewriter, values, moveBefore))) {
|
|
auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
|
|
std::string errorMsg = listener->getLatestMatchFailureMessage();
|
|
(void)emitRemark(errorMsg);
|
|
}
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Test affine functionality.
|
|
//===----------------------------------------------------------------------===//
|
|
DiagnosedSilenceableFailure
|
|
transform::TestMakeComposedFoldedAffineApply::applyToOne(
|
|
TransformRewriter &rewriter, affine::AffineApplyOp affineApplyOp,
|
|
ApplyToEachResultList &results, TransformState &state) {
|
|
Location loc = affineApplyOp.getLoc();
|
|
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
|
|
rewriter, loc, affineApplyOp.getAffineMap(),
|
|
getAsOpFoldResult(affineApplyOp.getOperands()),
|
|
/*composeAffineMin=*/true);
|
|
Value result;
|
|
if (auto v = dyn_cast<Value>(ofr)) {
|
|
result = v;
|
|
} else {
|
|
result = rewriter.create<arith::ConstantIndexOp>(
|
|
loc, getConstantIntValue(ofr).value());
|
|
}
|
|
results.push_back(result.getDefiningOp());
|
|
rewriter.replaceOp(affineApplyOp, result);
|
|
return DiagnosedSilenceableFailure::success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Extension
|
|
//===----------------------------------------------------------------------===//
|
|
namespace {
|
|
|
|
class TestTransformsDialectExtension
|
|
: public transform::TransformDialectExtension<
|
|
TestTransformsDialectExtension> {
|
|
public:
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformsDialectExtension)
|
|
|
|
using Base::Base;
|
|
|
|
void init() {
|
|
registerTransformOps<
|
|
#define GET_OP_LIST
|
|
#include "TestTransformsOps.cpp.inc"
|
|
>();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace test {
|
|
void registerTestTransformsTransformDialectExtension(
|
|
DialectRegistry ®istry) {
|
|
registry.addExtensions<TestTransformsDialectExtension>();
|
|
}
|
|
} // namespace test
|