[mlir][flang] Added Weighted[Region]BranchOpInterface's. (#142079)
The new interfaces provide getters and setters for the weight information about the branches of BranchOpInterface and RegionBranchOpInterface operations. These interfaces are done the same way as LLVM dialect's BranchWeightOpInterface. The plan is to produce this information in Flang, e.g. mark most probably "cold" code as such and allow LLVM to order basic blocks accordingly. An example of such a code is copy loops generated for arrays repacking - we can mark it as "cold" assuming that the copy will not happen dynamically. If the copy actually happens the overhead of the copy is probably high enough so that we may not care about the little overhead of jumping to the "cold" code and fetching it.
This commit is contained in:
parent
af65cb68f5
commit
70343c8d44
@ -2323,9 +2323,13 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
|
||||
}];
|
||||
}
|
||||
|
||||
def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
|
||||
"getRegionInvocationBounds", "getEntrySuccessorRegions"]>, RecursiveMemoryEffects,
|
||||
NoRegionArguments]> {
|
||||
def fir_IfOp
|
||||
: region_Op<
|
||||
"if", [DeclareOpInterfaceMethods<
|
||||
RegionBranchOpInterface, ["getRegionInvocationBounds",
|
||||
"getEntrySuccessorRegions"]>,
|
||||
RecursiveMemoryEffects, NoRegionArguments,
|
||||
WeightedRegionBranchOpInterface]> {
|
||||
let summary = "if-then-else conditional operation";
|
||||
let description = [{
|
||||
Used to conditionally execute operations. This operation is the FIR
|
||||
@ -2342,7 +2346,8 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins I1:$condition);
|
||||
let arguments = (ins I1:$condition,
|
||||
OptionalAttr<DenseI32ArrayAttr>:$region_weights);
|
||||
let results = (outs Variadic<AnyType>:$results);
|
||||
|
||||
let regions = (region
|
||||
@ -2371,6 +2376,11 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
|
||||
|
||||
void resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,
|
||||
unsigned resultNum);
|
||||
|
||||
/// Returns the display name string for the region_weights attribute.
|
||||
static constexpr llvm::StringRef getWeightsAttrAssemblyName() {
|
||||
return "weights";
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -4418,6 +4418,19 @@ mlir::ParseResult fir::IfOp::parse(mlir::OpAsmParser &parser,
|
||||
parser.resolveOperand(cond, i1Type, result.operands))
|
||||
return mlir::failure();
|
||||
|
||||
if (mlir::succeeded(
|
||||
parser.parseOptionalKeyword(getWeightsAttrAssemblyName()))) {
|
||||
if (parser.parseLParen())
|
||||
return mlir::failure();
|
||||
mlir::DenseI32ArrayAttr weights;
|
||||
if (parser.parseCustomAttributeWithFallback(weights, mlir::Type{}))
|
||||
return mlir::failure();
|
||||
if (weights)
|
||||
result.addAttribute(getRegionWeightsAttrName(result.name), weights);
|
||||
if (parser.parseRParen())
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
if (parser.parseOptionalArrowTypeList(result.types))
|
||||
return mlir::failure();
|
||||
|
||||
@ -4449,6 +4462,11 @@ llvm::LogicalResult fir::IfOp::verify() {
|
||||
void fir::IfOp::print(mlir::OpAsmPrinter &p) {
|
||||
bool printBlockTerminators = false;
|
||||
p << ' ' << getCondition();
|
||||
if (auto weights = getRegionWeightsAttr()) {
|
||||
p << ' ' << getWeightsAttrAssemblyName() << '(';
|
||||
p.printStrippedAttrOrType(weights);
|
||||
p << ')';
|
||||
}
|
||||
if (!getResults().empty()) {
|
||||
p << " -> (" << getResultTypes() << ')';
|
||||
printBlockTerminators = true;
|
||||
@ -4464,7 +4482,8 @@ void fir::IfOp::print(mlir::OpAsmPrinter &p) {
|
||||
p.printRegion(otherReg, /*printEntryBlockArgs=*/false,
|
||||
printBlockTerminators);
|
||||
}
|
||||
p.printOptionalAttrDict((*this)->getAttrs());
|
||||
p.printOptionalAttrDict((*this)->getAttrs(),
|
||||
/*elideAttrs=*/{getRegionWeightsAttrName()});
|
||||
}
|
||||
|
||||
void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,
|
||||
|
@ -212,9 +212,12 @@ public:
|
||||
}
|
||||
|
||||
rewriter.setInsertionPointToEnd(condBlock);
|
||||
rewriter.create<mlir::cf::CondBranchOp>(
|
||||
auto branchOp = rewriter.create<mlir::cf::CondBranchOp>(
|
||||
loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
|
||||
otherwiseBlock, llvm::ArrayRef<mlir::Value>());
|
||||
llvm::ArrayRef<int32_t> weights = ifOp.getWeights();
|
||||
if (!weights.empty())
|
||||
branchOp.setWeights(weights);
|
||||
rewriter.replaceOp(ifOp, continueBlock->getArguments());
|
||||
return success();
|
||||
}
|
||||
|
46
flang/test/Fir/cfg-conversion-if.fir
Normal file
46
flang/test/Fir/cfg-conversion-if.fir
Normal file
@ -0,0 +1,46 @@
|
||||
// RUN: fir-opt --split-input-file --cfg-conversion %s | FileCheck %s
|
||||
|
||||
func.func private @callee() -> none
|
||||
|
||||
// CHECK-LABEL: func.func @if_then(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: i1) {
|
||||
// CHECK: cf.cond_br %[[ARG0]] weights([10, 90]), ^bb1, ^bb2
|
||||
// CHECK: ^bb1:
|
||||
// CHECK: %[[VAL_0:.*]] = fir.call @callee() : () -> none
|
||||
// CHECK: cf.br ^bb2
|
||||
// CHECK: ^bb2:
|
||||
// CHECK: return
|
||||
// CHECK: }
|
||||
func.func @if_then(%cond: i1) {
|
||||
fir.if %cond weights([10, 90]) {
|
||||
fir.call @callee() : () -> none
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @if_then_else(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: i1) -> i32 {
|
||||
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32
|
||||
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32
|
||||
// CHECK: cf.cond_br %[[ARG0]] weights([90, 10]), ^bb1, ^bb2
|
||||
// CHECK: ^bb1:
|
||||
// CHECK: cf.br ^bb3(%[[VAL_0]] : i32)
|
||||
// CHECK: ^bb2:
|
||||
// CHECK: cf.br ^bb3(%[[VAL_1]] : i32)
|
||||
// CHECK: ^bb3(%[[VAL_2:.*]]: i32):
|
||||
// CHECK: cf.br ^bb4
|
||||
// CHECK: ^bb4:
|
||||
// CHECK: return %[[VAL_2]] : i32
|
||||
// CHECK: }
|
||||
func.func @if_then_else(%cond: i1) -> i32 {
|
||||
%c0 = arith.constant 0 : i32
|
||||
%c1 = arith.constant 1 : i32
|
||||
%result = fir.if %cond weights([90, 10]) -> i32 {
|
||||
fir.result %c0 : i32
|
||||
} else {
|
||||
fir.result %c1 : i32
|
||||
}
|
||||
return %result : i32
|
||||
}
|
@ -1015,3 +1015,19 @@ func.func @test_box_total_elements(%arg0: !fir.class<!fir.type<sometype{i:i32}>>
|
||||
%6 = arith.addi %2, %5 : index
|
||||
return %6 : index
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @test_if_weights(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: i1) {
|
||||
func.func @test_if_weights(%cond: i1) {
|
||||
// CHECK: fir.if %[[ARG0]] weights([99, 1]) {
|
||||
// CHECK: }
|
||||
fir.if %cond weights([99, 1]) {
|
||||
}
|
||||
// CHECK: fir.if %[[ARG0]] weights([99, 1]) {
|
||||
// CHECK: } else {
|
||||
// CHECK: }
|
||||
fir.if %cond weights ([99,1]) {
|
||||
} else {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -1393,3 +1393,31 @@ fir.local {type = local_init} @x.localizer : f32 init {
|
||||
^bb0(%arg0: f32, %arg1: f32):
|
||||
fir.yield(%arg0 : f32)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @wrong_weights_number_in_if_then(%cond: i1) {
|
||||
// expected-error @below {{expects number of region weights to match number of regions: 1 vs 2}}
|
||||
fir.if %cond weights([50]) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @wrong_weights_number_in_if_then_else(%cond: i1) {
|
||||
// expected-error @below {{expects number of region weights to match number of regions: 3 vs 2}}
|
||||
fir.if %cond weights([50, 40, 10]) {
|
||||
} else {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @negative_weight_in_if_then(%cond: i1) {
|
||||
// expected-error @below {{weight #0 must be non-negative}}
|
||||
fir.if %cond weights([-1, 101]) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -112,10 +112,11 @@ def BranchOp : CF_Op<"br", [
|
||||
// CondBranchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def CondBranchOp : CF_Op<"cond_br",
|
||||
[AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
|
||||
Pure, Terminator]> {
|
||||
def CondBranchOp
|
||||
: CF_Op<"cond_br", [AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<
|
||||
BranchOpInterface, ["getSuccessorForOperands"]>,
|
||||
WeightedBranchOpInterface, Pure, Terminator]> {
|
||||
let summary = "Conditional branch operation";
|
||||
let description = [{
|
||||
The `cf.cond_br` terminator operation represents a conditional branch on a
|
||||
@ -144,20 +145,23 @@ def CondBranchOp : CF_Op<"cond_br",
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins I1:$condition,
|
||||
Variadic<AnyType>:$trueDestOperands,
|
||||
Variadic<AnyType>:$falseDestOperands);
|
||||
let arguments = (ins I1:$condition, Variadic<AnyType>:$trueDestOperands,
|
||||
Variadic<AnyType>:$falseDestOperands,
|
||||
OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
|
||||
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
|
||||
"ValueRange":$trueOperands, "Block *":$falseDest,
|
||||
"ValueRange":$falseOperands), [{
|
||||
build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
|
||||
let builders = [OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
|
||||
"ValueRange":$trueOperands,
|
||||
"Block *":$falseDest,
|
||||
"ValueRange":$falseOperands),
|
||||
[{
|
||||
build($_builder, $_state, condition, trueOperands, falseOperands, /*branch_weights=*/{}, trueDest,
|
||||
falseDest);
|
||||
}]>,
|
||||
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
|
||||
"Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
|
||||
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
|
||||
"Block *":$falseDest,
|
||||
CArg<"ValueRange", "{}">:$falseOperands),
|
||||
[{
|
||||
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
|
||||
falseOperands);
|
||||
}]>];
|
||||
@ -216,7 +220,7 @@ def CondBranchOp : CF_Op<"cond_br",
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let assemblyFormat = [{
|
||||
$condition `,`
|
||||
$condition (`weights` `(` $branch_weights^ `)` )? `,`
|
||||
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
|
||||
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
|
||||
attr-dict
|
||||
|
@ -168,42 +168,6 @@ def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
|
||||
];
|
||||
}
|
||||
|
||||
def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
|
||||
let description = [{
|
||||
An interface for operations that can carry branch weights metadata. It
|
||||
provides setters and getters for the operation's branch weights attribute.
|
||||
The default implementation of the interface methods expect the operation to
|
||||
have an attribute of type DenseI32ArrayAttr named branch_weights.
|
||||
}];
|
||||
|
||||
let cppNamespace = "::mlir::LLVM";
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<
|
||||
/*desc=*/ "Returns the branch weights attribute or nullptr",
|
||||
/*returnType=*/ "::mlir::DenseI32ArrayAttr",
|
||||
/*methodName=*/ "getBranchWeightsOrNull",
|
||||
/*args=*/ (ins),
|
||||
/*methodBody=*/ [{}],
|
||||
/*defaultImpl=*/ [{
|
||||
auto op = cast<ConcreteOp>(this->getOperation());
|
||||
return op.getBranchWeightsAttr();
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/ "Sets the branch weights attribute",
|
||||
/*returnType=*/ "void",
|
||||
/*methodName=*/ "setBranchWeights",
|
||||
/*args=*/ (ins "::mlir::DenseI32ArrayAttr":$attr),
|
||||
/*methodBody=*/ [{}],
|
||||
/*defaultImpl=*/ [{
|
||||
auto op = cast<ConcreteOp>(this->getOperation());
|
||||
op.setBranchWeightsAttr(attr);
|
||||
}]
|
||||
>
|
||||
];
|
||||
}
|
||||
|
||||
def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
|
||||
let description = [{
|
||||
An interface for memory operations that can carry access groups metadata.
|
||||
|
@ -660,12 +660,12 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "FPTrunc",
|
||||
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
|
||||
|
||||
// Call-related operations.
|
||||
def LLVM_InvokeOp : LLVM_Op<"invoke", [
|
||||
AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<BranchOpInterface>,
|
||||
DeclareOpInterfaceMethods<CallOpInterface>,
|
||||
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
|
||||
Terminator]> {
|
||||
def LLVM_InvokeOp
|
||||
: LLVM_Op<"invoke", [AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<BranchOpInterface>,
|
||||
DeclareOpInterfaceMethods<CallOpInterface>,
|
||||
DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
|
||||
Terminator]> {
|
||||
let arguments = (ins
|
||||
OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
|
||||
OptionalAttr<FlatSymbolRefAttr>:$callee,
|
||||
@ -734,12 +734,12 @@ def LLVM_VaArgOp : LLVM_Op<"va_arg"> {
|
||||
// CallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
|
||||
[AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
|
||||
DeclareOpInterfaceMethods<CallOpInterface>,
|
||||
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
|
||||
DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
|
||||
def LLVM_CallOp
|
||||
: LLVM_MemAccessOpBase<
|
||||
"call", [AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
|
||||
DeclareOpInterfaceMethods<CallOpInterface>,
|
||||
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
|
||||
let summary = "Call to an LLVM function.";
|
||||
let description = [{
|
||||
In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect
|
||||
@ -788,21 +788,16 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
|
||||
OptionalAttr<FlatSymbolRefAttr>:$callee,
|
||||
Variadic<LLVM_Type>:$callee_operands,
|
||||
DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags,
|
||||
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
|
||||
DefaultValuedAttr<CConv, "CConv::C">:$CConv,
|
||||
DefaultValuedAttr<TailCallKind, "TailCallKind::None">:$TailCallKind,
|
||||
OptionalAttr<LLVM_MemoryEffectsAttr>:$memory_effects,
|
||||
UnitAttr:$convergent,
|
||||
UnitAttr:$no_unwind,
|
||||
UnitAttr:$will_return,
|
||||
UnitAttr:$convergent, UnitAttr:$no_unwind, UnitAttr:$will_return,
|
||||
VariadicOfVariadic<LLVM_Type, "op_bundle_sizes">:$op_bundle_operands,
|
||||
DenseI32ArrayAttr:$op_bundle_sizes,
|
||||
OptionalAttr<ArrayAttr>:$op_bundle_tags,
|
||||
OptionalAttr<DictArrayAttr>:$arg_attrs,
|
||||
OptionalAttr<DictArrayAttr>:$res_attrs,
|
||||
UnitAttr:$no_inline,
|
||||
UnitAttr:$always_inline,
|
||||
UnitAttr:$inline_hint);
|
||||
OptionalAttr<DictArrayAttr>:$res_attrs, UnitAttr:$no_inline,
|
||||
UnitAttr:$always_inline, UnitAttr:$inline_hint);
|
||||
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
|
||||
let arguments = !con(args, aliasAttrs);
|
||||
let results = (outs Optional<LLVM_Type>:$result);
|
||||
@ -1047,11 +1042,12 @@ def LLVM_BrOp : LLVM_TerminatorOp<"br",
|
||||
LLVM_TerminatorPassthroughOpBuilder
|
||||
];
|
||||
}
|
||||
def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
|
||||
[AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<BranchOpInterface>,
|
||||
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
|
||||
Pure]> {
|
||||
def LLVM_CondBrOp
|
||||
: LLVM_TerminatorOp<
|
||||
"cond_br", [AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<BranchOpInterface>,
|
||||
DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
|
||||
Pure]> {
|
||||
let arguments = (ins I1:$condition,
|
||||
Variadic<LLVM_Type>:$trueDestOperands,
|
||||
Variadic<LLVM_Type>:$falseDestOperands,
|
||||
@ -1136,11 +1132,12 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
|
||||
[AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<BranchOpInterface>,
|
||||
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
|
||||
Pure]> {
|
||||
def LLVM_SwitchOp
|
||||
: LLVM_TerminatorOp<
|
||||
"switch", [AttrSizedOperandSegments,
|
||||
DeclareOpInterfaceMethods<BranchOpInterface>,
|
||||
DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
|
||||
Pure]> {
|
||||
let arguments = (ins
|
||||
AnySignlessInteger:$value,
|
||||
Variadic<AnyType>:$defaultOperands,
|
||||
|
@ -142,6 +142,26 @@ LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
|
||||
const SuccessorOperands &operands);
|
||||
} // namespace detail
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WeightedBranchOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace detail {
|
||||
/// Verify that the branch weights attached to an operation
|
||||
/// implementing WeightedBranchOpInterface are correct.
|
||||
LogicalResult verifyBranchWeights(Operation *op);
|
||||
} // namespace detail
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WeightedRegiobBranchOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace detail {
|
||||
/// Verify that the region weights attached to an operation
|
||||
/// implementing WeightedRegiobBranchOpInterface are correct.
|
||||
LogicalResult verifyRegionBranchWeights(Operation *op);
|
||||
} // namespace detail
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RegionBranchOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -375,6 +375,118 @@ def SelectLikeOpInterface : OpInterface<"SelectLikeOpInterface"> {
|
||||
];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WeightedBranchOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def WeightedBranchOpInterface : OpInterface<"WeightedBranchOpInterface"> {
|
||||
let description = [{
|
||||
This interface provides weight information for branching terminator
|
||||
operations, i.e. terminator operations with successors.
|
||||
|
||||
This interface provides methods for getting/setting integer non-negative
|
||||
weight of each branch. The probability of executing a branch
|
||||
is computed as the ratio between the branch's weight and the total
|
||||
sum of the weights (which cannot be zero).
|
||||
The weights are optional. If they are provided, then their number
|
||||
must match the number of successors of the operation.
|
||||
|
||||
The default implementations of the methods expect the operation
|
||||
to have an attribute of type DenseI32ArrayAttr named branch_weights.
|
||||
}];
|
||||
let cppNamespace = "::mlir";
|
||||
|
||||
let methods = [InterfaceMethod<
|
||||
/*desc=*/"Returns the branch weights",
|
||||
/*returnType=*/"::llvm::ArrayRef<int32_t>",
|
||||
/*methodName=*/"getWeights",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImpl=*/[{
|
||||
auto op = cast<ConcreteOp>(this->getOperation());
|
||||
if (auto attr = op.getBranchWeightsAttr())
|
||||
return attr.asArrayRef();
|
||||
return {};
|
||||
}]>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/"Sets the branch weights",
|
||||
/*returnType=*/"void",
|
||||
/*methodName=*/"setWeights",
|
||||
/*args=*/(ins "::llvm::ArrayRef<int32_t>":$weights),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImpl=*/[{
|
||||
auto op = cast<ConcreteOp>(this->getOperation());
|
||||
op.setBranchWeightsAttr(::mlir::DenseI32ArrayAttr::get(op->getContext(), weights));
|
||||
}]>,
|
||||
];
|
||||
|
||||
let verify = [{
|
||||
return ::mlir::detail::verifyBranchWeights($_op);
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WeightedRegionBranchOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO: the probabilities of entering a particular region seem
|
||||
// to correlate with the values returned by
|
||||
// RegionBranchOpInterface::invocationBounds(), and we should probably
|
||||
// verify that the values are consistent. In that case, should
|
||||
// WeightedRegionBranchOpInterface extend RegionBranchOpInterface?
|
||||
def WeightedRegionBranchOpInterface
|
||||
: OpInterface<"WeightedRegionBranchOpInterface"> {
|
||||
let description = [{
|
||||
This interface provides weight information for region operations
|
||||
that exhibit branching behavior between held regions.
|
||||
|
||||
This interface provides methods for getting/setting integer non-negative
|
||||
weight of each branch. The probability of executing a region is computed
|
||||
as the ratio between the region branch's weight and the total sum
|
||||
of the weights (which cannot be zero).
|
||||
The weights are optional. If they are provided, then their number
|
||||
must match the number of regions held by the operation
|
||||
(including empty regions).
|
||||
|
||||
The weights specify the probability of branching to a particular
|
||||
region when first executing the operation.
|
||||
For example, for loop-like operations with a single region
|
||||
the weight specifies the probability of entering the loop.
|
||||
|
||||
The default implementations of the methods expect the operation
|
||||
to have an attribute of type DenseI32ArrayAttr named branch_weights.
|
||||
}];
|
||||
let cppNamespace = "::mlir";
|
||||
|
||||
let methods = [InterfaceMethod<
|
||||
/*desc=*/"Returns the region weights",
|
||||
/*returnType=*/"::llvm::ArrayRef<int32_t>",
|
||||
/*methodName=*/"getWeights",
|
||||
/*args=*/(ins),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImpl=*/[{
|
||||
auto op = cast<ConcreteOp>(this->getOperation());
|
||||
if (auto attr = op.getRegionWeightsAttr())
|
||||
return attr.asArrayRef();
|
||||
return {};
|
||||
}]>,
|
||||
InterfaceMethod<
|
||||
/*desc=*/"Sets the region weights",
|
||||
/*returnType=*/"void",
|
||||
/*methodName=*/"setWeights",
|
||||
/*args=*/(ins "::llvm::ArrayRef<int32_t>":$weights),
|
||||
/*methodBody=*/[{}],
|
||||
/*defaultImpl=*/[{
|
||||
auto op = cast<ConcreteOp>(this->getOperation());
|
||||
op.setRegionWeightsAttr(::mlir::DenseI32ArrayAttr::get(op->getContext(), weights));
|
||||
}]>,
|
||||
];
|
||||
|
||||
let verify = [{
|
||||
return ::mlir::detail::verifyRegionBranchWeights($_op);
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ControlFlow Traits
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -189,7 +189,7 @@ public:
|
||||
llvm::Instruction *inst);
|
||||
|
||||
/// Sets LLVM profiling metadata for operations that have branch weights.
|
||||
void setBranchWeightsMetadata(BranchWeightOpInterface op);
|
||||
void setBranchWeightsMetadata(WeightedBranchOpInterface op);
|
||||
|
||||
/// Sets LLVM loop metadata for branch operations that have a loop annotation
|
||||
/// attribute.
|
||||
|
@ -166,10 +166,15 @@ struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
|
||||
TypeRange(adaptor.getFalseDestOperands()));
|
||||
if (failed(convertedFalseBlock))
|
||||
return failure();
|
||||
Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
|
||||
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
|
||||
op, adaptor.getCondition(), *convertedTrueBlock,
|
||||
adaptor.getTrueDestOperands(), *convertedFalseBlock,
|
||||
adaptor.getFalseDestOperands());
|
||||
ArrayRef<int32_t> weights = op.getWeights();
|
||||
if (!weights.empty()) {
|
||||
newOp.setWeights(weights);
|
||||
op.removeBranchWeightsAttr();
|
||||
}
|
||||
// TODO: We should not just forward all attributes like that. But there are
|
||||
// existing Flang tests that depend on this behavior.
|
||||
newOp->setAttrs(op->getAttrDictionary());
|
||||
|
@ -589,10 +589,6 @@ LogicalResult SwitchOp::verify() {
|
||||
static_cast<int64_t>(getCaseDestinations().size())))
|
||||
return emitOpError("expects number of case values to match number of "
|
||||
"case destinations");
|
||||
if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors())
|
||||
return emitError("expects number of branch weights to match number of "
|
||||
"successors: ")
|
||||
<< getBranchWeights()->size() << " vs " << getNumSuccessors();
|
||||
if (getCaseValues() &&
|
||||
getValue().getType() != getCaseValues()->getElementType())
|
||||
return emitError("expects case value type to match condition value type");
|
||||
@ -962,7 +958,6 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
|
||||
assert(callee && "expected non-null callee in direct call builder");
|
||||
build(builder, state, results,
|
||||
/*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
|
||||
/*branch_weights=*/nullptr,
|
||||
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
|
||||
/*memory_effects=*/nullptr,
|
||||
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
|
||||
@ -992,7 +987,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
|
||||
build(builder, state, getCallOpResultTypes(calleeType),
|
||||
getCallOpVarCalleeType(calleeType), callee, args,
|
||||
/*fastmathFlags=*/nullptr,
|
||||
/*branch_weights=*/nullptr, /*CConv=*/nullptr,
|
||||
/*CConv=*/nullptr,
|
||||
/*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
|
||||
/*convergent=*/nullptr,
|
||||
/*no_unwind=*/nullptr, /*will_return=*/nullptr,
|
||||
@ -1009,7 +1004,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
|
||||
build(builder, state, getCallOpResultTypes(calleeType),
|
||||
getCallOpVarCalleeType(calleeType),
|
||||
/*callee=*/nullptr, args,
|
||||
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
|
||||
/*fastmathFlags=*/nullptr,
|
||||
/*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
|
||||
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
|
||||
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
|
||||
@ -1025,7 +1020,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
|
||||
auto calleeType = func.getFunctionType();
|
||||
build(builder, state, getCallOpResultTypes(calleeType),
|
||||
getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
|
||||
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
|
||||
/*fastmathFlags=*/nullptr,
|
||||
/*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
|
||||
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
|
||||
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include <utility>
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
|
||||
@ -80,6 +81,51 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WeightedBranchOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyWeights(Operation *op,
|
||||
llvm::ArrayRef<int32_t> weights,
|
||||
std::size_t expectedWeightsNum,
|
||||
llvm::StringRef weightAnchorName,
|
||||
llvm::StringRef weightRefName) {
|
||||
if (weights.empty())
|
||||
return success();
|
||||
|
||||
if (weights.size() != expectedWeightsNum)
|
||||
return op->emitError() << "expects number of " << weightAnchorName
|
||||
<< " weights to match number of " << weightRefName
|
||||
<< ": " << weights.size() << " vs "
|
||||
<< expectedWeightsNum;
|
||||
|
||||
for (auto [index, weight] : llvm::enumerate(weights))
|
||||
if (weight < 0)
|
||||
return op->emitError() << "weight #" << index << " must be non-negative";
|
||||
|
||||
if (llvm::all_of(weights, [](int32_t value) { return value == 0; }))
|
||||
return op->emitError() << "branch weights cannot all be zero";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult detail::verifyBranchWeights(Operation *op) {
|
||||
llvm::ArrayRef<int32_t> weights =
|
||||
cast<WeightedBranchOpInterface>(op).getWeights();
|
||||
return verifyWeights(op, weights, op->getNumSuccessors(), "branch",
|
||||
"successors");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WeightedRegionBranchOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
|
||||
llvm::ArrayRef<int32_t> weights =
|
||||
cast<WeightedRegionBranchOpInterface>(op).getWeights();
|
||||
return verifyWeights(op, weights, op->getNumRegions(), "region", "regions");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RegionBranchOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -146,8 +146,15 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
|
||||
branchWeights.push_back(branchWeight->getZExtValue());
|
||||
}
|
||||
|
||||
if (auto iface = dyn_cast<BranchWeightOpInterface>(op)) {
|
||||
iface.setBranchWeights(builder.getDenseI32ArrayAttr(branchWeights));
|
||||
if (auto iface = dyn_cast<WeightedBranchOpInterface>(op)) {
|
||||
// LLVM allows attaching a single weight to call instructions.
|
||||
// This is used for carrying the execution count information
|
||||
// in PGO modes. MLIR WeightedBranchOpInterface does not allow this,
|
||||
// so we drop the metadata in this case.
|
||||
// LLVM should probably use the VP form of MD_prof metadata
|
||||
// for such cases.
|
||||
if (op->getNumSuccessors() != 0)
|
||||
iface.setWeights(branchWeights);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
|
@ -1055,7 +1055,7 @@ LogicalResult ModuleTranslation::convertBlockImpl(Block &bb,
|
||||
return failure();
|
||||
|
||||
// Set the branch weight metadata on the translated instruction.
|
||||
if (auto iface = dyn_cast<BranchWeightOpInterface>(op))
|
||||
if (auto iface = dyn_cast<WeightedBranchOpInterface>(op))
|
||||
setBranchWeightsMetadata(iface);
|
||||
}
|
||||
|
||||
@ -2026,14 +2026,15 @@ void ModuleTranslation::setDereferenceableMetadata(
|
||||
inst->setMetadata(kindId, derefSizeNode);
|
||||
}
|
||||
|
||||
void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
|
||||
DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
|
||||
if (!weightsAttr)
|
||||
void ModuleTranslation::setBranchWeightsMetadata(WeightedBranchOpInterface op) {
|
||||
SmallVector<uint32_t> weights;
|
||||
llvm::transform(op.getWeights(), std::back_inserter(weights),
|
||||
[](int32_t value) { return static_cast<uint32_t>(value); });
|
||||
if (weights.empty())
|
||||
return;
|
||||
|
||||
llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op);
|
||||
assert(inst && "expected the operation to have a mapping to an instruction");
|
||||
SmallVector<uint32_t> weights(weightsAttr.asArrayRef());
|
||||
inst->setMetadata(
|
||||
llvm::LLVMContext::MD_prof,
|
||||
llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights));
|
||||
|
@ -67,3 +67,17 @@ func.func @unreachable_block() {
|
||||
^bb1(%arg0: index):
|
||||
cf.br ^bb1(%arg0 : index)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case for cf.cond_br with weights.
|
||||
|
||||
// CHECK-LABEL: func.func @cf_cond_br_with_weights(
|
||||
func.func @cf_cond_br_with_weights(%cond: i1, %a: index, %b: index) -> index {
|
||||
// CHECK: llvm.cond_br %{{.*}} weights([90, 10]), ^bb1(%{{.*}} : i64), ^bb2(%{{.*}} : i64)
|
||||
cf.cond_br %cond, ^bb1(%a : index), ^bb2(%b : index) {branch_weights = array<i32: 90, 10>}
|
||||
^bb1(%arg1: index):
|
||||
return %arg1 : index
|
||||
^bb2(%arg2: index):
|
||||
return %arg2 : index
|
||||
}
|
||||
|
@ -67,3 +67,39 @@ func.func @switch_missing_default(%flag : i32, %caseOperand : i32) {
|
||||
^bb3(%bb3arg : i32):
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @wrong_weights_number
|
||||
func.func @wrong_weights_number(%cond: i1) {
|
||||
// expected-error@+1 {{expects number of branch weights to match number of successors: 1 vs 2}}
|
||||
cf.cond_br %cond weights([100]), ^bb1, ^bb2
|
||||
^bb1:
|
||||
return
|
||||
^bb2:
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @negative_weight
|
||||
func.func @wrong_total_weight(%cond: i1) {
|
||||
// expected-error@+1 {{weight #0 must be non-negative}}
|
||||
cf.cond_br %cond weights([-1, 101]), ^bb1, ^bb2
|
||||
^bb1:
|
||||
return
|
||||
^bb2:
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @zero_weights
|
||||
func.func @wrong_total_weight(%cond: i1) {
|
||||
// expected-error@+1 {{branch weights cannot all be zero}}
|
||||
cf.cond_br %cond weights([0, 0]), ^bb1, ^bb2
|
||||
^bb1:
|
||||
return
|
||||
^bb2:
|
||||
return
|
||||
}
|
||||
|
@ -51,3 +51,13 @@ func.func @switch_result_number(%arg0: i32) {
|
||||
^bb2:
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @cond_weights
|
||||
func.func @cond_weights(%cond: i1) {
|
||||
// CHECK: cf.cond_br %{{.*}} weights([60, 40]), ^{{.*}}, ^{{.*}}
|
||||
cf.cond_br %cond weights([60, 40]), ^bb1, ^bb2
|
||||
^bb1:
|
||||
return
|
||||
^bb2:
|
||||
return
|
||||
}
|
||||
|
@ -36,14 +36,17 @@ bbd:
|
||||
|
||||
; // -----
|
||||
|
||||
; Verify that a single weight attached to a call is not translated.
|
||||
; The MLIR WeightedBranchOpInterface does not support this case.
|
||||
|
||||
; CHECK: llvm.func @fn()
|
||||
declare void @fn()
|
||||
declare i32 @fn()
|
||||
|
||||
; CHECK-LABEL: @call_branch_weights
|
||||
define void @call_branch_weights() {
|
||||
; CHECK: llvm.call @fn() {branch_weights = array<i32: 42>}
|
||||
call void @fn(), !prof !0
|
||||
ret void
|
||||
define i32 @call_branch_weights() {
|
||||
; CHECK: llvm.call @fn() : () -> i32
|
||||
%1 = call i32 @fn(), !prof !0
|
||||
ret i32 %1
|
||||
}
|
||||
|
||||
!0 = !{!"branch_weights", i32 42}
|
||||
|
@ -448,3 +448,19 @@ llvm.mlir.global external constant @const() {addr_space = 0 : i32, dso_local} :
|
||||
}
|
||||
|
||||
llvm.func extern_weak @extern_func()
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @invoke_branch_weights_callee()
|
||||
llvm.func @__gxx_personality_v0(...) -> i32
|
||||
|
||||
llvm.func @invoke_branch_weights() -> i32 attributes {personality = @__gxx_personality_v0} {
|
||||
%0 = llvm.mlir.constant(1 : i32) : i32
|
||||
// expected-error @below{{expects number of branch weights to match number of successors: 1 vs 2}}
|
||||
llvm.invoke @invoke_branch_weights_callee() to ^bb2 unwind ^bb1 {branch_weights = array<i32 : 42>} : () -> ()
|
||||
^bb1: // pred: ^bb0
|
||||
%1 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
|
||||
llvm.br ^bb2
|
||||
^bb2: // 2 preds: ^bb0, ^bb1
|
||||
llvm.return %0 : i32
|
||||
}
|
||||
|
@ -1906,32 +1906,6 @@ llvm.func @cond_br_weights(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 {
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @fn()
|
||||
|
||||
// CHECK-LABEL: @call_branch_weights
|
||||
llvm.func @call_branch_weights() {
|
||||
// CHECK: !prof ![[NODE:[0-9]+]]
|
||||
llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> ()
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK: ![[NODE]] = !{!"branch_weights", i32 42}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @fn() -> i32
|
||||
|
||||
// CHECK-LABEL: @call_branch_weights
|
||||
llvm.func @call_branch_weights() {
|
||||
// CHECK: !prof ![[NODE:[0-9]+]]
|
||||
%res = llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> i32
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// CHECK: ![[NODE]] = !{!"branch_weights", i32 42}
|
||||
|
||||
// -----
|
||||
|
||||
llvm.func @foo()
|
||||
llvm.func @__gxx_personality_v0(...) -> i32
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user