[mlir][LLVM] Add disjoint flag (#115855)

The implementation is mostly based on the one existing for the exact
flag.

disjoint means that for each bit, that bit is zero in at least one of
the inputs. This allows the Or to be treated as an Add since no carry
can occur from any bit. If the disjoint keyword is present, the result
value of the or is a [poison
value](https://llvm.org/docs/LangRef.html#poisonvalues) if both inputs
have a one in the same bit position. For vectors, only the element
containing the bit is poison.
This commit is contained in:
lfrenot 2024-11-15 12:48:01 +00:00 committed by GitHub
parent 6d058317e6
commit 40afff7bd9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 94 additions and 1 deletions

View File

@ -114,6 +114,33 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
];
}
def DisjointFlagInterface : OpInterface<"DisjointFlagInterface"> {
let description = [{
This interface defines an LLVM operation with a disjoint flag and
provides a uniform API for accessing it.
}];
let cppNamespace = "::mlir::LLVM";
let methods = [
InterfaceMethod<[{
Get the disjoint flag for the operation.
}], "bool", "getIsDisjoint", (ins), [{}], [{
return $_op.getProperties().isDisjoint;
}]>,
InterfaceMethod<[{
Set the disjoint flag for the operation.
}], "void", "setIsDisjoint", (ins "bool":$isDisjoint), [{}], [{
$_op.getProperties().isDisjoint = isDisjoint;
}]>,
StaticInterfaceMethod<[{
Get the attribute name of the isDisjoint property.
}], "StringRef", "getIsDisjointName", (ins), [{}], [{
return "isDisjoint";
}]>,
];
}
def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
let description = [{
This interface defines an LLVM operation with an nneg flag and

View File

@ -93,6 +93,26 @@ class LLVM_IntArithmeticOpWithExactFlag<string mnemonic, string instName,
"$res = builder.Create" # instName #
"($lhs, $rhs, /*Name=*/\"\", op.getIsExact());";
}
class LLVM_IntArithmeticOpWithDisjointFlag<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
!listconcat([DeclareOpInterfaceMethods<DisjointFlagInterface>], traits)> {
let arguments = !con(commonArgs, (ins UnitAttr:$isDisjoint));
string mlirBuilder = [{
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
moduleImport.setDisjointFlag(inst, op);
$res = op;
}];
let assemblyFormat = [{
(`disjoint` $isDisjoint^)? $lhs `,` $rhs attr-dict `:` type($res)
}];
string llvmBuilder = [{
auto inst = builder.Create}] # instName # [{($lhs, $rhs, /*Name=*/"");
moduleTranslation.setDisjointFlag(op, inst);
$res = inst;
}];
}
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@ -138,7 +158,7 @@ def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv">;
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
def LLVM_OrOp : LLVM_IntArithmeticOpWithDisjointFlag<"or", "Or"> {
let hasFolder = 1;
}
def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;

View File

@ -192,6 +192,11 @@ public:
/// implement the exact flag interface.
void setExactFlag(llvm::Instruction *inst, Operation *op) const;
/// Sets the disjoint flag attribute for the imported operation `op`
/// given the original instruction `inst`. Asserts if the operation does
/// not implement the disjoint flag interface.
void setDisjointFlag(llvm::Instruction *inst, Operation *op) const;
/// Sets the nneg flag attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
/// implement the nneg flag interface.

View File

@ -167,6 +167,12 @@ public:
/// attribute.
void setLoopMetadata(Operation *op, llvm::Instruction *inst);
/// Sets the disjoint flag attribute for the exported instruction `value`
/// given the original operation `op`. Asserts if the operation does
/// not implement the disjoint flag interface, and asserts if the value
/// is an instruction that implements the disjoint flag.
void setDisjointFlag(Operation *op, llvm::Value *value);
/// Converts the type from MLIR LLVM dialect to LLVM.
llvm::Type *convertType(Type type);

View File

@ -689,6 +689,14 @@ void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
iface.setIsExact(inst->isExact());
}
void ModuleImport::setDisjointFlag(llvm::Instruction *inst,
Operation *op) const {
auto iface = cast<DisjointFlagInterface>(op);
auto instDisjoint = cast<llvm::PossiblyDisjointInst>(inst);
iface.setIsDisjoint(instDisjoint->isDisjoint());
}
void ModuleImport::setNonNegFlag(llvm::Instruction *inst, Operation *op) const {
auto iface = cast<NonNegFlagInterface>(op);

View File

@ -1898,6 +1898,13 @@ void ModuleTranslation::setLoopMetadata(Operation *op,
inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
}
void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *value) {
auto iface = cast<DisjointFlagInterface>(op);
// We do a dyn_cast here in case the value got folded into a constant.
if (auto disjointInst = dyn_cast<llvm::PossiblyDisjointInst>(value))
disjointInst->setIsDisjoint(iface.getIsDisjoint());
}
llvm::Type *ModuleTranslation::convertType(Type type) {
return typeTranslator.translateType(type);
}

View File

@ -59,6 +59,10 @@ func.func @ops(%arg0: i32, %arg1: f32,
%ashr_flag = llvm.ashr exact %arg0, %arg0 : i32
%lshr_flag = llvm.lshr exact %arg0, %arg0 : i32
// Integer disjoint flag.
// CHECK: {{.*}} = llvm.or disjoint %[[I32]], %[[I32]] : i32
%or_flag = llvm.or disjoint %arg0, %arg0 : i32
// Floating point binary operations.
//
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32

View File

@ -0,0 +1,8 @@
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
; CHECK-LABEL: @disjointflag_inst
define void @disjointflag_inst(i64 %arg1, i64 %arg2) {
; CHECK: llvm.or disjoint %{{.*}}, %{{.*}} : i64
%1 = or disjoint i64 %arg1, %arg2
ret void
}

View File

@ -0,0 +1,8 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
// CHECK-LABEL: define void @disjointflag_func
llvm.func @disjointflag_func(%arg0: i64, %arg1: i64) {
// CHECK: %{{.*}} = or disjoint i64 %{{.*}}, %{{.*}}
%0 = llvm.or disjoint %arg0, %arg1 : i64
llvm.return
}