[llvm-ir2vec][MIR2Vec] Supporting MIR mode in triplet and entity generation (#164329)
Add support for Machine IR (MIR) triplet and entity generation in llvm-ir2vec. This change extends llvm-ir2vec to support Machine IR (MIR) in addition to LLVM IR, enabling the generation of training data for MIR2Vec embeddings. MIR2Vec provides machine-level code embeddings that capture target-specific instruction semantics, complementing the target-independent IR2Vec embeddings. - Extended llvm-ir2vec to support triplet and entity generation for Machine IR (MIR) - Added `--mode=mir` option to specify MIR mode (vs LLVM IR mode) - Implemented MIR triplet generation with Next and Arg relationships - Added entity mapping generation for MIR vocabulary - Updated documentation to explain MIR-specific features and usage (Partially addresses #162200 ; Tracking issue - #141817)
This commit is contained in:
parent
2b6686f2cd
commit
10bec2cd9d
@ -68,32 +68,52 @@ these two modes are used to generate the triplets and entity mappings.
|
||||
Triplet Generation
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
With the `triplets` subcommand, :program:`llvm-ir2vec` analyzes LLVM IR and extracts
|
||||
numeric triplets consisting of opcode IDs, type IDs, and operand IDs. These triplets
|
||||
With the `triplets` subcommand, :program:`llvm-ir2vec` analyzes LLVM IR or Machine IR
|
||||
and extracts numeric triplets consisting of opcode IDs and operand IDs. These triplets
|
||||
are generated in the standard format used for knowledge graph embedding training.
|
||||
The tool outputs numeric IDs directly using the ir2vec::Vocabulary mapping
|
||||
infrastructure, eliminating the need for string-to-ID preprocessing.
|
||||
The tool outputs numeric IDs directly using the vocabulary mapping infrastructure,
|
||||
eliminating the need for string-to-ID preprocessing.
|
||||
|
||||
Usage:
|
||||
Usage for LLVM IR:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
llvm-ir2vec triplets input.bc -o triplets_train2id.txt
|
||||
llvm-ir2vec triplets --mode=llvm input.bc -o triplets_train2id.txt
|
||||
|
||||
Usage for Machine IR:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
llvm-ir2vec triplets --mode=mir input.mir -o triplets_train2id.txt
|
||||
|
||||
Entity Mapping Generation
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
With the `entities` subcommand, :program:`llvm-ir2vec` generates the entity mappings
|
||||
supported by IR2Vec in the standard format used for knowledge graph embedding
|
||||
training. This subcommand outputs all supported entities (opcodes, types, and
|
||||
operands) with their corresponding numeric IDs, and is not specific for an
|
||||
LLVM IR file.
|
||||
supported by IR2Vec or MIR2Vec in the standard format used for knowledge graph embedding
|
||||
training. This subcommand outputs all supported entities with their corresponding numeric IDs.
|
||||
|
||||
Usage:
|
||||
For LLVM IR, entities include opcodes, types, and operands. For Machine IR, entities include
|
||||
machine opcodes, common operands, and register classes (both physical and virtual).
|
||||
|
||||
Usage for LLVM IR:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
llvm-ir2vec entities -o entity2id.txt
|
||||
llvm-ir2vec entities --mode=llvm -o entity2id.txt
|
||||
|
||||
Usage for Machine IR:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
llvm-ir2vec entities --mode=mir input.mir -o entity2id.txt
|
||||
|
||||
.. note::
|
||||
|
||||
For LLVM IR mode, the entity mapping is target-independent and does not require an input file.
|
||||
For Machine IR mode, an input .mir file is required to determine the target architecture,
|
||||
as entity mappings vary by target (different architectures have different instruction sets
|
||||
and register classes).
|
||||
|
||||
Embedding Generation
|
||||
~~~~~~~~~~~~~~~~~~~~
|
||||
@ -222,12 +242,17 @@ Subcommand-specific options:
|
||||
|
||||
.. option:: <input-file>
|
||||
|
||||
The input LLVM IR or bitcode file to process. This positional argument is
|
||||
required for the `triplets` subcommand.
|
||||
The input LLVM IR/bitcode file (.ll/.bc) or Machine IR file (.mir) to process.
|
||||
This positional argument is required for the `triplets` subcommand.
|
||||
|
||||
**entities** subcommand:
|
||||
|
||||
No subcommand-specific options.
|
||||
.. option:: <input-file>
|
||||
|
||||
The input Machine IR file (.mir) to process. This positional argument is required
|
||||
for the `entities` subcommand when using ``--mode=mir``, as the entity mappings
|
||||
are target-specific. For ``--mode=llvm``, no input file is required as IR2Vec
|
||||
entity mappings are target-independent.
|
||||
|
||||
OUTPUT FORMAT
|
||||
-------------
|
||||
@ -240,19 +265,37 @@ metadata headers. The format includes:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
MAX_RELATIONS=<max_relations_count>
|
||||
MAX_RELATION=<max_relation_count>
|
||||
<head_entity_id> <tail_entity_id> <relation_id>
|
||||
<head_entity_id> <tail_entity_id> <relation_id>
|
||||
...
|
||||
|
||||
Each line after the metadata header represents one instruction relationship,
|
||||
with numeric IDs for head entity, relation, and tail entity. The metadata
|
||||
header (MAX_RELATIONS) provides counts for post-processing and training setup.
|
||||
with numeric IDs for head entity, tail entity, and relation type. The metadata
|
||||
header (MAX_RELATION) indicates the maximum relation ID used.
|
||||
|
||||
**Relation Types:**
|
||||
|
||||
For LLVM IR (IR2Vec):
|
||||
* **0** = Type relationship (instruction to its type)
|
||||
* **1** = Next relationship (sequential instructions)
|
||||
* **2+** = Argument relationships (Arg0, Arg1, Arg2, ...)
|
||||
|
||||
For Machine IR (MIR2Vec):
|
||||
* **0** = Next relationship (sequential instructions)
|
||||
* **1+** = Argument relationships (Arg0, Arg1, Arg2, ...)
|
||||
|
||||
**Entity IDs:**
|
||||
|
||||
For LLVM IR: Entity IDs represent opcodes, types, and operands as defined by the IR2Vec vocabulary.
|
||||
|
||||
For Machine IR: Entity IDs represent machine opcodes, common operands (immediate, frame index, etc.),
|
||||
physical register classes, and virtual register classes as defined by the MIR2Vec vocabulary. The entity layout is target-specific.
|
||||
|
||||
Entity Mode Output
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
In entity mode, the output consists of entity mapping in the format:
|
||||
In entity mode, the output consists of entity mappings in the format:
|
||||
|
||||
.. code-block:: text
|
||||
|
||||
@ -264,6 +307,13 @@ In entity mode, the output consists of entity mapping in the format:
|
||||
The first line contains the total number of entities, followed by one entity
|
||||
mapping per line with tab-separated entity string and numeric ID.
|
||||
|
||||
For LLVM IR, entities include instruction opcodes (e.g., "Add", "Ret"), types
|
||||
(e.g., "INT", "PTR"), and operand kinds.
|
||||
|
||||
For Machine IR, entities include machine opcodes (e.g., "COPY", "ADD"),
|
||||
common operands (e.g., "Immediate", "FrameIndex"), physical register classes
|
||||
(e.g., "PhyReg_GR32"), and virtual register classes (e.g., "VirtReg_GR32").
|
||||
|
||||
Embedding Mode Output
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@ -111,6 +111,11 @@ class MIRVocabulary {
|
||||
size_t TotalEntries = 0;
|
||||
} Layout;
|
||||
|
||||
// TODO: See if we can have only one reg classes section instead of physical
|
||||
// and virtual separate sections in the vocabulary. This would reduce the
|
||||
// number of vocabulary entities significantly.
|
||||
// We can potentially distinguish physical and virtual registers by
|
||||
// considering them as a separate feature.
|
||||
enum class Section : unsigned {
|
||||
Opcodes = 0,
|
||||
CommonOperands = 1,
|
||||
@ -185,6 +190,25 @@ class MIRVocabulary {
|
||||
return Storage[static_cast<unsigned>(SectionID)][LocalIndex];
|
||||
}
|
||||
|
||||
/// Get entity ID (flat index) for a common operand type
|
||||
/// This is used for triplet generation
|
||||
unsigned getEntityIDForCommonOperand(
|
||||
MachineOperand::MachineOperandType OperandType) const {
|
||||
return Layout.CommonOperandBase + getCommonOperandIndex(OperandType);
|
||||
}
|
||||
|
||||
/// Get entity ID (flat index) for a register
|
||||
/// This is used for triplet generation
|
||||
unsigned getEntityIDForRegister(Register Reg) const {
|
||||
if (!Reg.isValid() || Reg.isStack())
|
||||
return Layout
|
||||
.VirtRegBase; // Return VirtRegBase for invalid/stack registers
|
||||
unsigned LocalIndex = getRegisterOperandIndex(Reg);
|
||||
size_t BaseOffset =
|
||||
Reg.isPhysical() ? Layout.PhyRegBase : Layout.VirtRegBase;
|
||||
return BaseOffset + LocalIndex;
|
||||
}
|
||||
|
||||
public:
|
||||
/// Static method for extracting base opcode names (public for testing)
|
||||
static std::string extractBaseOpcodeName(StringRef InstrName);
|
||||
@ -201,6 +225,20 @@ public:
|
||||
|
||||
unsigned getDimension() const { return Storage.getDimension(); }
|
||||
|
||||
/// Get entity ID (flat index) for an opcode
|
||||
/// This is used for triplet generation
|
||||
unsigned getEntityIDForOpcode(unsigned Opcode) const {
|
||||
return Layout.OpcodeBase + getCanonicalOpcodeIndex(Opcode);
|
||||
}
|
||||
|
||||
/// Get entity ID (flat index) for a machine operand
|
||||
/// This is used for triplet generation
|
||||
unsigned getEntityIDForMachineOperand(const MachineOperand &MO) const {
|
||||
if (MO.getType() == MachineOperand::MO_Register)
|
||||
return getEntityIDForRegister(MO.getReg());
|
||||
return getEntityIDForCommonOperand(MO.getType());
|
||||
}
|
||||
|
||||
// Accessor methods
|
||||
const Embedding &operator[](unsigned Opcode) const {
|
||||
unsigned LocalIndex = getCanonicalOpcodeIndex(Opcode);
|
||||
|
||||
28
llvm/test/tools/llvm-ir2vec/entities.mir
Normal file
28
llvm/test/tools/llvm-ir2vec/entities.mir
Normal file
@ -0,0 +1,28 @@
|
||||
# REQUIRES: x86_64-linux
|
||||
# RUN: llvm-ir2vec entities --mode=mir %s -o 2>&1 %t1.log
|
||||
# RUN: diff %S/output/reference_x86_entities.txt %t1.log
|
||||
|
||||
--- |
|
||||
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
|
||||
target triple = "x86_64-unknown-linux-gnu"
|
||||
|
||||
define dso_local noundef i32 @test_function(i32 noundef %a) {
|
||||
entry:
|
||||
ret i32 %a
|
||||
}
|
||||
...
|
||||
---
|
||||
name: test_function
|
||||
alignment: 16
|
||||
tracksRegLiveness: true
|
||||
registers:
|
||||
- { id: 0, class: gr32 }
|
||||
liveins:
|
||||
- { reg: '$edi', virtual-reg: '%0' }
|
||||
body: |
|
||||
bb.0.entry:
|
||||
liveins: $edi
|
||||
|
||||
%0:gr32 = COPY $edi
|
||||
$eax = COPY %0
|
||||
RET 0, $eax
|
||||
3
llvm/test/tools/llvm-ir2vec/output/lit.local.cfg
Normal file
3
llvm/test/tools/llvm-ir2vec/output/lit.local.cfg
Normal file
@ -0,0 +1,3 @@
|
||||
# Don't treat files in this directory as tests
|
||||
# These are reference data files, not test scripts
|
||||
config.suffixes = []
|
||||
33
llvm/test/tools/llvm-ir2vec/output/reference_triplets.txt
Normal file
33
llvm/test/tools/llvm-ir2vec/output/reference_triplets.txt
Normal file
@ -0,0 +1,33 @@
|
||||
MAX_RELATION=4
|
||||
187 7072 1
|
||||
187 6968 2
|
||||
187 187 0
|
||||
187 7072 1
|
||||
187 6969 2
|
||||
187 10 0
|
||||
10 7072 1
|
||||
10 7072 2
|
||||
10 7072 3
|
||||
10 6961 4
|
||||
10 187 0
|
||||
187 6952 1
|
||||
187 7072 2
|
||||
187 1555 0
|
||||
1555 6882 1
|
||||
1555 6952 2
|
||||
187 7072 1
|
||||
187 6968 2
|
||||
187 187 0
|
||||
187 7072 1
|
||||
187 6969 2
|
||||
187 601 0
|
||||
601 7072 1
|
||||
601 7072 2
|
||||
601 7072 3
|
||||
601 6961 4
|
||||
601 187 0
|
||||
187 6952 1
|
||||
187 7072 2
|
||||
187 1555 0
|
||||
1555 6882 1
|
||||
1555 6952 2
|
||||
7174
llvm/test/tools/llvm-ir2vec/output/reference_x86_entities.txt
Normal file
7174
llvm/test/tools/llvm-ir2vec/output/reference_x86_entities.txt
Normal file
File diff suppressed because it is too large
Load Diff
61
llvm/test/tools/llvm-ir2vec/triplets.mir
Normal file
61
llvm/test/tools/llvm-ir2vec/triplets.mir
Normal file
@ -0,0 +1,61 @@
|
||||
# REQUIRES: x86_64-linux
|
||||
# RUN: llvm-ir2vec triplets --mode=mir %s -o 2>&1 %t1.log
|
||||
# RUN: diff %S/output/reference_triplets.txt %t1.log
|
||||
|
||||
--- |
|
||||
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
|
||||
target triple = "x86_64-unknown-linux-gnu"
|
||||
|
||||
define dso_local noundef i32 @add_function(i32 noundef %a, i32 noundef %b) {
|
||||
entry:
|
||||
%sum = add nsw i32 %a, %b
|
||||
ret i32 %sum
|
||||
}
|
||||
|
||||
define dso_local noundef i32 @mul_function(i32 noundef %x, i32 noundef %y) {
|
||||
entry:
|
||||
%product = mul nsw i32 %x, %y
|
||||
ret i32 %product
|
||||
}
|
||||
...
|
||||
---
|
||||
name: add_function
|
||||
alignment: 16
|
||||
tracksRegLiveness: true
|
||||
registers:
|
||||
- { id: 0, class: gr32 }
|
||||
- { id: 1, class: gr32 }
|
||||
- { id: 2, class: gr32 }
|
||||
liveins:
|
||||
- { reg: '$edi', virtual-reg: '%0' }
|
||||
- { reg: '$esi', virtual-reg: '%1' }
|
||||
body: |
|
||||
bb.0.entry:
|
||||
liveins: $edi, $esi
|
||||
|
||||
%1:gr32 = COPY $esi
|
||||
%0:gr32 = COPY $edi
|
||||
%2:gr32 = nsw ADD32rr %0, %1, implicit-def dead $eflags
|
||||
$eax = COPY %2
|
||||
RET 0, $eax
|
||||
|
||||
---
|
||||
name: mul_function
|
||||
alignment: 16
|
||||
tracksRegLiveness: true
|
||||
registers:
|
||||
- { id: 0, class: gr32 }
|
||||
- { id: 1, class: gr32 }
|
||||
- { id: 2, class: gr32 }
|
||||
liveins:
|
||||
- { reg: '$edi', virtual-reg: '%0' }
|
||||
- { reg: '$esi', virtual-reg: '%1' }
|
||||
body: |
|
||||
bb.0.entry:
|
||||
liveins: $edi, $esi
|
||||
|
||||
%1:gr32 = COPY $esi
|
||||
%0:gr32 = COPY $edi
|
||||
%2:gr32 = nsw IMUL32rr %0, %1, implicit-def dead $eflags
|
||||
$eax = COPY %2
|
||||
RET 0, $eax
|
||||
@ -19,12 +19,22 @@
|
||||
/// Generates numeric triplets (head, tail, relation) for vocabulary
|
||||
/// training. Output format: MAX_RELATION=N header followed by
|
||||
/// head\ttail\trelation lines. Relations: 0=Type, 1=Next, 2+=Arg0,Arg1,...
|
||||
/// Usage: llvm-ir2vec triplets input.bc -o train2id.txt
|
||||
///
|
||||
/// For LLVM IR:
|
||||
/// llvm-ir2vec triplets input.bc -o train2id.txt
|
||||
///
|
||||
/// For Machine IR:
|
||||
/// llvm-ir2vec triplets -mode=mir input.mir -o train2id.txt
|
||||
///
|
||||
/// 2. Entity Mappings (entities):
|
||||
/// Generates entity mappings for vocabulary training.
|
||||
/// Output format: <total_entities> header followed by entity\tid lines.
|
||||
/// Usage: llvm-ir2vec entities input.bc -o entity2id.txt
|
||||
///
|
||||
/// For LLVM IR:
|
||||
/// llvm-ir2vec entities input.bc -o entity2id.txt
|
||||
///
|
||||
/// For Machine IR:
|
||||
/// llvm-ir2vec entities -mode=mir input.mir -o entity2id.txt
|
||||
///
|
||||
/// 3. Embedding Generation (embeddings):
|
||||
/// Generates IR2Vec/MIR2Vec embeddings using a trained vocabulary.
|
||||
@ -67,6 +77,8 @@
|
||||
#include "llvm/CodeGen/MIRParser/MIRParser.h"
|
||||
#include "llvm/CodeGen/MachineFunction.h"
|
||||
#include "llvm/CodeGen/MachineModuleInfo.h"
|
||||
#include "llvm/CodeGen/TargetInstrInfo.h"
|
||||
#include "llvm/CodeGen/TargetRegisterInfo.h"
|
||||
#include "llvm/MC/TargetRegistry.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Support/WithColor.h"
|
||||
@ -106,11 +118,10 @@ static cl::SubCommand
|
||||
"Generate embeddings using trained vocabulary");
|
||||
|
||||
// Common options
|
||||
static cl::opt<std::string>
|
||||
InputFilename(cl::Positional,
|
||||
cl::desc("<input bitcode file or '-' for stdin>"),
|
||||
cl::init("-"), cl::sub(TripletsSubCmd),
|
||||
cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
|
||||
static cl::opt<std::string> InputFilename(
|
||||
cl::Positional, cl::desc("<input bitcode/MIR file or '-' for stdin>"),
|
||||
cl::init("-"), cl::sub(TripletsSubCmd), cl::sub(EntitiesSubCmd),
|
||||
cl::sub(EmbeddingsSubCmd), cl::cat(CommonCategory));
|
||||
|
||||
static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),
|
||||
cl::value_desc("filename"),
|
||||
@ -345,6 +356,12 @@ Error processModule(Module &M, raw_ostream &OS) {
|
||||
|
||||
namespace mir2vec {
|
||||
|
||||
/// Relation types for MIR2Vec triplet generation
|
||||
enum MIRRelationType {
|
||||
MIRNextRelation = 0, ///< Sequential instruction relationship
|
||||
MIRArgRelation = 1 ///< Instruction to operand relationship (ArgRelation + N)
|
||||
};
|
||||
|
||||
/// Helper class for MIR2Vec embedding generation
|
||||
class MIR2VecTool {
|
||||
private:
|
||||
@ -354,7 +371,7 @@ private:
|
||||
public:
|
||||
explicit MIR2VecTool(MachineModuleInfo &MMI) : MMI(MMI) {}
|
||||
|
||||
/// Initialize MIR2Vec vocabulary
|
||||
/// Initialize MIR2Vec vocabulary from file (for embeddings generation)
|
||||
bool initializeVocabulary(const Module &M) {
|
||||
MIR2VecVocabProvider Provider(MMI);
|
||||
auto VocabOrErr = Provider.getVocabulary(M);
|
||||
@ -368,6 +385,146 @@ public:
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initialize vocabulary with layout information only.
|
||||
/// This creates a minimal vocabulary with correct layout but no actual
|
||||
/// embeddings. Sufficient for generating training data and entity mappings.
|
||||
///
|
||||
/// Note: Requires target-specific information from the first machine function
|
||||
/// to determine the vocabulary layout (number of opcodes, register classes).
|
||||
///
|
||||
/// FIXME: Use --target option to get target info directly, avoiding the need
|
||||
/// to parse machine functions for pre-training operations.
|
||||
bool initializeVocabularyForLayout(const Module &M) {
|
||||
for (const Function &F : M) {
|
||||
if (F.isDeclaration())
|
||||
continue;
|
||||
|
||||
MachineFunction *MF = MMI.getMachineFunction(F);
|
||||
if (!MF)
|
||||
continue;
|
||||
|
||||
const TargetInstrInfo &TII = *MF->getSubtarget().getInstrInfo();
|
||||
const TargetRegisterInfo &TRI = *MF->getSubtarget().getRegisterInfo();
|
||||
const MachineRegisterInfo &MRI = MF->getRegInfo();
|
||||
|
||||
auto VocabOrErr =
|
||||
MIRVocabulary::createDummyVocabForTest(TII, TRI, MRI, 1);
|
||||
if (!VocabOrErr) {
|
||||
WithColor::error(errs(), ToolName)
|
||||
<< "Failed to create dummy vocabulary - "
|
||||
<< toString(VocabOrErr.takeError()) << "\n";
|
||||
return false;
|
||||
}
|
||||
Vocab = std::make_unique<MIRVocabulary>(std::move(*VocabOrErr));
|
||||
return true;
|
||||
}
|
||||
|
||||
WithColor::error(errs(), ToolName)
|
||||
<< "No machine functions found to initialize vocabulary\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Generate triplets for the module
|
||||
/// Output format: MAX_RELATION=N header followed by relationships
|
||||
void generateTriplets(const Module &M, raw_ostream &OS) const {
|
||||
unsigned MaxRelation = MIRNextRelation; // Track maximum relation ID
|
||||
std::string Relationships;
|
||||
raw_string_ostream RelOS(Relationships);
|
||||
|
||||
for (const Function &F : M) {
|
||||
if (F.isDeclaration())
|
||||
continue;
|
||||
|
||||
MachineFunction *MF = MMI.getMachineFunction(F);
|
||||
if (!MF) {
|
||||
WithColor::warning(errs(), ToolName)
|
||||
<< "No MachineFunction for " << F.getName() << "\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
unsigned FuncMaxRelation = generateTriplets(*MF, RelOS);
|
||||
MaxRelation = std::max(MaxRelation, FuncMaxRelation);
|
||||
}
|
||||
|
||||
RelOS.flush();
|
||||
|
||||
// Write metadata header followed by relationships
|
||||
OS << "MAX_RELATION=" << MaxRelation << '\n';
|
||||
OS << Relationships;
|
||||
}
|
||||
|
||||
/// Generate triplets for a single machine function
|
||||
/// Returns the maximum relation ID used in this function
|
||||
unsigned generateTriplets(const MachineFunction &MF, raw_ostream &OS) const {
|
||||
unsigned MaxRelation = MIRNextRelation;
|
||||
unsigned PrevOpcode = 0;
|
||||
bool HasPrevOpcode = false;
|
||||
|
||||
if (!Vocab) {
|
||||
WithColor::error(errs(), ToolName)
|
||||
<< "MIR Vocabulary must be initialized for triplet generation.\n";
|
||||
return MaxRelation;
|
||||
}
|
||||
|
||||
for (const MachineBasicBlock &MBB : MF) {
|
||||
for (const MachineInstr &MI : MBB) {
|
||||
// Skip debug instructions
|
||||
if (MI.isDebugInstr())
|
||||
continue;
|
||||
|
||||
// Get opcode entity ID
|
||||
unsigned OpcodeID = Vocab->getEntityIDForOpcode(MI.getOpcode());
|
||||
|
||||
// Add "Next" relationship with previous instruction
|
||||
if (HasPrevOpcode) {
|
||||
OS << PrevOpcode << '\t' << OpcodeID << '\t' << MIRNextRelation
|
||||
<< '\n';
|
||||
LLVM_DEBUG(dbgs()
|
||||
<< Vocab->getStringKey(PrevOpcode) << '\t'
|
||||
<< Vocab->getStringKey(OpcodeID) << '\t' << "Next\n");
|
||||
}
|
||||
|
||||
// Add "Arg" relationships for operands
|
||||
unsigned ArgIndex = 0;
|
||||
for (const MachineOperand &MO : MI.operands()) {
|
||||
auto OperandID = Vocab->getEntityIDForMachineOperand(MO);
|
||||
unsigned RelationID = MIRArgRelation + ArgIndex;
|
||||
OS << OpcodeID << '\t' << OperandID << '\t' << RelationID << '\n';
|
||||
LLVM_DEBUG({
|
||||
std::string OperandStr = Vocab->getStringKey(OperandID);
|
||||
dbgs() << Vocab->getStringKey(OpcodeID) << '\t' << OperandStr
|
||||
<< '\t' << "Arg" << ArgIndex << '\n';
|
||||
});
|
||||
|
||||
++ArgIndex;
|
||||
}
|
||||
|
||||
// Update MaxRelation if there were operands
|
||||
if (ArgIndex > 0)
|
||||
MaxRelation = std::max(MaxRelation, MIRArgRelation + ArgIndex - 1);
|
||||
|
||||
PrevOpcode = OpcodeID;
|
||||
HasPrevOpcode = true;
|
||||
}
|
||||
}
|
||||
|
||||
return MaxRelation;
|
||||
}
|
||||
|
||||
/// Generate entity mappings with vocabulary
|
||||
void generateEntityMappings(raw_ostream &OS) const {
|
||||
if (!Vocab) {
|
||||
WithColor::error(errs(), ToolName)
|
||||
<< "Vocabulary must be initialized for entity mappings.\n";
|
||||
return;
|
||||
}
|
||||
|
||||
const unsigned EntityCount = Vocab->getCanonicalSize();
|
||||
OS << EntityCount << "\n";
|
||||
for (unsigned EntityID = 0; EntityID < EntityCount; ++EntityID)
|
||||
OS << Vocab->getStringKey(EntityID) << '\t' << EntityID << '\n';
|
||||
}
|
||||
|
||||
/// Generate embeddings for all machine functions in the module
|
||||
void generateEmbeddings(const Module &M, raw_ostream &OS) const {
|
||||
if (!Vocab) {
|
||||
@ -538,38 +695,67 @@ int main(int argc, char **argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Create MIR2Vec tool and initialize vocabulary
|
||||
// Create MIR2Vec tool
|
||||
MIR2VecTool Tool(*MMI);
|
||||
if (!Tool.initializeVocabulary(*M))
|
||||
return 1;
|
||||
|
||||
// Initialize vocabulary. For triplet/entity generation, only layout is
|
||||
// needed For embedding generation, the full vocabulary is needed.
|
||||
//
|
||||
// Note: Unlike IR2Vec, MIR2Vec vocabulary initialization requires
|
||||
// target-specific information for generating the vocabulary layout. So, we
|
||||
// always initialize the vocabulary in this case.
|
||||
if (TripletsSubCmd || EntitiesSubCmd) {
|
||||
if (!Tool.initializeVocabularyForLayout(*M)) {
|
||||
WithColor::error(errs(), ToolName)
|
||||
<< "Failed to initialize MIR2Vec vocabulary for layout.\n";
|
||||
return 1;
|
||||
}
|
||||
} else {
|
||||
if (!Tool.initializeVocabulary(*M)) {
|
||||
WithColor::error(errs(), ToolName)
|
||||
<< "Failed to initialize MIR2Vec vocabulary.\n";
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
assert(Tool.getVocabulary() &&
|
||||
"MIR2Vec vocabulary should be initialized at this point");
|
||||
LLVM_DEBUG(dbgs() << "MIR2Vec vocabulary loaded successfully.\n"
|
||||
<< "Vocabulary dimension: "
|
||||
<< Tool.getVocabulary()->getDimension() << "\n"
|
||||
<< "Vocabulary size: "
|
||||
<< Tool.getVocabulary()->getCanonicalSize() << "\n");
|
||||
|
||||
// Generate embeddings based on subcommand
|
||||
if (!FunctionName.empty()) {
|
||||
// Process single function
|
||||
Function *F = M->getFunction(FunctionName);
|
||||
if (!F) {
|
||||
WithColor::error(errs(), ToolName)
|
||||
<< "Function '" << FunctionName << "' not found\n";
|
||||
return 1;
|
||||
}
|
||||
// Handle subcommands
|
||||
if (TripletsSubCmd) {
|
||||
Tool.generateTriplets(*M, OS);
|
||||
} else if (EntitiesSubCmd) {
|
||||
Tool.generateEntityMappings(OS);
|
||||
} else if (EmbeddingsSubCmd) {
|
||||
if (!FunctionName.empty()) {
|
||||
// Process single function
|
||||
Function *F = M->getFunction(FunctionName);
|
||||
if (!F) {
|
||||
WithColor::error(errs(), ToolName)
|
||||
<< "Function '" << FunctionName << "' not found\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
MachineFunction *MF = MMI->getMachineFunction(*F);
|
||||
if (!MF) {
|
||||
WithColor::error(errs(), ToolName)
|
||||
<< "No MachineFunction for " << FunctionName << "\n";
|
||||
return 1;
|
||||
}
|
||||
MachineFunction *MF = MMI->getMachineFunction(*F);
|
||||
if (!MF) {
|
||||
WithColor::error(errs(), ToolName)
|
||||
<< "No MachineFunction for " << FunctionName << "\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
Tool.generateEmbeddings(*MF, OS);
|
||||
Tool.generateEmbeddings(*MF, OS);
|
||||
} else {
|
||||
// Process all functions
|
||||
Tool.generateEmbeddings(*M, OS);
|
||||
}
|
||||
} else {
|
||||
// Process all functions
|
||||
Tool.generateEmbeddings(*M, OS);
|
||||
WithColor::error(errs(), ToolName)
|
||||
<< "Please specify a subcommand: triplets, entities, or embeddings\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user