[mlir][flang] Added Weighted[Region]BranchOpInterface's. (#142079)

The new interfaces provide getters and setters for the weight
information about the branches of BranchOpInterface and
RegionBranchOpInterface operations.

These interfaces are done the same way as LLVM dialect's
BranchWeightOpInterface.

The plan is to produce this information in Flang, e.g. mark
most probably "cold" code as such and allow LLVM to order
basic blocks accordingly. An example of such a code is
copy loops generated for arrays repacking - we can mark it
as "cold" assuming that the copy will not happen dynamically.
If the copy actually happens the overhead of the copy is probably high
enough so that we may not care about the little overhead
of jumping to the "cold" code and fetching it.
This commit is contained in:
Slava Zakharin 2025-06-17 16:14:13 -07:00 committed by GitHub
parent af65cb68f5
commit 70343c8d44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 461 additions and 135 deletions

View File

@ -2323,9 +2323,13 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
}];
}
def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getRegionInvocationBounds", "getEntrySuccessorRegions"]>, RecursiveMemoryEffects,
NoRegionArguments]> {
def fir_IfOp
: region_Op<
"if", [DeclareOpInterfaceMethods<
RegionBranchOpInterface, ["getRegionInvocationBounds",
"getEntrySuccessorRegions"]>,
RecursiveMemoryEffects, NoRegionArguments,
WeightedRegionBranchOpInterface]> {
let summary = "if-then-else conditional operation";
let description = [{
Used to conditionally execute operations. This operation is the FIR
@ -2342,7 +2346,8 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
```
}];
let arguments = (ins I1:$condition);
let arguments = (ins I1:$condition,
OptionalAttr<DenseI32ArrayAttr>:$region_weights);
let results = (outs Variadic<AnyType>:$results);
let regions = (region
@ -2371,6 +2376,11 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
void resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,
unsigned resultNum);
/// Returns the display name string for the region_weights attribute.
static constexpr llvm::StringRef getWeightsAttrAssemblyName() {
return "weights";
}
}];
}

View File

@ -4418,6 +4418,19 @@ mlir::ParseResult fir::IfOp::parse(mlir::OpAsmParser &parser,
parser.resolveOperand(cond, i1Type, result.operands))
return mlir::failure();
if (mlir::succeeded(
parser.parseOptionalKeyword(getWeightsAttrAssemblyName()))) {
if (parser.parseLParen())
return mlir::failure();
mlir::DenseI32ArrayAttr weights;
if (parser.parseCustomAttributeWithFallback(weights, mlir::Type{}))
return mlir::failure();
if (weights)
result.addAttribute(getRegionWeightsAttrName(result.name), weights);
if (parser.parseRParen())
return mlir::failure();
}
if (parser.parseOptionalArrowTypeList(result.types))
return mlir::failure();
@ -4449,6 +4462,11 @@ llvm::LogicalResult fir::IfOp::verify() {
void fir::IfOp::print(mlir::OpAsmPrinter &p) {
bool printBlockTerminators = false;
p << ' ' << getCondition();
if (auto weights = getRegionWeightsAttr()) {
p << ' ' << getWeightsAttrAssemblyName() << '(';
p.printStrippedAttrOrType(weights);
p << ')';
}
if (!getResults().empty()) {
p << " -> (" << getResultTypes() << ')';
printBlockTerminators = true;
@ -4464,7 +4482,8 @@ void fir::IfOp::print(mlir::OpAsmPrinter &p) {
p.printRegion(otherReg, /*printEntryBlockArgs=*/false,
printBlockTerminators);
}
p.printOptionalAttrDict((*this)->getAttrs());
p.printOptionalAttrDict((*this)->getAttrs(),
/*elideAttrs=*/{getRegionWeightsAttrName()});
}
void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,

View File

@ -212,9 +212,12 @@ public:
}
rewriter.setInsertionPointToEnd(condBlock);
rewriter.create<mlir::cf::CondBranchOp>(
auto branchOp = rewriter.create<mlir::cf::CondBranchOp>(
loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
otherwiseBlock, llvm::ArrayRef<mlir::Value>());
llvm::ArrayRef<int32_t> weights = ifOp.getWeights();
if (!weights.empty())
branchOp.setWeights(weights);
rewriter.replaceOp(ifOp, continueBlock->getArguments());
return success();
}

View File

@ -0,0 +1,46 @@
// RUN: fir-opt --split-input-file --cfg-conversion %s | FileCheck %s
func.func private @callee() -> none
// CHECK-LABEL: func.func @if_then(
// CHECK-SAME: %[[ARG0:.*]]: i1) {
// CHECK: cf.cond_br %[[ARG0]] weights([10, 90]), ^bb1, ^bb2
// CHECK: ^bb1:
// CHECK: %[[VAL_0:.*]] = fir.call @callee() : () -> none
// CHECK: cf.br ^bb2
// CHECK: ^bb2:
// CHECK: return
// CHECK: }
func.func @if_then(%cond: i1) {
fir.if %cond weights([10, 90]) {
fir.call @callee() : () -> none
}
return
}
// -----
// CHECK-LABEL: func.func @if_then_else(
// CHECK-SAME: %[[ARG0:.*]]: i1) -> i32 {
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32
// CHECK: cf.cond_br %[[ARG0]] weights([90, 10]), ^bb1, ^bb2
// CHECK: ^bb1:
// CHECK: cf.br ^bb3(%[[VAL_0]] : i32)
// CHECK: ^bb2:
// CHECK: cf.br ^bb3(%[[VAL_1]] : i32)
// CHECK: ^bb3(%[[VAL_2:.*]]: i32):
// CHECK: cf.br ^bb4
// CHECK: ^bb4:
// CHECK: return %[[VAL_2]] : i32
// CHECK: }
func.func @if_then_else(%cond: i1) -> i32 {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%result = fir.if %cond weights([90, 10]) -> i32 {
fir.result %c0 : i32
} else {
fir.result %c1 : i32
}
return %result : i32
}

View File

@ -1015,3 +1015,19 @@ func.func @test_box_total_elements(%arg0: !fir.class<!fir.type<sometype{i:i32}>>
%6 = arith.addi %2, %5 : index
return %6 : index
}
// CHECK-LABEL: func.func @test_if_weights(
// CHECK-SAME: %[[ARG0:.*]]: i1) {
func.func @test_if_weights(%cond: i1) {
// CHECK: fir.if %[[ARG0]] weights([99, 1]) {
// CHECK: }
fir.if %cond weights([99, 1]) {
}
// CHECK: fir.if %[[ARG0]] weights([99, 1]) {
// CHECK: } else {
// CHECK: }
fir.if %cond weights ([99,1]) {
} else {
}
return
}

View File

@ -1393,3 +1393,31 @@ fir.local {type = local_init} @x.localizer : f32 init {
^bb0(%arg0: f32, %arg1: f32):
fir.yield(%arg0 : f32)
}
// -----
func.func @wrong_weights_number_in_if_then(%cond: i1) {
// expected-error @below {{expects number of region weights to match number of regions: 1 vs 2}}
fir.if %cond weights([50]) {
}
return
}
// -----
func.func @wrong_weights_number_in_if_then_else(%cond: i1) {
// expected-error @below {{expects number of region weights to match number of regions: 3 vs 2}}
fir.if %cond weights([50, 40, 10]) {
} else {
}
return
}
// -----
func.func @negative_weight_in_if_then(%cond: i1) {
// expected-error @below {{weight #0 must be non-negative}}
fir.if %cond weights([-1, 101]) {
}
return
}

View File

@ -112,10 +112,11 @@ def BranchOp : CF_Op<"br", [
// CondBranchOp
//===----------------------------------------------------------------------===//
def CondBranchOp : CF_Op<"cond_br",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
Pure, Terminator]> {
def CondBranchOp
: CF_Op<"cond_br", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<
BranchOpInterface, ["getSuccessorForOperands"]>,
WeightedBranchOpInterface, Pure, Terminator]> {
let summary = "Conditional branch operation";
let description = [{
The `cf.cond_br` terminator operation represents a conditional branch on a
@ -144,20 +145,23 @@ def CondBranchOp : CF_Op<"cond_br",
```
}];
let arguments = (ins I1:$condition,
Variadic<AnyType>:$trueDestOperands,
Variadic<AnyType>:$falseDestOperands);
let arguments = (ins I1:$condition, Variadic<AnyType>:$trueDestOperands,
Variadic<AnyType>:$falseDestOperands,
OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
let builders = [
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"ValueRange":$trueOperands, "Block *":$falseDest,
"ValueRange":$falseOperands), [{
build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
let builders = [OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"ValueRange":$trueOperands,
"Block *":$falseDest,
"ValueRange":$falseOperands),
[{
build($_builder, $_state, condition, trueOperands, falseOperands, /*branch_weights=*/{}, trueDest,
falseDest);
}]>,
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
"Block *":$falseDest,
CArg<"ValueRange", "{}">:$falseOperands),
[{
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
falseOperands);
}]>];
@ -216,7 +220,7 @@ def CondBranchOp : CF_Op<"cond_br",
let hasCanonicalizer = 1;
let assemblyFormat = [{
$condition `,`
$condition (`weights` `(` $branch_weights^ `)` )? `,`
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
attr-dict

View File

@ -168,42 +168,6 @@ def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
];
}
def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
let description = [{
An interface for operations that can carry branch weights metadata. It
provides setters and getters for the operation's branch weights attribute.
The default implementation of the interface methods expect the operation to
have an attribute of type DenseI32ArrayAttr named branch_weights.
}];
let cppNamespace = "::mlir::LLVM";
let methods = [
InterfaceMethod<
/*desc=*/ "Returns the branch weights attribute or nullptr",
/*returnType=*/ "::mlir::DenseI32ArrayAttr",
/*methodName=*/ "getBranchWeightsOrNull",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getBranchWeightsAttr();
}]
>,
InterfaceMethod<
/*desc=*/ "Sets the branch weights attribute",
/*returnType=*/ "void",
/*methodName=*/ "setBranchWeights",
/*args=*/ (ins "::mlir::DenseI32ArrayAttr":$attr),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
op.setBranchWeightsAttr(attr);
}]
>
];
}
def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
let description = [{
An interface for memory operations that can carry access groups metadata.

View File

@ -660,12 +660,12 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "FPTrunc",
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
// Call-related operations.
def LLVM_InvokeOp : LLVM_Op<"invoke", [
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
Terminator]> {
def LLVM_InvokeOp
: LLVM_Op<"invoke", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
Terminator]> {
let arguments = (ins
OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
OptionalAttr<FlatSymbolRefAttr>:$callee,
@ -734,12 +734,12 @@ def LLVM_VaArgOp : LLVM_Op<"va_arg"> {
// CallOp
//===----------------------------------------------------------------------===//
def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
def LLVM_CallOp
: LLVM_MemAccessOpBase<
"call", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
DeclareOpInterfaceMethods<CallOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = "Call to an LLVM function.";
let description = [{
In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect
@ -788,21 +788,16 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>:$callee_operands,
DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags,
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
DefaultValuedAttr<CConv, "CConv::C">:$CConv,
DefaultValuedAttr<TailCallKind, "TailCallKind::None">:$TailCallKind,
OptionalAttr<LLVM_MemoryEffectsAttr>:$memory_effects,
UnitAttr:$convergent,
UnitAttr:$no_unwind,
UnitAttr:$will_return,
UnitAttr:$convergent, UnitAttr:$no_unwind, UnitAttr:$will_return,
VariadicOfVariadic<LLVM_Type, "op_bundle_sizes">:$op_bundle_operands,
DenseI32ArrayAttr:$op_bundle_sizes,
OptionalAttr<ArrayAttr>:$op_bundle_tags,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs,
UnitAttr:$no_inline,
UnitAttr:$always_inline,
UnitAttr:$inline_hint);
OptionalAttr<DictArrayAttr>:$res_attrs, UnitAttr:$no_inline,
UnitAttr:$always_inline, UnitAttr:$inline_hint);
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
let arguments = !con(args, aliasAttrs);
let results = (outs Optional<LLVM_Type>:$result);
@ -1047,11 +1042,12 @@ def LLVM_BrOp : LLVM_TerminatorOp<"br",
LLVM_TerminatorPassthroughOpBuilder
];
}
def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
Pure]> {
def LLVM_CondBrOp
: LLVM_TerminatorOp<
"cond_br", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface>,
DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
Pure]> {
let arguments = (ins I1:$condition,
Variadic<LLVM_Type>:$trueDestOperands,
Variadic<LLVM_Type>:$falseDestOperands,
@ -1136,11 +1132,12 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> {
}];
}
def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface>,
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
Pure]> {
def LLVM_SwitchOp
: LLVM_TerminatorOp<
"switch", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface>,
DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
Pure]> {
let arguments = (ins
AnySignlessInteger:$value,
Variadic<AnyType>:$defaultOperands,

View File

@ -142,6 +142,26 @@ LogicalResult verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
const SuccessorOperands &operands);
} // namespace detail
//===----------------------------------------------------------------------===//
// WeightedBranchOpInterface
//===----------------------------------------------------------------------===//
namespace detail {
/// Verify that the branch weights attached to an operation
/// implementing WeightedBranchOpInterface are correct.
LogicalResult verifyBranchWeights(Operation *op);
} // namespace detail
//===----------------------------------------------------------------------===//
// WeightedRegiobBranchOpInterface
//===----------------------------------------------------------------------===//
namespace detail {
/// Verify that the region weights attached to an operation
/// implementing WeightedRegiobBranchOpInterface are correct.
LogicalResult verifyRegionBranchWeights(Operation *op);
} // namespace detail
//===----------------------------------------------------------------------===//
// RegionBranchOpInterface
//===----------------------------------------------------------------------===//

View File

@ -375,6 +375,118 @@ def SelectLikeOpInterface : OpInterface<"SelectLikeOpInterface"> {
];
}
//===----------------------------------------------------------------------===//
// WeightedBranchOpInterface
//===----------------------------------------------------------------------===//
def WeightedBranchOpInterface : OpInterface<"WeightedBranchOpInterface"> {
let description = [{
This interface provides weight information for branching terminator
operations, i.e. terminator operations with successors.
This interface provides methods for getting/setting integer non-negative
weight of each branch. The probability of executing a branch
is computed as the ratio between the branch's weight and the total
sum of the weights (which cannot be zero).
The weights are optional. If they are provided, then their number
must match the number of successors of the operation.
The default implementations of the methods expect the operation
to have an attribute of type DenseI32ArrayAttr named branch_weights.
}];
let cppNamespace = "::mlir";
let methods = [InterfaceMethod<
/*desc=*/"Returns the branch weights",
/*returnType=*/"::llvm::ArrayRef<int32_t>",
/*methodName=*/"getWeights",
/*args=*/(ins),
/*methodBody=*/[{}],
/*defaultImpl=*/[{
auto op = cast<ConcreteOp>(this->getOperation());
if (auto attr = op.getBranchWeightsAttr())
return attr.asArrayRef();
return {};
}]>,
InterfaceMethod<
/*desc=*/"Sets the branch weights",
/*returnType=*/"void",
/*methodName=*/"setWeights",
/*args=*/(ins "::llvm::ArrayRef<int32_t>":$weights),
/*methodBody=*/[{}],
/*defaultImpl=*/[{
auto op = cast<ConcreteOp>(this->getOperation());
op.setBranchWeightsAttr(::mlir::DenseI32ArrayAttr::get(op->getContext(), weights));
}]>,
];
let verify = [{
return ::mlir::detail::verifyBranchWeights($_op);
}];
}
//===----------------------------------------------------------------------===//
// WeightedRegionBranchOpInterface
//===----------------------------------------------------------------------===//
// TODO: the probabilities of entering a particular region seem
// to correlate with the values returned by
// RegionBranchOpInterface::invocationBounds(), and we should probably
// verify that the values are consistent. In that case, should
// WeightedRegionBranchOpInterface extend RegionBranchOpInterface?
def WeightedRegionBranchOpInterface
: OpInterface<"WeightedRegionBranchOpInterface"> {
let description = [{
This interface provides weight information for region operations
that exhibit branching behavior between held regions.
This interface provides methods for getting/setting integer non-negative
weight of each branch. The probability of executing a region is computed
as the ratio between the region branch's weight and the total sum
of the weights (which cannot be zero).
The weights are optional. If they are provided, then their number
must match the number of regions held by the operation
(including empty regions).
The weights specify the probability of branching to a particular
region when first executing the operation.
For example, for loop-like operations with a single region
the weight specifies the probability of entering the loop.
The default implementations of the methods expect the operation
to have an attribute of type DenseI32ArrayAttr named branch_weights.
}];
let cppNamespace = "::mlir";
let methods = [InterfaceMethod<
/*desc=*/"Returns the region weights",
/*returnType=*/"::llvm::ArrayRef<int32_t>",
/*methodName=*/"getWeights",
/*args=*/(ins),
/*methodBody=*/[{}],
/*defaultImpl=*/[{
auto op = cast<ConcreteOp>(this->getOperation());
if (auto attr = op.getRegionWeightsAttr())
return attr.asArrayRef();
return {};
}]>,
InterfaceMethod<
/*desc=*/"Sets the region weights",
/*returnType=*/"void",
/*methodName=*/"setWeights",
/*args=*/(ins "::llvm::ArrayRef<int32_t>":$weights),
/*methodBody=*/[{}],
/*defaultImpl=*/[{
auto op = cast<ConcreteOp>(this->getOperation());
op.setRegionWeightsAttr(::mlir::DenseI32ArrayAttr::get(op->getContext(), weights));
}]>,
];
let verify = [{
return ::mlir::detail::verifyRegionBranchWeights($_op);
}];
}
//===----------------------------------------------------------------------===//
// ControlFlow Traits
//===----------------------------------------------------------------------===//

