[mlir][acc] Sink constants into acc.compute_region when creating (#187777)

When converting OpenACC compute constructs to acc.compute_region, also
sink constants inside so they do not become live-ins.
This commit is contained in:
Razvan Lupusoru 2026-03-20 12:56:58 -07:00 committed by GitHub
parent bd3b06b0a7
commit 66f06f54cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 0 deletions

View File

@ -52,6 +52,7 @@
#include "mlir/Dialect/OpenACC/OpenACCUtilsLoop.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"
@ -126,6 +127,36 @@ static void setParDimsAttr(Operation *op, GPUParallelDimsAttr attr) {
op->setAttr(GPUParallelDimsAttr::name, attr);
}
/// Clone defining ops of constant live-in values into `region`, rewrite uses
/// inside the region to the clones, and remove those values from
/// `liveInValues` so they are not threaded through `acc.compute_region` ins.
static void materializeConstantLiveInsIntoRegion(Region &region,
SetVector<Value> &liveInValues,
RewriterBase &rewriter) {
SmallVector<Value> constantLiveIns;
for (Value v : liveInValues) {
Operation *defOp = v.getDefiningOp();
if (defOp && matchPattern(defOp, m_Constant())) {
// As per the definition of ConstantLike trait, constants must have a
// single result.
assert(defOp->getNumResults() == 1 &&
"constants must have a single result");
constantLiveIns.push_back(v);
}
}
if (constantLiveIns.empty())
return;
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&region.front());
for (Value v : constantLiveIns) {
Value newV = rewriter.clone(*v.getDefiningOp())->getResult(0);
replaceAllUsesInRegionWith(v, newV, region);
liveInValues.remove(v);
}
}
/// Insert a parallel dimension into the list, maintaining order by
/// GPUParallelDimAttr::getOrder (descending).
static void insertParDim(SmallVectorImpl<GPUParallelDimAttr> &parDims,
@ -320,6 +351,7 @@ public:
Region &region = computeOp.getRegion();
SetVector<Value> liveInValues;
getUsedValuesDefinedAbove(region, region, liveInValues);
materializeConstantLiveInsIntoRegion(region, liveInValues, rewriter);
IRMapping mapping;
auto computeRegion = buildComputeRegion(
computeOp->getLoc(), launchArgs, liveInValues.getArrayRef(),

View File

@ -105,3 +105,27 @@ func.func @kernels_loop(%buf: memref<8xi32>) {
acc.copyout accPtr(%dev : memref<8xi32>) to varPtr(%buf : memref<8xi32>)
return
}
// -----
// Constant live-ins are cloned into the compute region body so they are not
// passed through `acc.compute_region` arguments.
// CHECK-LABEL: func.func @constant_livein_materialized_into_compute_region
func.func @constant_livein_materialized_into_compute_region(%buf: memref<1xi32>) {
%c0 = arith.constant 0 : index
%c42 = arith.constant 42 : i32
%dev = acc.copyin varPtr(%buf : memref<1xi32>) -> memref<1xi32>
// CHECK: acc.kernel_environment
// CHECK: acc.compute_region ins({{.*}}) : (memref<1xi32>) {
// CHECK-DAG: arith.constant 42 : i32
// CHECK-DAG: arith.constant 0 : index
// CHECK: memref.store
// CHECK: acc.yield
acc.serial dataOperands(%dev : memref<1xi32>) {
memref.store %c42, %dev[%c0] : memref<1xi32>
acc.yield
}
acc.copyout accPtr(%dev : memref<1xi32>) to varPtr(%buf : memref<1xi32>)
return
}