[mlir][Value] Add getNumUses, hasNUses, and hasNUsesOrMore to Value (#142084)

We already have hasOneUse. Like llvm::Value we provide helper methods to
query the number of uses of a Value. Add unittests for Value, because
that was missing.

---------

Co-authored-by: Michael Maitland <michaelmaitland@meta.com>
This commit is contained in:
Michael Maitland 2025-05-30 00:39:45 -04:00 committed by GitHub
parent f5d3470d42
commit 7454098a9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 125 additions and 17 deletions

View File

@ -257,14 +257,10 @@ results and print informations about them:
llvm::outs() << " has no uses\n"; llvm::outs() << " has no uses\n";
continue; continue;
} }
if (result.hasOneUse()) { if (result.hasOneUse())
llvm::outs() << " has a single use: "; llvm::outs() << " has a single use: ";
} else { else
llvm::outs() << " has " llvm::outs() << " has " << result.getNumUses() << " uses:\n";
<< std::distance(result.getUses().begin(),
result.getUses().end())
<< " uses:\n";
}
for (Operation *userOp : result.getUsers()) { for (Operation *userOp : result.getUsers()) {
llvm::outs() << " - " << userOp->getName() << "\n"; llvm::outs() << " - " << userOp->getName() << "\n";
} }

View File

@ -187,9 +187,23 @@ public:
/// Returns a range of all uses, which is useful for iterating over all uses. /// Returns a range of all uses, which is useful for iterating over all uses.
use_range getUses() const { return {use_begin(), use_end()}; } use_range getUses() const { return {use_begin(), use_end()}; }
/// This method computes the number of uses of this Value.
///
/// This is a linear time operation. Use hasOneUse, hasNUses, or
/// hasNUsesOrMore to check for specific values.
unsigned getNumUses() const;
/// Returns true if this value has exactly one use. /// Returns true if this value has exactly one use.
bool hasOneUse() const { return impl->hasOneUse(); } bool hasOneUse() const { return impl->hasOneUse(); }
/// Return true if this Value has exactly n uses.
bool hasNUses(unsigned n) const;
/// Return true if this value has n uses or more.
///
/// This is logically equivalent to getNumUses() >= N.
bool hasNUsesOrMore(unsigned n) const;
/// Returns true if this value has no uses. /// Returns true if this value has no uses.
bool use_empty() const { return impl->use_empty(); } bool use_empty() const { return impl->use_empty(); }

View File

@ -1993,8 +1993,7 @@ LogicalResult BytecodeReader::Impl::sortUseListOrder(Value value) {
UseListOrderStorage customOrder = UseListOrderStorage customOrder =
valueToUseListMap.at(value.getAsOpaquePointer()); valueToUseListMap.at(value.getAsOpaquePointer());
SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices); SmallVector<unsigned, 4> shuffle = std::move(customOrder.indices);
uint64_t numUses = uint64_t numUses = value.getNumUses();
std::distance(value.getUses().begin(), value.getUses().end());
// If the encoding was a pair of indices `(src, dst)` for every permutation, // If the encoding was a pair of indices `(src, dst)` for every permutation,
// reconstruct the shuffle vector for every use. Initialize the shuffle vector // reconstruct the shuffle vector for every use. Initialize the shuffle vector

View File

