diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index 02ffa0da7a8b..c0c11c9e3899 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -126,6 +126,9 @@ FailureOr loopUnrollByFactor( scf::ForOp forOp, uint64_t unrollFactor, function_ref annotateFn = nullptr); +/// Unrolls this loop completely. +LogicalResult loopUnrollFull(scf::ForOp forOp); + /// Unrolls and jams this `scf.for` operation by the specified unroll factor. /// Returns failure if the loop cannot be unrolled either due to restrictions or /// due to invalid unroll factors. In case of unroll factor of 1, the function diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index fa82bcb816a2..bc1cb24303ad 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -498,6 +498,20 @@ FailureOr mlir::loopUnrollByFactor( return resultLoops; } +/// Unrolls this loop completely. +LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) { + IRRewriter rewriter(forOp.getContext()); + std::optional mayBeConstantTripCount = getConstantTripCount(forOp); + if (!mayBeConstantTripCount.has_value()) + return failure(); + uint64_t tripCount = *mayBeConstantTripCount; + if (tripCount == 0) + return success(); + if (tripCount == 1) + return forOp.promoteIfSingleIteration(rewriter); + return loopUnrollByFactor(forOp, tripCount); +} + /// Check if bounds of all inner loops are defined outside of `forOp` /// and return false if not. static bool areInnerBoundsInvariant(scf::ForOp forOp) { diff --git a/mlir/test/Transforms/scf-loop-unroll.mlir b/mlir/test/Transforms/scf-loop-unroll.mlir index baf6b2970ac0..0ef6ad15d4eb 100644 --- a/mlir/test/Transforms/scf-loop-unroll.mlir +++ b/mlir/test/Transforms/scf-loop-unroll.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=3" -split-input-file -canonicalize | FileCheck %s // RUN: mlir-opt %s --test-loop-unrolling="unroll-factor=1" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-BY-1 +// RUN: mlir-opt %s --test-loop-unrolling="unroll-full=true" -split-input-file -canonicalize | FileCheck %s --check-prefix UNROLL-FULL // CHECK-LABEL: scf_loop_unroll_single func.func @scf_loop_unroll_single(%arg0 : f32, %arg1 : f32) -> f32 { @@ -56,3 +57,59 @@ func.func @scf_loop_unroll_factor_1_promote() -> () { // UNROLL-BY-1-NEXT: %[[C0:.*]] = arith.constant 0 : index // UNROLL-BY-1-NEXT: %{{.*}} = "test.foo"(%[[C0]]) : (index) -> i32 } + +// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_single +// UNROLL-FULL-SAME: %[[ARG:.*]]: index) +func.func @scf_loop_unroll_full_single(%arg : index) -> index { + %0 = arith.constant 0 : index + %1 = arith.constant 1 : index + %2 = arith.constant 4 : index + %4 = scf.for %iv = %0 to %2 step %1 iter_args(%arg1 = %1) -> index { + %3 = arith.addi %arg1, %arg : index + scf.yield %3 : index + } + return %4 : index + // UNROLL-FULL: %[[C1:.*]] = arith.constant 1 : index + // UNROLL-FULL: %[[V0:.*]] = arith.addi %[[ARG]], %[[C1]] : index + // UNROLL-FULL: %[[V1:.*]] = arith.addi %[[V0]], %[[ARG]] : index + // UNROLL-FULL: %[[V2:.*]] = arith.addi %[[V1]], %[[ARG]] : index + // UNROLL-FULL: %[[V3:.*]] = arith.addi %[[V2]], %[[ARG]] : index + // UNROLL-FULL: return %[[V3]] : index +} + +// UNROLL-FULL-LABEL: func @scf_loop_unroll_full_outter_loops +// UNROLL-FULL-SAME: %[[ARG:.*]]: vector<4x4xindex>) +func.func @scf_loop_unroll_full_outter_loops(%arg0: vector<4x4xindex>) -> index { + %0 = arith.constant 0 : index + %1 = arith.constant 1 : index + %2 = arith.constant 4 : index + %6 = scf.for %arg1 = %0 to %2 step %1 iter_args(%it0 = %0) -> index { + %5 = scf.for %arg2 = %0 to %2 step %1 iter_args(%it1 = %it0) -> index { + %3 = vector.extract %arg0[%arg1, %arg2] : index from vector<4x4xindex> + %4 = arith.addi %3, %it1 : index + scf.yield %3 : index + } + scf.yield %5 : index + } + return %6 : index + // UNROLL-FULL: %[[C0:.*]] = arith.constant 0 : index + // UNROLL-FULL: %[[C1:.*]] = arith.constant 1 : index + // UNROLL-FULL: %[[C4:.*]] = arith.constant 4 : index + // UNROLL-FULL: %[[SUM0:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%{{.*}} = %[[C0]]) + // UNROLL-FULL: %[[VAL:.*]] = vector.extract %[[ARG]][0, %[[IV]]] : index from vector<4x4xindex> + // UNROLL-FULL: scf.yield %[[VAL]] : index + // UNROLL-FULL: } + // UNROLL-FULL: %[[SUM1:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%{{.*}} = %[[SUM0]]) + // UNROLL-FULL: %[[VAL:.*]] = vector.extract %[[ARG]][1, %[[IV]]] : index from vector<4x4xindex> + // UNROLL-FULL: scf.yield %[[VAL]] : index + // UNROLL-FULL: } + // UNROLL-FULL: %[[SUM2:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%{{.*}} = %[[SUM1]]) + // UNROLL-FULL: %[[VAL:.*]] = vector.extract %[[ARG]][2, %[[IV]]] : index from vector<4x4xindex> + // UNROLL-FULL: scf.yield %[[VAL]] : index + // UNROLL-FULL: } + // UNROLL-FULL: %[[SUM3:.*]] = scf.for %[[IV:.*]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%{{.*}} = %[[SUM2]]) + // UNROLL-FULL: %[[VAL:.*]] = vector.extract %[[ARG]][3, %[[IV]]] : index from vector<4x4xindex> + // UNROLL-FULL: scf.yield %[[VAL]] : index + // UNROLL-FULL: } + // UNROLL-FULL: return %[[SUM3]] : index +} diff --git a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp index 8694a7f9bbd6..ced003305a7b 100644 --- a/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp +++ b/mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp @@ -42,10 +42,11 @@ struct TestLoopUnrollingPass TestLoopUnrollingPass(const TestLoopUnrollingPass &) {} explicit TestLoopUnrollingPass(uint64_t unrollFactorParam, unsigned loopDepthParam, - bool annotateLoopParam) { + bool annotateLoopParam, bool unrollFullParam) { unrollFactor = unrollFactorParam; loopDepth = loopDepthParam; annotateLoop = annotateLoopParam; + unrollFull = unrollFactorParam; } void getDependentDialects(DialectRegistry ®istry) const override { @@ -63,8 +64,12 @@ struct TestLoopUnrollingPass op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i)); } }; - for (auto loop : loops) - (void)loopUnrollByFactor(loop, unrollFactor, annotateFn); + for (auto loop : loops) { + if (unrollFull) + (void)loopUnrollFull(loop); + else + (void)loopUnrollByFactor(loop, unrollFactor, annotateFn); + } } Option unrollFactor{*this, "unroll-factor", llvm::cl::desc("Loop unroll factor."), @@ -77,6 +82,9 @@ struct TestLoopUnrollingPass llvm::cl::init(false)}; Option loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."), llvm::cl::init(0)}; + Option unrollFull{*this, "unroll-full", + llvm::cl::desc("Full unroll loops."), + llvm::cl::init(false)}; }; } // namespace