View File

@ -189,7 +189,7 @@ public:
llvm::Instruction *inst);
/// Sets LLVM profiling metadata for operations that have branch weights.
void setBranchWeightsMetadata(BranchWeightOpInterface op);
void setBranchWeightsMetadata(WeightedBranchOpInterface op);
/// Sets LLVM loop metadata for branch operations that have a loop annotation
/// attribute.

View File

@ -166,10 +166,15 @@ struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
TypeRange(adaptor.getFalseDestOperands()));
if (failed(convertedFalseBlock))
return failure();
Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
op, adaptor.getCondition(), *convertedTrueBlock,
adaptor.getTrueDestOperands(), *convertedFalseBlock,
adaptor.getFalseDestOperands());
ArrayRef<int32_t> weights = op.getWeights();
if (!weights.empty()) {
newOp.setWeights(weights);
op.removeBranchWeightsAttr();
}
// TODO: We should not just forward all attributes like that. But there are
// existing Flang tests that depend on this behavior.
newOp->setAttrs(op->getAttrDictionary());

View File

@ -589,10 +589,6 @@ LogicalResult SwitchOp::verify() {
static_cast<int64_t>(getCaseDestinations().size())))
return emitOpError("expects number of case values to match number of "
"case destinations");
if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors())
return emitError("expects number of branch weights to match number of "
"successors: ")
<< getBranchWeights()->size() << " vs " << getNumSuccessors();
if (getCaseValues() &&
getValue().getType() != getCaseValues()->getElementType())
return emitError("expects case value type to match condition value type");
@ -962,7 +958,6 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
assert(callee && "expected non-null callee in direct call builder");
build(builder, state, results,
/*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
/*branch_weights=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
/*memory_effects=*/nullptr,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
@ -992,7 +987,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType), callee, args,
/*fastmathFlags=*/nullptr,
/*branch_weights=*/nullptr, /*CConv=*/nullptr,
/*CConv=*/nullptr,
/*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
/*convergent=*/nullptr,
/*no_unwind=*/nullptr, /*will_return=*/nullptr,
@ -1009,7 +1004,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType),
/*callee=*/nullptr, args,
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
/*fastmathFlags=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
@ -1025,7 +1020,7 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
auto calleeType = func.getFunctionType();
build(builder, state, getCallOpResultTypes(calleeType),
getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
/*fastmathFlags=*/nullptr,
/*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
/*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},

View File

@ -9,6 +9,7 @@
#include <utility>
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/SmallPtrSet.h"
@ -80,6 +81,51 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
return success();
}
//===----------------------------------------------------------------------===//
// WeightedBranchOpInterface
//===----------------------------------------------------------------------===//
static LogicalResult verifyWeights(Operation *op,
llvm::ArrayRef<int32_t> weights,
std::size_t expectedWeightsNum,
llvm::StringRef weightAnchorName,
llvm::StringRef weightRefName) {
if (weights.empty())
return success();
if (weights.size() != expectedWeightsNum)
return op->emitError() << "expects number of " << weightAnchorName
<< " weights to match number of " << weightRefName
<< ": " << weights.size() << " vs "
<< expectedWeightsNum;
for (auto [index, weight] : llvm::enumerate(weights))
if (weight < 0)
return op->emitError() << "weight #" << index << " must be non-negative";
if (llvm::all_of(weights, [](int32_t value) { return value == 0; }))
return op->emitError() << "branch weights cannot all be zero";
return success();
}
LogicalResult detail::verifyBranchWeights(Operation *op) {
llvm::ArrayRef<int32_t> weights =
cast<WeightedBranchOpInterface>(op).getWeights();
return verifyWeights(op, weights, op->getNumSuccessors(), "branch",
"successors");
}
//===----------------------------------------------------------------------===//
// WeightedRegionBranchOpInterface
//===----------------------------------------------------------------------===//
LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
llvm::ArrayRef<int32_t> weights =
cast<WeightedRegionBranchOpInterface>(op).getWeights();
return verifyWeights(op, weights, op->getNumRegions(), "region", "regions");
}
//===----------------------------------------------------------------------===//
// RegionBranchOpInterface
//===----------------------------------------------------------------------===//

View File

@ -146,8 +146,15 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
branchWeights.push_back(branchWeight->getZExtValue());
}
if (auto iface = dyn_cast<BranchWeightOpInterface>(op)) {
iface.setBranchWeights(builder.getDenseI32ArrayAttr(branchWeights));
if (auto iface = dyn_cast<WeightedBranchOpInterface>(op)) {
// LLVM allows attaching a single weight to call instructions.
// This is used for carrying the execution count information
// in PGO modes. MLIR WeightedBranchOpInterface does not allow this,
// so we drop the metadata in this case.
// LLVM should probably use the VP form of MD_prof metadata
// for such cases.
if (op->getNumSuccessors() != 0)
iface.setWeights(branchWeights);
return success();
}
return failure();

View File

@ -1055,7 +1055,7 @@ LogicalResult ModuleTranslation::convertBlockImpl(Block &bb,
return failure();
// Set the branch weight metadata on the translated instruction.
if (auto iface = dyn_cast<BranchWeightOpInterface>(op))
if (auto iface = dyn_cast<WeightedBranchOpInterface>(op))
setBranchWeightsMetadata(iface);
}
@ -2026,14 +2026,15 @@ void ModuleTranslation::setDereferenceableMetadata(
inst->setMetadata(kindId, derefSizeNode);
}
void ModuleTranslation::setBranchWeightsMetadata(BranchWeightOpInterface op) {
DenseI32ArrayAttr weightsAttr = op.getBranchWeightsOrNull();
if (!weightsAttr)
void ModuleTranslation::setBranchWeightsMetadata(WeightedBranchOpInterface op) {
SmallVector<uint32_t> weights;
llvm::transform(op.getWeights(), std::back_inserter(weights),
[](int32_t value) { return static_cast<uint32_t>(value); });
if (weights.empty())
return;
llvm::Instruction *inst = isa<CallOp>(op) ? lookupCall(op) : lookupBranch(op);
assert(inst && "expected the operation to have a mapping to an instruction");
SmallVector<uint32_t> weights(weightsAttr.asArrayRef());
inst->setMetadata(
llvm::LLVMContext::MD_prof,
llvm::MDBuilder(getLLVMContext()).createBranchWeights(weights));

View File

@ -67,3 +67,17 @@ func.func @unreachable_block() {
^bb1(%arg0: index):
cf.br ^bb1(%arg0 : index)
}
// -----
// Test case for cf.cond_br with weights.
// CHECK-LABEL: func.func @cf_cond_br_with_weights(
func.func @cf_cond_br_with_weights(%cond: i1, %a: index, %b: index) -> index {
// CHECK: llvm.cond_br %{{.*}} weights([90, 10]), ^bb1(%{{.*}} : i64), ^bb2(%{{.*}} : i64)
cf.cond_br %cond, ^bb1(%a : index), ^bb2(%b : index) {branch_weights = array<i32: 90, 10>}
^bb1(%arg1: index):
return %arg1 : index
^bb2(%arg2: index):
return %arg2 : index
}

View File

@ -67,3 +67,39 @@ func.func @switch_missing_default(%flag : i32, %caseOperand : i32) {
^bb3(%bb3arg : i32):
return
}
// -----
// CHECK-LABEL: func @wrong_weights_number
func.func @wrong_weights_number(%cond: i1) {
// expected-error@+1 {{expects number of branch weights to match number of successors: 1 vs 2}}
cf.cond_br %cond weights([100]), ^bb1, ^bb2
^bb1:
return
^bb2:
return
}
// -----
// CHECK-LABEL: func @negative_weight
func.func @wrong_total_weight(%cond: i1) {
// expected-error@+1 {{weight #0 must be non-negative}}
cf.cond_br %cond weights([-1, 101]), ^bb1, ^bb2
^bb1:
return
^bb2:
return
}
// -----
// CHECK-LABEL: func @zero_weights
func.func @wrong_total_weight(%cond: i1) {
// expected-error@+1 {{branch weights cannot all be zero}}
cf.cond_br %cond weights([0, 0]), ^bb1, ^bb2
^bb1:
return
^bb2:
return
}

View File

@ -51,3 +51,13 @@ func.func @switch_result_number(%arg0: i32) {
^bb2:
return
}
// CHECK-LABEL: func @cond_weights
func.func @cond_weights(%cond: i1) {
// CHECK: cf.cond_br %{{.*}} weights([60, 40]), ^{{.*}}, ^{{.*}}
cf.cond_br %cond weights([60, 40]), ^bb1, ^bb2
^bb1:
return
^bb2:
return
}

View File

@ -36,14 +36,17 @@ bbd:
; // -----
; Verify that a single weight attached to a call is not translated.
; The MLIR WeightedBranchOpInterface does not support this case.
; CHECK: llvm.func @fn()
declare void @fn()
declare i32 @fn()
; CHECK-LABEL: @call_branch_weights
define void @call_branch_weights() {
; CHECK: llvm.call @fn() {branch_weights = array<i32: 42>}
call void @fn(), !prof !0
ret void
define i32 @call_branch_weights() {
; CHECK: llvm.call @fn() : () -> i32
%1 = call i32 @fn(), !prof !0
ret i32 %1
}
!0 = !{!"branch_weights", i32 42}

View File

@ -448,3 +448,19 @@ llvm.mlir.global external constant @const() {addr_space = 0 : i32, dso_local} :
}
llvm.func extern_weak @extern_func()
// -----
llvm.func @invoke_branch_weights_callee()
llvm.func @__gxx_personality_v0(...) -> i32
llvm.func @invoke_branch_weights() -> i32 attributes {personality = @__gxx_personality_v0} {
%0 = llvm.mlir.constant(1 : i32) : i32
// expected-error @below{{expects number of branch weights to match number of successors: 1 vs 2}}
llvm.invoke @invoke_branch_weights_callee() to ^bb2 unwind ^bb1 {branch_weights = array<i32 : 42>} : () -> ()
^bb1: // pred: ^bb0
%1 = llvm.landingpad cleanup : !llvm.struct<(ptr, i32)>
llvm.br ^bb2
^bb2: // 2 preds: ^bb0, ^bb1
llvm.return %0 : i32
}

View File

@ -1906,32 +1906,6 @@ llvm.func @cond_br_weights(%cond : i1, %arg0 : i32, %arg1 : i32) -> i32 {
// -----
llvm.func @fn()
// CHECK-LABEL: @call_branch_weights
llvm.func @call_branch_weights() {
// CHECK: !prof ![[NODE:[0-9]+]]
llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> ()
llvm.return
}
// CHECK: ![[NODE]] = !{!"branch_weights", i32 42}
// -----
llvm.func @fn() -> i32
// CHECK-LABEL: @call_branch_weights
llvm.func @call_branch_weights() {
// CHECK: !prof ![[NODE:[0-9]+]]
%res = llvm.call @fn() {branch_weights = array<i32 : 42>} : () -> i32
llvm.return
}
// CHECK: ![[NODE]] = !{!"branch_weights", i32 42}
// -----
llvm.func @foo()
llvm.func @__gxx_personality_v0(...) -> i32