From 55111e8d171dad5cefab756cfd443707fbc69aad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Susan=20Tan=20=28=E3=82=B9-=E3=82=B6=E3=83=B3=E3=80=80?= =?UTF-8?q?=E3=82=BF=E3=83=B3=29?= Date: Wed, 25 Mar 2026 15:27:43 -0400 Subject: [PATCH] [flang] use fir.bitcast for FIRToMemRef scalar reinterpretation (#188328) Use fir.bitcast in FIR-to-MemRef casts so bit patterns are preserved (e.g. TRANSFER), while keeping fir.convert for memref/reference marshaling and non-bitcast-compatible cases. --- .../lib/Optimizer/Transforms/FIRToMemRef.cpp | 56 ++++++++++++++++--- .../test/Transforms/FIRToMemRef/logical.mlir | 4 +- 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp index 447ee9c35f81..3b0b4bc007e6 100644 --- a/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp +++ b/flang/lib/Optimizer/Transforms/FIRToMemRef.cpp @@ -383,6 +383,46 @@ static Value castTypeToIndexType(Value originalValue, originalValue); } +static bool shouldUseBoundaryBitcast(mlir::Type fromTy, mlir::Type toTy) { + auto isBitcastCompatibleScalarType = [](mlir::Type ty) { + return mlir::isa( + ty) || + (mlir::isa(ty) && + mlir::cast(ty).getLen() == + fir::CharacterType::singleton()); + }; + auto getKnownScalarBitWidth = [](mlir::Type ty) -> std::optional { + if (auto intTy = mlir::dyn_cast(ty)) + return intTy.getWidth(); + if (auto floatTy = mlir::dyn_cast(ty)) + return floatTy.getWidth(); + return std::nullopt; + }; + + if (fromTy == toTy) + return false; + const bool fromStd = fir::isa_std_type(fromTy); + const bool toStd = fir::isa_std_type(toTy); + if (fromStd == toStd) + return false; + if (!isBitcastCompatibleScalarType(fromTy) || + !isBitcastCompatibleScalarType(toTy)) + return false; + auto fromBits = getKnownScalarBitWidth(fromTy); + auto toBits = getKnownScalarBitWidth(toTy); + if (fromBits && toBits && *fromBits != *toBits) + return false; + return true; +} + +static mlir::Value createTypeConversion(PatternRewriter &rewriter, + mlir::Location loc, mlir::Type toTy, + mlir::Value value) { + if (shouldUseBoundaryBitcast(value.getType(), toTy)) + return fir::BitcastOp::create(rewriter, loc, toTy, value); + return fir::ConvertOp::create(rewriter, loc, toTy, value); +} + FailureOr> FIRToMemRef::getMemrefIndices(fir::ArrayCoorOp arrayCoorOp, Operation *memref, PatternRewriter &rewriter, Value converted, @@ -983,11 +1023,10 @@ void FIRToMemRef::rewriteLoadOp(fir::LoadOp load, PatternRewriter &rewriter, LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: new memref.load op:\n"; loadOp.dump(); assert(succeeded(verify(loadOp)))); - if (isa(originalType)) { - Value logicalVal = - fir::ConvertOp::create(rewriter, loadOp.getLoc(), originalType, loadOp); - loadOp.getResult().replaceAllUsesExcept(logicalVal, - logicalVal.getDefiningOp()); + if (loadOp.getType() != originalType) { + Value castVal = + createTypeConversion(rewriter, loadOp.getLoc(), originalType, loadOp); + loadOp.getResult().replaceAllUsesExcept(castVal, castVal.getDefiningOp()); } if (!isa(originalType)) @@ -1019,11 +1058,10 @@ void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter, Value value = store.getValue(); rewriter.setInsertionPointAfter(store); - if (isa(value.getType())) { - Type convertedType = typeConverter.convertType(value.getType()); + Type convertedType = typeConverter.convertType(value.getType()); + if (convertedType != value.getType()) value = - fir::ConvertOp::create(rewriter, store.getLoc(), convertedType, value); - } + createTypeConversion(rewriter, store.getLoc(), convertedType, value); Attribute attr = (store.getOperation())->getAttr("tbaa"); memref::StoreOp storeOp = rewriter.replaceOpWithNewOp( diff --git a/flang/test/Transforms/FIRToMemRef/logical.mlir b/flang/test/Transforms/FIRToMemRef/logical.mlir index 75a9fac3e1e4..948b8dcb2ae6 100644 --- a/flang/test/Transforms/FIRToMemRef/logical.mlir +++ b/flang/test/Transforms/FIRToMemRef/logical.mlir @@ -4,7 +4,7 @@ // CHECK-NEXT: [[DECLARE:%[0-9]+]] = fir.declare %arg0 dummy_scope [[DUMMY]] // CHECK-NEXT: [[CONVERT:%[0-9]+]] = fir.convert [[DECLARE]] : (!fir.ref>) -> memref // CHECK-NEXT: [[LOAD:%[0-9]+]] = memref.load [[CONVERT]][] : memref -// CHECK-NEXT: fir.convert [[LOAD]] : (i32) -> !fir.logical<4> +// CHECK-NEXT: fir.bitcast [[LOAD]] : (i32) -> !fir.logical<4> func.func @load_scalar(%arg0: !fir.ref>) { %0 = fir.undefined !fir.dscope %1 = fir.declare %arg0 dummy_scope %0 {uniq_name = "a"} : (!fir.ref>, !fir.dscope) -> !fir.ref> @@ -18,7 +18,7 @@ func.func @load_scalar(%arg0: !fir.ref>) { // CHECK: [[DECLARE:%[0-9]+]] = fir.declare %arg0 dummy_scope [[DUMMY]] // CHECK-NEXT: [[CONVERT:%[0-9]+]] = fir.convert [[CONSTTRUE]] : (i1) -> !fir.logical<4> // CHECK-NEXT: [[CONVERT1:%[0-9]+]] = fir.convert [[DECLARE]] : (!fir.ref>) -> memref -// CHECK-NEXT: [[INT:%[0-9]+]] = fir.convert [[CONVERT]] : (!fir.logical<4>) -> i32 +// CHECK-NEXT: [[INT:%[0-9]+]] = fir.bitcast [[CONVERT]] : (!fir.logical<4>) -> i32 // CHECK-NEXT: memref.store [[INT]], [[CONVERT1]][] : memref func.func @store_scalar(%arg0: !fir.ref>) { %true = arith.constant true