[flang] use fir.bitcast for FIRToMemRef scalar reinterpretation (#188328)

Use fir.bitcast in FIR-to-MemRef casts so bit patterns are preserved
(e.g. TRANSFER), while keeping fir.convert for memref/reference
marshaling and non-bitcast-compatible cases.
This commit is contained in:
Susan Tan (ス-ザン タン) 2026-03-25 15:27:43 -04:00 committed by GitHub
parent 2c0e63c9b6
commit 55111e8d17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 11 deletions

View File

@ -383,6 +383,46 @@ static Value castTypeToIndexType(Value originalValue,
originalValue);
}
static bool shouldUseBoundaryBitcast(mlir::Type fromTy, mlir::Type toTy) {
auto isBitcastCompatibleScalarType = [](mlir::Type ty) {
return mlir::isa<mlir::IntegerType, mlir::FloatType, fir::LogicalType>(
ty) ||
(mlir::isa<fir::CharacterType>(ty) &&
mlir::cast<fir::CharacterType>(ty).getLen() ==
fir::CharacterType::singleton());
};
auto getKnownScalarBitWidth = [](mlir::Type ty) -> std::optional<unsigned> {
if (auto intTy = mlir::dyn_cast<mlir::IntegerType>(ty))
return intTy.getWidth();
if (auto floatTy = mlir::dyn_cast<mlir::FloatType>(ty))
return floatTy.getWidth();
return std::nullopt;
};
if (fromTy == toTy)
return false;
const bool fromStd = fir::isa_std_type(fromTy);
const bool toStd = fir::isa_std_type(toTy);
if (fromStd == toStd)
return false;
if (!isBitcastCompatibleScalarType(fromTy) ||
!isBitcastCompatibleScalarType(toTy))
return false;
auto fromBits = getKnownScalarBitWidth(fromTy);
auto toBits = getKnownScalarBitWidth(toTy);
if (fromBits && toBits && *fromBits != *toBits)
return false;
return true;
}
static mlir::Value createTypeConversion(PatternRewriter &rewriter,
mlir::Location loc, mlir::Type toTy,
mlir::Value value) {
if (shouldUseBoundaryBitcast(value.getType(), toTy))
return fir::BitcastOp::create(rewriter, loc, toTy, value);
return fir::ConvertOp::create(rewriter, loc, toTy, value);
}
FailureOr<SmallVector<Value>>
FIRToMemRef::getMemrefIndices(fir::ArrayCoorOp arrayCoorOp, Operation *memref,
PatternRewriter &rewriter, Value converted,
@ -983,11 +1023,10 @@ void FIRToMemRef::rewriteLoadOp(fir::LoadOp load, PatternRewriter &rewriter,
LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: new memref.load op:\n";
loadOp.dump(); assert(succeeded(verify(loadOp))));
if (isa<fir::LogicalType>(originalType)) {
Value logicalVal =
fir::ConvertOp::create(rewriter, loadOp.getLoc(), originalType, loadOp);
loadOp.getResult().replaceAllUsesExcept(logicalVal,
logicalVal.getDefiningOp());
if (loadOp.getType() != originalType) {
Value castVal =
createTypeConversion(rewriter, loadOp.getLoc(), originalType, loadOp);
loadOp.getResult().replaceAllUsesExcept(castVal, castVal.getDefiningOp());
}
if (!isa<fir::LogicalType>(originalType))
@ -1019,11 +1058,10 @@ void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter,
Value value = store.getValue();
rewriter.setInsertionPointAfter(store);
if (isa<fir::LogicalType>(value.getType())) {
Type convertedType = typeConverter.convertType(value.getType());
Type convertedType = typeConverter.convertType(value.getType());
if (convertedType != value.getType())
value =
fir::ConvertOp::create(rewriter, store.getLoc(), convertedType, value);
}
createTypeConversion(rewriter, store.getLoc(), convertedType, value);
Attribute attr = (store.getOperation())->getAttr("tbaa");
memref::StoreOp storeOp = rewriter.replaceOpWithNewOp<memref::StoreOp>(

View File

@ -4,7 +4,7 @@
// CHECK-NEXT: [[DECLARE:%[0-9]+]] = fir.declare %arg0 dummy_scope [[DUMMY]]
// CHECK-NEXT: [[CONVERT:%[0-9]+]] = fir.convert [[DECLARE]] : (!fir.ref<!fir.logical<4>>) -> memref<i32>
// CHECK-NEXT: [[LOAD:%[0-9]+]] = memref.load [[CONVERT]][] : memref<i32>
// CHECK-NEXT: fir.convert [[LOAD]] : (i32) -> !fir.logical<4>
// CHECK-NEXT: fir.bitcast [[LOAD]] : (i32) -> !fir.logical<4>
func.func @load_scalar(%arg0: !fir.ref<!fir.logical<4>>) {
%0 = fir.undefined !fir.dscope
%1 = fir.declare %arg0 dummy_scope %0 {uniq_name = "a"} : (!fir.ref<!fir.logical<4>>, !fir.dscope) -> !fir.ref<!fir.logical<4>>
@ -18,7 +18,7 @@ func.func @load_scalar(%arg0: !fir.ref<!fir.logical<4>>) {
// CHECK: [[DECLARE:%[0-9]+]] = fir.declare %arg0 dummy_scope [[DUMMY]]
// CHECK-NEXT: [[CONVERT:%[0-9]+]] = fir.convert [[CONSTTRUE]] : (i1) -> !fir.logical<4>
// CHECK-NEXT: [[CONVERT1:%[0-9]+]] = fir.convert [[DECLARE]] : (!fir.ref<!fir.logical<4>>) -> memref<i32>
// CHECK-NEXT: [[INT:%[0-9]+]] = fir.convert [[CONVERT]] : (!fir.logical<4>) -> i32
// CHECK-NEXT: [[INT:%[0-9]+]] = fir.bitcast [[CONVERT]] : (!fir.logical<4>) -> i32
// CHECK-NEXT: memref.store [[INT]], [[CONVERT1]][] : memref<i32>
func.func @store_scalar(%arg0: !fir.ref<!fir.logical<4>>) {
%true = arith.constant true