[MLIR] Introduce a SelectLikeOpInterface (#104751)

This commit introduces a `SelectLikeOpInterface` that can be used to
handle select-like operations generically. Select operations are similar
to control flow operations, as they forward operands depending on
conditions. This is the reason why it was placed to the already existing
control flow interfaces.
This commit is contained in:
Christian Ulmann 2024-08-20 07:32:12 +02:00 committed by GitHub
parent f9031f00f2
commit bf68e9047f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 94 additions and 12 deletions

View File

@ -88,9 +88,9 @@ WalkContinuation walkSlice(mlir::ValueRange rootValues,
WalkCallback walkCallback);
/// Computes a vector of all control predecessors of `value`. Relies on
/// RegionBranchOpInterface and BranchOpInterface to determine predecessors.
/// Returns nullopt if `value` has no predecessors or when the relevant
/// operations are missing the interface implementations.
/// RegionBranchOpInterface, BranchOpInterface, and SelectLikeOpInterface to
/// determine predecessors. Returns nullopt if `value` has no predecessors or
/// when the relevant operations are missing the interface implementations.
std::optional<SmallVector<Value>> getControlFlowPredecessors(Value value);
} // namespace mlir

View File

@ -14,6 +14,7 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

View File

@ -12,6 +12,7 @@
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "mlir/Interfaces/CastInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@ -1578,6 +1579,7 @@ def SelectOp : Arith_Op<"select", [Pure,
AllTypesMatch<["true_value", "false_value", "result"]>,
BooleanConditionOrMatchingShape<"condition", "result">,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
DeclareOpInterfaceMethods<SelectLikeOpInterface>,
] # ElementwiseMappable.traits> {
let summary = "select operation";
let description = [{

View File

@ -835,7 +835,8 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector",
def LLVM_SelectOp
: LLVM_Op<"select",
[Pure, AllTypesMatch<["trueValue", "falseValue", "res"]>,
DeclareOpInterfaceMethods<FastmathFlagsInterface>]>,
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
DeclareOpInterfaceMethods<SelectLikeOpInterface>]>,
LLVM_Builder<
"$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
let arguments = (ins LLVM_ScalarOrVectorOf<I1>:$condition,

View File

@ -760,7 +760,8 @@ def SPIRV_SLessThanEqualOp : SPIRV_LogicalBinaryOp<"SLessThanEqual",
def SPIRV_SelectOp : SPIRV_Op<"Select",
[Pure,
AllTypesMatch<["true_value", "false_value", "result"]>,
UsableInSpecConstantOp]> {
UsableInSpecConstantOp,
DeclareOpInterfaceMethods<SelectLikeOpInterface>]> {
let summary = [{
Select between two objects. Before version 1.4, results are only
computed per component.

View File

@ -343,6 +343,38 @@ def RegionBranchTerminatorOpInterface :
}];
}
def SelectLikeOpInterface : OpInterface<"SelectLikeOpInterface"> {
let description = [{
This interface provides information for select-like operations, i.e.,
operations that forward specific operands to the output, depending on a
binary condition.
If the value of the condition is 1, then the `true` operand is returned,
and the third operand is ignored, even if it was poison.
If the value of the condition is 0, then the `false` operand is returned,
and the second operand is ignored, even if it was poison.
If the condition is poison, then poison is returned.
Implementing operations can also accept shaped conditions, in which case
the operation works element-wise.
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<[{
Returns the operand that would be chosen for a false condition.
}], "::mlir::Value", "getFalseValue", (ins)>,
InterfaceMethod<[{
Returns the operand that would be chosen for a true condition.
}], "::mlir::Value", "getTrueValue", (ins)>,
InterfaceMethod<[{
Returns the condition operand.
}], "::mlir::Value", "getCondition", (ins)>
];
}
//===----------------------------------------------------------------------===//
// ControlFlow Traits
//===----------------------------------------------------------------------===//

View File

@ -104,9 +104,11 @@ getBlockPredecessorOperands(BlockArgument blockArg) {
std::optional<SmallVector<Value>>
mlir::getControlFlowPredecessors(Value value) {
SmallVector<Value> result;
if (OpResult opResult = dyn_cast<OpResult>(value)) {
auto regionOp = dyn_cast<RegionBranchOpInterface>(opResult.getOwner());
if (auto selectOp = opResult.getDefiningOp<SelectLikeOpInterface>())
return SmallVector<Value>(
{selectOp.getTrueValue(), selectOp.getFalseValue()});
auto regionOp = opResult.getDefiningOp<RegionBranchOpInterface>();
// If the interface is not implemented, there are no control flow
// predecessors to work with.
if (!regionOp)

View File

@ -235,11 +235,6 @@ getUnderlyingObjectSet(Value pointerValue) {
if (auto addrCast = val.getDefiningOp<LLVM::AddrSpaceCastOp>())
return WalkContinuation::advanceTo(addrCast.getOperand());
// TODO: Add a SelectLikeOpInterface and use it in the slicing utility.
if (auto selectOp = val.getDefiningOp<LLVM::SelectOp>())
return WalkContinuation::advanceTo(
{selectOp.getTrueValue(), selectOp.getFalseValue()});
// Attempt to advance to control flow predecessors.
std::optional<SmallVector<Value>> controlFlowPredecessors =
getControlFlowPredecessors(val);

View File

@ -508,3 +508,51 @@ llvm.func @noalias_with_region(%arg0: !llvm.ptr) {
llvm.call @region(%arg0) : (!llvm.ptr) -> ()
llvm.return
}
// -----
// CHECK-DAG: #[[DOMAIN:.*]] = #llvm.alias_scope_domain<{{.*}}>
// CHECK-DAG: #[[$ARG_SCOPE:.*]] = #llvm.alias_scope<id = {{.*}}, domain = #[[DOMAIN]]{{(,.*)?}}>
llvm.func @foo(%arg: i32)
llvm.func @func(%arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) {
%cond = llvm.load %arg1 : !llvm.ptr -> i1
%1 = llvm.getelementptr inbounds %arg0[1] : (!llvm.ptr) -> !llvm.ptr, f32
%selected = llvm.select %cond, %arg0, %1 : i1, !llvm.ptr
%2 = llvm.load %selected : !llvm.ptr -> i32
llvm.call @foo(%2) : (i32) -> ()
llvm.return
}
// CHECK-LABEL: llvm.func @selects
// CHECK: llvm.load
// CHECK-NOT: alias_scopes
// CHECK-SAME: noalias_scopes = [#[[$ARG_SCOPE]]]
// CHECK: llvm.load
// CHECK-SAME: alias_scopes = [#[[$ARG_SCOPE]]]
llvm.func @selects(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
llvm.call @func(%arg0, %arg1) : (!llvm.ptr, !llvm.ptr) -> ()
llvm.return
}
// -----
llvm.func @foo(%arg: i32)
llvm.func @func(%cond: i1, %arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) {
%selected = llvm.select %cond, %arg0, %arg1 : i1, !llvm.ptr
%2 = llvm.load %selected : !llvm.ptr -> i32
llvm.call @foo(%2) : (i32) -> ()
llvm.return
}
// CHECK-LABEL: llvm.func @multi_ptr_select
// CHECK: llvm.load
// CHECK-NOT: alias_scopes
// CHECK-NOT: noalias_scopes
// CHECK: llvm.call @foo
llvm.func @multi_ptr_select(%cond: i1, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
llvm.call @func(%cond, %arg0, %arg1) : (i1, !llvm.ptr, !llvm.ptr) -> ()
llvm.return
}