[flang][CUDA] Unify element size computation in CUF helpers (#167398)

Refactor computeWidth from CUFOpConversion into a shared helper function
computeElementByteSize in CUFCommon.
This commit is contained in:
Zhen Wang 2025-11-10 14:28:32 -08:00 committed by GitHub
parent 95db31e7f6
commit d5125b3089
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 31 additions and 25 deletions

View File

@ -18,6 +18,7 @@ static constexpr llvm::StringRef cudaSharedMemSuffix = "__shared_mem";
namespace fir {
class FirOpBuilder;
class KindMapping;
} // namespace fir
namespace cuf {
@ -34,6 +35,10 @@ bool isRegisteredDeviceAttr(std::optional<cuf::DataAttribute> attr);
void genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder);
int computeElementByteSize(mlir::Location loc, mlir::Type type,
fir::KindMapping &kindMap,
bool emitErrorOnFailure = true);
} // namespace cuf
#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFCOMMON_H_

View File

@ -9,6 +9,7 @@
#include "flang/Optimizer/Builder/CUFCommon.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
@ -91,3 +92,25 @@ void cuf::genPointerSync(const mlir::Value box, fir::FirOpBuilder &builder) {
}
}
}
int cuf::computeElementByteSize(mlir::Location loc, mlir::Type type,
fir::KindMapping &kindMap,
bool emitErrorOnFailure) {
auto eleTy = fir::unwrapSequenceType(type);
if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
return t.getWidth() / 8;
if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
return t.getWidth() / 8;
if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
return kindMap.getLogicalBitsize(t.getFKind()) / 8;
if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
int elemSize =
mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
return 2 * elemSize;
}
if (auto t{mlir::dyn_cast<fir::CharacterType>(eleTy)})
return kindMap.getCharacterBitsize(t.getFKind()) / 8;
if (emitErrorOnFailure)
mlir::emitError(loc, "unsupported type");
return 0;
}

View File

@ -263,28 +263,6 @@ static bool inDeviceContext(mlir::Operation *op) {
return false;
}
static int computeWidth(mlir::Location loc, mlir::Type type,
fir::KindMapping &kindMap) {
auto eleTy = fir::unwrapSequenceType(type);
if (auto t{mlir::dyn_cast<mlir::IntegerType>(eleTy)})
return t.getWidth() / 8;
if (auto t{mlir::dyn_cast<mlir::FloatType>(eleTy)})
return t.getWidth() / 8;
if (eleTy.isInteger(1))
return 1;
if (auto t{mlir::dyn_cast<fir::LogicalType>(eleTy)})
return kindMap.getLogicalBitsize(t.getFKind()) / 8;
if (auto t{mlir::dyn_cast<mlir::ComplexType>(eleTy)}) {
int elemSize =
mlir::cast<mlir::FloatType>(t.getElementType()).getWidth() / 8;
return 2 * elemSize;
}
if (auto t{mlir::dyn_cast_or_null<fir::CharacterType>(eleTy)})
return kindMap.getCharacterBitsize(t.getFKind()) / 8;
mlir::emitError(loc, "unsupported type");
return 0;
}
struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
using OpRewritePattern::OpRewritePattern;
@ -320,7 +298,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
mlir::Value bytes;
fir::KindMapping kindMap{fir::getKindMapping(mod)};
if (fir::isa_trivial(op.getInType())) {
int width = computeWidth(loc, op.getInType(), kindMap);
int width = cuf::computeElementByteSize(loc, op.getInType(), kindMap);
bytes =
builder.createIntegerConstant(loc, builder.getIndexType(), width);
} else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
@ -330,7 +308,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
mlir::Type structTy = typeConverter->convertType(seqTy.getEleTy());
size = dl->getTypeSizeInBits(structTy) / 8;
} else {
size = computeWidth(loc, seqTy.getEleTy(), kindMap);
size = cuf::computeElementByteSize(loc, seqTy.getEleTy(), kindMap);
}
mlir::Value width =
builder.createIntegerConstant(loc, builder.getIndexType(), size);
@ -704,7 +682,7 @@ struct CUFDataTransferOpConversion
typeConverter->convertType(fir::unwrapSequenceType(dstTy));
width = dl->getTypeSizeInBits(structTy) / 8;
} else {
width = computeWidth(loc, dstTy, kindMap);
width = cuf::computeElementByteSize(loc, dstTy, kindMap);
}
mlir::Value widthValue = mlir::arith::ConstantOp::create(
rewriter, loc, i64Ty, rewriter.getIntegerAttr(i64Ty, width));