[Attributor] Take the address space from addrspacecast directly (#108258)

Currently `AAAddressSpace` relies on identifying the address spaces of
all underlying objects. However, it might infer sub-optimal address
space when the underlying object is a function argument. In
`AMDGPUPromoteKernelArgumentsPass`, the promotion of a pointer kernel
argument is by adding a series of `addrspacecast` instructions (as shown
below), and hoping `InferAddressSpacePass` can pick it up and do the
rewriting accordingly.

Before promotion:

```
define amdgpu_kernel void @kernel(ptr %to_be_promoted) {
  %val = load i32, ptr %to_be_promoted
  ...
  ret void
}
```

After promotion:

```
define amdgpu_kernel void @kernel(ptr %to_be_promoted) {
  %ptr.cast.0 = addrspace cast ptr % to_be_promoted to ptr addrspace(1)
  %ptr.cast.1 = addrspace cast ptr addrspace(1) %ptr.cast.0 to ptr
  # all the use of %to_be_promoted will use %ptr.cast.1
  %val = load i32, ptr %ptr.cast.1
  ...
  ret void
}
```

When `AAAddressSpace` analyzes the code after promotion, it will take
`%to_be_promoted` as the underlying object of `%ptr.cast.1`, and use its
address space (which is 0) as its final address space, thus simply do
nothing in `manifest`. The attributor framework will them eliminate the
address space cast from 0 to 1 and back to 0, and replace `%ptr.cast.1`
with `%to_be_promoted`, which basically reverts all changes by
`AMDGPUPromoteKernelArgumentsPass`.

IMHO I'm not sure if `AMDGPUPromoteKernelArgumentsPass` promotes the
argument in a proper way. To improve the handling of this case, this PR
adds an extra handling when iterating over all underlying objects. If an
underlying object is a function argument, it means it reaches a terminal
such that we can't futher deduce its underlying object further. In this
case, we check all uses of the argument. If they are all `addrspacecast`
instructions and their destination address spaces are same, we take the
destination address space.

Fixes: SWDEV-482640.
This commit is contained in:
Shilei Tian 2024-10-09 22:51:07 -04:00 committed by GitHub
parent 03229e7c0b
commit 5a74a4a667
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 80 additions and 14 deletions

View File

@ -12583,16 +12583,36 @@ struct AAAddressSpaceImpl : public AAAddressSpace {
}
ChangeStatus updateImpl(Attributor &A) override {
unsigned FlatAS = A.getInfoCache().getFlatAddressSpace().value();
uint32_t OldAddressSpace = AssumedAddressSpace;
auto *AUO = A.getOrCreateAAFor<AAUnderlyingObjects>(getIRPosition(), this,
DepClassTy::REQUIRED);
auto Pred = [&](Value &Obj) {
auto CheckAddressSpace = [&](Value &Obj) {
if (isa<UndefValue>(&Obj))
return true;
// If an argument in flat address space only has addrspace cast uses, and
// those casts are same, then we take the dst addrspace.
if (auto *Arg = dyn_cast<Argument>(&Obj)) {
if (Arg->getType()->getPointerAddressSpace() == FlatAS) {
unsigned CastAddrSpace = FlatAS;
for (auto *U : Arg->users()) {
auto *ASCI = dyn_cast<AddrSpaceCastInst>(U);
if (!ASCI)
return takeAddressSpace(Obj.getType()->getPointerAddressSpace());
if (CastAddrSpace != FlatAS &&
CastAddrSpace != ASCI->getDestAddressSpace())
return false;
CastAddrSpace = ASCI->getDestAddressSpace();
}
if (CastAddrSpace != FlatAS)
return takeAddressSpace(CastAddrSpace);
}
}
return takeAddressSpace(Obj.getType()->getPointerAddressSpace());
};
if (!AUO->forallUnderlyingObjects(Pred))
auto *AUO = A.getOrCreateAAFor<AAUnderlyingObjects>(getIRPosition(), this,
DepClassTy::REQUIRED);
if (!AUO->forallUnderlyingObjects(CheckAddressSpace))
return indicatePessimisticFixpoint();
return OldAddressSpace == AssumedAddressSpace ? ChangeStatus::UNCHANGED
@ -12601,17 +12621,21 @@ struct AAAddressSpaceImpl : public AAAddressSpace {
/// See AbstractAttribute::manifest(...).
ChangeStatus manifest(Attributor &A) override {
if (getAddressSpace() == InvalidAddressSpace ||
getAddressSpace() == getAssociatedType()->getPointerAddressSpace())
unsigned NewAS = getAddressSpace();
if (NewAS == InvalidAddressSpace ||
NewAS == getAssociatedType()->getPointerAddressSpace())
return ChangeStatus::UNCHANGED;
unsigned FlatAS = A.getInfoCache().getFlatAddressSpace().value();
Value *AssociatedValue = &getAssociatedValue();
Value *OriginalValue = peelAddrspacecast(AssociatedValue);
Value *OriginalValue = peelAddrspacecast(AssociatedValue, FlatAS);
PointerType *NewPtrTy =
PointerType::get(getAssociatedType()->getContext(), getAddressSpace());
PointerType::get(getAssociatedType()->getContext(), NewAS);
bool UseOriginalValue =
OriginalValue->getType()->getPointerAddressSpace() == getAddressSpace();
OriginalValue->getType()->getPointerAddressSpace() == NewAS;
bool Changed = false;
@ -12671,12 +12695,19 @@ private:
return AssumedAddressSpace == AS;
}
static Value *peelAddrspacecast(Value *V) {
if (auto *I = dyn_cast<AddrSpaceCastInst>(V))
return peelAddrspacecast(I->getPointerOperand());
static Value *peelAddrspacecast(Value *V, unsigned FlatAS) {
if (auto *I = dyn_cast<AddrSpaceCastInst>(V)) {
assert(I->getSrcAddressSpace() != FlatAS &&
"there should not be flat AS -> non-flat AS");
return I->getPointerOperand();
}
if (auto *C = dyn_cast<ConstantExpr>(V))
if (C->getOpcode() == Instruction::AddrSpaceCast)
return peelAddrspacecast(C->getOperand(0));
if (C->getOpcode() == Instruction::AddrSpaceCast) {
assert(C->getOperand(0)->getType()->getPointerAddressSpace() !=
FlatAS &&
"there should not be flat AS -> non-flat AS X");
return C->getOperand(0);
}
return V;
}
};

View File

@ -243,3 +243,38 @@ define void @foo(ptr addrspace(3) %val) {
ret void
}
define void @kernel_argument_promotion_pattern_intra_procedure(ptr %p, i32 %val) {
; CHECK-LABEL: define void @kernel_argument_promotion_pattern_intra_procedure(
; CHECK-SAME: ptr [[P:%.*]], i32 [[VAL:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[P_CAST_0:%.*]] = addrspacecast ptr [[P]] to ptr addrspace(1)
; CHECK-NEXT: store i32 [[VAL]], ptr addrspace(1) [[P_CAST_0]], align 4
; CHECK-NEXT: ret void
;
%p.cast.0 = addrspacecast ptr %p to ptr addrspace(1)
%p.cast.1 = addrspacecast ptr addrspace(1) %p.cast.0 to ptr
store i32 %val, ptr %p.cast.1
ret void
}
define internal void @use_argument_after_promotion(ptr %p, i32 %val) {
; CHECK-LABEL: define internal void @use_argument_after_promotion(
; CHECK-SAME: ptr [[P:%.*]], i32 [[VAL:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[TMP1:%.*]] = addrspacecast ptr [[P]] to ptr addrspace(1)
; CHECK-NEXT: store i32 [[VAL]], ptr addrspace(1) [[TMP1]], align 4
; CHECK-NEXT: ret void
;
store i32 %val, ptr %p
ret void
}
define void @kernel_argument_promotion_pattern_inter_procedure(ptr %p, i32 %val) {
; CHECK-LABEL: define void @kernel_argument_promotion_pattern_inter_procedure(
; CHECK-SAME: ptr [[P:%.*]], i32 [[VAL:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: call void @use_argument_after_promotion(ptr [[P]], i32 [[VAL]])
; CHECK-NEXT: ret void
;
%p.cast.0 = addrspacecast ptr %p to ptr addrspace(1)
%p.cast.1 = addrspacecast ptr addrspace(1) %p.cast.0 to ptr
call void @use_argument_after_promotion(ptr %p.cast.1, i32 %val)
ret void
}