diff --git a/mlir/include/mlir/Analysis/SliceWalk.h b/mlir/include/mlir/Analysis/SliceWalk.h index 481c5690c533..eb9ced2ff63b 100644 --- a/mlir/include/mlir/Analysis/SliceWalk.h +++ b/mlir/include/mlir/Analysis/SliceWalk.h @@ -88,9 +88,9 @@ WalkContinuation walkSlice(mlir::ValueRange rootValues, WalkCallback walkCallback); /// Computes a vector of all control predecessors of `value`. Relies on -/// RegionBranchOpInterface and BranchOpInterface to determine predecessors. -/// Returns nullopt if `value` has no predecessors or when the relevant -/// operations are missing the interface implementations. +/// RegionBranchOpInterface, BranchOpInterface, and SelectLikeOpInterface to +/// determine predecessors. Returns nullopt if `value` has no predecessors or +/// when the relevant operations are missing the interface implementations. std::optional> getControlFlowPredecessors(Value value); } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h index 00cdb13feb29..77241319851e 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -14,6 +14,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td index 477478a4651c..19a5e13a5d75 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -12,6 +12,7 @@ include "mlir/Dialect/Arith/IR/ArithBase.td" include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -1578,6 +1579,7 @@ def SelectOp : Arith_Op<"select", [Pure, AllTypesMatch<["true_value", "false_value", "result"]>, BooleanConditionOrMatchingShape<"condition", "result">, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, ] # ElementwiseMappable.traits> { let summary = "select operation"; let description = [{ diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index 643522d5903f..71f249fa538c 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -835,7 +835,8 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", def LLVM_SelectOp : LLVM_Op<"select", [Pure, AllTypesMatch<["trueValue", "falseValue", "res"]>, - DeclareOpInterfaceMethods]>, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]>, LLVM_Builder< "$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> { let arguments = (ins LLVM_ScalarOrVectorOf:$condition, diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td index 61c5a7a6394f..ab535d7b2a30 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td @@ -760,7 +760,8 @@ def SPIRV_SLessThanEqualOp : SPIRV_LogicalBinaryOp<"SLessThanEqual", def SPIRV_SelectOp : SPIRV_Op<"Select", [Pure, AllTypesMatch<["true_value", "false_value", "result"]>, - UsableInSpecConstantOp]> { + UsableInSpecConstantOp, + DeclareOpInterfaceMethods]> { let summary = [{ Select between two objects. Before version 1.4, results are only computed per component. diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td index 95ac5dea243a..69bce78e946c 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -343,6 +343,38 @@ def RegionBranchTerminatorOpInterface : }]; } +def SelectLikeOpInterface : OpInterface<"SelectLikeOpInterface"> { + let description = [{ + This interface provides information for select-like operations, i.e., + operations that forward specific operands to the output, depending on a + binary condition. + + If the value of the condition is 1, then the `true` operand is returned, + and the third operand is ignored, even if it was poison. + + If the value of the condition is 0, then the `false` operand is returned, + and the second operand is ignored, even if it was poison. + + If the condition is poison, then poison is returned. + + Implementing operations can also accept shaped conditions, in which case + the operation works element-wise. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Returns the operand that would be chosen for a false condition. + }], "::mlir::Value", "getFalseValue", (ins)>, + InterfaceMethod<[{ + Returns the operand that would be chosen for a true condition. + }], "::mlir::Value", "getTrueValue", (ins)>, + InterfaceMethod<[{ + Returns the condition operand. + }], "::mlir::Value", "getCondition", (ins)> + ]; +} + //===----------------------------------------------------------------------===// // ControlFlow Traits //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Analysis/SliceWalk.cpp b/mlir/lib/Analysis/SliceWalk.cpp index 9d770639dc53..817d71a3452c 100644 --- a/mlir/lib/Analysis/SliceWalk.cpp +++ b/mlir/lib/Analysis/SliceWalk.cpp @@ -104,9 +104,11 @@ getBlockPredecessorOperands(BlockArgument blockArg) { std::optional> mlir::getControlFlowPredecessors(Value value) { - SmallVector result; if (OpResult opResult = dyn_cast(value)) { - auto regionOp = dyn_cast(opResult.getOwner()); + if (auto selectOp = opResult.getDefiningOp()) + return SmallVector( + {selectOp.getTrueValue(), selectOp.getFalseValue()}); + auto regionOp = opResult.getDefiningOp(); // If the interface is not implemented, there are no control flow // predecessors to work with. if (!regionOp) diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp index 1399d419735d..031930dcfc21 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp @@ -235,11 +235,6 @@ getUnderlyingObjectSet(Value pointerValue) { if (auto addrCast = val.getDefiningOp()) return WalkContinuation::advanceTo(addrCast.getOperand()); - // TODO: Add a SelectLikeOpInterface and use it in the slicing utility. - if (auto selectOp = val.getDefiningOp()) - return WalkContinuation::advanceTo( - {selectOp.getTrueValue(), selectOp.getFalseValue()}); - // Attempt to advance to control flow predecessors. std::optional> controlFlowPredecessors = getControlFlowPredecessors(val); diff --git a/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir b/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir index bd5e7aa996ad..6b369c501210 100644 --- a/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir @@ -508,3 +508,51 @@ llvm.func @noalias_with_region(%arg0: !llvm.ptr) { llvm.call @region(%arg0) : (!llvm.ptr) -> () llvm.return } + +// ----- + +// CHECK-DAG: #[[DOMAIN:.*]] = #llvm.alias_scope_domain<{{.*}}> +// CHECK-DAG: #[[$ARG_SCOPE:.*]] = #llvm.alias_scope + +llvm.func @foo(%arg: i32) + +llvm.func @func(%arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) { + %cond = llvm.load %arg1 : !llvm.ptr -> i1 + %1 = llvm.getelementptr inbounds %arg0[1] : (!llvm.ptr) -> !llvm.ptr, f32 + %selected = llvm.select %cond, %arg0, %1 : i1, !llvm.ptr + %2 = llvm.load %selected : !llvm.ptr -> i32 + llvm.call @foo(%2) : (i32) -> () + llvm.return +} + +// CHECK-LABEL: llvm.func @selects +// CHECK: llvm.load +// CHECK-NOT: alias_scopes +// CHECK-SAME: noalias_scopes = [#[[$ARG_SCOPE]]] +// CHECK: llvm.load +// CHECK-SAME: alias_scopes = [#[[$ARG_SCOPE]]] +llvm.func @selects(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { + llvm.call @func(%arg0, %arg1) : (!llvm.ptr, !llvm.ptr) -> () + llvm.return +} + +// ----- + +llvm.func @foo(%arg: i32) + +llvm.func @func(%cond: i1, %arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) { + %selected = llvm.select %cond, %arg0, %arg1 : i1, !llvm.ptr + %2 = llvm.load %selected : !llvm.ptr -> i32 + llvm.call @foo(%2) : (i32) -> () + llvm.return +} + +// CHECK-LABEL: llvm.func @multi_ptr_select +// CHECK: llvm.load +// CHECK-NOT: alias_scopes +// CHECK-NOT: noalias_scopes +// CHECK: llvm.call @foo +llvm.func @multi_ptr_select(%cond: i1, %arg0: !llvm.ptr, %arg1: !llvm.ptr) { + llvm.call @func(%cond, %arg0, %arg1) : (i1, !llvm.ptr, !llvm.ptr) -> () + llvm.return +}