[mlir][Interfaces] Simplify and improve errors of RegionBranchOpInterface verifier (#174805)

Simplify the `RegionBranchOpInterface` verifier by utilizing new API
functions such as `getAllRegionBranchPoints`.

Also improve the error message by using the same terms that are used in
the interface definition: `region branch point`, `region successor`,
`successor operand`, `successor input`.
This commit is contained in:
Matthias Springer 2026-01-08 08:59:58 +01:00 committed by GitHub
parent 1fe27df5a9
commit f33b42dc4e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 73 additions and 112 deletions

View File

@ -173,7 +173,7 @@ LogicalResult verifyRegionBranchWeights(Operation *op);
namespace detail {
/// Verify that types match along control flow edges described the given op.
LogicalResult verifyTypesAlongControlFlowEdges(Operation *op);
LogicalResult verifyRegionBranchOpInterface(Operation *op);
} // namespace detail
/// A mapping from successor operands to successor inputs.

View File

@ -323,7 +323,7 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
let verify = [{
static_assert(!ConcreteOp::template hasTrait<OpTrait::ZeroRegions>(),
"expected operation to have non-zero regions");
return detail::verifyTypesAlongControlFlowEdges($_op);
return detail::verifyRegionBranchOpInterface($_op);
}];
let verifyWithRegions = 1;

View File

@ -152,115 +152,68 @@ LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
// RegionBranchOpInterface
//===----------------------------------------------------------------------===//
static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
RegionBranchPoint sourceNo,
RegionSuccessor succRegionNo) {
diag << "from ";
if (Operation *op = sourceNo.getTerminatorPredecessorOrNull())
diag << "Operation " << op->getName();
else
diag << "parent operands";
/// Verify that types match along control flow edges described the given op.
LogicalResult detail::verifyRegionBranchOpInterface(Operation *op) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
diag << " to ";
if (Region *region = succRegionNo.getSuccessor())
diag << "Region #" << region->getRegionNumber();
else
diag << "parent results";
return diag;
}
/// Verify that types match along all region control flow edges originating from
/// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
/// types of the inputs that flow to a successor region.
static LogicalResult
verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp,
RegionBranchPoint sourcePoint,
function_ref<FailureOr<TypeRange>(RegionSuccessor)>
getInputsTypesForRegion) {
SmallVector<RegionSuccessor, 2> successors;
branchOp.getSuccessorRegions(sourcePoint, successors);
for (RegionSuccessor &succ : successors) {
FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
if (failed(sourceTypes))
return failure();
TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
if (sourceTypes->size() != succInputsTypes.size()) {
InFlightDiagnostic diag =
branchOp->emitOpError("region control flow edge ");
std::string succStr;
llvm::raw_string_ostream os(succStr);
os << succ;
return printRegionEdgeName(diag, sourcePoint, succ)
<< ": source has " << sourceTypes->size()
<< " operands, but target successor " << os.str() << " needs "
<< succInputsTypes.size();
}
for (const auto &typesIdx :
llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
Type sourceType = std::get<0>(typesIdx.value());
Type inputType = std::get<1>(typesIdx.value());
if (!branchOp.areTypesCompatible(sourceType, inputType)) {
// Verify all control flow edges from region branch points to region
// successors.
SmallVector<RegionBranchPoint> regionBranchPoints =
regionInterface.getAllRegionBranchPoints();
for (const RegionBranchPoint &branchPoint : regionBranchPoints) {
SmallVector<RegionSuccessor> successors;
regionInterface.getSuccessorRegions(branchPoint, successors);
for (const RegionSuccessor &successor : successors) {
// Helper function that print the region branch point and the region
// successor.
auto emitRegionEdgeError = [&]() {
InFlightDiagnostic diag =
branchOp->emitOpError("along control flow edge ");
return printRegionEdgeName(diag, sourcePoint, succ)
<< ": source type #" << typesIdx.index() << " " << sourceType
<< " should match input type #" << typesIdx.index() << " "
<< inputType;
regionInterface->emitOpError("along control flow edge from ");
if (branchPoint.isParent()) {
diag << "parent";
diag.attachNote(op->getLoc()) << "region branch point";
} else {
diag << "Operation "
<< branchPoint.getTerminatorPredecessorOrNull()->getName();
diag.attachNote(
branchPoint.getTerminatorPredecessorOrNull()->getLoc())
<< "region branch point";
}
diag << " to ";
if (Region *region = successor.getSuccessor()) {
diag << "Region #" << region->getRegionNumber();
} else {
diag << "parent";
}
return diag;
};
// Verify number of successor operands and successor inputs.
OperandRange succOperands =
regionInterface.getSuccessorOperands(branchPoint, successor);
ValueRange succInputs = successor.getSuccessorInputs();
if (succOperands.size() != succInputs.size()) {
return emitRegionEdgeError()
<< ": region branch point has " << succOperands.size()
<< " operands, but region successor needs " << succInputs.size()
<< " inputs";
}
// Verify that the types are compatible.
TypeRange succInputTypes = succInputs.getTypes();
TypeRange succOperandTypes = succOperands.getTypes();
for (const auto &typesIdx :
llvm::enumerate(llvm::zip(succOperandTypes, succInputTypes))) {
Type succOperandType = std::get<0>(typesIdx.value());
Type succInputType = std::get<1>(typesIdx.value());
if (!regionInterface.areTypesCompatible(succOperandType, succInputType))
return emitRegionEdgeError()
<< ": successor operand type #" << typesIdx.index() << " "
<< succOperandType << " should match successor input type #"
<< typesIdx.index() << " " << succInputType;
}
}
}
return success();
}
/// Verify that types match along control flow edges described the given op.
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
auto inputTypesFromParent = [&](RegionSuccessor successor) -> TypeRange {
return regionInterface.getEntrySuccessorOperands(successor).getTypes();
};
// Verify types along control flow edges originating from the parent.
if (failed(verifyTypesAlongAllEdges(
regionInterface, RegionBranchPoint::parent(), inputTypesFromParent)))
return failure();
// Verify types along control flow edges originating from each region.
for (Region &region : op->getRegions()) {
// Collect all return-like terminators in the region.
SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
for (Block &block : region)
if (!block.empty())
if (auto terminator =
dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
regionReturnOps.push_back(terminator);
// If there is no return-like terminator, the op itself should verify
// type consistency.
if (regionReturnOps.empty())
continue;
// Verify types along control flow edges originating from each return-like
// terminator.
for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
auto inputTypesForRegion =
[&](RegionSuccessor successor) -> FailureOr<TypeRange> {
OperandRange terminatorOperands =
regionReturnOp.getSuccessorOperands(successor);
return TypeRange(terminatorOperands.getTypes());
};
if (failed(verifyTypesAlongAllEdges(regionInterface, regionReturnOp,
inputTypesForRegion)))
return failure();
}
}
return success();
}
@ -525,11 +478,15 @@ SmallVector<RegionBranchPoint>
RegionBranchOpInterface::getAllRegionBranchPoints() {
SmallVector<RegionBranchPoint> branchPoints;
branchPoints.push_back(RegionBranchPoint::parent());
for (Region &region : getOperation()->getRegions())
for (Block &block : region)
for (Region &region : getOperation()->getRegions()) {
for (Block &block : region) {
if (block.empty())
continue;
if (auto terminator =
dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
branchPoints.push_back(RegionBranchPoint(terminator));
}
}
return branchPoints;
}

View File

@ -404,9 +404,10 @@ func.func @reduceReturn_not_inside_reduce(%arg0 : f32) {
func.func @std_if_incorrect_yield(%arg0: i1, %arg1: f32)
{
// expected-error@+1 {{region control flow edge from Operation scf.yield to parent results: source has 1 operands, but target successor <to parent> needs 2}}
// expected-error@+1 {{along control flow edge from Operation scf.yield to parent: region branch point has 1 operands, but region successor needs 2 inputs}}
%x, %y = scf.if %arg0 -> (f32, f32) {
%0 = arith.addf %arg1, %arg1 : f32
// expected-note@+1 {{region branch point}}
scf.yield %0 : f32
} else {
%0 = arith.subf %arg1, %arg1 : f32
@ -575,8 +576,9 @@ func.func @while_invalid_terminator() {
func.func @while_cross_region_type_mismatch() {
%true = arith.constant true
// expected-error@+1 {{region control flow edge from Operation scf.condition to Region #1: source has 0 operands, but target successor <to region #1 with 1 inputs> needs 1}}
// expected-error@+1 {{along control flow edge from Operation scf.condition to Region #1: region branch point has 0 operands, but region successor needs 1 inputs}}
scf.while : () -> () {
// expected-note@+1 {{region branch point}}
scf.condition(%true)
} do {
^bb0(%arg0: i32):
@ -588,8 +590,9 @@ func.func @while_cross_region_type_mismatch() {
func.func @while_cross_region_type_mismatch() {
%true = arith.constant true
// expected-error@+1 {{along control flow edge from Operation scf.condition to Region #1: source type #0 'i1' should match input type #0 'i32'}}
// expected-error@+1 {{along control flow edge from Operation scf.condition to Region #1: successor operand type #0 'i1' should match successor input type #0 'i32'}}
%0 = scf.while : () -> (i1) {
// expected-note@+1 {{region branch point}}
scf.condition(%true) %true : i1
} do {
^bb0(%arg0: i32):
@ -601,8 +604,9 @@ func.func @while_cross_region_type_mismatch() {
func.func @while_result_type_mismatch() {
%true = arith.constant true
// expected-error@+1 {{region control flow edge from Operation scf.condition to parent results: source has 1 operands, but target successor <to parent> needs 0}}
// expected-error@+1 {{along control flow edge from Operation scf.condition to parent: region branch point has 1 operands, but region successor needs 0 inputs}}
scf.while : () -> () {
// expected-note@+1 {{region branch point}}
scf.condition(%true) %true : i1
} do {
^bb0(%arg0: i1):