[mlir][OpenMP] Add __atomic_store to AtomicInfo (#121055)

This PR adds functionality for `__atomic_store` libcall in AtomicInfo.
This allows for supporting complex types in `atomic write`.

Fixes https://github.com/llvm/llvm-project/issues/113479
Fixes https://github.com/llvm/llvm-project/issues/115652
This commit is contained in:
NimishMishra 2025-04-29 20:23:36 +05:30 committed by GitHub
parent 6ffccea1c2
commit b62afbccc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 108 additions and 6 deletions

View File

@ -97,6 +97,8 @@ public:
bool IsVolatile, bool IsWeak);
std::pair<LoadInst *, AllocaInst *> EmitAtomicLoadLibcall(AtomicOrdering AO);
void EmitAtomicStoreLibcall(AtomicOrdering AO, Value *Source);
};
} // end namespace llvm

View File

@ -3285,11 +3285,12 @@ public:
/// \param Expr The value to store.
/// \param AO Atomic ordering of the generated atomic
/// instructions.
/// \param AllocaIP Insert point for allocas
///
/// \return Insertion point after generated atomic Write IR.
InsertPointTy createAtomicWrite(const LocationDescription &Loc,
AtomicOpValue &X, Value *Expr,
AtomicOrdering AO);
AtomicOrdering AO, InsertPointTy AllocaIP);
/// Emit atomic update for constructs: X = X BinOp Expr ,or X = Expr BinOp X
/// For complex Operations: X = UpdateOp(X) => CmpExch X, old_X, UpdateOp(X)

View File

