[mlir][acc] Add ACCSpecializeForDevice and ACCSpecializeForHost passes Add two new transformation passes for specializing OpenACC IR for different execution contexts: ACCSpecializeForDevice: - Strips OpenACC constructs that are invalid in device code - Replaces data entry ops with their var operands - Unwraps regions from compute/data constructs - Erases runtime operations (init, shutdown, wait, etc.) This pass is applicable in two contexts: 1. Functions marked with `acc.specialized_routine` attribute, where the entire function body is device code 2. Non-specialized functions, where patterns are applied only to `acc` operations nested inside compute constructs (parallel, serial, kernels), not to the constructs themselves ACCSpecializeForHost: - Converts orphan `acc` operations for host execution - Transforms `acc.atomic.*` to load/store via `PointerLikeType` interface - Converts `acc.loop` to `scf.for` or `scf.execute_region` - Replaces orphan data entry ops with their var operands This pass operates in two modes: 1. Default (orphan) mode: Only converts `acc` operations that are not inside or attached to compute regions. Used for host `acc routine`s where compute constructs should be preserved. 2. Host fallback mode (enable-host-fallback=true): Converts ALL `acc` operations including compute constructs, data regions, and runtime ops. This is used to allow testing of the full conversion. These patterns will be used to handle conditional host execution of `acc` regions with if clause. The pattern population functions (populateACCSpecializeForDevice, populateACCOrphanToHostPatterns, populateACCHostFallbackPatterns) are exposed so other passes can reuse these patterns. --------- Co-authored-by: Susan Tan <zujunt@nvidia.com> Co-authored-by: Scott Manley <rscottmanley@gmail.com>
173 lines
6.9 KiB
C++
173 lines
6.9 KiB
C++
//===- ACCSpecializeForDevice.cpp -----------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This pass strips OpenACC constructs that are invalid or unnecessary inside
|
|
// device code (specialized acc routines or compute construct regions).
|
|
//
|
|
// Overview:
|
|
// ---------
|
|
// In a specialized acc routine or compute construct, many OpenACC operations
|
|
// do not make sense because they are host-side constructs. This pass removes
|
|
// or transforms these operations appropriately:
|
|
//
|
|
// - Data operations that manage device memory from host perspective
|
|
// - Compute constructs that launch kernels (we're already on device)
|
|
// - Runtime operations like init/shutdown/set/wait
|
|
//
|
|
// Transformations:
|
|
// ----------------
|
|
// The pass applies the following transformations:
|
|
//
|
|
// 1. Data Entry Ops (replaced with var operand):
|
|
// acc.attach, acc.copyin, acc.create, acc.declare_device_resident,
|
|
// acc.declare_link, acc.deviceptr, acc.get_deviceptr, acc.nocreate,
|
|
// acc.present, acc.update_device, acc.use_device
|
|
//
|
|
// 2. Data Exit Ops (erased):
|
|
// acc.copyout, acc.delete, acc.detach, acc.update_host
|
|
//
|
|
// 3. Structured Data/Compute Constructs (region inlined):
|
|
// acc.data, acc.host_data, acc.kernel_environment, acc.parallel,
|
|
// acc.serial, acc.kernels
|
|
//
|
|
// 4. Unstructured Data Ops (erased):
|
|
// acc.enter_data, acc.exit_data, acc.update, acc.declare_enter,
|
|
// acc.declare_exit
|
|
//
|
|
// 5. Runtime Ops (erased):
|
|
// acc.init, acc.shutdown, acc.set, acc.wait
|
|
//
|
|
// Scope of Application:
|
|
// ---------------------
|
|
// - For functions with `acc.specialized_routine` attribute: patterns are
|
|
// applied to the entire function body.
|
|
// - For non-specialized functions: patterns are applied only to ACC
|
|
// operations INSIDE compute constructs (parallel, serial, kernels),
|
|
// not to the compute constructs themselves or their data operands.
|
|
//
|
|
// Note: acc.cache, acc.private, acc.reduction, acc.firstprivate are NOT
|
|
// transformed by this pass as they are valid in device code.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
|
|
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/OpenACC/OpenACC.h"
|
|
#include "mlir/Dialect/OpenACC/Transforms/ACCSpecializePatterns.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
namespace mlir {
|
|
namespace acc {
|
|
#define GEN_PASS_DEF_ACCSPECIALIZEFORDEVICE
|
|
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
|
|
} // namespace acc
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::acc;
|
|
|
|
namespace {
|
|
|
|
class ACCSpecializeForDevice
|
|
: public acc::impl::ACCSpecializeForDeviceBase<ACCSpecializeForDevice> {
|
|
public:
|
|
using ACCSpecializeForDeviceBase<
|
|
ACCSpecializeForDevice>::ACCSpecializeForDeviceBase;
|
|
|
|
void runOnOperation() override {
|
|
func::FuncOp func = getOperation();
|
|
|
|
RewritePatternSet patterns(&getContext());
|
|
acc::populateACCSpecializeForDevicePatterns(patterns);
|
|
GreedyRewriteConfig config;
|
|
config.setUseTopDownTraversal(true);
|
|
|
|
if (acc::isSpecializedAccRoutine(func)) {
|
|
// For specialized acc routines, apply patterns to the entire function
|
|
(void)applyPatternsGreedily(func, std::move(patterns), config);
|
|
} else {
|
|
// For non-specialized functions, apply patterns only to ACC operations
|
|
// inside compute constructs (not to the compute constructs themselves).
|
|
SmallVector<Operation *> opsToTransform;
|
|
func.walk([&](Operation *op) {
|
|
if (isa<ACC_COMPUTE_CONSTRUCT_OPS>(op)) {
|
|
// Walk inside the compute construct and collect ACC ops
|
|
op->walk([&](Operation *innerOp) {
|
|
// Skip the compute construct itself
|
|
if (innerOp == op)
|
|
return;
|
|
if (isa<acc::OpenACCDialect>(innerOp->getDialect()))
|
|
opsToTransform.push_back(innerOp);
|
|
});
|
|
}
|
|
});
|
|
if (!opsToTransform.empty())
|
|
(void)applyOpPatternsGreedily(opsToTransform, std::move(patterns),
|
|
config);
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Pattern population functions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void mlir::acc::populateACCSpecializeForDevicePatterns(
|
|
RewritePatternSet &patterns) {
|
|
MLIRContext *context = patterns.getContext();
|
|
|
|
// Declare patterns - erase declare_enter and its associated declare_exit
|
|
patterns.insert<ACCDeclareEnterOpConversion>(context);
|
|
|
|
// Data entry ops - replaced with their var operand
|
|
// Note: acc.cache, acc.private, acc.reduction, acc.firstprivate are NOT
|
|
// included here - they are valid in device code
|
|
patterns.insert<ACCOpReplaceWithVarConversion<acc::AttachOp>,
|
|
ACCOpReplaceWithVarConversion<acc::CopyinOp>,
|
|
ACCOpReplaceWithVarConversion<acc::CreateOp>,
|
|
ACCOpReplaceWithVarConversion<acc::DeclareDeviceResidentOp>,
|
|
ACCOpReplaceWithVarConversion<acc::DeclareLinkOp>,
|
|
ACCOpReplaceWithVarConversion<acc::DevicePtrOp>,
|
|
ACCOpReplaceWithVarConversion<acc::GetDevicePtrOp>,
|
|
ACCOpReplaceWithVarConversion<acc::NoCreateOp>,
|
|
ACCOpReplaceWithVarConversion<acc::PresentOp>,
|
|
ACCOpReplaceWithVarConversion<acc::UpdateDeviceOp>,
|
|
ACCOpReplaceWithVarConversion<acc::UseDeviceOp>>(context);
|
|
|
|
// Data exit ops - simply erased (no results)
|
|
patterns.insert<ACCOpEraseConversion<acc::CopyoutOp>,
|
|
ACCOpEraseConversion<acc::DeleteOp>,
|
|
ACCOpEraseConversion<acc::DetachOp>,
|
|
ACCOpEraseConversion<acc::UpdateHostOp>>(context);
|
|
|
|
// Structured data constructs - unwrap their regions
|
|
patterns.insert<ACCRegionUnwrapConversion<acc::DataOp>,
|
|
ACCRegionUnwrapConversion<acc::HostDataOp>,
|
|
ACCRegionUnwrapConversion<acc::KernelEnvironmentOp>>(context);
|
|
|
|
// Compute constructs - unwrap their regions
|
|
patterns.insert<ACCRegionUnwrapConversion<acc::ParallelOp>,
|
|
ACCRegionUnwrapConversion<acc::SerialOp>,
|
|
ACCRegionUnwrapConversion<acc::KernelsOp>>(context);
|
|
|
|
// Unstructured data operations - erase them
|
|
patterns.insert<ACCOpEraseConversion<acc::EnterDataOp>,
|
|
ACCOpEraseConversion<acc::ExitDataOp>,
|
|
ACCOpEraseConversion<acc::UpdateOp>>(context);
|
|
|
|
// Runtime operations - erase them
|
|
patterns.insert<
|
|
ACCOpEraseConversion<acc::InitOp>, ACCOpEraseConversion<acc::ShutdownOp>,
|
|
ACCOpEraseConversion<acc::SetOp>, ACCOpEraseConversion<acc::WaitOp>>(
|
|
context);
|
|
}
|