llvm-project/mlir/test/lib/Transforms/TestTransformsOps.cpp
Fabian Mora 8f4da2cbf0
[mlir][affine] Fix min simplification in makeComposedAffineApply (#145376)
This patch fixes a bug discovered in the
`affine::makeComposedFoldedAffineApply` function when `composeAffineMin
== true`. The bug happened because the simplification assumed the
symbols appearing in the `affine.apply` op corresponded to symbols in
the `affine.min` op, and that's not always the case. For example:

```mlir
#map = affine_map<()[s0, s1] -> (s1)>
#map1 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
module {
  func.func @min_max_full_simplify() -> index {
    %0 = test.value_with_bounds {max = 64 : index, min = 32 : index}
    %1 = test.value_with_bounds {max = 64 : index, min = 32 : index}
    %2 = affine.min #map()[%0, %1]
    %3 = affine.apply #map1()[%2, %0]
    return %3 : index
  }
}
```

This patch also introduces the test `make_composed_folded_affine_apply`
transform operation to test this simplification. It also adds tests
ensuring we get correct behavior.

---------

Co-authored-by: Nicolas Vasilache <nico.vasilache@amd.com>
2025-06-24 07:55:12 -04:00

113 lines
4.1 KiB
C++

//===- TestTransformsOps.cpp - Test Transforms ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines transform dialect operations for testing MLIR
// transformations
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Transforms/RegionUtils.h"
#define GET_OP_CLASSES
#include "TestTransformsOps.h.inc"
using namespace mlir;
using namespace mlir::transform;
#define GET_OP_CLASSES
#include "TestTransformsOps.cpp.inc"
DiagnosedSilenceableFailure
transform::TestMoveOperandDeps::apply(TransformRewriter &rewriter,
TransformResults &TransformResults,
TransformState &state) {
Operation *op = *state.getPayloadOps(getOp()).begin();
Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
if (failed(moveOperationDependencies(rewriter, op, moveBefore))) {
auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
std::string errorMsg = listener->getLatestMatchFailureMessage();
(void)emitRemark(errorMsg);
}
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure
transform::TestMoveValueDefns::apply(TransformRewriter &rewriter,
TransformResults &TransformResults,
TransformState &state) {
SmallVector<Value> values;
for (auto tdValue : getValues()) {
values.push_back(*state.getPayloadValues(tdValue).begin());
}
Operation *moveBefore = *state.getPayloadOps(getInsertionPoint()).begin();
if (failed(moveValueDefinitions(rewriter, values, moveBefore))) {
auto listener = cast<ErrorCheckingTrackingListener>(rewriter.getListener());
std::string errorMsg = listener->getLatestMatchFailureMessage();
(void)emitRemark(errorMsg);
}
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// Test affine functionality.
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::TestMakeComposedFoldedAffineApply::applyToOne(
TransformRewriter &rewriter, affine::AffineApplyOp affineApplyOp,
ApplyToEachResultList &results, TransformState &state) {
Location loc = affineApplyOp.getLoc();
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, loc, affineApplyOp.getAffineMap(),
getAsOpFoldResult(affineApplyOp.getOperands()),
/*composeAffineMin=*/true);
Value result;
if (auto v = dyn_cast<Value>(ofr)) {
result = v;
} else {
result = rewriter.create<arith::ConstantIndexOp>(
loc, getConstantIntValue(ofr).value());
}
results.push_back(result.getDefiningOp());
rewriter.replaceOp(affineApplyOp, result);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// Extension
//===----------------------------------------------------------------------===//
namespace {
class TestTransformsDialectExtension
: public transform::TransformDialectExtension<
TestTransformsDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTransformsDialectExtension)
using Base::Base;
void init() {
registerTransformOps<
#define GET_OP_LIST
#include "TestTransformsOps.cpp.inc"
>();
}
};
} // namespace
namespace test {
void registerTestTransformsTransformDialectExtension(
DialectRegistry &registry) {
registry.addExtensions<TestTransformsDialectExtension>();
}
} // namespace test