@ -145,6 +145,42 @@ AtomicInfo::EmitAtomicLoadLibcall(AtomicOrdering AO) {
AllocaResult);
}
void AtomicInfo::EmitAtomicStoreLibcall(AtomicOrdering AO, Value *Source) {
LLVMContext &Ctx = getLLVMContext();
SmallVector<Value *, 6> Args;
AttributeList Attr;
Module *M = Builder->GetInsertBlock()->getModule();
const DataLayout &DL = M->getDataLayout();
Args.push_back(
ConstantInt::get(DL.getIntPtrType(Ctx), this->getAtomicSizeInBits() / 8));
Value *PtrVal = getAtomicPointer();
PtrVal = Builder->CreateAddrSpaceCast(PtrVal, PointerType::getUnqual(Ctx));
Args.push_back(PtrVal);
auto CurrentIP = Builder->saveIP();
Builder->restoreIP(AllocaIP);
Value *SourceAlloca = Builder->CreateAlloca(Source->getType());
Builder->restoreIP(CurrentIP);
Builder->CreateStore(Source, SourceAlloca);
SourceAlloca = Builder->CreatePointerBitCastOrAddrSpaceCast(
SourceAlloca, Builder->getPtrTy());
Args.push_back(SourceAlloca);
Constant *OrderingVal =
ConstantInt::get(Type::getInt32Ty(Ctx), (int)toCABI(AO));
Args.push_back(OrderingVal);
SmallVector<Type *, 6> ArgTys;
for (Value *Arg : Args)
ArgTys.push_back(Arg->getType());
FunctionType *FnType = FunctionType::get(Type::getVoidTy(Ctx), ArgTys, false);
FunctionCallee LibcallFn =
M->getOrInsertFunction("__atomic_store", FnType, Attr);
CallInst *Call = Builder->CreateCall(LibcallFn, Args);
Call->setAttributes(Attr);
}
std::pair<Value *, Value *> AtomicInfo::EmitAtomicCompareExchange(
Value *ExpectedVal, Value *DesiredVal, AtomicOrdering Success,
AtomicOrdering Failure, bool IsVolatile, bool IsWeak) {

View File

@ -8684,7 +8684,7 @@ OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
OpenMPIRBuilder::InsertPointTy
OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
AtomicOpValue &X, Value *Expr,
AtomicOrdering AO) {
AtomicOrdering AO, InsertPointTy AllocaIP) {
if (!updateToLocation(Loc))
return Loc.IP;
@ -8692,12 +8692,22 @@ OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
"OMP Atomic expects a pointer to target memory");
Type *XElemTy = X.ElemTy;
assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
XElemTy->isPointerTy()) &&
XElemTy->isPointerTy() || XElemTy->isStructTy()) &&
"OMP atomic write expected a scalar type");
if (XElemTy->isIntegerTy()) {
StoreInst *XSt = Builder.CreateStore(Expr, X.Var, X.IsVolatile);
XSt->setAtomic(AO);
} else if (XElemTy->isStructTy()) {
LoadInst *OldVal = Builder.CreateLoad(XElemTy, X.Var, "omp.atomic.read");
const DataLayout &LoadDL = OldVal->getModule()->getDataLayout();
unsigned LoadSize =
LoadDL.getTypeStoreSize(OldVal->getPointerOperand()->getType());
OpenMPIRBuilder::AtomicInfo atomicInfo(
&Builder, XElemTy, LoadSize * 8, LoadSize * 8, OldVal->getAlign(),
OldVal->getAlign(), true /* UseLibcall */, AllocaIP, X.Var);
atomicInfo.EmitAtomicStoreLibcall(AO, Expr);
OldVal->eraseFromParent();
} else {
// We need to bitcast and perform atomic op as integers
IntegerType *IntCastTy =

View File

@ -3875,6 +3875,9 @@ TEST_F(OpenMPIRBuilderTest, OMPAtomicWriteFlt) {
IRBuilder<> Builder(BB);
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
BasicBlock *EntryBB = BB;
OpenMPIRBuilder::InsertPointTy AllocaIP(EntryBB,
EntryBB->getFirstInsertionPt());
LLVMContext &Ctx = M->getContext();
Type *Float32 = Type::getFloatTy(Ctx);
@ -3884,7 +3887,8 @@ TEST_F(OpenMPIRBuilderTest, OMPAtomicWriteFlt) {
AtomicOrdering AO = AtomicOrdering::Monotonic;
Constant *ValToWrite = ConstantFP::get(Float32, 1.0);
Builder.restoreIP(OMPBuilder.createAtomicWrite(Loc, X, ValToWrite, AO));
Builder.restoreIP(
OMPBuilder.createAtomicWrite(Loc, X, ValToWrite, AO, AllocaIP));
IntegerType *IntCastTy =
IntegerType::get(M->getContext(), Float32->getScalarSizeInBits());
@ -3918,8 +3922,11 @@ TEST_F(OpenMPIRBuilderTest, OMPAtomicWriteInt) {
ConstantInt *ValToWrite = ConstantInt::get(Type::getInt32Ty(Ctx), 1U);
BasicBlock *EntryBB = BB;
OpenMPIRBuilder::InsertPointTy AllocaIP(EntryBB,
EntryBB->getFirstInsertionPt());
Builder.restoreIP(OMPBuilder.createAtomicWrite(Loc, X, ValToWrite, AO));
Builder.restoreIP(
OMPBuilder.createAtomicWrite(Loc, X, ValToWrite, AO, AllocaIP));
StoreInst *StoreofAtomic = nullptr;

View File

@ -2808,6 +2808,8 @@ convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
return failure();
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrder());
@ -2816,7 +2818,8 @@ convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::Type *ty = moduleTranslation.convertType(writeOp.getExpr().getType());
llvm::OpenMPIRBuilder::AtomicOpValue x = {dest, ty, /*isSigned=*/false,
/*isVolatile=*/false};
builder.restoreIP(ompBuilder->createAtomicWrite(ompLoc, x, expr, ao));
builder.restoreIP(
ompBuilder->createAtomicWrite(ompLoc, x, expr, ao, allocaIP));
return success();
}

View File

@ -1481,6 +1481,49 @@ llvm.func @omp_atomic_update(%x:!llvm.ptr, %expr: i32, %xbool: !llvm.ptr, %exprb
// -----
// CHECK-LABEL: @omp_atomic_write
llvm.func @omp_atomic_write() {
// CHECK: %[[ALLOCA0:.*]] = alloca { float, float }, align 8
// CHECK: %[[ALLOCA1:.*]] = alloca { float, float }, align 8
// CHECK: %[[X:.*]] = alloca float, i64 1, align 4
// CHECK: %[[R1:.*]] = alloca float, i64 1, align 4
// CHECK: %[[ALLOCA:.*]] = alloca { float, float }, i64 1, align 8
// CHECK: %[[LOAD:.*]] = load float, ptr %[[R1]], align 4
// CHECK: %[[IDX1:.*]] = insertvalue { float, float } undef, float %[[LOAD]], 0
// CHECK: %[[IDX2:.*]] = insertvalue { float, float } %[[IDX1]], float 0.000000e+00, 1
// CHECK: br label %entry
// CHECK: entry:
// CHECK: store { float, float } %[[IDX2]], ptr %[[ALLOCA1]], align 4
// CHECK: call void @__atomic_store(i64 8, ptr %[[ALLOCA]], ptr %[[ALLOCA1]], i32 0)
// CHECK: store { float, float } { float 1.000000e+00, float 1.000000e+00 }, ptr %[[ALLOCA0]], align 4
// CHECK: call void @__atomic_store(i64 8, ptr %[[ALLOCA]], ptr %[[ALLOCA0]], i32 0)
%0 = llvm.mlir.constant(1 : i64) : i64
%1 = llvm.alloca %0 x f32 {bindc_name = "x"} : (i64) -> !llvm.ptr
%2 = llvm.mlir.constant(1 : i64) : i64
%3 = llvm.alloca %2 x f32 {bindc_name = "r1"} : (i64) -> !llvm.ptr
%4 = llvm.mlir.constant(1 : i64) : i64
%5 = llvm.alloca %4 x !llvm.struct<(f32, f32)> {bindc_name = "c1"} : (i64) -> !llvm.ptr
%6 = llvm.mlir.constant(1.000000e+00 : f32) : f32
%7 = llvm.mlir.constant(0.000000e+00 : f32) : f32
%8 = llvm.mlir.constant(1 : i64) : i64
%9 = llvm.mlir.constant(1 : i64) : i64
%10 = llvm.mlir.constant(1 : i64) : i64
%11 = llvm.load %3 : !llvm.ptr -> f32
%12 = llvm.mlir.undef : !llvm.struct<(f32, f32)>
%13 = llvm.insertvalue %11, %12[0] : !llvm.struct<(f32, f32)>
%14 = llvm.insertvalue %7, %13[1] : !llvm.struct<(f32, f32)>
omp.atomic.write %5 = %14 : !llvm.ptr, !llvm.struct<(f32, f32)>
%15 = llvm.mlir.undef : !llvm.struct<(f32, f32)>
%16 = llvm.insertvalue %6, %15[0] : !llvm.struct<(f32, f32)>
%17 = llvm.insertvalue %6, %16[1] : !llvm.struct<(f32, f32)>
omp.atomic.write %5 = %17 : !llvm.ptr, !llvm.struct<(f32, f32)>
llvm.return
}
// -----
//CHECK: %[[ATOMIC_TEMP_LOAD:.*]] = alloca { float, float }, align 8
//CHECK: %[[X_NEW_VAL:.*]] = alloca { float, float }, align 8
//CHECK: {{.*}} = alloca { float, float }, i64 1, align 8