[MLIR] Introduce a SelectLikeOpInterface (#104751)
This commit introduces a `SelectLikeOpInterface` that can be used to handle select-like operations generically. Select operations are similar to control flow operations, as they forward operands depending on conditions. This is the reason why it was placed to the already existing control flow interfaces.
This commit is contained in:
parent
f9031f00f2
commit
bf68e9047f
@ -88,9 +88,9 @@ WalkContinuation walkSlice(mlir::ValueRange rootValues,
|
||||
WalkCallback walkCallback);
|
||||
|
||||
/// Computes a vector of all control predecessors of `value`. Relies on
|
||||
/// RegionBranchOpInterface and BranchOpInterface to determine predecessors.
|
||||
/// Returns nullopt if `value` has no predecessors or when the relevant
|
||||
/// operations are missing the interface implementations.
|
||||
/// RegionBranchOpInterface, BranchOpInterface, and SelectLikeOpInterface to
|
||||
/// determine predecessors. Returns nullopt if `value` has no predecessors or
|
||||
/// when the relevant operations are missing the interface implementations.
|
||||
std::optional<SmallVector<Value>> getControlFlowPredecessors(Value value);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/InferIntRangeInterface.h"
|
||||
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
@ -12,6 +12,7 @@
|
||||
include "mlir/Dialect/Arith/IR/ArithBase.td"
|
||||
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/InferIntRangeInterface.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
@ -1578,6 +1579,7 @@ def SelectOp : Arith_Op<"select", [Pure,
|
||||
AllTypesMatch<["true_value", "false_value", "result"]>,
|
||||
BooleanConditionOrMatchingShape<"condition", "result">,
|
||||
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
|
||||
DeclareOpInterfaceMethods<SelectLikeOpInterface>,
|
||||
] # ElementwiseMappable.traits> {
|
||||
let summary = "select operation";
|
||||
let description = [{
|
||||
|
||||
@ -835,7 +835,8 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector",
|
||||
def LLVM_SelectOp
|
||||
: LLVM_Op<"select",
|
||||
[Pure, AllTypesMatch<["trueValue", "falseValue", "res"]>,
|
||||
DeclareOpInterfaceMethods<FastmathFlagsInterface>]>,
|
||||
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
|
||||
DeclareOpInterfaceMethods<SelectLikeOpInterface>]>,
|
||||
LLVM_Builder<
|
||||
"$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
|
||||
let arguments = (ins LLVM_ScalarOrVectorOf<I1>:$condition,
|
||||
|
||||
@ -760,7 +760,8 @@ def SPIRV_SLessThanEqualOp : SPIRV_LogicalBinaryOp<"SLessThanEqual",
|
||||
def SPIRV_SelectOp : SPIRV_Op<"Select",
|
||||
[Pure,
|
||||
AllTypesMatch<["true_value", "false_value", "result"]>,
|
||||
UsableInSpecConstantOp]> {
|
||||
UsableInSpecConstantOp,
|
||||
DeclareOpInterfaceMethods<SelectLikeOpInterface>]> {
|
||||
let summary = [{
|
||||
Select between two objects. Before version 1.4, results are only
|
||||
computed per component.
|
||||
|
||||
@ -343,6 +343,38 @@ def RegionBranchTerminatorOpInterface :
|
||||
}];
|
||||
}
|
||||
|
||||
def SelectLikeOpInterface : OpInterface<"SelectLikeOpInterface"> {
|
||||
let description = [{
|
||||
This interface provides information for select-like operations, i.e.,
|
||||
operations that forward specific operands to the output, depending on a
|
||||
binary condition.
|
||||
|
||||
If the value of the condition is 1, then the `true` operand is returned,
|
||||
and the third operand is ignored, even if it was poison.
|
||||
|
||||
If the value of the condition is 0, then the `false` operand is returned,
|
||||
and the second operand is ignored, even if it was poison.
|
||||
|
||||
If the condition is poison, then poison is returned.
|
||||
|
||||
Implementing operations can also accept shaped conditions, in which case
|
||||
the operation works element-wise.
|
||||
}];
|
||||
let cppNamespace = "::mlir";
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<[{
|
||||
Returns the operand that would be chosen for a false condition.
|
||||
}], "::mlir::Value", "getFalseValue", (ins)>,
|
||||
InterfaceMethod<[{
|
||||
Returns the operand that would be chosen for a true condition.
|
||||
}], "::mlir::Value", "getTrueValue", (ins)>,
|
||||
InterfaceMethod<[{
|
||||
Returns the condition operand.
|
||||
}], "::mlir::Value", "getCondition", (ins)>
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ControlFlow Traits
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -104,9 +104,11 @@ getBlockPredecessorOperands(BlockArgument blockArg) {
|
||||
|
||||
std::optional<SmallVector<Value>>
|
||||
mlir::getControlFlowPredecessors(Value value) {
|
||||
SmallVector<Value> result;
|
||||
if (OpResult opResult = dyn_cast<OpResult>(value)) {
|
||||
auto regionOp = dyn_cast<RegionBranchOpInterface>(opResult.getOwner());
|
||||
if (auto selectOp = opResult.getDefiningOp<SelectLikeOpInterface>())
|
||||
return SmallVector<Value>(
|
||||
{selectOp.getTrueValue(), selectOp.getFalseValue()});
|
||||
auto regionOp = opResult.getDefiningOp<RegionBranchOpInterface>();
|
||||
// If the interface is not implemented, there are no control flow
|
||||
// predecessors to work with.
|
||||
if (!regionOp)
|
||||
|
||||
@ -235,11 +235,6 @@ getUnderlyingObjectSet(Value pointerValue) {
|
||||
if (auto addrCast = val.getDefiningOp<LLVM::AddrSpaceCastOp>())
|
||||
return WalkContinuation::advanceTo(addrCast.getOperand());
|
||||
|
||||
// TODO: Add a SelectLikeOpInterface and use it in the slicing utility.
|
||||
if (auto selectOp = val.getDefiningOp<LLVM::SelectOp>())
|
||||
return WalkContinuation::advanceTo(
|
||||
{selectOp.getTrueValue(), selectOp.getFalseValue()});
|
||||
|
||||
// Attempt to advance to control flow predecessors.
|
||||
std::optional<SmallVector<Value>> controlFlowPredecessors =
|
||||
getControlFlowPredecessors(val);
|
||||
|
||||
@ -508,3 +508,51 @@ llvm.func @noalias_with_region(%arg0: !llvm.ptr) {
|
||||
llvm.call @region(%arg0) : (!llvm.ptr) -> ()
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[DOMAIN:.*]] = #llvm.alias_scope_domain<{{.*}}>
|
||||
// CHECK-DAG: #[[$ARG_SCOPE:.*]] = #llvm.alias_scope<id = {{.*}}, domain = #[[DOMAIN]]{{(,.*)?}}>
|
||||
|
||||
llvm.func @foo(%arg: i32)
|
||||
|
||||
llvm.func @func(%arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) {
|
||||
%cond = llvm.load %arg1 : !llvm.ptr -> i1
|
||||
%1 = llvm.getelementptr inbounds %arg0[1] : (!llvm.ptr) -> !llvm.ptr, f32
|
||||
%selected = llvm.select %cond, %arg0, %1 : i1, !llvm.ptr
|
||||
%2 = llvm.load %selected : !llvm.ptr -> i32
|
||||
llvm.call @foo(%2) : (i32) -> ()
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @selects
|
||||
// CHECK: llvm.load
|
||||
// CHECK-NOT: alias_scopes
|
||||
// CHECK-SAME: noalias_scopes = [#[[$ARG_SCOPE]]]
|
||||
// CHECK: llvm.load
|
||||
// CHECK-SAME: alias_scopes = [#[[$ARG_SCOPE]]]
|
||||
llvm.func @selects(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
|
||||
llvm.call @func(%arg0, %arg1) : (!llvm.ptr, !llvm.ptr) -> ()
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @foo(%arg: i32)
|
||||
|
||||
llvm.func @func(%cond: i1, %arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) {
|
||||
%selected = llvm.select %cond, %arg0, %arg1 : i1, !llvm.ptr
|
||||
%2 = llvm.load %selected : !llvm.ptr -> i32
|
||||
llvm.call @foo(%2) : (i32) -> ()
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: llvm.func @multi_ptr_select
|
||||
// CHECK: llvm.load
|
||||
// CHECK-NOT: alias_scopes
|
||||
// CHECK-NOT: noalias_scopes
|
||||
// CHECK: llvm.call @foo
|
||||
llvm.func @multi_ptr_select(%cond: i1, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
|
||||
llvm.call @func(%cond, %arg0, %arg1) : (i1, !llvm.ptr, !llvm.ptr) -> ()
|
||||
llvm.return
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user