[flang][CUDA] Unify element size computation in CUF helpers (#167398)
Refactor computeWidth from CUFOpConversion into a shared helper function computeElementByteSize in CUFCommon.
This commit is contained in:
parent
95db31e7f6
commit
d5125b3089
@ -18,6 +18,7 @@ static constexpr llvm::StringRef cudaSharedMemSuffix = "__shared_mem";
|
||||
|
||||
namespace fir {
|
||||
class FirOpBuilder;
|
||||
class KindMapping;
|
||||
} // namespace fir
|
||||
|
||||
namespace cuf {
|
||||
@ -34,6 +35,10 @@ bool isRegisteredDeviceAttr(std::optional<cuf::DataAttribute> attr);
|
||||
|
||||
void genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder);
|
||||
|
||||
int computeElementByteSize(mlir::Location loc, mlir::Type type,
|
||||
fir::KindMapping &kindMap,
|
||||
bool emitErrorOnFailure = true);
|
||||
|
||||
} // namespace cuf
|
||||
|
||||
#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
#include "flang/Optimizer/Builder/CUFCommon.h"
|
||||
#include "flang/Optimizer/Builder/FIRBuilder.h"
|
||||
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
|
||||
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
|
||||
#include "flang/Optimizer/HLFIR/HLFIROps.h"
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
@ -91,3 +92,25 @@ void cuf::genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int cuf::computeElementByteSize(mlir::Location loc, mlir::Type type,
|
||||
fir::KindMapping &kindMap,
|
||||
bool emitErrorOnFailure) {
|
||||
auto eleTy = fir::unwrapSequenceType(type);
|
||||
if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
|
||||
return t.getWidth() / 8;
|
||||
if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
|
||||
return t.getWidth() / 8;
|
||||
if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
|
||||
return kindMap.getLogicalBitsize(t.getFKind()) / 8;
|
||||
if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
|
||||
int elemSize =
|
||||
mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
|
||||
return 2 * elemSize;
|
||||
}
|
||||
if (auto t{mlir::dyn_cast<fir::CharacterType>(eleTy)})
|
||||
return kindMap.getCharacterBitsize(t.getFKind()) / 8;
|
||||
if (emitErrorOnFailure)
|
||||
mlir::emitError(loc, "unsupported type");
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -263,28 +263,6 @@ static bool inDeviceContext(mlir::Operation *op) {
|
||||
return false;
|
||||
}
|
||||
|
||||
static int computeWidth(mlir::Location loc, mlir::Type type,
|
||||
fir::KindMapping &kindMap) {
|
||||
auto eleTy = fir::unwrapSequenceType(type);
|
||||
if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
|
||||
return t.getWidth() / 8;
|
||||
if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
|
||||
return t.getWidth() / 8;
|
||||
if (eleTy.isInteger(1))
|
||||
return 1;
|
||||
if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
|
||||
return kindMap.getLogicalBitsize(t.getFKind()) / 8;
|
||||
if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
|
||||
int elemSize =
|
||||
mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
|
||||
return 2 * elemSize;
|
||||
}
|
||||
if (auto t{mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)})
|
||||
return kindMap.getCharacterBitsize(t.getFKind()) / 8;
|
||||
mlir::emitError(loc, "unsupported type");
|
||||
return 0;
|
||||
}
|
||||
|
||||
struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
@ -320,7 +298,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
|
||||
mlir::Value bytes;
|
||||
fir::KindMapping kindMap{fir::getKindMapping(mod)};
|
||||
if (fir::isa_trivial(op.getInType())) {
|
||||
int width = computeWidth(loc, op.getInType(), kindMap);
|
||||
int width = cuf::computeElementByteSize(loc, op.getInType(), kindMap);
|
||||
bytes =
|
||||
builder.createIntegerConstant(loc, builder.getIndexType(), width);
|
||||
} else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
|
||||
@ -330,7 +308,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
|
||||
mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy());
|
||||
size = dl->getTypeSizeInBits(structTy) / 8;
|
||||
} else {
|
||||
size = computeWidth(loc, seqTy.getEleTy(), kindMap);
|
||||
size = cuf::computeElementByteSize(loc, seqTy.getEleTy(), kindMap);
|
||||
}
|
||||
mlir::Value width =
|
||||
builder.createIntegerConstant(loc, builder.getIndexType(), size);
|
||||
@ -704,7 +682,7 @@ struct CUFDataTransferOpConversion
|
||||
typeConverter->convertType(fir::unwrapSequenceType(dstTy));
|
||||
width = dl->getTypeSizeInBits(structTy) / 8;
|
||||
} else {
|
||||
width = computeWidth(loc, dstTy, kindMap);
|
||||
width = cuf::computeElementByteSize(loc, dstTy, kindMap);
|
||||
}
|
||||
mlir::Value widthValue = mlir::arith::ConstantOp::create(
|
||||
rewriter, loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user