Part of https://discourse.llvm.org/t/rfc-switching-the-llvm-dialect-and-dialect-lowerings-to-opaque-pointers/68179 This patch simply adds the pass option use-opaque-pointers to instruct the pass to use opaque-pointers instead of typed pointers during conversion. The pass itself does not actually make use of any pointers or anything, so did not require any changes except simply setting the option in the type converter. The tests have also been converted to using opaque-pointers Differential Revision: https://reviews.llvm.org/D144462
245 lines
10 KiB
C++
245 lines
10 KiB
C++
//===- OpenACCToLLVM.cpp - Prepare OpenACC data for LLVM translation ------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/OpenACCToLLVM/ConvertOpenACCToLLVM.h"
|
|
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/OpenACC/OpenACC.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_CONVERTOPENACCTOLLVMPASS
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// DataDescriptor implementation
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
constexpr StringRef getStructName() { return "openacc_data"; }
|
|
|
|
/// Construct a helper for the given descriptor value.
|
|
DataDescriptor::DataDescriptor(Value descriptor) : StructBuilder(descriptor) {
|
|
assert(value != nullptr && "value cannot be null");
|
|
}
|
|
|
|
/// Builds IR creating an `undef` value of the data descriptor.
|
|
DataDescriptor DataDescriptor::undef(OpBuilder &builder, Location loc,
|
|
Type basePtrTy, Type ptrTy) {
|
|
Type descriptorType = LLVM::LLVMStructType::getNewIdentified(
|
|
builder.getContext(), getStructName(),
|
|
{basePtrTy, ptrTy, builder.getI64Type()});
|
|
Value descriptor = builder.create<LLVM::UndefOp>(loc, descriptorType);
|
|
return DataDescriptor(descriptor);
|
|
}
|
|
|
|
/// Check whether the type is a valid data descriptor.
|
|
bool DataDescriptor::isValid(Value descriptor) {
|
|
if (auto type = descriptor.getType().dyn_cast<LLVM::LLVMStructType>()) {
|
|
if (type.isIdentified() && type.getName().startswith(getStructName()) &&
|
|
type.getBody().size() == 3 &&
|
|
(type.getBody()[kPtrBasePosInDataDescriptor]
|
|
.isa<LLVM::LLVMPointerType>() ||
|
|
type.getBody()[kPtrBasePosInDataDescriptor]
|
|
.isa<LLVM::LLVMStructType>()) &&
|
|
type.getBody()[kPtrPosInDataDescriptor].isa<LLVM::LLVMPointerType>() &&
|
|
type.getBody()[kSizePosInDataDescriptor].isInteger(64))
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Builds IR inserting the base pointer value into the descriptor.
|
|
void DataDescriptor::setBasePointer(OpBuilder &builder, Location loc,
|
|
Value basePtr) {
|
|
setPtr(builder, loc, kPtrBasePosInDataDescriptor, basePtr);
|
|
}
|
|
|
|
/// Builds IR inserting the pointer value into the descriptor.
|
|
void DataDescriptor::setPointer(OpBuilder &builder, Location loc, Value ptr) {
|
|
setPtr(builder, loc, kPtrPosInDataDescriptor, ptr);
|
|
}
|
|
|
|
/// Builds IR inserting the size value into the descriptor.
|
|
void DataDescriptor::setSize(OpBuilder &builder, Location loc, Value size) {
|
|
setPtr(builder, loc, kSizePosInDataDescriptor, size);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Conversion patterns
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
template <typename Op>
|
|
class LegalizeDataOpForLLVMTranslation : public ConvertOpToLLVMPattern<Op> {
|
|
using ConvertOpToLLVMPattern<Op>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
|
ConversionPatternRewriter &builder) const override {
|
|
Location loc = op.getLoc();
|
|
TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
|
|
|
|
unsigned numDataOperand = op.getNumDataOperands();
|
|
|
|
// Keep the non data operands without modification.
|
|
auto nonDataOperands = adaptor.getOperands().take_front(
|
|
adaptor.getOperands().size() - numDataOperand);
|
|
SmallVector<Value> convertedOperands;
|
|
convertedOperands.append(nonDataOperands.begin(), nonDataOperands.end());
|
|
|
|
// Go over the data operand and legalize them for translation.
|
|
for (unsigned idx = 0; idx < numDataOperand; ++idx) {
|
|
Value originalDataOperand = op.getDataOperand(idx);
|
|
|
|
// Traverse operands that were converted to MemRefDescriptors.
|
|
if (auto memRefType =
|
|
originalDataOperand.getType().dyn_cast<MemRefType>()) {
|
|
Type structType = converter->convertType(memRefType);
|
|
Value memRefDescriptor = builder
|
|
.create<UnrealizedConversionCastOp>(
|
|
loc, structType, originalDataOperand)
|
|
.getResult(0);
|
|
|
|
// Calculate the size of the memref and get the pointer to the allocated
|
|
// buffer.
|
|
SmallVector<Value> sizes;
|
|
SmallVector<Value> strides;
|
|
Value sizeBytes;
|
|
ConvertToLLVMPattern::getMemRefDescriptorSizes(
|
|
loc, memRefType, {}, builder, sizes, strides, sizeBytes);
|
|
MemRefDescriptor descriptor(memRefDescriptor);
|
|
Value dataPtr = descriptor.alignedPtr(builder, loc);
|
|
auto ptrType = descriptor.getElementPtrType();
|
|
|
|
auto descr = DataDescriptor::undef(builder, loc, structType, ptrType);
|
|
descr.setBasePointer(builder, loc, memRefDescriptor);
|
|
descr.setPointer(builder, loc, dataPtr);
|
|
descr.setSize(builder, loc, sizeBytes);
|
|
convertedOperands.push_back(descr);
|
|
} else if (originalDataOperand.getType().isa<LLVM::LLVMPointerType>()) {
|
|
convertedOperands.push_back(originalDataOperand);
|
|
} else {
|
|
// Type not supported.
|
|
return builder.notifyMatchFailure(op, "unsupported type");
|
|
}
|
|
}
|
|
|
|
builder.replaceOpWithNewOp<Op>(op, TypeRange(), convertedOperands,
|
|
op.getOperation()->getAttrs());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::populateOpenACCToLLVMConversionPatterns(
|
|
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
|
patterns.add<LegalizeDataOpForLLVMTranslation<acc::DataOp>>(converter);
|
|
patterns.add<LegalizeDataOpForLLVMTranslation<acc::EnterDataOp>>(converter);
|
|
patterns.add<LegalizeDataOpForLLVMTranslation<acc::ExitDataOp>>(converter);
|
|
patterns.add<LegalizeDataOpForLLVMTranslation<acc::ParallelOp>>(converter);
|
|
patterns.add<LegalizeDataOpForLLVMTranslation<acc::UpdateOp>>(converter);
|
|
}
|
|
|
|
namespace {
|
|
struct ConvertOpenACCToLLVMPass
|
|
: public impl::ConvertOpenACCToLLVMPassBase<ConvertOpenACCToLLVMPass> {
|
|
using Base::Base;
|
|
|
|
void runOnOperation() override;
|
|
};
|
|
} // namespace
|
|
|
|
void ConvertOpenACCToLLVMPass::runOnOperation() {
|
|
auto op = getOperation();
|
|
auto *context = op.getContext();
|
|
|
|
// Convert to OpenACC operations with LLVM IR dialect
|
|
RewritePatternSet patterns(context);
|
|
LowerToLLVMOptions options(context);
|
|
options.useOpaquePointers = useOpaquePointers;
|
|
LLVMTypeConverter converter(context, options);
|
|
populateOpenACCToLLVMConversionPatterns(converter, patterns);
|
|
|
|
ConversionTarget target(*context);
|
|
target.addLegalDialect<LLVM::LLVMDialect>();
|
|
target.addLegalOp<UnrealizedConversionCastOp>();
|
|
|
|
auto allDataOperandsAreConverted = [](ValueRange operands) {
|
|
for (Value operand : operands) {
|
|
if (!DataDescriptor::isValid(operand) &&
|
|
!operand.getType().isa<LLVM::LLVMPointerType>())
|
|
return false;
|
|
}
|
|
return true;
|
|
};
|
|
|
|
target.addDynamicallyLegalOp<acc::DataOp>(
|
|
[allDataOperandsAreConverted](acc::DataOp op) {
|
|
return allDataOperandsAreConverted(op.getCopyOperands()) &&
|
|
allDataOperandsAreConverted(op.getCopyinOperands()) &&
|
|
allDataOperandsAreConverted(op.getCopyinReadonlyOperands()) &&
|
|
allDataOperandsAreConverted(op.getCopyoutOperands()) &&
|
|
allDataOperandsAreConverted(op.getCopyoutZeroOperands()) &&
|
|
allDataOperandsAreConverted(op.getCreateOperands()) &&
|
|
allDataOperandsAreConverted(op.getCreateZeroOperands()) &&
|
|
allDataOperandsAreConverted(op.getNoCreateOperands()) &&
|
|
allDataOperandsAreConverted(op.getPresentOperands()) &&
|
|
allDataOperandsAreConverted(op.getDeviceptrOperands()) &&
|
|
allDataOperandsAreConverted(op.getAttachOperands());
|
|
});
|
|
|
|
target.addDynamicallyLegalOp<acc::EnterDataOp>(
|
|
[allDataOperandsAreConverted](acc::EnterDataOp op) {
|
|
return allDataOperandsAreConverted(op.getCopyinOperands()) &&
|
|
allDataOperandsAreConverted(op.getCreateOperands()) &&
|
|
allDataOperandsAreConverted(op.getCreateZeroOperands()) &&
|
|
allDataOperandsAreConverted(op.getAttachOperands());
|
|
});
|
|
|
|
target.addDynamicallyLegalOp<acc::ExitDataOp>(
|
|
[allDataOperandsAreConverted](acc::ExitDataOp op) {
|
|
return allDataOperandsAreConverted(op.getCopyoutOperands()) &&
|
|
allDataOperandsAreConverted(op.getDeleteOperands()) &&
|
|
allDataOperandsAreConverted(op.getDetachOperands());
|
|
});
|
|
|
|
target.addDynamicallyLegalOp<acc::ParallelOp>(
|
|
[allDataOperandsAreConverted](acc::ParallelOp op) {
|
|
return allDataOperandsAreConverted(op.getReductionOperands()) &&
|
|
allDataOperandsAreConverted(op.getCopyOperands()) &&
|
|
allDataOperandsAreConverted(op.getCopyinOperands()) &&
|
|
allDataOperandsAreConverted(op.getCopyinReadonlyOperands()) &&
|
|
allDataOperandsAreConverted(op.getCopyoutOperands()) &&
|
|
allDataOperandsAreConverted(op.getCopyoutZeroOperands()) &&
|
|
allDataOperandsAreConverted(op.getCreateOperands()) &&
|
|
allDataOperandsAreConverted(op.getCreateZeroOperands()) &&
|
|
allDataOperandsAreConverted(op.getNoCreateOperands()) &&
|
|
allDataOperandsAreConverted(op.getPresentOperands()) &&
|
|
allDataOperandsAreConverted(op.getDevicePtrOperands()) &&
|
|
allDataOperandsAreConverted(op.getAttachOperands()) &&
|
|
allDataOperandsAreConverted(op.getGangPrivateOperands()) &&
|
|
allDataOperandsAreConverted(op.getGangFirstPrivateOperands());
|
|
});
|
|
|
|
target.addDynamicallyLegalOp<acc::UpdateOp>(
|
|
[allDataOperandsAreConverted](acc::UpdateOp op) {
|
|
return allDataOperandsAreConverted(op.getHostOperands()) &&
|
|
allDataOperandsAreConverted(op.getDeviceOperands());
|
|
});
|
|
|
|
if (failed(applyPartialConversion(op, target, std::move(patterns))))
|
|
signalPassFailure();
|
|
}
|