From 7ff0dc4b9fdee840de0901e969ea11880a28d433 Mon Sep 17 00:00:00 2001 From: "Chi-Chun, Chen" Date: Tue, 31 Mar 2026 11:11:08 -0500 Subject: [PATCH] [mlir][OpenMP] Add iterator support to depend clause (#189090) Extend the depend clause to support `!omp.iterated` handles alongside plain depend vars, so the IR can represent both forms. Assisted with copilot This is part of feature work for https://github.com/llvm/llvm-project/issues/188061 --- .../Optimizer/OpenMP/FunctionFiltering.cpp | 2 + .../Optimizer/OpenMP/LowerWorkdistribute.cpp | 4 + .../mlir/Dialect/OpenMP/OpenMPClauses.td | 16 +- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 6 +- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 165 ++++++++++++------ .../OpenMP/OpenMPToLLVMIRTranslation.cpp | 7 + mlir/test/Dialect/OpenMP/invalid.mlir | 28 ++- mlir/test/Dialect/OpenMP/ops.mlir | 37 +++- mlir/test/Target/LLVMIR/openmp-todo.mlir | 30 ++++ 9 files changed, 223 insertions(+), 72 deletions(-) diff --git a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp index 472d6a9f08a6..475ed35cac9f 100644 --- a/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp +++ b/flang/lib/Optimizer/OpenMP/FunctionFiltering.cpp @@ -309,6 +309,8 @@ private: // Variables unused by the device. targetOp.getDependVarsMutable().clear(); targetOp.setDependKindsAttr(nullptr); + targetOp.getDependIteratedMutable().clear(); + targetOp.setDependIteratedKindsAttr(nullptr); targetOp.getDeviceMutable().clear(); targetOp.getIfExprMutable().clear(); diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp index 2c7980064500..8a08f67006c0 100644 --- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp +++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp @@ -760,6 +760,7 @@ FailureOr splitTargetData(omp::TargetOp targetOp, rewriter, targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), targetOp.getBareAttr(), targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(), targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), targetOp.getHostEvalVars(), targetOp.getIfExpr(), targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), @@ -1480,6 +1481,7 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector &preMapOperands, rewriter, targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), targetOp.getBareAttr(), targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(), targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), preHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), @@ -1570,6 +1572,7 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector &postMapOperands, rewriter, targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), targetOp.getBareAttr(), targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(), targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), isolatedHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), @@ -1650,6 +1653,7 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp, rewriter, targetOp.getLoc(), targetOp.getAllocateVars(), targetOp.getAllocatorVars(), targetOp.getBareAttr(), targetOp.getDependKindsAttr(), targetOp.getDependVars(), + targetOp.getDependIteratedKindsAttr(), targetOp.getDependIterated(), targetOp.getDevice(), targetOp.getHasDeviceAddrVars(), postHostEvalVars, targetOp.getIfExpr(), targetOp.getInReductionVars(), targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(), diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td index b5a047dc613f..f24efd0d4fc4 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td @@ -316,20 +316,26 @@ class OpenMP_DependClauseSkip< bit description = false, bit extraClassDeclaration = false > : OpenMP_Clause { - let arguments = (ins - OptionalAttr:$depend_kinds, - Variadic:$depend_vars - ); + let arguments = (ins OptionalAttr:$depend_kinds, + Variadic:$depend_vars, + OptionalAttr:$depend_iterated_kinds, + Variadic:$depend_iterated); let optAssemblyFormat = [{ `depend` `(` - custom($depend_vars, type($depend_vars), $depend_kinds) `)` + custom($depend_vars, type($depend_vars), $depend_kinds, + $depend_iterated, type($depend_iterated), + $depend_iterated_kinds) `)` }]; let description = [{ The `depend_kinds` and `depend_vars` arguments are variadic lists of values that specify the dependencies of this particular task in relation to other tasks. + + The `depend_iterated_kinds` and `depend_iterated` arguments are variadic + lists of iterator-produced handles (from `omp.iterator`) that specify + dependencies expanded at runtime via an iterator modifier. }]; } diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index 88c8ab4f6f94..6a2fd4484157 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -1665,9 +1665,9 @@ def OrderedRegionOp : OpenMP_Op<"ordered.region", clauses = [ // 2.17.5 taskwait Construct //===----------------------------------------------------------------------===// -def TaskwaitOp : OpenMP_Op<"taskwait", clauses = [ - OpenMP_DependClause, OpenMP_NowaitClause - ]> { +def TaskwaitOp + : OpenMP_Op<"taskwait", traits = [AttrSizedOperandSegments], + clauses = [OpenMP_DependClause, OpenMP_NowaitClause]> { let summary = "taskwait construct"; let description = [{ The taskwait construct specifies a wait on the completion of child tasks diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 04418ee39be5..82fbc909f527 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1712,50 +1712,77 @@ verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, /// depend-entry-list ::= depend-entry /// | depend-entry-list `,` depend-entry /// depend-entry ::= depend-kind `->` ssa-id `:` type -static ParseResult -parseDependVarList(OpAsmParser &parser, - SmallVectorImpl &dependVars, - SmallVectorImpl &dependTypes, ArrayAttr &dependKinds) { +/// | depend-kind `->` ssa-id `:` iterated-type +static ParseResult parseDependVarList( + OpAsmParser &parser, + SmallVectorImpl &dependVars, + SmallVectorImpl &dependTypes, ArrayAttr &dependKinds, + SmallVectorImpl &iteratedVars, + SmallVectorImpl &iteratedTypes, ArrayAttr &iteratedKinds) { SmallVector kindsVec; + SmallVector iterKindsVec; if (failed(parser.parseCommaSeparatedList([&]() { StringRef keyword; + OpAsmParser::UnresolvedOperand operand; + Type ty; if (parser.parseKeyword(&keyword) || parser.parseArrow() || - parser.parseOperand(dependVars.emplace_back()) || - parser.parseColonType(dependTypes.emplace_back())) + parser.parseOperand(operand) || parser.parseColonType(ty)) return failure(); - if (std::optional keywordDepend = - (symbolizeClauseTaskDepend(keyword))) - kindsVec.emplace_back( - ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend)); - else + std::optional keywordDepend = + symbolizeClauseTaskDepend(keyword); + if (!keywordDepend) return failure(); + auto kindAttr = + ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend); + if (llvm::isa(ty)) { + iteratedVars.push_back(operand); + iteratedTypes.push_back(ty); + iterKindsVec.push_back(kindAttr); + } else { + dependVars.push_back(operand); + dependTypes.push_back(ty); + kindsVec.push_back(kindAttr); + } return success(); }))) return failure(); SmallVector kinds(kindsVec.begin(), kindsVec.end()); dependKinds = ArrayAttr::get(parser.getContext(), kinds); + SmallVector iterKinds(iterKindsVec.begin(), iterKindsVec.end()); + iteratedKinds = ArrayAttr::get(parser.getContext(), iterKinds); return success(); } /// Print Depend clause static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, - std::optional dependKinds) { - - for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) { - if (i != 0) - p << ", "; - p << stringifyClauseTaskDepend( - llvm::cast((*dependKinds)[i]) - .getValue()) - << " -> " << dependVars[i] << " : " << dependTypes[i]; - } + std::optional dependKinds, + OperandRange iteratedVars, + TypeRange iteratedTypes, + std::optional iteratedKinds) { + bool first = true; + auto printEntries = [&](OperandRange vars, TypeRange types, + std::optional kinds) { + for (unsigned i = 0, e = vars.size(); i < e; ++i) { + if (!first) + p << ", "; + p << stringifyClauseTaskDepend( + llvm::cast((*kinds)[i]) + .getValue()) + << " -> " << vars[i] << " : " << types[i]; + first = false; + } + }; + printEntries(dependVars, dependTypes, dependKinds); + printEntries(iteratedVars, iteratedTypes, iteratedKinds); } /// Verifies Depend clause static LogicalResult verifyDependVarList(Operation *op, std::optional dependKinds, - OperandRange dependVars) { + OperandRange dependVars, + std::optional iteratedKinds, + OperandRange iteratedVars) { if (!dependVars.empty()) { if (!dependKinds || dependKinds->size() != dependVars.size()) return op->emitOpError() << "expected as many depend values" @@ -1763,7 +1790,15 @@ static LogicalResult verifyDependVarList(Operation *op, } else { if (dependKinds && !dependKinds->empty()) return op->emitOpError() << "unexpected depend values"; - return success(); + } + + if (!iteratedVars.empty()) { + if (!iteratedKinds || iteratedKinds->size() != iteratedVars.size()) + return op->emitOpError() << "expected as many depend iterated values" + " as depend iterated variables"; + } else { + if (iteratedKinds && !iteratedKinds->empty()) + return op->emitOpError() << "unexpected depend iterated values"; } return success(); @@ -2266,15 +2301,17 @@ void TargetEnterDataOp::build( OpBuilder &builder, OperationState &state, const TargetEnterExitUpdateDataOperands &clauses) { MLIRContext *ctx = builder.getContext(); - TargetEnterDataOp::build(builder, state, - makeArrayAttr(ctx, clauses.dependKinds), - clauses.dependVars, clauses.device, clauses.ifExpr, - clauses.mapVars, clauses.nowait); + TargetEnterDataOp::build( + builder, state, makeArrayAttr(ctx, clauses.dependKinds), + clauses.dependVars, makeArrayAttr(ctx, clauses.dependIteratedKinds), + clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars, + clauses.nowait); } LogicalResult TargetEnterDataOp::verify() { LogicalResult verifyDependVars = - verifyDependVarList(*this, getDependKinds(), getDependVars()); + verifyDependVarList(*this, getDependKinds(), getDependVars(), + getDependIteratedKinds(), getDependIterated()); return failed(verifyDependVars) ? verifyDependVars : verifyMapClause(*this, getMapVars()); } @@ -2286,15 +2323,17 @@ LogicalResult TargetEnterDataOp::verify() { void TargetExitDataOp::build(OpBuilder &builder, OperationState &state, const TargetEnterExitUpdateDataOperands &clauses) { MLIRContext *ctx = builder.getContext(); - TargetExitDataOp::build(builder, state, - makeArrayAttr(ctx, clauses.dependKinds), - clauses.dependVars, clauses.device, clauses.ifExpr, - clauses.mapVars, clauses.nowait); + TargetExitDataOp::build( + builder, state, makeArrayAttr(ctx, clauses.dependKinds), + clauses.dependVars, makeArrayAttr(ctx, clauses.dependIteratedKinds), + clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars, + clauses.nowait); } LogicalResult TargetExitDataOp::verify() { LogicalResult verifyDependVars = - verifyDependVarList(*this, getDependKinds(), getDependVars()); + verifyDependVarList(*this, getDependKinds(), getDependVars(), + getDependIteratedKinds(), getDependIterated()); return failed(verifyDependVars) ? verifyDependVars : verifyMapClause(*this, getMapVars()); } @@ -2307,13 +2346,16 @@ void TargetUpdateOp::build(OpBuilder &builder, OperationState &state, const TargetEnterExitUpdateDataOperands &clauses) { MLIRContext *ctx = builder.getContext(); TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds), - clauses.dependVars, clauses.device, clauses.ifExpr, + clauses.dependVars, + makeArrayAttr(ctx, clauses.dependIteratedKinds), + clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars, clauses.nowait); } LogicalResult TargetUpdateOp::verify() { LogicalResult verifyDependVars = - verifyDependVarList(*this, getDependKinds(), getDependVars()); + verifyDependVarList(*this, getDependKinds(), getDependVars(), + getDependIteratedKinds(), getDependIterated()); return failed(verifyDependVars) ? verifyDependVars : verifyMapClause(*this, getMapVars()); } @@ -2327,20 +2369,24 @@ void TargetOp::build(OpBuilder &builder, OperationState &state, MLIRContext *ctx = builder.getContext(); // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars, // inReductionByref, inReductionSyms. - TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, - clauses.bare, makeArrayAttr(ctx, clauses.dependKinds), - clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars, - clauses.hostEvalVars, clauses.ifExpr, - /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr, - /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, - clauses.mapVars, clauses.nowait, clauses.privateVars, - makeArrayAttr(ctx, clauses.privateSyms), - clauses.privateNeedsBarrier, clauses.threadLimitVars, - /*private_maps=*/nullptr); + TargetOp::build( + builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.bare, + makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars, + makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated, + clauses.device, clauses.hasDeviceAddrVars, clauses.hostEvalVars, + clauses.ifExpr, + /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr, + /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, clauses.mapVars, + clauses.nowait, clauses.privateVars, + makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier, + clauses.threadLimitVars, + /*private_maps=*/nullptr); } LogicalResult TargetOp::verify() { - if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars()))) + if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars(), + getDependIteratedKinds(), + getDependIterated()))) return failure(); if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr", @@ -3270,21 +3316,23 @@ LogicalResult DeclareReductionOp::verifyRegions() { void TaskOp::build(OpBuilder &builder, OperationState &state, const TaskOperands &clauses) { MLIRContext *ctx = builder.getContext(); - TaskOp::build(builder, state, clauses.iterated, clauses.affinityVars, - clauses.allocateVars, clauses.allocatorVars, - makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars, - clauses.final, clauses.ifExpr, clauses.inReductionVars, - makeDenseBoolArrayAttr(ctx, clauses.inReductionByref), - makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable, - clauses.priority, /*private_vars=*/clauses.privateVars, - /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms), - clauses.privateNeedsBarrier, clauses.untied, - clauses.eventHandle); + TaskOp::build( + builder, state, clauses.iterated, clauses.affinityVars, + clauses.allocateVars, clauses.allocatorVars, + makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars, + makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated, + clauses.final, clauses.ifExpr, clauses.inReductionVars, + makeDenseBoolArrayAttr(ctx, clauses.inReductionByref), + makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable, + clauses.priority, /*private_vars=*/clauses.privateVars, + /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms), + clauses.privateNeedsBarrier, clauses.untied, clauses.eventHandle); } LogicalResult TaskOp::verify() { LogicalResult verifyDependVars = - verifyDependVarList(*this, getDependKinds(), getDependVars()); + verifyDependVarList(*this, getDependKinds(), getDependVars(), + getDependIteratedKinds(), getDependIterated()); return failed(verifyDependVars) ? verifyDependVars : verifyReductionVarList(*this, getInReductionSyms(), @@ -4142,7 +4190,8 @@ void TaskwaitOp::build(OpBuilder &builder, OperationState &state, const TaskwaitOperands &clauses) { // TODO Store clauses in op: dependKinds, dependVars, nowait. TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr, - /*depend_vars=*/{}, /*nowait=*/nullptr); + /*depend_vars=*/{}, /*depend_iterated_kinds=*/nullptr, + /*depend_iterated=*/{}, /*nowait=*/nullptr); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index aa8a9b9b1f7f..281235a0af46 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -357,6 +357,11 @@ static LogicalResult checkImplementationStatus(Operation &op) { if (!op.getDependVars().empty() || op.getDependKinds()) result = todo("depend"); }; + auto checkDependIteratorModifier = [&todo](auto op, LogicalResult &result) { + if (!op.getDependIterated().empty() || + (op.getDependIteratedKinds() && !op.getDependIteratedKinds()->empty())) + result = todo("depend with iterator modifier"); + }; auto checkHint = [](auto op, LogicalResult &) { if (op.getHint()) op.emitWarning("hint clause discarded"); @@ -429,6 +434,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { }) .Case([&](omp::TaskOp op) { checkAllocate(op, result); + checkDependIteratorModifier(op, result); checkInReduction(op, result); }) .Case([&](omp::TaskgroupOp op) { @@ -463,6 +469,7 @@ static LogicalResult checkImplementationStatus(Operation &op) { .Case([&](omp::TargetOp op) { checkAllocate(op, result); checkBare(op, result); + checkDependIteratorModifier(op, result); checkInReduction(op, result); checkThreadLimit(op, result); }) diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 4879ea754bf7..db5d1b60c569 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -1716,6 +1716,26 @@ func.func @omp_task_depend(%data_var: memref) { // ----- +func.func @omp_task_depend_iterated_no_vars(%data_var: memref) { + // expected-error @below {{op unexpected depend iterated values}} + "omp.task"() ({ + "omp.terminator"() : () -> () + }) {depend_iterated_kinds = [#omp], operandSegmentSizes = array} : () -> () + "func.return"() : () -> () +} + +// ----- + +func.func @omp_task_depend_iterated_mismatch(%it: !omp.iterated) { + // expected-error @below {{op expected as many depend iterated values as depend iterated variables}} + "omp.task"(%it) ({ + "omp.terminator"() : () -> () + }) {depend_iterated_kinds = [], operandSegmentSizes = array} : (!omp.iterated) -> () + "func.return"() : () -> () +} + +// ----- + func.func @omp_task(%ptr: !llvm.ptr) { // expected-error @below {{op expected symbol reference @add_f32 to point to a reduction declaration}} omp.task in_reduction(@add_f32 %ptr -> %arg0 : !llvm.ptr) { @@ -2274,7 +2294,7 @@ func.func @omp_target_enter_data(%map1: memref) { func.func @omp_target_enter_data_depend(%a: memref) { %0 = omp.map.info var_ptr(%a: memref, tensor) map_clauses(to) capture(ByRef) -> memref // expected-error @below {{op expected as many depend values as depend variables}} - omp.target_enter_data map_entries(%0: memref ) {operandSegmentSizes = array} + omp.target_enter_data map_entries(%0: memref ) {operandSegmentSizes = array} return } @@ -2292,7 +2312,7 @@ func.func @omp_target_exit_data(%map1: memref) { func.func @omp_target_exit_data_depend(%a: memref) { %0 = omp.map.info var_ptr(%a: memref, tensor) map_clauses(from) capture(ByRef) -> memref // expected-error @below {{op expected as many depend values as depend variables}} - omp.target_exit_data map_entries(%0: memref ) {operandSegmentSizes = array} + omp.target_exit_data map_entries(%0: memref ) {operandSegmentSizes = array} return } @@ -2373,7 +2393,7 @@ llvm.mlir.global internal @_QFsubEx() : i32 func.func @omp_target_update_data_depend(%a: memref) { %0 = omp.map.info var_ptr(%a: memref, tensor) map_clauses(to) capture(ByRef) -> memref // expected-error @below {{op expected as many depend values as depend variables}} - omp.target_update map_entries(%0: memref ) {operandSegmentSizes = array} + omp.target_update map_entries(%0: memref ) {operandSegmentSizes = array} return } @@ -2475,7 +2495,7 @@ func.func @omp_target_depend(%data_var: memref) { // expected-error @below {{op expected as many depend values as depend variables}} "omp.target"(%data_var) ({ "omp.terminator"() : () -> () - }) {depend_kinds = [], operandSegmentSizes = array} : (memref) -> () + }) {depend_kinds = [], operandSegmentSizes = array} : (memref) -> () "func.return"() : () -> () } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index d924d479eba9..b0554eba459f 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -871,7 +871,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %devic "omp.target"(%device, %if_cond, %num_threads) ({ // CHECK: omp.terminator omp.terminator - }) {nowait, operandSegmentSizes = array} : ( si32, i1, i32 ) -> () + }) {nowait, operandSegmentSizes = array} : ( si32, i1, i32 ) -> () // Test with optional map clause. // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref, tensor) map_clauses(always, to) capture(ByRef) -> memref {name = ""} @@ -2269,6 +2269,39 @@ func.func @omp_task_depend(%arg0: memref, %arg1: memref) { return } +// CHECK-LABEL: func.func @omp_task_depend_iterated +func.func @omp_task_depend_iterated(%lb : index, %ub : index, %step : index, + %addr : !llvm.ptr) -> () { + // CHECK: %[[IT:.*]] = omp.iterator(%[[IV:.*]]: index) = (%{{.*}} to %{{.*}} step %{{.*}}) { + // CHECK: omp.yield(%{{.*}} : !llvm.ptr) + // CHECK: } -> !omp.iterated + // CHECK: omp.task depend(taskdependin -> %[[IT]] : !omp.iterated) { + %it = omp.iterator(%iv: index) = (%lb to %ub step %step) { + omp.yield(%addr : !llvm.ptr) + } -> !omp.iterated + + omp.task depend(taskdependin -> %it : !omp.iterated) { + omp.terminator + } + return +} + +// CHECK-LABEL: func.func @omp_task_depend_iterated_mixed +func.func @omp_task_depend_iterated_mixed(%lb : index, %ub : index, %step : index, + %addr : !llvm.ptr, + %plain : memref) -> () { + // CHECK: %[[IT:.*]] = omp.iterator + // CHECK: omp.task depend(taskdependout -> %{{.*}} : memref, taskdependin -> %[[IT]] : !omp.iterated) { + %it = omp.iterator(%iv: index) = (%lb to %ub step %step) { + omp.yield(%addr : !llvm.ptr) + } -> !omp.iterated + + omp.task depend(taskdependout -> %plain : memref, taskdependin -> %it : !omp.iterated) { + omp.terminator + } + return +} + // CHECK-LABEL: @omp_target_depend // CHECK-SAME: (%arg0: memref, %arg1: memref) { @@ -2277,7 +2310,7 @@ func.func @omp_target_depend(%arg0: memref, %arg1: memref) { omp.target depend(taskdependin -> %arg0 : memref, taskdependin -> %arg1 : memref, taskdependinout -> %arg0 : memref) { // CHECK: omp.terminator omp.terminator - } {operandSegmentSizes = array} + } {operandSegmentSizes = array} return } diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index 8fb66cb4dd0e..ea7ec3cfc3bd 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -190,6 +190,36 @@ llvm.func @target_in_reduction(%x : !llvm.ptr) { // ----- +llvm.func @task_depend_iterator_modifier(%lb : i64, %ub : i64, %step : i64, + %addr : !llvm.ptr) { + %it = omp.iterator(%iv: i64) = (%lb to %ub step %step) { + omp.yield(%addr : !llvm.ptr) + } -> !omp.iterated + // expected-error@below {{not yet implemented: Unhandled clause depend with iterator modifier in omp.task operation}} + // expected-error@below {{LLVM Translation failed for operation: omp.task}} + omp.task depend(taskdependin -> %it : !omp.iterated) { + omp.terminator + } + llvm.return +} + +// ----- + +llvm.func @target_depend_iterator_modifier(%lb : i64, %ub : i64, %step : i64, + %addr : !llvm.ptr) { + %it = omp.iterator(%iv: i64) = (%lb to %ub step %step) { + omp.yield(%addr : !llvm.ptr) + } -> !omp.iterated + // expected-error@below {{not yet implemented: Unhandled clause depend with iterator modifier in omp.target operation}} + // expected-error@below {{LLVM Translation failed for operation: omp.target}} + omp.target depend(taskdependin -> %it : !omp.iterated) { + omp.terminator + } + llvm.return +} + +// ----- + llvm.func @target_enter_data_depend(%x: !llvm.ptr) { // expected-error@below {{not yet implemented: Unhandled clause depend in omp.target_enter_data operation}} // expected-error@below {{LLVM Translation failed for operation: omp.target_enter_data}}