[mlir][acc] Sink constants into acc.compute_region when creating (#187777)
When converting OpenACC compute constructs to acc.compute_region, also sink constants inside so they do not become live-ins.
This commit is contained in:
parent
bd3b06b0a7
commit
66f06f54cb
@ -52,6 +52,7 @@
|
||||
#include "mlir/Dialect/OpenACC/OpenACCUtilsLoop.h"
|
||||
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
@ -126,6 +127,36 @@ static void setParDimsAttr(Operation *op, GPUParallelDimsAttr attr) {
|
||||
op->setAttr(GPUParallelDimsAttr::name, attr);
|
||||
}
|
||||
|
||||
/// Clone defining ops of constant live-in values into `region`, rewrite uses
|
||||
/// inside the region to the clones, and remove those values from
|
||||
/// `liveInValues` so they are not threaded through `acc.compute_region` ins.
|
||||
static void materializeConstantLiveInsIntoRegion(Region ®ion,
|
||||
SetVector<Value> &liveInValues,
|
||||
RewriterBase &rewriter) {
|
||||
SmallVector<Value> constantLiveIns;
|
||||
for (Value v : liveInValues) {
|
||||
Operation *defOp = v.getDefiningOp();
|
||||
if (defOp && matchPattern(defOp, m_Constant())) {
|
||||
// As per the definition of ConstantLike trait, constants must have a
|
||||
// single result.
|
||||
assert(defOp->getNumResults() == 1 &&
|
||||
"constants must have a single result");
|
||||
constantLiveIns.push_back(v);
|
||||
}
|
||||
}
|
||||
if (constantLiveIns.empty())
|
||||
return;
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(®ion.front());
|
||||
|
||||
for (Value v : constantLiveIns) {
|
||||
Value newV = rewriter.clone(*v.getDefiningOp())->getResult(0);
|
||||
replaceAllUsesInRegionWith(v, newV, region);
|
||||
liveInValues.remove(v);
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert a parallel dimension into the list, maintaining order by
|
||||
/// GPUParallelDimAttr::getOrder (descending).
|
||||
static void insertParDim(SmallVectorImpl<GPUParallelDimAttr> &parDims,
|
||||
@ -320,6 +351,7 @@ public:
|
||||
Region ®ion = computeOp.getRegion();
|
||||
SetVector<Value> liveInValues;
|
||||
getUsedValuesDefinedAbove(region, region, liveInValues);
|
||||
materializeConstantLiveInsIntoRegion(region, liveInValues, rewriter);
|
||||
IRMapping mapping;
|
||||
auto computeRegion = buildComputeRegion(
|
||||
computeOp->getLoc(), launchArgs, liveInValues.getArrayRef(),
|
||||
|
||||
@ -105,3 +105,27 @@ func.func @kernels_loop(%buf: memref<8xi32>) {
|
||||
acc.copyout accPtr(%dev : memref<8xi32>) to varPtr(%buf : memref<8xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Constant live-ins are cloned into the compute region body so they are not
|
||||
// passed through `acc.compute_region` arguments.
|
||||
|
||||
// CHECK-LABEL: func.func @constant_livein_materialized_into_compute_region
|
||||
func.func @constant_livein_materialized_into_compute_region(%buf: memref<1xi32>) {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c42 = arith.constant 42 : i32
|
||||
%dev = acc.copyin varPtr(%buf : memref<1xi32>) -> memref<1xi32>
|
||||
// CHECK: acc.kernel_environment
|
||||
// CHECK: acc.compute_region ins({{.*}}) : (memref<1xi32>) {
|
||||
// CHECK-DAG: arith.constant 42 : i32
|
||||
// CHECK-DAG: arith.constant 0 : index
|
||||
// CHECK: memref.store
|
||||
// CHECK: acc.yield
|
||||
acc.serial dataOperands(%dev : memref<1xi32>) {
|
||||
memref.store %c42, %dev[%c0] : memref<1xi32>
|
||||
acc.yield
|
||||
}
|
||||
acc.copyout accPtr(%dev : memref<1xi32>) to varPtr(%buf : memref<1xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user