handle multi-block host fallback regions by wrapping them in scf.execute_region, instead of rejecting with `not yet implemented: region with multiple blocks`.
283 lines
10 KiB
C++
283 lines
10 KiB
C++
//===- ACCIfClauseLowering.cpp - Lower ACC compute construct if clauses --===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This pass lowers OpenACC compute constructs (parallel, kernels, serial) with
|
|
// `if` clauses using region specialization. It creates two execution paths:
|
|
// device execution when the condition is true, host execution when false.
|
|
//
|
|
// Overview:
|
|
// ---------
|
|
// When an ACC compute construct has an `if` clause, the construct should only
|
|
// execute on the device when the condition is true. If the condition is false,
|
|
// the code should execute on the host instead. This pass transforms:
|
|
//
|
|
// acc.parallel if(%cond) { ... }
|
|
//
|
|
// Into:
|
|
//
|
|
// scf.if %cond {
|
|
// // Device path: clone data ops, compute construct without if, exit ops
|
|
// acc.parallel { ... }
|
|
// } else {
|
|
// // Host path: original region body with ACC ops converted to host
|
|
// }
|
|
//
|
|
// Transformations:
|
|
// ----------------
|
|
// For each compute construct with an `if` clause:
|
|
//
|
|
// 1. Device Path (true branch):
|
|
// - Clone data entry operations (acc.copyin, acc.create, etc.)
|
|
// - Clone the compute construct without the `if` clause
|
|
// - Clone data exit operations (acc.copyout, acc.delete, etc.)
|
|
//
|
|
// 2. Host Path (false branch):
|
|
// - Move the original region body to the else branch
|
|
// - Apply host fallback patterns to convert ACC ops to host equivalents
|
|
//
|
|
// 3. Cleanup:
|
|
// - Erase the original compute construct and data operations
|
|
// - Replace uses of ACC variables with host variables in the else branch
|
|
//
|
|
// Requirements:
|
|
// -------------
|
|
// To use this pass in a pipeline, the following requirements exist:
|
|
//
|
|
// 1. Analysis Registration (Optional): If custom behavior is needed for
|
|
// emitting not-yet-implemented messages for unsupported cases, the pipeline
|
|
// should pre-register the `acc::OpenACCSupport` analysis.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
|
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
|
|
#include "mlir/Dialect/OpenACC/OpenACC.h"
|
|
#include "mlir/Dialect/OpenACC/OpenACCUtilsLoop.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/Support/Debug.h"
|
|
|
|
namespace mlir {
|
|
namespace acc {
|
|
#define GEN_PASS_DEF_ACCIFCLAUSELOWERING
|
|
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
|
|
} // namespace acc
|
|
} // namespace mlir
|
|
|
|
#define DEBUG_TYPE "acc-if-clause-lowering"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::acc;
|
|
|
|
namespace {
|
|
|
|
class ACCIfClauseLowering
|
|
: public acc::impl::ACCIfClauseLoweringBase<ACCIfClauseLowering> {
|
|
using ACCIfClauseLoweringBase<ACCIfClauseLowering>::ACCIfClauseLoweringBase;
|
|
|
|
private:
|
|
OpenACCSupport *accSupport = nullptr;
|
|
|
|
void convertHostRegion(Operation *computeOp, Region ®ion);
|
|
|
|
template <typename OpTy>
|
|
void lowerIfClauseForComputeConstruct(OpTy computeConstructOp,
|
|
SmallVector<Operation *> &eraseOps);
|
|
|
|
public:
|
|
void runOnOperation() override;
|
|
};
|
|
|
|
void ACCIfClauseLowering::convertHostRegion(Operation *computeOp,
|
|
Region ®ion) {
|
|
// Only collect ACC dialect operations - other ops don't need conversion
|
|
SmallVector<Operation *> hostOps;
|
|
region.walk<WalkOrder::PreOrder>([&](Operation *op) {
|
|
if (isa<acc::OpenACCDialect>(op->getDialect()))
|
|
hostOps.push_back(op);
|
|
});
|
|
|
|
RewritePatternSet patterns(computeOp->getContext());
|
|
populateACCHostFallbackPatterns(patterns, *accSupport);
|
|
|
|
GreedyRewriteConfig config;
|
|
config.setUseTopDownTraversal(true);
|
|
config.setStrictness(GreedyRewriteStrictness::ExistingOps);
|
|
if (failed(applyOpPatternsGreedily(hostOps, std::move(patterns), config)))
|
|
accSupport->emitNYI(computeOp->getLoc(), "failed to convert host region");
|
|
}
|
|
|
|
// Template function to handle if condition conversion for ACC compute
|
|
// constructs
|
|
template <typename OpTy>
|
|
void ACCIfClauseLowering::lowerIfClauseForComputeConstruct(
|
|
OpTy computeConstructOp, SmallVector<Operation *> &eraseOps) {
|
|
Value ifCond = computeConstructOp.getIfCond();
|
|
if (!ifCond)
|
|
return;
|
|
|
|
IRRewriter rewriter(computeConstructOp);
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Converting " << computeConstructOp->getName()
|
|
<< " with if condition: " << computeConstructOp
|
|
<< "\n");
|
|
|
|
// Collect data clause operations that need to be recreated in the if
|
|
// condition
|
|
SmallVector<Operation *> dataEntryOps;
|
|
SmallVector<Operation *> dataExitOps;
|
|
SmallVector<Operation *> firstprivateOps;
|
|
SmallVector<Operation *> privateOps;
|
|
SmallVector<Operation *> reductionOps;
|
|
|
|
// Collect data entry operations
|
|
for (Value operand : computeConstructOp.getDataClauseOperands())
|
|
if (Operation *defOp = operand.getDefiningOp())
|
|
if (isa<ACC_DATA_ENTRY_OPS>(defOp))
|
|
dataEntryOps.push_back(defOp);
|
|
|
|
// Find corresponding exit operations for each entry operation.
|
|
// Iterate backwards through entry ops since exit ops appear in reverse order.
|
|
for (Operation *dataEntryOp : llvm::reverse(dataEntryOps))
|
|
for (Operation *user : dataEntryOp->getUsers())
|
|
if (isa<ACC_DATA_EXIT_OPS>(user))
|
|
dataExitOps.push_back(user);
|
|
|
|
// Collect firstprivate, private, and reduction operations
|
|
auto collectOps = [&](SmallVector<Operation *> &ops, OperandRange operands) {
|
|
for (Value operand : operands)
|
|
if (Operation *defOp = operand.getDefiningOp())
|
|
ops.push_back(defOp);
|
|
};
|
|
collectOps(firstprivateOps, computeConstructOp.getFirstprivateOperands());
|
|
collectOps(privateOps, computeConstructOp.getPrivateOperands());
|
|
collectOps(reductionOps, computeConstructOp.getReductionOperands());
|
|
|
|
// Create scf.if with device and host execution paths
|
|
auto ifOp = scf::IfOp::create(rewriter, computeConstructOp.getLoc(),
|
|
TypeRange{}, ifCond, /*withElseRegion=*/true);
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Cloning " << dataEntryOps.size()
|
|
<< " data entry operations for device path\n");
|
|
|
|
// Device execution path (true branch)
|
|
Block &thenBlock = ifOp.getThenRegion().front();
|
|
rewriter.setInsertionPointToStart(&thenBlock);
|
|
|
|
// Clone data entry operations
|
|
SmallVector<Value> deviceDataOperands;
|
|
SmallVector<Value> firstprivateOperands;
|
|
SmallVector<Value> privateOperands;
|
|
SmallVector<Value> reductionOperands;
|
|
|
|
// Map the data entry and firstprivate ops for the cloned region
|
|
IRMapping deviceMapping;
|
|
auto cloneAndMapOps = [&](SmallVector<Operation *> &ops,
|
|
SmallVector<Value> &operands) {
|
|
for (Operation *op : ops) {
|
|
Operation *clonedOp = rewriter.clone(*op, deviceMapping);
|
|
operands.push_back(clonedOp->getResult(0));
|
|
deviceMapping.map(op->getResult(0), clonedOp->getResult(0));
|
|
}
|
|
};
|
|
cloneAndMapOps(dataEntryOps, deviceDataOperands);
|
|
cloneAndMapOps(firstprivateOps, firstprivateOperands);
|
|
cloneAndMapOps(privateOps, privateOperands);
|
|
cloneAndMapOps(reductionOps, reductionOperands);
|
|
|
|
// Create new compute op without if condition for device execution by
|
|
// cloning
|
|
OpTy newComputeOp = cast<OpTy>(
|
|
rewriter.clone(*computeConstructOp.getOperation(), deviceMapping));
|
|
newComputeOp.getIfCondMutable().clear();
|
|
newComputeOp.getDataClauseOperandsMutable().assign(deviceDataOperands);
|
|
newComputeOp.getFirstprivateOperandsMutable().assign(firstprivateOperands);
|
|
newComputeOp.getPrivateOperandsMutable().assign(privateOperands);
|
|
newComputeOp.getReductionOperandsMutable().assign(reductionOperands);
|
|
|
|
// Clone data exit operations
|
|
rewriter.setInsertionPointAfter(newComputeOp);
|
|
for (Operation *dataOp : dataExitOps)
|
|
rewriter.clone(*dataOp, deviceMapping);
|
|
|
|
rewriter.setInsertionPointToEnd(&thenBlock);
|
|
if (!thenBlock.getTerminator())
|
|
scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
|
|
|
|
// Host execution path (false branch)
|
|
Region &hostRegion = computeConstructOp.getRegion();
|
|
if (hostRegion.hasOneBlock()) {
|
|
// Don't need to clone original ops, just take them and legalize for host.
|
|
ifOp.getElseRegion().takeBody(hostRegion);
|
|
|
|
// Swap acc yield for scf yield.
|
|
Block &elseBlock = ifOp.getElseRegion().front();
|
|
elseBlock.getTerminator()->erase();
|
|
rewriter.setInsertionPointToEnd(&elseBlock);
|
|
scf::YieldOp::create(rewriter, computeConstructOp.getLoc());
|
|
|
|
convertHostRegion(computeConstructOp, ifOp.getElseRegion());
|
|
} else {
|
|
// scf.if regions must stay single-block. Wrap the original multi-block ACC
|
|
// body in scf.execute_region so it can be hosted in the else branch.
|
|
Block &elseBlock = ifOp.getElseRegion().front();
|
|
rewriter.setInsertionPoint(elseBlock.getTerminator());
|
|
IRMapping hostMapping;
|
|
auto hostExecuteRegion = wrapMultiBlockRegionWithSCFExecuteRegion(
|
|
hostRegion, hostMapping, computeConstructOp.getLoc(), rewriter);
|
|
convertHostRegion(computeConstructOp, hostExecuteRegion.getRegion());
|
|
}
|
|
|
|
// The original op is now empty and can be erased
|
|
eraseOps.push_back(computeConstructOp);
|
|
|
|
// TODO: Can probably 'move' the data ops instead of cloning them
|
|
// which would eliminate need to explicitly erase
|
|
for (Operation *dataOp : dataExitOps)
|
|
eraseOps.push_back(dataOp);
|
|
|
|
// The new host code may contain uses of the acc variables. Replace them by
|
|
// the host values.
|
|
auto replaceAndEraseOps = [&](SmallVector<Operation *> &ops) {
|
|
for (Operation *op : ops) {
|
|
getAccVar(op).replaceAllUsesWith(getVar(op));
|
|
eraseOps.push_back(op);
|
|
}
|
|
};
|
|
replaceAndEraseOps(dataEntryOps);
|
|
replaceAndEraseOps(firstprivateOps);
|
|
replaceAndEraseOps(privateOps);
|
|
replaceAndEraseOps(reductionOps);
|
|
}
|
|
|
|
void ACCIfClauseLowering::runOnOperation() {
|
|
func::FuncOp funcOp = getOperation();
|
|
accSupport = &getAnalysis<OpenACCSupport>();
|
|
|
|
SmallVector<Operation *> eraseOps;
|
|
funcOp.walk([&](Operation *op) {
|
|
if (auto parallelOp = dyn_cast<acc::ParallelOp>(op))
|
|
lowerIfClauseForComputeConstruct(parallelOp, eraseOps);
|
|
else if (auto kernelsOp = dyn_cast<acc::KernelsOp>(op))
|
|
lowerIfClauseForComputeConstruct(kernelsOp, eraseOps);
|
|
else if (auto serialOp = dyn_cast<acc::SerialOp>(op))
|
|
lowerIfClauseForComputeConstruct(serialOp, eraseOps);
|
|
});
|
|
|
|
for (Operation *op : eraseOps)
|
|
op->erase();
|
|
}
|
|
|
|
} // namespace
|