@ -1787,7 +1787,7 @@ struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
for (auto [lb, ub, step, iv] : for (auto [lb, ub, step, iv] :
llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(), llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(),
op.getMixedStep(), op.getInductionVars())) { op.getMixedStep(), op.getInductionVars())) {
if (iv.getUses().begin() == iv.getUses().end()) if (iv.hasNUses(0))
continue; continue;
auto numIterations = constantTripCount(lb, ub, step); auto numIterations = constantTripCount(lb, ub, step);
if (!numIterations.has_value() || numIterations.value() != 1) { if (!numIterations.has_value() || numIterations.value() != 1) {

View File

@ -51,6 +51,18 @@ Block *Value::getParentBlock() {
return llvm::cast<BlockArgument>(*this).getOwner(); return llvm::cast<BlockArgument>(*this).getOwner();
} }
unsigned Value::getNumUses() const {
return (unsigned)std::distance(use_begin(), use_end());
}
bool Value::hasNUses(unsigned n) const {
return hasNItems(use_begin(), use_end(), n);
}
bool Value::hasNUsesOrMore(unsigned n) const {
return hasNItemsOrMore(use_begin(), use_end(), n);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Value::UseLists // Value::UseLists
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -49,14 +49,10 @@ struct TestPrintDefUsePass
llvm::outs() << " has no uses\n"; llvm::outs() << " has no uses\n";
continue; continue;
} }
if (result.hasOneUse()) { if (result.hasOneUse())
llvm::outs() << " has a single use: "; llvm::outs() << " has a single use: ";
} else { else
llvm::outs() << " has " llvm::outs() << " has " << result.getNumUses() << " uses:\n";
<< std::distance(result.getUses().begin(),
result.getUses().end())
<< " uses:\n";
}
for (Operation *userOp : result.getUsers()) { for (Operation *userOp : result.getUsers()) {
llvm::outs() << " - " << userOp->getName() << "\n"; llvm::outs() << " - " << userOp->getName() << "\n";
} }

View File

@ -17,6 +17,7 @@ add_mlir_unittest(MLIRIRTests
TypeTest.cpp TypeTest.cpp
TypeAttrNamesTest.cpp TypeAttrNamesTest.cpp
OpPropertiesTest.cpp OpPropertiesTest.cpp
ValueTest.cpp
DEPENDS DEPENDS
MLIRTestInterfaceIncGen MLIRTestInterfaceIncGen

View File

@ -0,0 +1,90 @@
//===- mlir/unittest/IR/ValueTest.cpp - Value unit tests ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Value.h"
#include "../../test/lib/Dialect/Test/TestDialect.h"
#include "../../test/lib/Dialect/Test/TestOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OperationSupport.h"
#include "gtest/gtest.h"
using namespace mlir;
static Operation *createOp(MLIRContext *context,
ArrayRef<Value> operands = std::nullopt,
ArrayRef<Type> resultTypes = std::nullopt,
unsigned int numRegions = 0) {
context->allowUnregisteredDialects();
return Operation::create(
UnknownLoc::get(context), OperationName("foo.bar", context), resultTypes,
operands, std::nullopt, nullptr, std::nullopt, numRegions);
}
namespace {
TEST(ValueTest, getNumUses) {
MLIRContext context;
Builder builder(&context);
Operation *op0 =
createOp(&context, /*operands=*/std::nullopt, builder.getIntegerType(16));
Value v0 = op0->getResult(0);
EXPECT_EQ(v0.getNumUses(), (unsigned)0);
createOp(&context, {v0}, builder.getIntegerType(16));
EXPECT_EQ(v0.getNumUses(), (unsigned)1);
createOp(&context, {v0, v0}, builder.getIntegerType(16));
EXPECT_EQ(v0.getNumUses(), (unsigned)3);
}
TEST(ValueTest, hasNUses) {
MLIRContext context;
Builder builder(&context);
Operation *op =
createOp(&context, /*operands=*/std::nullopt, builder.getIntegerType(16));
Value v0 = op->getResult(0);
EXPECT_TRUE(v0.hasNUses(0));
EXPECT_FALSE(v0.hasNUses(1));
createOp(&context, {v0}, builder.getIntegerType(16));
EXPECT_FALSE(v0.hasNUses(0));
EXPECT_TRUE(v0.hasNUses(1));
createOp(&context, {v0, v0}, builder.getIntegerType(16));
EXPECT_FALSE(v0.hasNUses(0));
EXPECT_FALSE(v0.hasNUses(1));
EXPECT_TRUE(v0.hasNUses(3));
}
TEST(ValueTest, hasNUsesOrMore) {
MLIRContext context;
Builder builder(&context);
Operation *op =
createOp(&context, /*operands=*/std::nullopt, builder.getIntegerType(16));
Value v0 = op->getResult(0);
EXPECT_TRUE(v0.hasNUsesOrMore(0));
EXPECT_FALSE(v0.hasNUsesOrMore(1));
createOp(&context, {v0}, builder.getIntegerType(16));
EXPECT_TRUE(v0.hasNUsesOrMore(0));
EXPECT_TRUE(v0.hasNUsesOrMore(1));
EXPECT_FALSE(v0.hasNUsesOrMore(2));
createOp(&context, {v0, v0}, builder.getIntegerType(16));
EXPECT_TRUE(v0.hasNUsesOrMore(0));
EXPECT_TRUE(v0.hasNUsesOrMore(1));
EXPECT_TRUE(v0.hasNUsesOrMore(3));
EXPECT_FALSE(v0.hasNUsesOrMore(4));
}
} // end anonymous namespace