[mlir][LLVM] Add disjoint flag (#115855)
The implementation is mostly based on the one existing for the exact flag. disjoint means that for each bit, that bit is zero in at least one of the inputs. This allows the Or to be treated as an Add since no carry can occur from any bit. If the disjoint keyword is present, the result value of the or is a [poison value](https://llvm.org/docs/LangRef.html#poisonvalues) if both inputs have a one in the same bit position. For vectors, only the element containing the bit is poison.
This commit is contained in:
parent
6d058317e6
commit
40afff7bd9
@ -114,6 +114,33 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
|
|||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def DisjointFlagInterface : OpInterface<"DisjointFlagInterface"> {
|
||||||
|
let description = [{
|
||||||
|
This interface defines an LLVM operation with a disjoint flag and
|
||||||
|
provides a uniform API for accessing it.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let cppNamespace = "::mlir::LLVM";
|
||||||
|
|
||||||
|
let methods = [
|
||||||
|
InterfaceMethod<[{
|
||||||
|
Get the disjoint flag for the operation.
|
||||||
|
}], "bool", "getIsDisjoint", (ins), [{}], [{
|
||||||
|
return $_op.getProperties().isDisjoint;
|
||||||
|
}]>,
|
||||||
|
InterfaceMethod<[{
|
||||||
|
Set the disjoint flag for the operation.
|
||||||
|
}], "void", "setIsDisjoint", (ins "bool":$isDisjoint), [{}], [{
|
||||||
|
$_op.getProperties().isDisjoint = isDisjoint;
|
||||||
|
}]>,
|
||||||
|
StaticInterfaceMethod<[{
|
||||||
|
Get the attribute name of the isDisjoint property.
|
||||||
|
}], "StringRef", "getIsDisjointName", (ins), [{}], [{
|
||||||
|
return "isDisjoint";
|
||||||
|
}]>,
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|
||||||
def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
|
def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
|
||||||
let description = [{
|
let description = [{
|
||||||
This interface defines an LLVM operation with an nneg flag and
|
This interface defines an LLVM operation with an nneg flag and
|
||||||
|
@ -93,6 +93,26 @@ class LLVM_IntArithmeticOpWithExactFlag<string mnemonic, string instName,
|
|||||||
"$res = builder.Create" # instName #
|
"$res = builder.Create" # instName #
|
||||||
"($lhs, $rhs, /*Name=*/\"\", op.getIsExact());";
|
"($lhs, $rhs, /*Name=*/\"\", op.getIsExact());";
|
||||||
}
|
}
|
||||||
|
class LLVM_IntArithmeticOpWithDisjointFlag<string mnemonic, string instName,
|
||||||
|
list<Trait> traits = []> :
|
||||||
|
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
|
||||||
|
!listconcat([DeclareOpInterfaceMethods<DisjointFlagInterface>], traits)> {
|
||||||
|
let arguments = !con(commonArgs, (ins UnitAttr:$isDisjoint));
|
||||||
|
|
||||||
|
string mlirBuilder = [{
|
||||||
|
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
|
||||||
|
moduleImport.setDisjointFlag(inst, op);
|
||||||
|
$res = op;
|
||||||
|
}];
|
||||||
|
let assemblyFormat = [{
|
||||||
|
(`disjoint` $isDisjoint^)? $lhs `,` $rhs attr-dict `:` type($res)
|
||||||
|
}];
|
||||||
|
string llvmBuilder = [{
|
||||||
|
auto inst = builder.Create}] # instName # [{($lhs, $rhs, /*Name=*/"");
|
||||||
|
moduleTranslation.setDisjointFlag(op, inst);
|
||||||
|
$res = inst;
|
||||||
|
}];
|
||||||
|
}
|
||||||
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
|
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
|
||||||
list<Trait> traits = []> :
|
list<Trait> traits = []> :
|
||||||
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
|
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
|
||||||
@ -138,7 +158,7 @@ def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv">;
|
|||||||
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
|
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
|
||||||
def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
|
def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
|
||||||
def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
|
def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
|
||||||
def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
|
def LLVM_OrOp : LLVM_IntArithmeticOpWithDisjointFlag<"or", "Or"> {
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
|
def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
|
||||||
|
@ -192,6 +192,11 @@ public:
|
|||||||
/// implement the exact flag interface.
|
/// implement the exact flag interface.
|
||||||
void setExactFlag(llvm::Instruction *inst, Operation *op) const;
|
void setExactFlag(llvm::Instruction *inst, Operation *op) const;
|
||||||
|
|
||||||
|
/// Sets the disjoint flag attribute for the imported operation `op`
|
||||||
|
/// given the original instruction `inst`. Asserts if the operation does
|
||||||
|
/// not implement the disjoint flag interface.
|
||||||
|
void setDisjointFlag(llvm::Instruction *inst, Operation *op) const;
|
||||||
|
|
||||||
/// Sets the nneg flag attribute for the imported operation `op` given
|
/// Sets the nneg flag attribute for the imported operation `op` given
|
||||||
/// the original instruction `inst`. Asserts if the operation does not
|
/// the original instruction `inst`. Asserts if the operation does not
|
||||||
/// implement the nneg flag interface.
|
/// implement the nneg flag interface.
|
||||||
|
@ -167,6 +167,12 @@ public:
|
|||||||
/// attribute.
|
/// attribute.
|
||||||
void setLoopMetadata(Operation *op, llvm::Instruction *inst);
|
void setLoopMetadata(Operation *op, llvm::Instruction *inst);
|
||||||
|
|
||||||
|
/// Sets the disjoint flag attribute for the exported instruction `value`
|
||||||
|
/// given the original operation `op`. Asserts if the operation does
|
||||||
|
/// not implement the disjoint flag interface, and asserts if the value
|
||||||
|
/// is an instruction that implements the disjoint flag.
|
||||||
|
void setDisjointFlag(Operation *op, llvm::Value *value);
|
||||||
|
|
||||||
/// Converts the type from MLIR LLVM dialect to LLVM.
|
/// Converts the type from MLIR LLVM dialect to LLVM.
|
||||||
llvm::Type *convertType(Type type);
|
llvm::Type *convertType(Type type);
|
||||||
|
|
||||||
|
@ -689,6 +689,14 @@ void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
|
|||||||
iface.setIsExact(inst->isExact());
|
iface.setIsExact(inst->isExact());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ModuleImport::setDisjointFlag(llvm::Instruction *inst,
|
||||||
|
Operation *op) const {
|
||||||
|
auto iface = cast<DisjointFlagInterface>(op);
|
||||||
|
auto instDisjoint = cast<llvm::PossiblyDisjointInst>(inst);
|
||||||
|
|
||||||
|
iface.setIsDisjoint(instDisjoint->isDisjoint());
|
||||||
|
}
|
||||||
|
|
||||||
void ModuleImport::setNonNegFlag(llvm::Instruction *inst, Operation *op) const {
|
void ModuleImport::setNonNegFlag(llvm::Instruction *inst, Operation *op) const {
|
||||||
auto iface = cast<NonNegFlagInterface>(op);
|
auto iface = cast<NonNegFlagInterface>(op);
|
||||||
|
|
||||||
|
@ -1898,6 +1898,13 @@ void ModuleTranslation::setLoopMetadata(Operation *op,
|
|||||||
inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
|
inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *value) {
|
||||||
|
auto iface = cast<DisjointFlagInterface>(op);
|
||||||
|
// We do a dyn_cast here in case the value got folded into a constant.
|
||||||
|
if (auto disjointInst = dyn_cast<llvm::PossiblyDisjointInst>(value))
|
||||||
|
disjointInst->setIsDisjoint(iface.getIsDisjoint());
|
||||||
|
}
|
||||||
|
|
||||||
llvm::Type *ModuleTranslation::convertType(Type type) {
|
llvm::Type *ModuleTranslation::convertType(Type type) {
|
||||||
return typeTranslator.translateType(type);
|
return typeTranslator.translateType(type);
|
||||||
}
|
}
|
||||||
|
@ -59,6 +59,10 @@ func.func @ops(%arg0: i32, %arg1: f32,
|
|||||||
%ashr_flag = llvm.ashr exact %arg0, %arg0 : i32
|
%ashr_flag = llvm.ashr exact %arg0, %arg0 : i32
|
||||||
%lshr_flag = llvm.lshr exact %arg0, %arg0 : i32
|
%lshr_flag = llvm.lshr exact %arg0, %arg0 : i32
|
||||||
|
|
||||||
|
// Integer disjoint flag.
|
||||||
|
// CHECK: {{.*}} = llvm.or disjoint %[[I32]], %[[I32]] : i32
|
||||||
|
%or_flag = llvm.or disjoint %arg0, %arg0 : i32
|
||||||
|
|
||||||
// Floating point binary operations.
|
// Floating point binary operations.
|
||||||
//
|
//
|
||||||
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
|
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
|
||||||
|
8
mlir/test/Target/LLVMIR/Import/disjoint.ll
Normal file
8
mlir/test/Target/LLVMIR/Import/disjoint.ll
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
|
||||||
|
|
||||||
|
; CHECK-LABEL: @disjointflag_inst
|
||||||
|
define void @disjointflag_inst(i64 %arg1, i64 %arg2) {
|
||||||
|
; CHECK: llvm.or disjoint %{{.*}}, %{{.*}} : i64
|
||||||
|
%1 = or disjoint i64 %arg1, %arg2
|
||||||
|
ret void
|
||||||
|
}
|
8
mlir/test/Target/LLVMIR/disjoint.mlir
Normal file
8
mlir/test/Target/LLVMIR/disjoint.mlir
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: define void @disjointflag_func
|
||||||
|
llvm.func @disjointflag_func(%arg0: i64, %arg1: i64) {
|
||||||
|
// CHECK: %{{.*}} = or disjoint i64 %{{.*}}, %{{.*}}
|
||||||
|
%0 = llvm.or disjoint %arg0, %arg1 : i64
|
||||||
|
llvm.return
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user