diff --git a/flang/lib/Optimizer/Builder/FIRBuilder.cpp b/flang/lib/Optimizer/Builder/FIRBuilder.cpp index 9ad37c8df434..8fa695a5c0c2 100644 --- a/flang/lib/Optimizer/Builder/FIRBuilder.cpp +++ b/flang/lib/Optimizer/Builder/FIRBuilder.cpp @@ -479,7 +479,10 @@ mlir::Value fir::factory::createConvert(mlir::OpBuilder &builder, mlir::Location loc, mlir::Type toTy, mlir::Value val) { if (val.getType() != toTy) { - assert(!fir::isa_derived(toTy)); + assert((!fir::isa_derived(toTy) || + mlir::cast(val.getType()).getTypeList() == + mlir::cast(toTy).getTypeList()) && + "incompatible record types"); return builder.create(loc, toTy, val); } return val; diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp index 1611de9e6389..15fcc09c6219 100644 --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -660,6 +660,31 @@ struct ConvertOpConversion : public fir::FIROpConversion { auto loc = convert.getLoc(); auto i1Type = mlir::IntegerType::get(convert.getContext(), 1); + if (mlir::isa(toFirTy)) { + // Convert to compatible BIND(C) record type. + // Double check that the record types are compatible (it should have + // already been checked by the verifier). + assert(mlir::cast(fromFirTy).getTypeList() == + mlir::cast(toFirTy).getTypeList() && + "incompatible record types"); + + auto toStTy = mlir::cast(toTy); + mlir::Value val = rewriter.create(loc, toStTy); + auto indexTypeMap = toStTy.getSubelementIndexMap(); + assert(indexTypeMap.has_value() && "invalid record type"); + + for (auto [attr, type] : indexTypeMap.value()) { + int64_t index = mlir::cast(attr).getInt(); + auto extVal = + rewriter.create(loc, op0, index); + val = + rewriter.create(loc, val, extVal, index); + } + + rewriter.replaceOp(convert, val); + return mlir::success(); + } + if (mlir::isa(fromFirTy) || mlir::isa(toFirTy)) { // By specification fir::LogicalType value may be any number, diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 8fdc06f6fce3..90ce8b876059 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -1410,6 +1410,15 @@ bool fir::ConvertOp::areVectorsCompatible(mlir::Type inTy, mlir::Type outTy) { return true; } +static bool areRecordsCompatible(mlir::Type inTy, mlir::Type outTy) { + // Both records must have the same field types. + // Trust frontend semantics for in-depth checks, such as if both records + // have the BIND(C) attribute. + auto inRecTy = mlir::dyn_cast(inTy); + auto outRecTy = mlir::dyn_cast(outTy); + return inRecTy && outRecTy && inRecTy.getTypeList() == outRecTy.getTypeList(); +} + bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) { if (inType == outType) return true; @@ -1428,7 +1437,8 @@ bool fir::ConvertOp::canBeConverted(mlir::Type inType, mlir::Type outType) { (fir::isBoxedRecordType(inType) && fir::isPolymorphicType(outType)) || (fir::isPolymorphicType(inType) && fir::isPolymorphicType(outType)) || (fir::isPolymorphicType(inType) && mlir::isa(outType)) || - areVectorsCompatible(inType, outType); + areVectorsCompatible(inType, outType) || + areRecordsCompatible(inType, outType); } llvm::LogicalResult fir::ConvertOp::verify() { diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir index 0c17d7c25a8c..1182a0a10f21 100644 --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -816,6 +816,31 @@ func.func @convert_complex16(%arg0 : complex) -> complex { // ----- +// Test `fir.convert` operation conversion between compatible fir.record types. + +func.func @convert_record(%arg0 : !fir.type<_QMmod1Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>) -> + !fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}> { + %0 = fir.convert %arg0 : (!fir.type<_QMmod1Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}>) -> + !fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}> + return %0 : !fir.type<_QMmod2Trec{i:i32,f:f64,c:!llvm.struct<(f32, f32)>,cstr:!fir.array<4x!fir.char<1>>}> +} + +// CHECK-LABEL: func @convert_record( +// CHECK-SAME: %[[ARG0:.*]]: [[MOD1_REC:!llvm.struct<"_QMmod1Trec", \(i32, f64, struct<\(f32, f32\)>, array<4 x array<1 x i8>>\)>]]) -> +// CHECK-SAME: [[MOD2_REC:!llvm.struct<"_QMmod2Trec", \(i32, f64, struct<\(f32, f32\)>, array<4 x array<1 x i8>>\)>]] +// CHECK: %{{.*}} = llvm.mlir.undef : [[MOD2_REC]] +// CHECK-DAG: %[[I:.*]] = llvm.extractvalue %[[ARG0]][0] : [[MOD1_REC]] +// CHECK-DAG: %{{.*}} = llvm.insertvalue %[[I]], %{{.*}}[0] : [[MOD2_REC]] +// CHECK-DAG: %[[F:.*]] = llvm.extractvalue %[[ARG0]][1] : [[MOD1_REC]] +// CHECK-DAG: %{{.*}} = llvm.insertvalue %[[F]], %{{.*}}[1] : [[MOD2_REC]] +// CHECK-DAG: %[[C:.*]] = llvm.extractvalue %[[ARG0]][2] : [[MOD1_REC]] +// CHECK-DAG: %{{.*}} = llvm.insertvalue %[[C]], %{{.*}}[2] : [[MOD2_REC]] +// CHECK-DAG: %[[CSTR:.*]] = llvm.extractvalue %[[ARG0]][3] : [[MOD1_REC]] +// CHECK-DAG: %{{.*}} = llvm.insertvalue %[[CSTR]], %{{.*}}[3] : [[MOD2_REC]] +// CHECK: llvm.return %{{.*}} : [[MOD2_REC]] + +// ----- + // Test `fir.store` --> `llvm.store` conversion func.func @test_store_index(%val_to_store : index, %addr : !fir.ref) { diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir index 086a426db564..7e3f9d649841 100644 --- a/flang/test/Fir/invalid.fir +++ b/flang/test/Fir/invalid.fir @@ -965,6 +965,14 @@ func.func @fp_to_logical(%arg0: f32) -> !fir.logical<4> { // ----- +func.func @rec_to_rec(%arg0: !fir.type) -> !fir.type { + // expected-error@+1{{'fir.convert' op invalid type conversion}} + %0 = fir.convert %arg0 : (!fir.type) -> !fir.type + return %0 : !fir.type +} + +// ----- + func.func @bad_box_offset(%not_a_box : !fir.ref) { // expected-error@+1{{'fir.box_offset' op box_ref operand must have !fir.ref> type}} %addr1 = fir.box_offset %not_a_box base_addr : (!fir.ref) -> !fir.llvm_ptr>