From 7454098a9ed46ca5ad310bf3ec9347eb08eba007 Mon Sep 17 00:00:00 2001 From: Michael Maitland Date: Fri, 30 May 2025 00:39:45 -0400 Subject: [PATCH] [mlir][Value] Add getNumUses, hasNUses, and hasNUsesOrMore to Value (#142084) We already have hasOneUse. Like llvm::Value we provide helper methods to query the number of uses of a Value. Add unittests for Value, because that was missing. --------- Co-authored-by: Michael Maitland --- .../Tutorials/UnderstandingTheIRStructure.md | 10 +-- mlir/include/mlir/IR/Value.h | 14 +++ mlir/lib/Bytecode/Reader/BytecodeReader.cpp | 3 +- mlir/lib/Dialect/SCF/IR/SCF.cpp | 2 +- mlir/lib/IR/Value.cpp | 12 +++ mlir/test/lib/IR/TestPrintDefUse.cpp | 10 +-- mlir/unittests/IR/CMakeLists.txt | 1 + mlir/unittests/IR/ValueTest.cpp | 90 +++++++++++++++++++ 8 files changed, 125 insertions(+), 17 deletions(-) create mode 100644 mlir/unittests/IR/ValueTest.cpp diff --git a/mlir/docs/Tutorials/UnderstandingTheIRStructure.md b/mlir/docs/Tutorials/UnderstandingTheIRStructure.md index 595d6949a03f..30b50cb09490 100644 --- a/mlir/docs/Tutorials/UnderstandingTheIRStructure.md +++ b/mlir/docs/Tutorials/UnderstandingTheIRStructure.md @@ -257,14 +257,10 @@ results and print informations about them: llvm::outs() << " has no uses\n"; continue; } - if (result.hasOneUse()) { + if (result.hasOneUse()) llvm::outs() << " has a single use: "; - } else { - llvm::outs() << " has " - << std::distance(result.getUses().begin(), - result.getUses().end()) - << " uses:\n"; - } + else + llvm::outs() << " has " << result.getNumUses() << " uses:\n"; for (Operation *userOp : result.getUsers()) { llvm::outs() << " - " << userOp->getName() << "\n"; } diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index d54e3c0ad26d..4d6d89fa69a0 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -187,9 +187,23 @@ public: /// Returns a range of all uses, which is useful for iterating over all uses. use_range getUses() const { return {use_begin(), use_end()}; } + /// This method computes the number of uses of this Value. + /// + /// This is a linear time operation. Use hasOneUse, hasNUses, or + /// hasNUsesOrMore to check for specific values. + unsigned getNumUses() const; + /// Returns true if this value has exactly one use. bool hasOneUse() const { return impl->hasOneUse(); } + /// Return true if this Value has exactly n uses. + bool hasNUses(unsigned n) const; + + /// Return true if this value has n uses or more. + /// + /// This is logically equivalent to getNumUses() >= N. + bool hasNUsesOrMore(unsigned n) const; + /// Returns true if this value has no uses. bool use_empty() const { return impl->use_empty(); } diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 1052946d4550..44458d010c6c 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -1993,8 +1993,7 @@ LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) { UseListOrderStorage customOrder = valueToUseListMap.at(value.getAsOpaquePointer()); SmallVector shuffle = std::move(customOrder.indices); - uint64_t numUses = - std::distance(value.getUses().begin(), value.getUses().end()); + uint64_t numUses = value.getNumUses(); // If the encoding was a pair of indices `(src, dst)` for every permutation, // reconstruct the shuffle vector for every use. Initialize the shuffle vector diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 748379ea671b..5a0b8a058dd6 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1787,7 +1787,7 @@ struct ForallOpReplaceConstantInductionVar : public OpRewritePattern { for (auto [lb, ub, step, iv] : llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(), op.getMixedStep(), op.getInductionVars())) { - if (iv.getUses().begin() == iv.getUses().end()) + if (iv.hasNUses(0)) continue; auto numIterations = constantTripCount(lb, ub, step); if (!numIterations.has_value() || numIterations.value() != 1) { diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 178765353cc1..7b3a9462a091 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -51,6 +51,18 @@ Block *Value::getParentBlock() { return llvm::cast(*this).getOwner(); } +unsigned Value::getNumUses() const { + return (unsigned)std::distance(use_begin(), use_end()); +} + +bool Value::hasNUses(unsigned n) const { + return hasNItems(use_begin(), use_end(), n); +} + +bool Value::hasNUsesOrMore(unsigned n) const { + return hasNItemsOrMore(use_begin(), use_end(), n); +} + //===----------------------------------------------------------------------===// // Value::UseLists //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/IR/TestPrintDefUse.cpp b/mlir/test/lib/IR/TestPrintDefUse.cpp index 5d489a342f57..b983366fc16d 100644 --- a/mlir/test/lib/IR/TestPrintDefUse.cpp +++ b/mlir/test/lib/IR/TestPrintDefUse.cpp @@ -49,14 +49,10 @@ struct TestPrintDefUsePass llvm::outs() << " has no uses\n"; continue; } - if (result.hasOneUse()) { + if (result.hasOneUse()) llvm::outs() << " has a single use: "; - } else { - llvm::outs() << " has " - << std::distance(result.getUses().begin(), - result.getUses().end()) - << " uses:\n"; - } + else + llvm::outs() << " has " << result.getNumUses() << " uses:\n"; for (Operation *userOp : result.getUsers()) { llvm::outs() << " - " << userOp->getName() << "\n"; } diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt index 821ff7d14dab..9ab6029c3480 100644 --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -17,6 +17,7 @@ add_mlir_unittest(MLIRIRTests TypeTest.cpp TypeAttrNamesTest.cpp OpPropertiesTest.cpp + ValueTest.cpp DEPENDS MLIRTestInterfaceIncGen diff --git a/mlir/unittests/IR/ValueTest.cpp b/mlir/unittests/IR/ValueTest.cpp new file mode 100644 index 000000000000..433681c5ceaa --- /dev/null +++ b/mlir/unittests/IR/ValueTest.cpp @@ -0,0 +1,90 @@ +//===- mlir/unittest/IR/ValueTest.cpp - Value unit tests ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Value.h" +#include "../../test/lib/Dialect/Test/TestDialect.h" +#include "../../test/lib/Dialect/Test/TestOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OperationSupport.h" +#include "gtest/gtest.h" + +using namespace mlir; + +static Operation *createOp(MLIRContext *context, + ArrayRef operands = std::nullopt, + ArrayRef resultTypes = std::nullopt, + unsigned int numRegions = 0) { + context->allowUnregisteredDialects(); + return Operation::create( + UnknownLoc::get(context), OperationName("foo.bar", context), resultTypes, + operands, std::nullopt, nullptr, std::nullopt, numRegions); +} + +namespace { + +TEST(ValueTest, getNumUses) { + MLIRContext context; + Builder builder(&context); + + Operation *op0 = + createOp(&context, /*operands=*/std::nullopt, builder.getIntegerType(16)); + + Value v0 = op0->getResult(0); + EXPECT_EQ(v0.getNumUses(), (unsigned)0); + + createOp(&context, {v0}, builder.getIntegerType(16)); + EXPECT_EQ(v0.getNumUses(), (unsigned)1); + + createOp(&context, {v0, v0}, builder.getIntegerType(16)); + EXPECT_EQ(v0.getNumUses(), (unsigned)3); +} + +TEST(ValueTest, hasNUses) { + MLIRContext context; + Builder builder(&context); + + Operation *op = + createOp(&context, /*operands=*/std::nullopt, builder.getIntegerType(16)); + Value v0 = op->getResult(0); + EXPECT_TRUE(v0.hasNUses(0)); + EXPECT_FALSE(v0.hasNUses(1)); + + createOp(&context, {v0}, builder.getIntegerType(16)); + EXPECT_FALSE(v0.hasNUses(0)); + EXPECT_TRUE(v0.hasNUses(1)); + + createOp(&context, {v0, v0}, builder.getIntegerType(16)); + EXPECT_FALSE(v0.hasNUses(0)); + EXPECT_FALSE(v0.hasNUses(1)); + EXPECT_TRUE(v0.hasNUses(3)); +} + +TEST(ValueTest, hasNUsesOrMore) { + MLIRContext context; + Builder builder(&context); + + Operation *op = + createOp(&context, /*operands=*/std::nullopt, builder.getIntegerType(16)); + Value v0 = op->getResult(0); + EXPECT_TRUE(v0.hasNUsesOrMore(0)); + EXPECT_FALSE(v0.hasNUsesOrMore(1)); + + createOp(&context, {v0}, builder.getIntegerType(16)); + EXPECT_TRUE(v0.hasNUsesOrMore(0)); + EXPECT_TRUE(v0.hasNUsesOrMore(1)); + EXPECT_FALSE(v0.hasNUsesOrMore(2)); + + createOp(&context, {v0, v0}, builder.getIntegerType(16)); + EXPECT_TRUE(v0.hasNUsesOrMore(0)); + EXPECT_TRUE(v0.hasNUsesOrMore(1)); + EXPECT_TRUE(v0.hasNUsesOrMore(3)); + EXPECT_FALSE(v0.hasNUsesOrMore(4)); +} + +} // end anonymous namespace