[mlir][llvm] Add align attribute to llvm.intr.masked.{expandload,compressstore} (#153063)

* Add `requiresArgsAndResultsAttr` to `LLVM_OneResultIntrOp`
* Add `args_attrs` to `llvm.intr.masked.{expandload,compressstore}`

The LLVM intrinsics
[`llvm.intr.masked.expandload`](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics)
and
[`llvm.intr.masked.compressstore`](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics)
both allow an optional align parameter attribute to be set which
defaults to one.

Inlining the documentation below for [`llvm.intr.masked.expandload` 's
](https://llvm.org/docs/LangRef.html#id1522) and
[`llvm.intr.masked.compressstore`'s](https://llvm.org/docs/LangRef.html#id1522)
arguments respectively

> The `align` parameter attribute can be provided for the first
argument. The pointer alignment defaults to 1.

> The `align` parameter attribute can be provided for the second
argument. The pointer alignment defaults to 1.
This commit is contained in:
Erick Ochoa Lopez 2025-08-15 08:34:14 -04:00 committed by GitHub
parent 69453d7021
commit 61caab7789
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 116 additions and 14 deletions

View File

@ -87,21 +87,21 @@ class LLVM_TernarySameArgsIntrOpF<string func, list<Trait> traits = []> :
class LLVM_CountZerosIntrOp<string func, list<Trait> traits = []> : class LLVM_CountZerosIntrOp<string func, list<Trait> traits = []> :
LLVM_OneResultIntrOp<func, [], [0], LLVM_OneResultIntrOp<func, [], [0],
!listconcat([Pure, SameOperandsAndResultType], traits), !listconcat([Pure, SameOperandsAndResultType], traits),
/*requiresFastmath=*/0, /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
/*immArgPositions=*/[1], /*immArgAttrNames=*/["is_zero_poison"]> { /*immArgPositions=*/[1], /*immArgAttrNames=*/["is_zero_poison"]> {
let arguments = (ins LLVM_ScalarOrVectorOf<AnySignlessInteger>:$in, let arguments = (ins LLVM_ScalarOrVectorOf<AnySignlessInteger>:$in,
I1Attr:$is_zero_poison); I1Attr:$is_zero_poison);
} }
def LLVM_AbsOp : LLVM_OneResultIntrOp<"abs", [], [0], [Pure], def LLVM_AbsOp : LLVM_OneResultIntrOp<"abs", [], [0], [Pure],
/*requiresFastmath=*/0, /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
/*immArgPositions=*/[1], /*immArgAttrNames=*/["is_int_min_poison"]> { /*immArgPositions=*/[1], /*immArgAttrNames=*/["is_int_min_poison"]> {
let arguments = (ins LLVM_ScalarOrVectorOf<AnySignlessInteger>:$in, let arguments = (ins LLVM_ScalarOrVectorOf<AnySignlessInteger>:$in,
I1Attr:$is_int_min_poison); I1Attr:$is_int_min_poison);
} }
def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure], def LLVM_IsFPClass : LLVM_OneResultIntrOp<"is.fpclass", [], [0], [Pure],
/*requiresFastmath=*/0, /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
/*immArgPositions=*/[1], /*immArgAttrNames=*/["bit"]> { /*immArgPositions=*/[1], /*immArgAttrNames=*/["bit"]> {
let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$in, I32Attr:$bit); let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyFloat>:$in, I32Attr:$bit);
} }
@ -360,8 +360,8 @@ def LLVM_LifetimeEndOp : LLVM_LifetimeBaseOp<"lifetime.end">;
def LLVM_InvariantStartOp : LLVM_OneResultIntrOp<"invariant.start", [], [1], def LLVM_InvariantStartOp : LLVM_OneResultIntrOp<"invariant.start", [], [1],
[DeclareOpInterfaceMethods<PromotableOpInterface>], [DeclareOpInterfaceMethods<PromotableOpInterface>],
/*requiresFastmath=*/0, /*immArgPositions=*/[0], /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
/*immArgAttrNames=*/["size"]> { /*immArgPositions=*/[0], /*immArgAttrNames=*/["size"]> {
let arguments = (ins I64Attr:$size, LLVM_AnyPointer:$ptr); let arguments = (ins I64Attr:$size, LLVM_AnyPointer:$ptr);
let results = (outs LLVM_DefaultPointer:$res); let results = (outs LLVM_DefaultPointer:$res);
let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))"; let assemblyFormat = "$size `,` $ptr attr-dict `:` qualified(type($ptr))";
@ -412,6 +412,7 @@ class LLVM_ConstrainedIntr<string mnem, int numArgs,
!gt(hasRoundingMode, 0) : [DeclareOpInterfaceMethods<RoundingModeOpInterface>], !gt(hasRoundingMode, 0) : [DeclareOpInterfaceMethods<RoundingModeOpInterface>],
true : []), true : []),
/*requiresFastmath=*/0, /*requiresFastmath=*/0,
/*requiresArgAndResultAttrs=*/0,
/*immArgPositions=*/[], /*immArgPositions=*/[],
/*immArgAttrNames=*/[]> { /*immArgAttrNames=*/[]> {
dag regularArgs = !dag(ins, !listsplat(LLVM_Type, numArgs), !foreach(i, !range(numArgs), "arg_" #i)); dag regularArgs = !dag(ins, !listsplat(LLVM_Type, numArgs), !foreach(i, !range(numArgs), "arg_" #i));
@ -589,7 +590,7 @@ def LLVM_ExpectOp
def LLVM_ExpectWithProbabilityOp def LLVM_ExpectWithProbabilityOp
: LLVM_OneResultIntrOp<"expect.with.probability", [], [0], : LLVM_OneResultIntrOp<"expect.with.probability", [], [0],
[Pure, AllTypesMatch<["val", "expected", "res"]>], [Pure, AllTypesMatch<["val", "expected", "res"]>],
/*requiresFastmath=*/0, /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
/*immArgPositions=*/[2], /*immArgAttrNames=*/["prob"]> { /*immArgPositions=*/[2], /*immArgAttrNames=*/["prob"]> {
let arguments = (ins AnySignlessInteger:$val, let arguments = (ins AnySignlessInteger:$val,
AnySignlessInteger:$expected, AnySignlessInteger:$expected,
@ -825,7 +826,7 @@ class LLVM_VecReductionAccBase<string mnem, Type element>
/*overloadedResults=*/[], /*overloadedResults=*/[],
/*overloadedOperands=*/[1], /*overloadedOperands=*/[1],
/*traits=*/[Pure, SameOperandsAndResultElementType], /*traits=*/[Pure, SameOperandsAndResultElementType],
/*equiresFastmath=*/1>, /*requiresFastmath=*/1>,
Arguments<(ins element:$start_value, Arguments<(ins element:$start_value,
LLVM_VectorOf<element>:$input, LLVM_VectorOf<element>:$input,
DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags)>; DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags)>;
@ -1069,14 +1070,36 @@ def LLVM_masked_scatter : LLVM_ZeroResultIntrOp<"masked.scatter"> {
} }
/// Create a call to Masked Expand Load intrinsic. /// Create a call to Masked Expand Load intrinsic.
def LLVM_masked_expandload : LLVM_IntrOp<"masked.expandload", [0], [], [], 1> { def LLVM_masked_expandload
let arguments = (ins LLVM_AnyPointer, LLVM_VectorOf<I1>, LLVM_AnyVector); : LLVM_OneResultIntrOp<"masked.expandload", [0], [],
/*traits=*/[], /*requiresFastMath=*/0, /*requiresArgAndResultAttrs=*/1,
/*immArgPositions=*/[], /*immArgAttrNames=*/[]> {
dag args = (ins LLVM_AnyPointer:$ptr,
LLVM_VectorOf<I1>:$mask,
LLVM_AnyVector:$passthru);
let arguments = !con(args, baseArgs);
let builders = [
OpBuilder<(ins "TypeRange":$resTy, "Value":$ptr, "Value":$mask, "Value":$passthru, CArg<"uint64_t", "1">:$align)>
];
} }
/// Create a call to Masked Compress Store intrinsic. /// Create a call to Masked Compress Store intrinsic.
def LLVM_masked_compressstore def LLVM_masked_compressstore
: LLVM_IntrOp<"masked.compressstore", [], [0], [], 0> { : LLVM_ZeroResultIntrOp<"masked.compressstore", [0],
let arguments = (ins LLVM_AnyVector, LLVM_AnyPointer, LLVM_VectorOf<I1>); /*traits=*/[], /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
/*requiresArgAndResultAttrs=*/1, /*requiresOpBundles=*/0,
/*immArgPositions=*/[], /*immArgAttrNames=*/[]> {
dag args = (ins LLVM_AnyVector:$value,
LLVM_AnyPointer:$ptr,
LLVM_VectorOf<I1>:$mask);
let arguments = !con(args, baseArgs);
let builders = [
OpBuilder<(ins "Value":$value, "Value":$ptr, "Value":$mask, CArg<"uint64_t", "1">:$align)>
];
} }
// //
@ -1155,7 +1178,7 @@ def LLVM_vector_insert
PredOpTrait<"it is not inserting scalable into fixed-length vectors.", PredOpTrait<"it is not inserting scalable into fixed-length vectors.",
CPred<"!isScalableVectorType($srcvec.getType()) || " CPred<"!isScalableVectorType($srcvec.getType()) || "
"isScalableVectorType($dstvec.getType())">>], "isScalableVectorType($dstvec.getType())">>],
/*requiresFastmath=*/0, /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
/*immArgPositions=*/[2], /*immArgAttrNames=*/["pos"]> { /*immArgPositions=*/[2], /*immArgAttrNames=*/["pos"]> {
let arguments = (ins LLVM_AnyVector:$dstvec, LLVM_AnyVector:$srcvec, let arguments = (ins LLVM_AnyVector:$dstvec, LLVM_AnyVector:$srcvec,
I64Attr:$pos); I64Attr:$pos);
@ -1189,7 +1212,7 @@ def LLVM_vector_extract
PredOpTrait<"it is not extracting scalable from fixed-length vectors.", PredOpTrait<"it is not extracting scalable from fixed-length vectors.",
CPred<"!isScalableVectorType($res.getType()) || " CPred<"!isScalableVectorType($res.getType()) || "
"isScalableVectorType($srcvec.getType())">>], "isScalableVectorType($srcvec.getType())">>],
/*requiresFastmath=*/0, /*requiresFastmath=*/0, /*requiresArgAndResultAttrs=*/0,
/*immArgPositions=*/[1], /*immArgAttrNames=*/["pos"]> { /*immArgPositions=*/[1], /*immArgAttrNames=*/["pos"]> {
let arguments = (ins LLVM_AnyVector:$srcvec, I64Attr:$pos); let arguments = (ins LLVM_AnyVector:$srcvec, I64Attr:$pos);
let results = (outs LLVM_AnyVector:$res); let results = (outs LLVM_AnyVector:$res);

View File

@ -475,11 +475,12 @@ class LLVM_OneResultIntrOp<string mnem, list<int> overloadedResults = [],
list<int> overloadedOperands = [], list<int> overloadedOperands = [],
list<Trait> traits = [], list<Trait> traits = [],
bit requiresFastmath = 0, bit requiresFastmath = 0,
bit requiresArgAndResultAttrs = 0,
list<int> immArgPositions = [], list<int> immArgPositions = [],
list<string> immArgAttrNames = []> list<string> immArgAttrNames = []>
: LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1, : LLVM_IntrOp<mnem, overloadedResults, overloadedOperands, traits, 1,
/*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0, /*requiresAccessGroup=*/0, /*requiresAliasAnalysis=*/0,
requiresFastmath, /*requiresArgAndResultAttrs=*/0, requiresFastmath, requiresArgAndResultAttrs,
/*requiresOpBundles=*/0, immArgPositions, /*requiresOpBundles=*/0, immArgPositions,
immArgAttrNames>; immArgAttrNames>;

View File

@ -141,6 +141,38 @@ static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) {
return success(); return success();
} }
static ArrayAttr getLLVMAlignParamForCompressExpand(OpBuilder &builder,
bool isExpandLoad,
uint64_t alignment = 1) {
// From
// https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
// https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
//
// The pointer alignment defaults to 1.
if (alignment == 1) {
return nullptr;
}
auto emptyDictAttr = builder.getDictionaryAttr({});
auto alignmentAttr = builder.getI64IntegerAttr(alignment);
auto namedAttr =
builder.getNamedAttr(LLVMDialect::getAlignAttrName(), alignmentAttr);
SmallVector<mlir::NamedAttribute> attrs = {namedAttr};
auto alignDictAttr = builder.getDictionaryAttr(attrs);
// From
// https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
// https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
//
// The align parameter attribute can be provided for [expandload]'s first
// argument. The align parameter attribute can be provided for
// [compressstore]'s second argument.
int pos = isExpandLoad ? 0 : 1;
return pos == 0 ? builder.getArrayAttr(
{alignDictAttr, emptyDictAttr, emptyDictAttr})
: builder.getArrayAttr(
{emptyDictAttr, alignDictAttr, emptyDictAttr});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Operand bundle helpers. // Operand bundle helpers.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -4116,6 +4148,32 @@ LogicalResult LLVM::masked_scatter::verify() {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// masked_expandload (intrinsic)
//===----------------------------------------------------------------------===//
void LLVM::masked_expandload::build(OpBuilder &builder, OperationState &state,
mlir::TypeRange resTys, Value ptr,
Value mask, Value passthru,
uint64_t align) {
ArrayAttr argAttrs = getLLVMAlignParamForCompressExpand(builder, true, align);
build(builder, state, resTys, ptr, mask, passthru, /*arg_attrs=*/argAttrs,
/*res_attrs=*/nullptr);
}
//===----------------------------------------------------------------------===//
// masked_compressstore (intrinsic)
//===----------------------------------------------------------------------===//
void LLVM::masked_compressstore::build(OpBuilder &builder,
OperationState &state, Value value,
Value ptr, Value mask, uint64_t align) {
ArrayAttr argAttrs =
getLLVMAlignParamForCompressExpand(builder, false, align);
build(builder, state, value, ptr, mask, /*arg_attrs=*/argAttrs,
/*res_attrs=*/nullptr);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// InlineAsmOp // InlineAsmOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -545,6 +545,15 @@ define void @masked_expand_compress_intrinsics(ptr %0, <7 x i1> %1, <7 x float>
ret void ret void
} }
; CHECK-LABEL: llvm.func @masked_expand_compress_intrinsics_with_alignment
define void @masked_expand_compress_intrinsics_with_alignment(ptr %0, <7 x i1> %1, <7 x float> %2) {
; CHECK: %[[val1:.+]] = "llvm.intr.masked.expandload"(%{{.*}}, %{{.*}}, %{{.*}}) <{arg_attrs = [{llvm.align = 8 : i64}, {}, {}]}> : (!llvm.ptr, vector<7xi1>, vector<7xf32>) -> vector<7xf32>
%4 = call <7 x float> @llvm.masked.expandload.v7f32(ptr align 8 %0, <7 x i1> %1, <7 x float> %2)
; CHECK: "llvm.intr.masked.compressstore"(%[[val1]], %{{.*}}, %{{.*}}) <{arg_attrs = [{}, {llvm.align = 8 : i64}, {}]}> : (vector<7xf32>, !llvm.ptr, vector<7xi1>) -> ()
call void @llvm.masked.compressstore.v7f32(<7 x float> %4, ptr align 8 %0, <7 x i1> %1)
ret void
}
; CHECK-LABEL: llvm.func @annotate_intrinsics ; CHECK-LABEL: llvm.func @annotate_intrinsics
define void @annotate_intrinsics(ptr %var, ptr %ptr, i16 %int, ptr %annotation, ptr %fileName, i32 %line, ptr %args) { define void @annotate_intrinsics(ptr %var, ptr %ptr, i16 %int, ptr %annotation, ptr %fileName, i32 %line, ptr %args) {
; CHECK: "llvm.intr.var.annotation"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr) -> () ; CHECK: "llvm.intr.var.annotation"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, !llvm.ptr, !llvm.ptr, i32, !llvm.ptr) -> ()

View File

@ -577,6 +577,17 @@ llvm.func @masked_expand_compress_intrinsics(%ptr: !llvm.ptr, %mask: vector<7xi1
llvm.return llvm.return
} }
// CHECK-LABEL: @masked_expand_compress_intrinsics_with_alignment
llvm.func @masked_expand_compress_intrinsics_with_alignment(%ptr: !llvm.ptr, %mask: vector<7xi1>, %passthru: vector<7xf32>) {
// CHECK: call <7 x float> @llvm.masked.expandload.v7f32(ptr align 8 %{{.*}}, <7 x i1> %{{.*}}, <7 x float> %{{.*}})
%0 = "llvm.intr.masked.expandload"(%ptr, %mask, %passthru) {arg_attrs = [{llvm.align = 8 : i32}, {}, {}]}
: (!llvm.ptr, vector<7xi1>, vector<7xf32>) -> (vector<7xf32>)
// CHECK: call void @llvm.masked.compressstore.v7f32(<7 x float> %{{.*}}, ptr align 8 %{{.*}}, <7 x i1> %{{.*}})
"llvm.intr.masked.compressstore"(%0, %ptr, %mask) {arg_attrs = [{}, {llvm.align = 8 : i32}, {}]}
: (vector<7xf32>, !llvm.ptr, vector<7xi1>) -> ()
llvm.return
}
// CHECK-LABEL: @annotate_intrinsics // CHECK-LABEL: @annotate_intrinsics
llvm.func @annotate_intrinsics(%var: !llvm.ptr, %int: i16, %ptr: !llvm.ptr, %annotation: !llvm.ptr, %fileName: !llvm.ptr, %line: i32, %attr: !llvm.ptr) { llvm.func @annotate_intrinsics(%var: !llvm.ptr, %int: i16, %ptr: !llvm.ptr, %annotation: !llvm.ptr, %fileName: !llvm.ptr, %line: i32, %attr: !llvm.ptr) {
// CHECK: call void @llvm.var.annotation.p0.p0(ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, ptr %{{.*}}) // CHECK: call void @llvm.var.annotation.p0.p0(ptr %{{.*}}, ptr %{{.*}}, ptr %{{.*}}, i32 %{{.*}}, ptr %{{.*}})