S. VenkataKeerthy e86bd05bdc
[IR2Vec] Restructuring Vocabulary (#145119)
This PR restructures the vocabulary. 

* String based look-ups are removed. Vocabulary is changed from a map to vector. (#141832)
* Grouped all the vocabulary related methods under a single class - `ir2vec::Vocabulary`. This replaces `IR2VecVocabResult`.
* `ir2vec::Vocabulary` effectively abstracts out the _layout_ and other internal details of the vector structure. Exposes necessary APIs for accessing the Vocabulary. 

These changes ensure that _all_ known opcodes and types are present in the vocabulary. We have retained the original operands. This can be extended going forward. 

(Tracking issue - #141817)
2025-07-14 11:07:29 -07:00

507 lines
15 KiB
C++

//===- IR2VecTest.cpp - Unit tests for IR2Vec -----------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "llvm/Analysis/IR2Vec.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/JSON.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <map>
#include <vector>
using namespace llvm;
using namespace ir2vec;
using namespace ::testing;
namespace {
class TestableEmbedder : public Embedder {
public:
TestableEmbedder(const Function &F, const Vocabulary &V) : Embedder(F, V) {}
void computeEmbeddings() const override {}
void computeEmbeddings(const BasicBlock &BB) const override {}
};
TEST(EmbeddingTest, ConstructorsAndAccessors) {
// Default constructor
{
Embedding E;
EXPECT_TRUE(E.empty());
EXPECT_EQ(E.size(), 0u);
}
// Constructor with const std::vector<double>&
{
std::vector<double> Data = {1.0, 2.0, 3.0};
Embedding E(Data);
EXPECT_FALSE(E.empty());
ASSERT_THAT(E, SizeIs(3u));
EXPECT_THAT(E.getData(), ElementsAre(1.0, 2.0, 3.0));
EXPECT_EQ(E[0], 1.0);
EXPECT_EQ(E[1], 2.0);
EXPECT_EQ(E[2], 3.0);
}
// Constructor with std::vector<double>&&
{
Embedding E(std::vector<double>({4.0, 5.0}));
ASSERT_THAT(E, SizeIs(2u));
EXPECT_THAT(E.getData(), ElementsAre(4.0, 5.0));
}
// Constructor with std::initializer_list<double>
{
Embedding E({6.0, 7.0, 8.0, 9.0});
ASSERT_THAT(E, SizeIs(4u));
EXPECT_THAT(E.getData(), ElementsAre(6.0, 7.0, 8.0, 9.0));
EXPECT_EQ(E[0], 6.0);
E[0] = 6.5;
EXPECT_EQ(E[0], 6.5);
}
// Constructor with size_t
{
Embedding E(5);
ASSERT_THAT(E, SizeIs(5u));
EXPECT_THAT(E.getData(), ElementsAre(0.0, 0.0, 0.0, 0.0, 0.0));
}
// Constructor with size_t and double
{
Embedding E(5, 1.5);
ASSERT_THAT(E, SizeIs(5u));
EXPECT_THAT(E.getData(), ElementsAre(1.5, 1.5, 1.5, 1.5, 1.5));
}
// Test iterators
{
Embedding E({6.5, 7.0, 8.0, 9.0});
std::vector<double> VecE;
for (double Val : E) {
VecE.push_back(Val);
}
EXPECT_THAT(VecE, ElementsAre(6.5, 7.0, 8.0, 9.0));
const Embedding CE = E;
std::vector<double> VecCE;
for (const double &Val : CE) {
VecCE.push_back(Val);
}
EXPECT_THAT(VecCE, ElementsAre(6.5, 7.0, 8.0, 9.0));
EXPECT_EQ(*E.begin(), 6.5);
EXPECT_EQ(*(E.end() - 1), 9.0);
EXPECT_EQ(*CE.cbegin(), 6.5);
EXPECT_EQ(*(CE.cend() - 1), 9.0);
}
}
TEST(EmbeddingTest, AddVectorsOutOfPlace) {
Embedding E1 = {1.0, 2.0, 3.0};
Embedding E2 = {0.5, 1.5, -1.0};
Embedding E3 = E1 + E2;
EXPECT_THAT(E3, ElementsAre(1.5, 3.5, 2.0));
// Check that E1 and E2 are unchanged
EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
}
TEST(EmbeddingTest, AddVectors) {
Embedding E1 = {1.0, 2.0, 3.0};
Embedding E2 = {0.5, 1.5, -1.0};
E1 += E2;
EXPECT_THAT(E1, ElementsAre(1.5, 3.5, 2.0));
// Check that E2 is unchanged
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
}
TEST(EmbeddingTest, SubtractVectorsOutOfPlace) {
Embedding E1 = {1.0, 2.0, 3.0};
Embedding E2 = {0.5, 1.5, -1.0};
Embedding E3 = E1 - E2;
EXPECT_THAT(E3, ElementsAre(0.5, 0.5, 4.0));
// Check that E1 and E2 are unchanged
EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
}
TEST(EmbeddingTest, SubtractVectors) {
Embedding E1 = {1.0, 2.0, 3.0};
Embedding E2 = {0.5, 1.5, -1.0};
E1 -= E2;
EXPECT_THAT(E1, ElementsAre(0.5, 0.5, 4.0));
// Check that E2 is unchanged
EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
}
TEST(EmbeddingTest, ScaleVector) {
Embedding E1 = {1.0, 2.0, 3.0};
E1 *= 0.5f;
EXPECT_THAT(E1, ElementsAre(0.5, 1.0, 1.5));
}
TEST(EmbeddingTest, ScaleVectorOutOfPlace) {
Embedding E1 = {1.0, 2.0, 3.0};
Embedding E2 = E1 * 0.5f;
EXPECT_THAT(E2, ElementsAre(0.5, 1.0, 1.5));
// Check that E1 is unchanged
EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
}
TEST(EmbeddingTest, AddScaledVector) {
Embedding E1 = {1.0, 2.0, 3.0};
Embedding E2 = {2.0, 0.5, -1.0};
E1.scaleAndAdd(E2, 0.5f);
EXPECT_THAT(E1, ElementsAre(2.0, 2.25, 2.5));
// Check that E2 is unchanged
EXPECT_THAT(E2, ElementsAre(2.0, 0.5, -1.0));
}
TEST(EmbeddingTest, ApproximatelyEqual) {
Embedding E1 = {1.0, 2.0, 3.0};
Embedding E2 = {1.0000001, 2.0000001, 3.0000001};
EXPECT_TRUE(E1.approximatelyEquals(E2)); // Diff = 1e-7
Embedding E3 = {1.00002, 2.00002, 3.00002}; // Diff = 2e-5
EXPECT_FALSE(E1.approximatelyEquals(E3, 1e-6));
EXPECT_TRUE(E1.approximatelyEquals(E3, 3e-5));
Embedding E_clearly_within = {1.0000005, 2.0000005, 3.0000005}; // Diff = 5e-7
EXPECT_TRUE(E1.approximatelyEquals(E_clearly_within));
Embedding E_clearly_outside = {1.00001, 2.00001, 3.00001}; // Diff = 1e-5
EXPECT_FALSE(E1.approximatelyEquals(E_clearly_outside, 1e-6));
Embedding E4 = {1.0, 2.0, 3.5}; // Large diff
EXPECT_FALSE(E1.approximatelyEquals(E4, 0.01));
Embedding E5 = {1.0, 2.0, 3.0};
EXPECT_TRUE(E1.approximatelyEquals(E5, 0.0));
EXPECT_TRUE(E1.approximatelyEquals(E5));
}
#if GTEST_HAS_DEATH_TEST
#ifndef NDEBUG
TEST(EmbeddingTest, AccessOutOfBounds) {
Embedding E = {1.0, 2.0, 3.0};
EXPECT_DEATH(E[3], "Index out of bounds");
EXPECT_DEATH(E[-1], "Index out of bounds");
EXPECT_DEATH(E[4] = 4.0, "Index out of bounds");
}
TEST(EmbeddingTest, MismatchedDimensionsAddVectorsOutOfPlace) {
Embedding E1 = {1.0, 2.0};
Embedding E2 = {1.0};
EXPECT_DEATH(E1 + E2, "Vectors must have the same dimension");
}
TEST(EmbeddingTest, MismatchedDimensionsAddVectors) {
Embedding E1 = {1.0, 2.0};
Embedding E2 = {1.0};
EXPECT_DEATH(E1 += E2, "Vectors must have the same dimension");
}
TEST(EmbeddingTest, MismatchedDimensionsSubtractVectors) {
Embedding E1 = {1.0, 2.0};
Embedding E2 = {1.0};
EXPECT_DEATH(E1 -= E2, "Vectors must have the same dimension");
}
TEST(EmbeddingTest, MismatchedDimensionsAddScaledVector) {
Embedding E1 = {1.0, 2.0};
Embedding E2 = {1.0};
EXPECT_DEATH(E1.scaleAndAdd(E2, 1.0f),
"Vectors must have the same dimension");
}
TEST(EmbeddingTest, MismatchedDimensionsApproximatelyEqual) {
Embedding E1 = {1.0, 2.0};
Embedding E2 = {1.010};
EXPECT_DEATH(E1.approximatelyEquals(E2),
"Vectors must have the same dimension");
}
#endif // NDEBUG
#endif // GTEST_HAS_DEATH_TEST
TEST(IR2VecTest, CreateSymbolicEmbedder) {
Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest());
LLVMContext Ctx;
Module M("M", Ctx);
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
EXPECT_NE(Emb, nullptr);
}
TEST(IR2VecTest, CreateInvalidMode) {
Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest());
LLVMContext Ctx;
Module M("M", Ctx);
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
// static_cast an invalid int to IR2VecKind
auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
EXPECT_FALSE(static_cast<bool>(Result));
}
TEST(IR2VecTest, ZeroDimensionEmbedding) {
Embedding E1;
Embedding E2;
// Should be no-op, but not crash
E1 += E2;
E1 -= E2;
E1.scaleAndAdd(E2, 1.0f);
EXPECT_TRUE(E1.empty());
}
// Fixture for IR2Vec tests requiring IR setup.
class IR2VecTestFixture : public ::testing::Test {
protected:
Vocabulary V;
LLVMContext Ctx;
std::unique_ptr<Module> M;
Function *F = nullptr;
BasicBlock *BB = nullptr;
Instruction *AddInst = nullptr;
Instruction *RetInst = nullptr;
void SetUp() override {
V = Vocabulary(Vocabulary::createDummyVocabForTest(2));
// Setup IR
M = std::make_unique<Module>("TestM", Ctx);
FunctionType *FTy = FunctionType::get(
Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)},
false);
F = Function::Create(FTy, Function::ExternalLinkage, "f", M.get());
BB = BasicBlock::Create(Ctx, "entry", F);
Argument *Arg = F->getArg(0);
llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
AddInst = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
RetInst = ReturnInst::Create(Ctx, AddInst, BB);
}
};
TEST_F(IR2VecTestFixture, GetInstVecMap) {
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &InstMap = Emb->getInstVecMap();
EXPECT_EQ(InstMap.size(), 2u);
EXPECT_TRUE(InstMap.count(AddInst));
EXPECT_TRUE(InstMap.count(RetInst));
EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 27.6)));
EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 16.8)));
}
TEST_F(IR2VecTestFixture, GetBBVecMap) {
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBMap = Emb->getBBVecMap();
EXPECT_EQ(BBMap.size(), 1u);
EXPECT_TRUE(BBMap.count(BB));
EXPECT_EQ(BBMap.at(BB).size(), 2u);
// BB vector should be sum of add and ret: {27.6, 27.6} + {16.8, 16.8} =
// {44.4, 44.4}
EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 44.4)));
}
TEST_F(IR2VecTestFixture, GetBBVector) {
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &BBVec = Emb->getBBVector(*BB);
EXPECT_EQ(BBVec.size(), 2u);
EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 44.4)));
}
TEST_F(IR2VecTestFixture, GetFunctionVector) {
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
const auto &FuncVec = Emb->getFunctionVector();
EXPECT_EQ(FuncVec.size(), 2u);
// Function vector should match BB vector (only one BB): {44.4, 44.4}
EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 44.4)));
}
static constexpr unsigned MaxOpcodes = 67;
static constexpr unsigned MaxTypeIDs = 21;
static constexpr unsigned MaxOperands = 4;
TEST(IR2VecVocabularyTest, DummyVocabTest) {
for (unsigned Dim = 1; Dim <= 10; ++Dim) {
auto VocabVec = Vocabulary::createDummyVocabForTest(Dim);
// All embeddings should have the same dimension
for (const auto &Emb : VocabVec)
EXPECT_EQ(Emb.size(), Dim);
// Should have the correct total number of embeddings
EXPECT_EQ(VocabVec.size(), MaxOpcodes + MaxTypeIDs + MaxOperands);
auto ExpectedVocab = VocabVec;
IR2VecVocabAnalysis VocabAnalysis(std::move(VocabVec));
LLVMContext TestCtx;
Module TestMod("TestModuleForVocabAnalysis", TestCtx);
ModuleAnalysisManager MAM;
Vocabulary Result = VocabAnalysis.run(TestMod, MAM);
EXPECT_TRUE(Result.isValid());
EXPECT_EQ(Result.getDimension(), Dim);
EXPECT_EQ(Result.size(), MaxOpcodes + MaxTypeIDs + MaxOperands);
unsigned CurPos = 0;
for (const auto &Entry : Result)
EXPECT_TRUE(Entry.approximatelyEquals(ExpectedVocab[CurPos++], 0.01));
}
}
TEST(IR2VecVocabularyTest, StringKeyGeneration) {
EXPECT_EQ(Vocabulary::getStringKey(0), "Ret");
EXPECT_EQ(Vocabulary::getStringKey(12), "Add");
StringRef HalfTypeKey = Vocabulary::getStringKey(MaxOpcodes + 0);
StringRef FloatTypeKey = Vocabulary::getStringKey(MaxOpcodes + 2);
StringRef VoidTypeKey = Vocabulary::getStringKey(MaxOpcodes + 7);
StringRef IntTypeKey = Vocabulary::getStringKey(MaxOpcodes + 12);
EXPECT_EQ(HalfTypeKey, "FloatTy");
EXPECT_EQ(FloatTypeKey, "FloatTy");
EXPECT_EQ(VoidTypeKey, "VoidTy");
EXPECT_EQ(IntTypeKey, "IntegerTy");
StringRef FuncArgKey = Vocabulary::getStringKey(MaxOpcodes + MaxTypeIDs + 0);
StringRef PtrArgKey = Vocabulary::getStringKey(MaxOpcodes + MaxTypeIDs + 1);
EXPECT_EQ(FuncArgKey, "Function");
EXPECT_EQ(PtrArgKey, "Pointer");
}
TEST(IR2VecVocabularyTest, VocabularyDimensions) {
{
Vocabulary V(Vocabulary::createDummyVocabForTest(1));
EXPECT_TRUE(V.isValid());
EXPECT_EQ(V.getDimension(), 1u);
}
{
Vocabulary V(Vocabulary::createDummyVocabForTest(5));
EXPECT_TRUE(V.isValid());
EXPECT_EQ(V.getDimension(), 5u);
}
{
Vocabulary V(Vocabulary::createDummyVocabForTest(10));
EXPECT_TRUE(V.isValid());
EXPECT_EQ(V.getDimension(), 10u);
}
}
#if GTEST_HAS_DEATH_TEST
#ifndef NDEBUG
TEST(IR2VecVocabularyTest, InvalidAccess) {
Vocabulary V(Vocabulary::createDummyVocabForTest(2));
EXPECT_DEATH(V[0u], "Invalid opcode");
EXPECT_DEATH(V[100u], "Invalid opcode");
}
#endif // NDEBUG
#endif // GTEST_HAS_DEATH_TEST
TEST(IR2VecVocabularyTest, TypeIDStringKeyMapping) {
EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
static_cast<unsigned>(Type::VoidTyID)),
"VoidTy");
EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
static_cast<unsigned>(Type::IntegerTyID)),
"IntegerTy");
EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
static_cast<unsigned>(Type::FloatTyID)),
"FloatTy");
EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
static_cast<unsigned>(Type::PointerTyID)),
"PointerTy");
EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
static_cast<unsigned>(Type::FunctionTyID)),
"FunctionTy");
EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
static_cast<unsigned>(Type::StructTyID)),
"StructTy");
EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
static_cast<unsigned>(Type::ArrayTyID)),
"ArrayTy");
EXPECT_EQ(Vocabulary::getStringKey(
MaxOpcodes + static_cast<unsigned>(Type::FixedVectorTyID)),
"VectorTy");
EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
static_cast<unsigned>(Type::LabelTyID)),
"LabelTy");
EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
static_cast<unsigned>(Type::TokenTyID)),
"TokenTy");
EXPECT_EQ(Vocabulary::getStringKey(MaxOpcodes +
static_cast<unsigned>(Type::MetadataTyID)),
"MetadataTy");
}
TEST(IR2VecVocabularyTest, InvalidVocabularyConstruction) {
std::vector<Embedding> InvalidVocab;
InvalidVocab.push_back(Embedding(2, 1.0));
InvalidVocab.push_back(Embedding(2, 2.0));
Vocabulary V(std::move(InvalidVocab));
EXPECT_FALSE(V.isValid());
{
Vocabulary InvalidResult;
EXPECT_FALSE(InvalidResult.isValid());
#if GTEST_HAS_DEATH_TEST
#ifndef NDEBUG
EXPECT_DEATH(InvalidResult.getDimension(), "IR2Vec Vocabulary is invalid");
#endif // NDEBUG
#endif // GTEST_HAS_DEATH_TEST
}
}
} // end anonymous namespace