[mlir][AMDGPU] Add PermlaneSwapOp (#154345)

- Add PermlaneSwapOp that lowers to `rocdl.permlane16.swap` and
`rocdl.permlane32.swap`

---------

Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
This commit is contained in:
Tim Gymnich 2025-08-21 18:21:43 +02:00 committed by GitHub
parent fc62990657
commit e20fa4f412
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 281 additions and 1 deletions

View File

@ -656,6 +656,48 @@ def AMDGPU_SwizzleBitModeOp : AMDGPU_Op<"swizzle_bitmode",
}];
}
def AMDGPU_PermlaneSwapOp : AMDGPU_Op<"permlane_swap", [Pure, AllTypesMatch<["result", "src"]>]> {
let summary = "AMDGPU permlane swap op";
let description = [{
High-level wrapper on `rocdl.permlane{16,32}.swap` variants for permutations
on rows of lanes in a subgroup.
Supports arbitrary int/float/vector types, which will be repacked to i32 and
one or more `rocdl.permlane_swap` ops during lowering.
Supported lane permutations:
- Swap the data between odd and even rows of 16 lanes
- Swap the data between the first 32 lanes and the last 32 lanes
Example:
```mlir
%0 = amdgpu.permlane %src 16 : f16
%1 = amdgpu.permlane %src 32 { fetch_inactive = true, bound_ctrl = true } : f16
```
Operands:
* `$src`: Vector register to permute across lanes of the subgroup.
* `$row_length`: The length of a row to permute in number of lanes (valid values are 16 and 32).
* `$fetch_inactive`: Optional. Used to dertermine behavior of a fetch from a disabled lane.
`fetch_inactive = false`: If the source lane is disabled, use `bound_ctrl` to determine the source value.
`fetch_inactive = true`: If the source lane is disabled, fetch the source value anyway (ignoring `bound_ctrl`).
* `$bound_ctrl`: Optional. Used to determine what a thread should do if its source operand is from
a disabled lane: use the value zero, or disable the write.
`bound_ctrl = false`: Do not write when source is from a disabled lane
`bound_ctrl = true`: Use zero as input if source is from a disabled lane
Note: Lowering is only supported on gfx950 and up.
}];
let arguments = (ins AnyIntegerOrFloatOr1DVector:$src,
I32Attr:$row_length,
DefaultValuedAttr<BoolAttr, "false">:$fetch_inactive,
DefaultValuedAttr<BoolAttr, "false">:$bound_ctrl);
let results = (outs AnyIntegerOrFloatOr1DVector:$result);
let assemblyFormat = [{
$src $row_length attr-dict `:` type($result)
}];
let hasVerifier = 1;
}
def AMDGPU_LDSBarrierOp : AMDGPU_Op<"lds_barrier"> {
let summary = "Barrier that includes a wait for LDS memory operations.";
let description = [{

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/TypeUtilities.h"
@ -1876,6 +1877,54 @@ struct AMDGPUSwizzleBitModeLowering
}
};
struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {}
Chipset chipset;
LogicalResult
matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset < kGfx950)
return op->emitOpError("permlane_swap is only supported on gfx950+");
Location loc = op.getLoc();
Type i32 = rewriter.getI32Type();
Value src = adaptor.getSrc();
unsigned row_length = op.getRowLength();
bool fi = op.getFetchInactive();
bool boundctrl = op.getBoundCtrl();
SmallVector<Value> decomposed =
LLVM::decomposeValue(rewriter, loc, src, i32);
SmallVector<Value> permuted;
for (Value v : decomposed) {
Value res;
Type i32pair = LLVM::LLVMStructType::getLiteral(
rewriter.getContext(), {v.getType(), v.getType()});
if (row_length == 16)
res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi,
boundctrl);
else if (row_length == 32)
res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi,
boundctrl);
else
llvm_unreachable("unsupported row length");
Value vdstNew = LLVM::ExtractValueOp::create(rewriter, loc, res, {0});
permuted.emplace_back(vdstNew);
}
Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType());
rewriter.replaceOp(op, result);
return success();
}
};
struct ConvertAMDGPUToROCDLPass
: public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> {
using Base::Base;
@ -1944,6 +1993,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering,
PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering,
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering,
TransposeLoadOpLowering>(converter, chipset);
TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset);
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
}

View File

