From 66f06f54cb4d9fda87aed346b9d5747d0bc0215e Mon Sep 17 00:00:00 2001 From: Razvan Lupusoru Date: Fri, 20 Mar 2026 12:56:58 -0700 Subject: [PATCH] [mlir][acc] Sink constants into acc.compute_region when creating (#187777) When converting OpenACC compute constructs to acc.compute_region, also sink constants inside so they do not become live-ins. --- .../OpenACC/Transforms/ACCComputeLowering.cpp | 32 +++++++++++++++++++ .../OpenACC/acc-compute-lowering-compute.mlir | 24 ++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp index e0b0acff57ca..9cc36312d361 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCComputeLowering.cpp @@ -52,6 +52,7 @@ #include "mlir/Dialect/OpenACC/OpenACCUtilsLoop.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" @@ -126,6 +127,36 @@ static void setParDimsAttr(Operation *op, GPUParallelDimsAttr attr) { op->setAttr(GPUParallelDimsAttr::name, attr); } +/// Clone defining ops of constant live-in values into `region`, rewrite uses +/// inside the region to the clones, and remove those values from +/// `liveInValues` so they are not threaded through `acc.compute_region` ins. +static void materializeConstantLiveInsIntoRegion(Region ®ion, + SetVector &liveInValues, + RewriterBase &rewriter) { + SmallVector constantLiveIns; + for (Value v : liveInValues) { + Operation *defOp = v.getDefiningOp(); + if (defOp && matchPattern(defOp, m_Constant())) { + // As per the definition of ConstantLike trait, constants must have a + // single result. + assert(defOp->getNumResults() == 1 && + "constants must have a single result"); + constantLiveIns.push_back(v); + } + } + if (constantLiveIns.empty()) + return; + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(®ion.front()); + + for (Value v : constantLiveIns) { + Value newV = rewriter.clone(*v.getDefiningOp())->getResult(0); + replaceAllUsesInRegionWith(v, newV, region); + liveInValues.remove(v); + } +} + /// Insert a parallel dimension into the list, maintaining order by /// GPUParallelDimAttr::getOrder (descending). static void insertParDim(SmallVectorImpl &parDims, @@ -320,6 +351,7 @@ public: Region ®ion = computeOp.getRegion(); SetVector liveInValues; getUsedValuesDefinedAbove(region, region, liveInValues); + materializeConstantLiveInsIntoRegion(region, liveInValues, rewriter); IRMapping mapping; auto computeRegion = buildComputeRegion( computeOp->getLoc(), launchArgs, liveInValues.getArrayRef(), diff --git a/mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir b/mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir index 77c4ba94c4f1..ee177aaf6e7a 100644 --- a/mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir +++ b/mlir/test/Dialect/OpenACC/acc-compute-lowering-compute.mlir @@ -105,3 +105,27 @@ func.func @kernels_loop(%buf: memref<8xi32>) { acc.copyout accPtr(%dev : memref<8xi32>) to varPtr(%buf : memref<8xi32>) return } + +// ----- + +// Constant live-ins are cloned into the compute region body so they are not +// passed through `acc.compute_region` arguments. + +// CHECK-LABEL: func.func @constant_livein_materialized_into_compute_region +func.func @constant_livein_materialized_into_compute_region(%buf: memref<1xi32>) { + %c0 = arith.constant 0 : index + %c42 = arith.constant 42 : i32 + %dev = acc.copyin varPtr(%buf : memref<1xi32>) -> memref<1xi32> + // CHECK: acc.kernel_environment + // CHECK: acc.compute_region ins({{.*}}) : (memref<1xi32>) { + // CHECK-DAG: arith.constant 42 : i32 + // CHECK-DAG: arith.constant 0 : index + // CHECK: memref.store + // CHECK: acc.yield + acc.serial dataOperands(%dev : memref<1xi32>) { + memref.store %c42, %dev[%c0] : memref<1xi32> + acc.yield + } + acc.copyout accPtr(%dev : memref<1xi32>) to varPtr(%buf : memref<1xi32>) + return +}