[flang] use fir.bitcast for FIRToMemRef scalar reinterpretation (#188328)
Use fir.bitcast in FIR-to-MemRef casts so bit patterns are preserved (e.g. TRANSFER), while keeping fir.convert for memref/reference marshaling and non-bitcast-compatible cases.
This commit is contained in:
parent
2c0e63c9b6
commit
55111e8d17
@ -383,6 +383,46 @@ static Value castTypeToIndexType(Value originalValue,
|
||||
originalValue);
|
||||
}
|
||||
|
||||
static bool shouldUseBoundaryBitcast(mlir::Type fromTy, mlir::Type toTy) {
|
||||
auto isBitcastCompatibleScalarType = [](mlir::Type ty) {
|
||||
return mlir::isa<mlir::IntegerType, mlir::FloatType, fir::LogicalType>(
|
||||
ty) ||
|
||||
(mlir::isa<fir::CharacterType>(ty) &&
|
||||
mlir::cast<fir::CharacterType>(ty).getLen() ==
|
||||
fir::CharacterType::singleton());
|
||||
};
|
||||
auto getKnownScalarBitWidth = [](mlir::Type ty) -> std::optional<unsigned> {
|
||||
if (auto intTy = mlir::dyn_cast<mlir::IntegerType>(ty))
|
||||
return intTy.getWidth();
|
||||
if (auto floatTy = mlir::dyn_cast<mlir::FloatType>(ty))
|
||||
return floatTy.getWidth();
|
||||
return std::nullopt;
|
||||
};
|
||||
|
||||
if (fromTy == toTy)
|
||||
return false;
|
||||
const bool fromStd = fir::isa_std_type(fromTy);
|
||||
const bool toStd = fir::isa_std_type(toTy);
|
||||
if (fromStd == toStd)
|
||||
return false;
|
||||
if (!isBitcastCompatibleScalarType(fromTy) ||
|
||||
!isBitcastCompatibleScalarType(toTy))
|
||||
return false;
|
||||
auto fromBits = getKnownScalarBitWidth(fromTy);
|
||||
auto toBits = getKnownScalarBitWidth(toTy);
|
||||
if (fromBits && toBits && *fromBits != *toBits)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
static mlir::Value createTypeConversion(PatternRewriter &rewriter,
|
||||
mlir::Location loc, mlir::Type toTy,
|
||||
mlir::Value value) {
|
||||
if (shouldUseBoundaryBitcast(value.getType(), toTy))
|
||||
return fir::BitcastOp::create(rewriter, loc, toTy, value);
|
||||
return fir::ConvertOp::create(rewriter, loc, toTy, value);
|
||||
}
|
||||
|
||||
FailureOr<SmallVector<Value>>
|
||||
FIRToMemRef::getMemrefIndices(fir::ArrayCoorOp arrayCoorOp, Operation *memref,
|
||||
PatternRewriter &rewriter, Value converted,
|
||||
@ -983,11 +1023,10 @@ void FIRToMemRef::rewriteLoadOp(fir::LoadOp load, PatternRewriter &rewriter,
|
||||
LLVM_DEBUG(llvm::dbgs() << "FIRToMemRef: new memref.load op:\n";
|
||||
loadOp.dump(); assert(succeeded(verify(loadOp))));
|
||||
|
||||
if (isa<fir::LogicalType>(originalType)) {
|
||||
Value logicalVal =
|
||||
fir::ConvertOp::create(rewriter, loadOp.getLoc(), originalType, loadOp);
|
||||
loadOp.getResult().replaceAllUsesExcept(logicalVal,
|
||||
logicalVal.getDefiningOp());
|
||||
if (loadOp.getType() != originalType) {
|
||||
Value castVal =
|
||||
createTypeConversion(rewriter, loadOp.getLoc(), originalType, loadOp);
|
||||
loadOp.getResult().replaceAllUsesExcept(castVal, castVal.getDefiningOp());
|
||||
}
|
||||
|
||||
if (!isa<fir::LogicalType>(originalType))
|
||||
@ -1019,11 +1058,10 @@ void FIRToMemRef::rewriteStoreOp(fir::StoreOp store, PatternRewriter &rewriter,
|
||||
Value value = store.getValue();
|
||||
rewriter.setInsertionPointAfter(store);
|
||||
|
||||
if (isa<fir::LogicalType>(value.getType())) {
|
||||
Type convertedType = typeConverter.convertType(value.getType());
|
||||
Type convertedType = typeConverter.convertType(value.getType());
|
||||
if (convertedType != value.getType())
|
||||
value =
|
||||
fir::ConvertOp::create(rewriter, store.getLoc(), convertedType, value);
|
||||
}
|
||||
createTypeConversion(rewriter, store.getLoc(), convertedType, value);
|
||||
|
||||
Attribute attr = (store.getOperation())->getAttr("tbaa");
|
||||
memref::StoreOp storeOp = rewriter.replaceOpWithNewOp<memref::StoreOp>(
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
// CHECK-NEXT: [[DECLARE:%[0-9]+]] = fir.declare %arg0 dummy_scope [[DUMMY]]
|
||||
// CHECK-NEXT: [[CONVERT:%[0-9]+]] = fir.convert [[DECLARE]] : (!fir.ref<!fir.logical<4>>) -> memref<i32>
|
||||
// CHECK-NEXT: [[LOAD:%[0-9]+]] = memref.load [[CONVERT]][] : memref<i32>
|
||||
// CHECK-NEXT: fir.convert [[LOAD]] : (i32) -> !fir.logical<4>
|
||||
// CHECK-NEXT: fir.bitcast [[LOAD]] : (i32) -> !fir.logical<4>
|
||||
func.func @load_scalar(%arg0: !fir.ref<!fir.logical<4>>) {
|
||||
%0 = fir.undefined !fir.dscope
|
||||
%1 = fir.declare %arg0 dummy_scope %0 {uniq_name = "a"} : (!fir.ref<!fir.logical<4>>, !fir.dscope) -> !fir.ref<!fir.logical<4>>
|
||||
@ -18,7 +18,7 @@ func.func @load_scalar(%arg0: !fir.ref<!fir.logical<4>>) {
|
||||
// CHECK: [[DECLARE:%[0-9]+]] = fir.declare %arg0 dummy_scope [[DUMMY]]
|
||||
// CHECK-NEXT: [[CONVERT:%[0-9]+]] = fir.convert [[CONSTTRUE]] : (i1) -> !fir.logical<4>
|
||||
// CHECK-NEXT: [[CONVERT1:%[0-9]+]] = fir.convert [[DECLARE]] : (!fir.ref<!fir.logical<4>>) -> memref<i32>
|
||||
// CHECK-NEXT: [[INT:%[0-9]+]] = fir.convert [[CONVERT]] : (!fir.logical<4>) -> i32
|
||||
// CHECK-NEXT: [[INT:%[0-9]+]] = fir.bitcast [[CONVERT]] : (!fir.logical<4>) -> i32
|
||||
// CHECK-NEXT: memref.store [[INT]], [[CONVERT1]][] : memref<i32>
|
||||
func.func @store_scalar(%arg0: !fir.ref<!fir.logical<4>>) {
|
||||
%true = arith.constant true
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user