@ -510,6 +510,18 @@ LogicalResult DPPOp::verify() {
return success();
}
//===----------------------------------------------------------------------===//
// PermlaneSwapOp
//===----------------------------------------------------------------------===//
LogicalResult PermlaneSwapOp::verify() {
unsigned rowLength = getRowLength();
if (rowLength != 16 && rowLength != 32)
return emitOpError("row_length attribute must either be 16 or 32.");
return success();
}
//===----------------------------------------------------------------------===//
// GatherToLDSOp
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,163 @@
// RUN: mlir-opt --convert-amdgpu-to-rocdl=chipset=gfx950 --canonicalize %s | FileCheck %s
// CHECK-LABEL: func @test_permlane16_i32
// CHECK-SAME: (%[[ARG0:.*]]: i32)
func.func @test_permlane16_i32(%arg0 : i32) -> i32 {
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: return %[[RES]] : i32
%0 = amdgpu.permlane_swap %arg0 16 : i32
return %0 : i32
}
// CHECK-LABEL: func @test_permlane16_i32_optional_attr
// CHECK-SAME: (%[[ARG0:.*]]: i32)
func.func @test_permlane16_i32_optional_attr(%arg0 : i32) -> i32 {
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ARG0]], %[[ARG0]], true, true : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: return %[[RES]] : i32
%0 = amdgpu.permlane_swap %arg0 16 { fetch_inactive = true, bound_ctrl = true } : i32
return %0 : i32
}
// CHECK-LABEL: func @test_permlane32_i32
// CHECK-SAME: (%[[ARG0:.*]]: i32)
func.func @test_permlane32_i32(%arg0 : i32) -> i32 {
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ARG0]], %[[ARG0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: return %[[RES]] : i32
%0 = amdgpu.permlane_swap %arg0 32 : i32
return %0 : i32
}
// CHECK-LABEL: func @test_permlane16_f32
// CHECK-SAME: (%[[ARG0:.*]]: f32)
func.func @test_permlane16_f32(%arg0 : f32) -> f32 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
// CHECK: return %[[RES_CAST]] : f32
%0 = amdgpu.permlane_swap %arg0 16 : f32
return %0 : f32
}
// CHECK-LABEL: func @test_permlane32_f32
// CHECK-SAME: (%[[ARG0:.*]]: f32)
func.func @test_permlane32_f32(%arg0 : f32) -> f32 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f32 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[CAST]], %[[CAST]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[RES]] : i32 to f32
// CHECK: return %[[RES_CAST]] : f32
%0 = amdgpu.permlane_swap %arg0 32 : f32
return %0 : f32
}
// CHECK-LABEL: func @test_permlane16_f16
// CHECK-SAME: (%[[ARG0:.*]]: f16)
func.func @test_permlane16_f16(%arg0 : f16) -> f16 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane16.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
// CHECK: return %[[RES_CAST]] : f16
%0 = amdgpu.permlane_swap %arg0 16 : f16
return %0 : f16
}
// CHECK-LABEL: func @test_permlane32_f16
// CHECK-SAME: (%[[ARG0:.*]]: f16)
func.func @test_permlane32_f16(%arg0 : f16) -> f16 {
// CHECK: %[[CAST:.*]] = llvm.bitcast %[[ARG0]] : f16 to i16
// CHECK: %[[ZEXT:.*]] = llvm.zext %[[CAST]] : i16 to i32
// CHECK: %[[PERM:.*]] = rocdl.permlane32.swap %[[ZEXT]], %[[ZEXT]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[RES:.*]] = llvm.extractvalue %[[PERM]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[TRUNC:.*]] = llvm.trunc %[[RES]] : i32 to i16
// CHECK: %[[RES_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i16 to f16
// CHECK: return %[[RES_CAST]] : f16
%0 = amdgpu.permlane_swap %arg0 32 : f16
return %0 : f16
}
// CHECK-LABEL: func @test_permlane16_2xi32
// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
func.func @test_permlane16_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
%0 = amdgpu.permlane_swap %arg0 16 : vector<2xi32>
return %0 : vector<2xi32>
}
// CHECK-LABEL: func @test_permlane32_2xi32
// CHECK-SAME: (%[[ARG0:.*]]: vector<2xi32>)
func.func @test_permlane32_2xi32(%arg0 : vector<2xi32>) -> vector<2xi32> {
// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[ARG0]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[ARG0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: return %[[VEC_INSERT1]] : vector<2xi32>
%0 = amdgpu.permlane_swap %arg0 32 : vector<2xi32>
return %0 : vector<2xi32>
}
// CHECK-LABEL: func @test_permlane16_4xf16
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
func.func @test_permlane16_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane16.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
// CHECK: return %[[CAST2]] : vector<4xf16>
%0 = amdgpu.permlane_swap %arg0 16 : vector<4xf16>
return %0 : vector<4xf16>
}
// CHECK-LABEL: func @test_permlane32_4xf16
// CHECK-SAME: (%[[ARG0:.*]]: vector<4xf16>)
func.func @test_permlane32_4xf16(%arg0 : vector<4xf16>) -> vector<4xf16> {
// CHECK-DAG: %[[POISON:.*]] = llvm.mlir.poison : vector<2xi32>
// CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[CAST1:.*]] = llvm.bitcast %[[ARG0]] : vector<4xf16> to vector<2xi32>
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[CAST1]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[CAST1]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[PERM0_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM0]], %[[ELEM0]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM0:.*]] = llvm.extractvalue %[[PERM0_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[PERM1_TUPLE:.*]] = rocdl.permlane32.swap %[[ELEM1]], %[[ELEM1]], false, false : (i32, i32) -> <(i32, i32)>
// CHECK: %[[PERM1:.*]] = llvm.extractvalue %[[PERM1_TUPLE]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[VEC_INSERT0:.*]] = llvm.insertelement %[[PERM0]], %[[POISON]][%[[C0]] : i32] : vector<2xi32>
// CHECK: %[[VEC_INSERT1:.*]] = llvm.insertelement %[[PERM1]], %[[VEC_INSERT0]][%[[C1]] : i32] : vector<2xi32>
// CHECK: %[[CAST2:.*]] = llvm.bitcast %[[VEC_INSERT1]] : vector<2xi32> to vector<4xf16>
// CHECK: return %[[CAST2]] : vector<4xf16>
%0 = amdgpu.permlane_swap %arg0 32 : vector<4xf16>
return %0 : vector<4xf16>
}

View File

@ -524,6 +524,20 @@ func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
func.return %0 : f32
}
// CHECK-LABEL: func @permlane16_swap
func.func @permlane16_swap(%arg0 : f32) -> f32 {
// CHECK: amdgpu.permlane_swap
%0 = amdgpu.permlane_swap %arg0 16 : f32
func.return %0 : f32
}
// CHECK-LABEL: func @permlane32_swap
func.func @permlane32_swap(%arg0 : f32) -> f32 {
// CHECK: amdgpu.permlane_swap
%0 = amdgpu.permlane_swap %arg0 32 : f32
func.return %0 : f32
}
// CHECK-LABEL: func @scaled_mfma
func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
// CHECK: amdgpu.scaled_mfma