diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h index 566f4b8fadb5..b76c2891fad5 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -173,7 +173,7 @@ LogicalResult verifyRegionBranchWeights(Operation *op); namespace detail { /// Verify that types match along control flow edges described the given op. -LogicalResult verifyTypesAlongControlFlowEdges(Operation *op); +LogicalResult verifyRegionBranchOpInterface(Operation *op); } // namespace detail /// A mapping from successor operands to successor inputs. diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td index 2e654ba04ffe..ecad424e30c7 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -323,7 +323,7 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { let verify = [{ static_assert(!ConcreteOp::template hasTrait(), "expected operation to have non-zero regions"); - return detail::verifyTypesAlongControlFlowEdges($_op); + return detail::verifyRegionBranchOpInterface($_op); }]; let verifyWithRegions = 1; diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp index d393ddb8d833..2574f4e73d31 100644 --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -152,115 +152,68 @@ LogicalResult detail::verifyRegionBranchWeights(Operation *op) { // RegionBranchOpInterface //===----------------------------------------------------------------------===// -static InFlightDiagnostic &printRegionEdgeName(InFlightDiagnostic &diag, - RegionBranchPoint sourceNo, - RegionSuccessor succRegionNo) { - diag << "from "; - if (Operation *op = sourceNo.getTerminatorPredecessorOrNull()) - diag << "Operation " << op->getName(); - else - diag << "parent operands"; +/// Verify that types match along control flow edges described the given op. +LogicalResult detail::verifyRegionBranchOpInterface(Operation *op) { + auto regionInterface = cast(op); - diag << " to "; - if (Region *region = succRegionNo.getSuccessor()) - diag << "Region #" << region->getRegionNumber(); - else - diag << "parent results"; - return diag; -} - -/// Verify that types match along all region control flow edges originating from -/// `sourcePoint`. `getInputsTypesForRegion` is a function that returns the -/// types of the inputs that flow to a successor region. -static LogicalResult -verifyTypesAlongAllEdges(RegionBranchOpInterface branchOp, - RegionBranchPoint sourcePoint, - function_ref(RegionSuccessor)> - getInputsTypesForRegion) { - SmallVector successors; - branchOp.getSuccessorRegions(sourcePoint, successors); - - for (RegionSuccessor &succ : successors) { - FailureOr sourceTypes = getInputsTypesForRegion(succ); - if (failed(sourceTypes)) - return failure(); - - TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); - if (sourceTypes->size() != succInputsTypes.size()) { - InFlightDiagnostic diag = - branchOp->emitOpError("region control flow edge "); - std::string succStr; - llvm::raw_string_ostream os(succStr); - os << succ; - return printRegionEdgeName(diag, sourcePoint, succ) - << ": source has " << sourceTypes->size() - << " operands, but target successor " << os.str() << " needs " - << succInputsTypes.size(); - } - - for (const auto &typesIdx : - llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) { - Type sourceType = std::get<0>(typesIdx.value()); - Type inputType = std::get<1>(typesIdx.value()); - - if (!branchOp.areTypesCompatible(sourceType, inputType)) { + // Verify all control flow edges from region branch points to region + // successors. + SmallVector regionBranchPoints = + regionInterface.getAllRegionBranchPoints(); + for (const RegionBranchPoint &branchPoint : regionBranchPoints) { + SmallVector successors; + regionInterface.getSuccessorRegions(branchPoint, successors); + for (const RegionSuccessor &successor : successors) { + // Helper function that print the region branch point and the region + // successor. + auto emitRegionEdgeError = [&]() { InFlightDiagnostic diag = - branchOp->emitOpError("along control flow edge "); - return printRegionEdgeName(diag, sourcePoint, succ) - << ": source type #" << typesIdx.index() << " " << sourceType - << " should match input type #" << typesIdx.index() << " " - << inputType; + regionInterface->emitOpError("along control flow edge from "); + if (branchPoint.isParent()) { + diag << "parent"; + diag.attachNote(op->getLoc()) << "region branch point"; + } else { + diag << "Operation " + << branchPoint.getTerminatorPredecessorOrNull()->getName(); + diag.attachNote( + branchPoint.getTerminatorPredecessorOrNull()->getLoc()) + << "region branch point"; + } + diag << " to "; + if (Region *region = successor.getSuccessor()) { + diag << "Region #" << region->getRegionNumber(); + } else { + diag << "parent"; + } + return diag; + }; + + // Verify number of successor operands and successor inputs. + OperandRange succOperands = + regionInterface.getSuccessorOperands(branchPoint, successor); + ValueRange succInputs = successor.getSuccessorInputs(); + if (succOperands.size() != succInputs.size()) { + return emitRegionEdgeError() + << ": region branch point has " << succOperands.size() + << " operands, but region successor needs " << succInputs.size() + << " inputs"; + } + + // Verify that the types are compatible. + TypeRange succInputTypes = succInputs.getTypes(); + TypeRange succOperandTypes = succOperands.getTypes(); + for (const auto &typesIdx : + llvm::enumerate(llvm::zip(succOperandTypes, succInputTypes))) { + Type succOperandType = std::get<0>(typesIdx.value()); + Type succInputType = std::get<1>(typesIdx.value()); + if (!regionInterface.areTypesCompatible(succOperandType, succInputType)) + return emitRegionEdgeError() + << ": successor operand type #" << typesIdx.index() << " " + << succOperandType << " should match successor input type #" + << typesIdx.index() << " " << succInputType; } } } - - return success(); -} - -/// Verify that types match along control flow edges described the given op. -LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) { - auto regionInterface = cast(op); - - auto inputTypesFromParent = [&](RegionSuccessor successor) -> TypeRange { - return regionInterface.getEntrySuccessorOperands(successor).getTypes(); - }; - - // Verify types along control flow edges originating from the parent. - if (failed(verifyTypesAlongAllEdges( - regionInterface, RegionBranchPoint::parent(), inputTypesFromParent))) - return failure(); - - // Verify types along control flow edges originating from each region. - for (Region ®ion : op->getRegions()) { - // Collect all return-like terminators in the region. - SmallVector regionReturnOps; - for (Block &block : region) - if (!block.empty()) - if (auto terminator = - dyn_cast(block.back())) - regionReturnOps.push_back(terminator); - - // If there is no return-like terminator, the op itself should verify - // type consistency. - if (regionReturnOps.empty()) - continue; - - // Verify types along control flow edges originating from each return-like - // terminator. - for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) { - - auto inputTypesForRegion = - [&](RegionSuccessor successor) -> FailureOr { - OperandRange terminatorOperands = - regionReturnOp.getSuccessorOperands(successor); - return TypeRange(terminatorOperands.getTypes()); - }; - if (failed(verifyTypesAlongAllEdges(regionInterface, regionReturnOp, - inputTypesForRegion))) - return failure(); - } - } - return success(); } @@ -525,11 +478,15 @@ SmallVector RegionBranchOpInterface::getAllRegionBranchPoints() { SmallVector branchPoints; branchPoints.push_back(RegionBranchPoint::parent()); - for (Region ®ion : getOperation()->getRegions()) - for (Block &block : region) + for (Region ®ion : getOperation()->getRegions()) { + for (Block &block : region) { + if (block.empty()) + continue; if (auto terminator = dyn_cast(block.back())) branchPoints.push_back(RegionBranchPoint(terminator)); + } + } return branchPoints; } diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir index 6db43ffd4b81..13a9b1cd38d8 100644 --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -404,9 +404,10 @@ func.func @reduceReturn_not_inside_reduce(%arg0 : f32) { func.func @std_if_incorrect_yield(%arg0: i1, %arg1: f32) { - // expected-error@+1 {{region control flow edge from Operation scf.yield to parent results: source has 1 operands, but target successor needs 2}} + // expected-error@+1 {{along control flow edge from Operation scf.yield to parent: region branch point has 1 operands, but region successor needs 2 inputs}} %x, %y = scf.if %arg0 -> (f32, f32) { %0 = arith.addf %arg1, %arg1 : f32 + // expected-note@+1 {{region branch point}} scf.yield %0 : f32 } else { %0 = arith.subf %arg1, %arg1 : f32 @@ -575,8 +576,9 @@ func.func @while_invalid_terminator() { func.func @while_cross_region_type_mismatch() { %true = arith.constant true - // expected-error@+1 {{region control flow edge from Operation scf.condition to Region #1: source has 0 operands, but target successor needs 1}} + // expected-error@+1 {{along control flow edge from Operation scf.condition to Region #1: region branch point has 0 operands, but region successor needs 1 inputs}} scf.while : () -> () { + // expected-note@+1 {{region branch point}} scf.condition(%true) } do { ^bb0(%arg0: i32): @@ -588,8 +590,9 @@ func.func @while_cross_region_type_mismatch() { func.func @while_cross_region_type_mismatch() { %true = arith.constant true - // expected-error@+1 {{along control flow edge from Operation scf.condition to Region #1: source type #0 'i1' should match input type #0 'i32'}} + // expected-error@+1 {{along control flow edge from Operation scf.condition to Region #1: successor operand type #0 'i1' should match successor input type #0 'i32'}} %0 = scf.while : () -> (i1) { + // expected-note@+1 {{region branch point}} scf.condition(%true) %true : i1 } do { ^bb0(%arg0: i32): @@ -601,8 +604,9 @@ func.func @while_cross_region_type_mismatch() { func.func @while_result_type_mismatch() { %true = arith.constant true - // expected-error@+1 {{region control flow edge from Operation scf.condition to parent results: source has 1 operands, but target successor needs 0}} + // expected-error@+1 {{along control flow edge from Operation scf.condition to parent: region branch point has 1 operands, but region successor needs 0 inputs}} scf.while : () -> () { + // expected-note@+1 {{region branch point}} scf.condition(%true) %true : i1 } do { ^bb0(%arg0: i1):