[mlir][Interfaces] Simplify and improve errors of RegionBranchOpInterface verifier (#174805)
Simplify the `RegionBranchOpInterface` verifier by utilizing new API functions such as `getAllRegionBranchPoints`. Also improve the error message by using the same terms that are used in the interface definition: `region branch point`, `region successor`, `successor operand`, `successor input`.
This commit is contained in:
parent
1fe27df5a9
commit
f33b42dc4e
@ -173,7 +173,7 @@ LogicalResult verifyRegionBranchWeights(Operation *op);
|
||||
|
||||
namespace detail {
|
||||
/// Verify that types match along control flow edges described the given op.
|
||||
LogicalResult verifyTypesAlongControlFlowEdges(Operation *op);
|
||||
LogicalResult verifyRegionBranchOpInterface(Operation *op);
|
||||
} // namespace detail
|
||||
|
||||
/// A mapping from successor operands to successor inputs.
|
||||
|
||||
@ -323,7 +323,7 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
|
||||
let verify = [{
|
||||
static_assert(!ConcreteOp::template hasTrait<OpTrait::ZeroRegions>(),
|
||||
"expected operation to have non-zero regions");
|
||||
return detail::verifyTypesAlongControlFlowEdges($_op);
|
||||
return detail::verifyRegionBranchOpInterface($_op);
|
||||
}];
|
||||
let verifyWithRegions = 1;
|
||||
|
||||
|
||||
@ -152,115 +152,68 @@ LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
|
||||
// RegionBranchOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag,
|
||||
RegionBranchPoint sourceNo,
|
||||
RegionSuccessor succRegionNo) {
|
||||
diag << "from ";
|
||||
if (Operation *op = sourceNo.getTerminatorPredecessorOrNull())
|
||||
diag << "Operation " << op->getName();
|
||||
else
|
||||
diag << "parent operands";
|
||||
/// Verify that types match along control flow edges described the given op.
|
||||
LogicalResult detail::verifyRegionBranchOpInterface(Operation *op) {
|
||||
auto regionInterface = cast<RegionBranchOpInterface>(op);
|
||||
|
||||
diag << " to ";
|
||||
if (Region *region = succRegionNo.getSuccessor())
|
||||
diag << "Region #" << region->getRegionNumber();
|
||||
else
|
||||
diag << "parent results";
|
||||
return diag;
|
||||
}
|
||||
|
||||
/// Verify that types match along all region control flow edges originating from
|
||||
/// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the
|
||||
/// types of the inputs that flow to a successor region.
|
||||
static LogicalResult
|
||||
verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp,
|
||||
RegionBranchPoint sourcePoint,
|
||||
function_ref<FailureOr<TypeRange>(RegionSuccessor)>
|
||||
getInputsTypesForRegion) {
|
||||
SmallVector<RegionSuccessor, 2> successors;
|
||||
branchOp.getSuccessorRegions(sourcePoint, successors);
|
||||
|
||||
for (RegionSuccessor &succ : successors) {
|
||||
FailureOr<TypeRange> sourceTypes = getInputsTypesForRegion(succ);
|
||||
if (failed(sourceTypes))
|
||||
return failure();
|
||||
|
||||
TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
|
||||
if (sourceTypes->size() != succInputsTypes.size()) {
|
||||
InFlightDiagnostic diag =
|
||||
branchOp->emitOpError("region control flow edge ");
|
||||
std::string succStr;
|
||||
llvm::raw_string_ostream os(succStr);
|
||||
os << succ;
|
||||
return printRegionEdgeName(diag, sourcePoint, succ)
|
||||
<< ": source has " << sourceTypes->size()
|
||||
<< " operands, but target successor " << os.str() << " needs "
|
||||
<< succInputsTypes.size();
|
||||
}
|
||||
|
||||
for (const auto &typesIdx :
|
||||
llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
|
||||
Type sourceType = std::get<0>(typesIdx.value());
|
||||
Type inputType = std::get<1>(typesIdx.value());
|
||||
|
||||
if (!branchOp.areTypesCompatible(sourceType, inputType)) {
|
||||
// Verify all control flow edges from region branch points to region
|
||||
// successors.
|
||||
SmallVector<RegionBranchPoint> regionBranchPoints =
|
||||
regionInterface.getAllRegionBranchPoints();
|
||||
for (const RegionBranchPoint &branchPoint : regionBranchPoints) {
|
||||
SmallVector<RegionSuccessor> successors;
|
||||
regionInterface.getSuccessorRegions(branchPoint, successors);
|
||||
for (const RegionSuccessor &successor : successors) {
|
||||
// Helper function that print the region branch point and the region
|
||||
// successor.
|
||||
auto emitRegionEdgeError = [&]() {
|
||||
InFlightDiagnostic diag =
|
||||
branchOp->emitOpError("along control flow edge ");
|
||||
return printRegionEdgeName(diag, sourcePoint, succ)
|
||||
<< ": source type #" << typesIdx.index() << " " << sourceType
|
||||
<< " should match input type #" << typesIdx.index() << " "
|
||||
<< inputType;
|
||||
regionInterface->emitOpError("along control flow edge from ");
|
||||
if (branchPoint.isParent()) {
|
||||
diag << "parent";
|
||||
diag.attachNote(op->getLoc()) << "region branch point";
|
||||
} else {
|
||||
diag << "Operation "
|
||||
<< branchPoint.getTerminatorPredecessorOrNull()->getName();
|
||||
diag.attachNote(
|
||||
branchPoint.getTerminatorPredecessorOrNull()->getLoc())
|
||||
<< "region branch point";
|
||||
}
|
||||
diag << " to ";
|
||||
if (Region *region = successor.getSuccessor()) {
|
||||
diag << "Region #" << region->getRegionNumber();
|
||||
} else {
|
||||
diag << "parent";
|
||||
}
|
||||
return diag;
|
||||
};
|
||||
|
||||
// Verify number of successor operands and successor inputs.
|
||||
OperandRange succOperands =
|
||||
regionInterface.getSuccessorOperands(branchPoint, successor);
|
||||
ValueRange succInputs = successor.getSuccessorInputs();
|
||||
if (succOperands.size() != succInputs.size()) {
|
||||
return emitRegionEdgeError()
|
||||
<< ": region branch point has " << succOperands.size()
|
||||
<< " operands, but region successor needs " << succInputs.size()
|
||||
<< " inputs";
|
||||
}
|
||||
|
||||
// Verify that the types are compatible.
|
||||
TypeRange succInputTypes = succInputs.getTypes();
|
||||
TypeRange succOperandTypes = succOperands.getTypes();
|
||||
for (const auto &typesIdx :
|
||||
llvm::enumerate(llvm::zip(succOperandTypes, succInputTypes))) {
|
||||
Type succOperandType = std::get<0>(typesIdx.value());
|
||||
Type succInputType = std::get<1>(typesIdx.value());
|
||||
if (!regionInterface.areTypesCompatible(succOperandType, succInputType))
|
||||
return emitRegionEdgeError()
|
||||
<< ": successor operand type #" << typesIdx.index() << " "
|
||||
<< succOperandType << " should match successor input type #"
|
||||
<< typesIdx.index() << " " << succInputType;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Verify that types match along control flow edges described the given op.
|
||||
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
|
||||
auto regionInterface = cast<RegionBranchOpInterface>(op);
|
||||
|
||||
auto inputTypesFromParent = [&](RegionSuccessor successor) -> TypeRange {
|
||||
return regionInterface.getEntrySuccessorOperands(successor).getTypes();
|
||||
};
|
||||
|
||||
// Verify types along control flow edges originating from the parent.
|
||||
if (failed(verifyTypesAlongAllEdges(
|
||||
regionInterface, RegionBranchPoint::parent(), inputTypesFromParent)))
|
||||
return failure();
|
||||
|
||||
// Verify types along control flow edges originating from each region.
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
// Collect all return-like terminators in the region.
|
||||
SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
|
||||
for (Block &block : region)
|
||||
if (!block.empty())
|
||||
if (auto terminator =
|
||||
dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
|
||||
regionReturnOps.push_back(terminator);
|
||||
|
||||
// If there is no return-like terminator, the op itself should verify
|
||||
// type consistency.
|
||||
if (regionReturnOps.empty())
|
||||
continue;
|
||||
|
||||
// Verify types along control flow edges originating from each return-like
|
||||
// terminator.
|
||||
for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) {
|
||||
|
||||
auto inputTypesForRegion =
|
||||
[&](RegionSuccessor successor) -> FailureOr<TypeRange> {
|
||||
OperandRange terminatorOperands =
|
||||
regionReturnOp.getSuccessorOperands(successor);
|
||||
return TypeRange(terminatorOperands.getTypes());
|
||||
};
|
||||
if (failed(verifyTypesAlongAllEdges(regionInterface, regionReturnOp,
|
||||
inputTypesForRegion)))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@ -525,11 +478,15 @@ SmallVector<RegionBranchPoint>
|
||||
RegionBranchOpInterface::getAllRegionBranchPoints() {
|
||||
SmallVector<RegionBranchPoint> branchPoints;
|
||||
branchPoints.push_back(RegionBranchPoint::parent());
|
||||
for (Region ®ion : getOperation()->getRegions())
|
||||
for (Block &block : region)
|
||||
for (Region ®ion : getOperation()->getRegions()) {
|
||||
for (Block &block : region) {
|
||||
if (block.empty())
|
||||
continue;
|
||||
if (auto terminator =
|
||||
dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
|
||||
branchPoints.push_back(RegionBranchPoint(terminator));
|
||||
}
|
||||
}
|
||||
return branchPoints;
|
||||
}
|
||||
|
||||
|
||||
@ -404,9 +404,10 @@ func.func @reduceReturn_not_inside_reduce(%arg0 : f32) {
|
||||
|
||||
func.func @std_if_incorrect_yield(%arg0: i1, %arg1: f32)
|
||||
{
|
||||
// expected-error@+1 {{region control flow edge from Operation scf.yield to parent results: source has 1 operands, but target successor <to parent> needs 2}}
|
||||
// expected-error@+1 {{along control flow edge from Operation scf.yield to parent: region branch point has 1 operands, but region successor needs 2 inputs}}
|
||||
%x, %y = scf.if %arg0 -> (f32, f32) {
|
||||
%0 = arith.addf %arg1, %arg1 : f32
|
||||
// expected-note@+1 {{region branch point}}
|
||||
scf.yield %0 : f32
|
||||
} else {
|
||||
%0 = arith.subf %arg1, %arg1 : f32
|
||||
@ -575,8 +576,9 @@ func.func @while_invalid_terminator() {
|
||||
|
||||
func.func @while_cross_region_type_mismatch() {
|
||||
%true = arith.constant true
|
||||
// expected-error@+1 {{region control flow edge from Operation scf.condition to Region #1: source has 0 operands, but target successor <to region #1 with 1 inputs> needs 1}}
|
||||
// expected-error@+1 {{along control flow edge from Operation scf.condition to Region #1: region branch point has 0 operands, but region successor needs 1 inputs}}
|
||||
scf.while : () -> () {
|
||||
// expected-note@+1 {{region branch point}}
|
||||
scf.condition(%true)
|
||||
} do {
|
||||
^bb0(%arg0: i32):
|
||||
@ -588,8 +590,9 @@ func.func @while_cross_region_type_mismatch() {
|
||||
|
||||
func.func @while_cross_region_type_mismatch() {
|
||||
%true = arith.constant true
|
||||
// expected-error@+1 {{along control flow edge from Operation scf.condition to Region #1: source type #0 'i1' should match input type #0 'i32'}}
|
||||
// expected-error@+1 {{along control flow edge from Operation scf.condition to Region #1: successor operand type #0 'i1' should match successor input type #0 'i32'}}
|
||||
%0 = scf.while : () -> (i1) {
|
||||
// expected-note@+1 {{region branch point}}
|
||||
scf.condition(%true) %true : i1
|
||||
} do {
|
||||
^bb0(%arg0: i32):
|
||||
@ -601,8 +604,9 @@ func.func @while_cross_region_type_mismatch() {
|
||||
|
||||
func.func @while_result_type_mismatch() {
|
||||
%true = arith.constant true
|
||||
// expected-error@+1 {{region control flow edge from Operation scf.condition to parent results: source has 1 operands, but target successor <to parent> needs 0}}
|
||||
// expected-error@+1 {{along control flow edge from Operation scf.condition to parent: region branch point has 1 operands, but region successor needs 0 inputs}}
|
||||
scf.while : () -> () {
|
||||
// expected-note@+1 {{region branch point}}
|
||||
scf.condition(%true) %true : i1
|
||||
} do {
|
||||
^bb0(%arg0: i1):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user