[ViewOpGraph] Improve GraphViz output (#125509)

This patch improves the GraphViz output of ViewOpGraph
(--view-op-graph).

- Switch to rectangular record-based nodes, inspired by a similar
visualization in [Glow](https://github.com/pytorch/glow). Rectangles
make more efficient use of space when printing text.
- Add input and output ports for each operand and result, and remove
edge labels.
- Switch to a muted color palette to reduce eye strain.
This commit is contained in:
Eric Hein 2025-02-07 10:45:47 -05:00 committed by GitHub
parent 1611059f5d
commit 1f67070a3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 207 additions and 109 deletions

View File

@ -14,6 +14,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/IndentedOstream.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/GraphWriter.h"
#include <map>
@ -29,7 +30,7 @@ using namespace mlir;
static const StringRef kLineStyleControlFlow = "dashed";
static const StringRef kLineStyleDataFlow = "solid";
static const StringRef kShapeNode = "ellipse";
static const StringRef kShapeNode = "Mrecord";
static const StringRef kShapeNone = "plain";
/// Return the size limits for eliding large attributes.
@ -49,16 +50,25 @@ static std::string strFromOs(function_ref<void(raw_ostream &)> func) {
return buf;
}
/// Escape special characters such as '\n' and quotation marks.
static std::string escapeString(std::string str) {
return strFromOs([&](raw_ostream &os) { os.write_escaped(str); });
}
/// Put quotation marks around a given string.
static std::string quoteString(const std::string &str) {
return "\"" + str + "\"";
}
/// For Graphviz record nodes:
/// " Braces, vertical bars and angle brackets must be escaped with a backslash
/// character if you wish them to appear as a literal character "
std::string escapeLabelString(const std::string &str) {
std::string buf;
llvm::raw_string_ostream os(buf);
for (char c : str) {
if (llvm::is_contained({'{', '|', '<', '}', '>', '\n', '"'}, c))
os << '\\';
os << c;
}
return buf;
}
using AttributeMap = std::map<std::string, std::string>;
namespace {
@ -79,6 +89,12 @@ public:
std::optional<int> clusterId;
};
struct DataFlowEdge {
Value value;
Node node;
std::string port;
};
/// This pass generates a Graphviz dataflow visualization of an MLIR operation.
/// Note: See https://www.graphviz.org/doc/info/lang.html for more information
/// about the Graphviz DOT language.
@ -107,7 +123,7 @@ public:
private:
/// Generate a color mapping that will color every operation with the same
/// name the same way. It'll interpolate the hue in the HSV color-space,
/// attempting to keep the contrast suitable for black text.
/// using muted colors that provide good contrast for black text.
template <typename T>
void initColorMapping(T &irEntity) {
backgroundColors.clear();
@ -120,8 +136,10 @@ private:
});
for (auto indexedOps : llvm::enumerate(ops)) {
double hue = ((double)indexedOps.index()) / ops.size();
// Use lower saturation (0.3) and higher value (0.95) for better
// readability
backgroundColors[indexedOps.value()->getName()].second =
std::to_string(hue) + " 1.0 1.0";
std::to_string(hue) + " 0.3 0.95";
}
}
@ -129,8 +147,8 @@ private:
/// emitted.
void emitAllEdgeStmts() {
if (printDataFlowEdges) {
for (const auto &[value, node, label] : dataFlowEdges) {
emitEdgeStmt(valueToNode[value], node, label, kLineStyleDataFlow);
for (const auto &e : dataFlowEdges) {
emitEdgeStmt(valueToNode[e.value], e.node, e.port, kLineStyleDataFlow);
}
}
@ -147,8 +165,7 @@ private:
os.indent();
// Emit invisible anchor node from/to which arrows can be drawn.
Node anchorNode = emitNodeStmt(" ", kShapeNone);
os << attrStmt("label", quoteString(escapeString(std::move(label))))
<< ";\n";
os << attrStmt("label", quoteString(label)) << ";\n";
builder();
os.unindent();
os << "}\n";
@ -176,7 +193,8 @@ private:
// Always emit splat attributes.
if (isa<SplatElementsAttr>(attr)) {
attr.print(os);
os << escapeLabelString(
strFromOs([&](raw_ostream &os) { attr.print(os); }));
return;
}
@ -184,8 +202,8 @@ private:
auto elements = dyn_cast<ElementsAttr>(attr);
if (elements && elements.getNumElements() > largeAttrLimit) {
os << std::string(elements.getShapedType().getRank(), '[') << "..."
<< std::string(elements.getShapedType().getRank(), ']') << " : "
<< elements.getType();
<< std::string(elements.getShapedType().getRank(), ']') << " : ";
emitMlirType(os, elements.getType());
return;
}
@ -199,19 +217,27 @@ private:
std::string buf;
llvm::raw_string_ostream ss(buf);
attr.print(ss);
os << truncateString(buf);
os << escapeLabelString(truncateString(buf));
}
// Print a truncated and escaped MLIR type to `os`.
void emitMlirType(raw_ostream &os, Type type) {
std::string buf;
llvm::raw_string_ostream ss(buf);
type.print(ss);
os << escapeLabelString(truncateString(buf));
}
// Print a truncated and escaped MLIR operand to `os`.
void emitMlirOperand(raw_ostream &os, Value operand) {
operand.printAsOperand(os, OpPrintingFlags());
}
/// Append an edge to the list of edges.
/// Note: Edges are written to the output stream via `emitAllEdgeStmts`.
void emitEdgeStmt(Node n1, Node n2, std::string label, StringRef style) {
void emitEdgeStmt(Node n1, Node n2, std::string port, StringRef style) {
AttributeMap attrs;
attrs["style"] = style.str();
// Do not label edges that start/end at a cluster boundary. Such edges are
// clipped at the boundary, but labels are not. This can lead to labels
// floating around without any edge next to them.
if (!n1.clusterId && !n2.clusterId)
attrs["label"] = quoteString(escapeString(std::move(label)));
// Use `ltail` and `lhead` to draw edges between clusters.
if (n1.clusterId)
attrs["ltail"] = "cluster_" + std::to_string(*n1.clusterId);
@ -219,7 +245,15 @@ private:
attrs["lhead"] = "cluster_" + std::to_string(*n2.clusterId);
edges.push_back(strFromOs([&](raw_ostream &os) {
os << llvm::format("v%i -> v%i ", n1.id, n2.id);
os << "v" << n1.id;
if (!port.empty() && !n1.clusterId)
// Attach edge to south compass point of the result
os << ":res" << port << ":s";
os << " -> ";
os << "v" << n2.id;
if (!port.empty() && !n2.clusterId)
// Attach edge to north compass point of the operand
os << ":arg" << port << ":n";
emitAttrList(os, attrs);
}));
}
@ -240,11 +274,11 @@ private:
StringRef background = "") {
int nodeId = ++counter;
AttributeMap attrs;
attrs["label"] = quoteString(escapeString(std::move(label)));
attrs["label"] = quoteString(label);
attrs["shape"] = shape.str();
if (!background.empty()) {
attrs["style"] = "filled";
attrs["fillcolor"] = ("\"" + background + "\"").str();
attrs["fillcolor"] = quoteString(background.str());
}
os << llvm::format("v%i ", nodeId);
emitAttrList(os, attrs);
@ -252,8 +286,18 @@ private:
return Node(nodeId);
}
/// Generate a label for an operation.
std::string getLabel(Operation *op) {
std::string getValuePortName(Value operand) {
// Print value as an operand and omit the leading '%' character.
auto str = strFromOs([&](raw_ostream &os) {
operand.printAsOperand(os, OpPrintingFlags());
});
// Replace % and # with _
std::replace(str.begin(), str.end(), '%', '_');
std::replace(str.begin(), str.end(), '#', '_');
return str;
}
std::string getClusterLabel(Operation *op) {
return strFromOs([&](raw_ostream &os) {
// Print operation name and type.
os << op->getName();
@ -267,18 +311,73 @@ private:
// Print attributes.
if (printAttrs) {
os << "\n";
os << "\\l";
for (const NamedAttribute &attr : op->getAttrs()) {
os << '\n' << attr.getName().getValue() << ": ";
os << escapeLabelString(attr.getName().getValue().str()) << ": ";
emitMlirAttr(os, attr.getValue());
os << "\\l";
}
}
});
}
/// Generate a label for an operation.
std::string getRecordLabel(Operation *op) {
return strFromOs([&](raw_ostream &os) {
os << "{";
// Print operation inputs.
if (op->getNumOperands() > 0) {
os << "{";
auto operandToPort = [&](Value operand) {
os << "<arg" << getValuePortName(operand) << "> ";
emitMlirOperand(os, operand);
};
interleave(op->getOperands(), os, operandToPort, "|");
os << "}|";
}
// Print operation name and type.
os << op->getName() << "\\l";
// Print attributes.
if (printAttrs && !op->getAttrs().empty()) {
// Extra line break to separate attributes from the operation name.
os << "\\l";
for (const NamedAttribute &attr : op->getAttrs()) {
os << attr.getName().getValue() << ": ";
emitMlirAttr(os, attr.getValue());
os << "\\l";
}
}
if (op->getNumResults() > 0) {
os << "|{";
auto resultToPort = [&](Value result) {
os << "<res" << getValuePortName(result) << "> ";
emitMlirOperand(os, result);
if (printResultTypes) {
os << " ";
emitMlirType(os, result.getType());
}
};
interleave(op->getResults(), os, resultToPort, "|");
os << "}";
}
os << "}";
});
}
/// Generate a label for a block argument.
std::string getLabel(BlockArgument arg) {
return "arg" + std::to_string(arg.getArgNumber());
return strFromOs([&](raw_ostream &os) {
os << "<res" << getValuePortName(arg) << "> ";
arg.printAsOperand(os, OpPrintingFlags());
if (printResultTypes) {
os << " ";
emitMlirType(os, arg.getType());
}
});
}
/// Process a block. Emit a cluster and one node per block argument and
@ -287,14 +386,12 @@ private:
emitClusterStmt([&]() {
for (BlockArgument &blockArg : block.getArguments())
valueToNode[blockArg] = emitNodeStmt(getLabel(blockArg));
// Emit a node for each operation.
std::optional<Node> prevNode;
for (Operation &op : block) {
Node nextNode = processOperation(&op);
if (printControlFlowEdges && prevNode)
emitEdgeStmt(*prevNode, nextNode, /*label=*/"",
kLineStyleControlFlow);
emitEdgeStmt(*prevNode, nextNode, /*port=*/"", kLineStyleControlFlow);
prevNode = nextNode;
}
});
@ -311,18 +408,19 @@ private:
for (Region &region : op->getRegions())
processRegion(region);
},
getLabel(op));
getClusterLabel(op));
} else {
node = emitNodeStmt(getLabel(op), kShapeNode,
node = emitNodeStmt(getRecordLabel(op), kShapeNode,
backgroundColors[op->getName()].second);
}
// Insert data flow edges originating from each operand.
if (printDataFlowEdges) {
unsigned numOperands = op->getNumOperands();
for (unsigned i = 0; i < numOperands; i++)
dataFlowEdges.push_back({op->getOperand(i), node,
numOperands == 1 ? "" : std::to_string(i)});
for (unsigned i = 0; i < numOperands; i++) {
auto operand = op->getOperand(i);
dataFlowEdges.push_back({operand, node, getValuePortName(operand)});
}
}
for (Value result : op->getResults())
@ -352,7 +450,7 @@ private:
/// Mapping of SSA values to Graphviz nodes/clusters.
DenseMap<Value, Node> valueToNode;
/// Output for data flow edges is delayed until the end to handle cycles
std::vector<std::tuple<Value, Node, std::string>> dataFlowEdges;
std::vector<DataFlowEdge> dataFlowEdges;
/// Counter for generating unique node/subgraph identifiers.
int counter = 0;

View File

@ -1,21 +1,21 @@
// RUN: mlir-opt -view-op-graph %s -o %t 2>&1 | FileCheck -check-prefix=DFG %s
// DFG-LABEL: digraph G {
// DFG: compound = true;
// DFG: subgraph cluster_1 {
// DFG: v2 [label = " ", shape = plain];
// DFG: label = "builtin.module : ()\n";
// DFG: subgraph cluster_3 {
// DFG: v4 [label = " ", shape = plain];
// DFG: label = "";
// DFG: v5 [fillcolor = "0.000000 1.0 1.0", label = "arith.addi : (index)\n\noverflowFlags: #arith.overflow<none...", shape = ellipse, style = filled];
// DFG: v6 [fillcolor = "0.333333 1.0 1.0", label = "arith.constant : (index)\n\nvalue: 0 : index", shape = ellipse, style = filled];
// DFG: v7 [fillcolor = "0.333333 1.0 1.0", label = "arith.constant : (index)\n\nvalue: 1 : index", shape = ellipse, style = filled];
// DFG: }
// DFG: }
// DFG: v6 -> v5 [label = "0", style = solid];
// DFG: v7 -> v5 [label = "1", style = solid];
// DFG: }
// DFG-NEXT: compound = true;
// DFG-NEXT: subgraph cluster_1 {
// DFG-NEXT: v2 [label = " ", shape = plain];
// DFG-NEXT: label = "builtin.module : ()\l";
// DFG-NEXT: subgraph cluster_3 {
// DFG-NEXT: v4 [label = " ", shape = plain];
// DFG-NEXT: label = "";
// DFG-NEXT: v5 [fillcolor = "0.000000 0.3 0.95", label = "{{\{\{}}<arg_c0> %c0|<arg_c1> %c1}|arith.addi\l\loverflowFlags: #arith.overflow\<none...\l|{<res_0> %0 index}}", shape = Mrecord, style = filled];
// DFG-NEXT: v6 [fillcolor = "0.333333 0.3 0.95", label = "{arith.constant\l\lvalue: 0 : index\l|{<res_c0> %c0 index}}", shape = Mrecord, style = filled];
// DFG-NEXT: v7 [fillcolor = "0.333333 0.3 0.95", label = "{arith.constant\l\lvalue: 1 : index\l|{<res_c1> %c1 index}}", shape = Mrecord, style = filled];
// DFG-NEXT: }
// DFG-NEXT: }
// DFG-NEXT: v6:res_c0:s -> v5:arg_c0:n[style = solid];
// DFG-NEXT: v7:res_c1:s -> v5:arg_c1:n[style = solid];
// DFG-NEXT: }
module {
%add = arith.addi %c0, %c1 : index

View File

@ -1,45 +1,45 @@
// RUN: mlir-opt -view-op-graph -allow-unregistered-dialect %s -o %t 2>&1 | FileCheck -check-prefix=DFG %s
// DFG-LABEL: digraph G {
// DFG: compound = true;
// DFG: subgraph cluster_1 {
// DFG: v2 [label = " ", shape = plain];
// DFG: label = "builtin.module : ()\n";
// DFG: subgraph cluster_3 {
// DFG: v4 [label = " ", shape = plain];
// DFG: label = "";
// DFG: subgraph cluster_5 {
// DFG: v6 [label = " ", shape = plain];
// DFG: label = "test.graph_region : ()\n";
// DFG: subgraph cluster_7 {
// DFG: v8 [label = " ", shape = plain];
// DFG: label = "";
// DFG: v9 [fillcolor = "0.000000 1.0 1.0", label = "op1 : (i32)\n", shape = ellipse, style = filled];
// DFG: subgraph cluster_10 {
// DFG: v11 [label = " ", shape = plain];
// DFG: label = "test.ssacfg_region : (i32)\n";
// DFG: subgraph cluster_12 {
// DFG: v13 [label = " ", shape = plain];
// DFG: label = "";
// DFG: v14 [fillcolor = "0.166667 1.0 1.0", label = "op2 : (i32)\n", shape = ellipse, style = filled];
// DFG: }
// DFG: }
// DFG: v15 [fillcolor = "0.166667 1.0 1.0", label = "op2 : (i32)\n", shape = ellipse, style = filled];
// DFG: v16 [fillcolor = "0.500000 1.0 1.0", label = "op3 : (i32)\n", shape = ellipse, style = filled];
// DFG: }
// DFG: }
// DFG: }
// DFG: }
// DFG: v9 -> v9 [label = "0", style = solid];
// DFG: v15 -> v9 [label = "1", style = solid];
// DFG: v9 -> v14 [label = "0", style = solid];
// DFG: v11 -> v14 [ltail = cluster_10, style = solid];
// DFG: v15 -> v14 [label = "2", style = solid];
// DFG: v16 -> v14 [label = "3", style = solid];
// DFG: v9 -> v15 [label = "0", style = solid];
// DFG: v16 -> v15 [label = "1", style = solid];
// DFG: v9 -> v16 [label = "", style = solid];
// DFG: }
// DFG-NEXT: compound = true;
// DFG-NEXT: subgraph cluster_1 {
// DFG-NEXT: v2 [label = " ", shape = plain];
// DFG-NEXT: label = "builtin.module : ()\l";
// DFG-NEXT: subgraph cluster_3 {
// DFG-NEXT: v4 [label = " ", shape = plain];
// DFG-NEXT: label = "";
// DFG-NEXT: subgraph cluster_5 {
// DFG-NEXT: v6 [label = " ", shape = plain];
// DFG-NEXT: label = "test.graph_region : ()\l";
// DFG-NEXT: subgraph cluster_7 {
// DFG-NEXT: v8 [label = " ", shape = plain];
// DFG-NEXT: label = "";
// DFG-NEXT: v9 [fillcolor = "0.000000 0.3 0.95", label = "{{\{\{}}<arg_0> %0|<arg_2> %2}|op1\l|{<res_0> %0 i32}}", shape = Mrecord, style = filled];
// DFG-NEXT: subgraph cluster_10 {
// DFG-NEXT: v11 [label = " ", shape = plain];
// DFG-NEXT: label = "test.ssacfg_region : (i32)\l";
// DFG-NEXT: subgraph cluster_12 {
// DFG-NEXT: v13 [label = " ", shape = plain];
// DFG-NEXT: label = "";
// DFG-NEXT: v14 [fillcolor = "0.166667 0.3 0.95", label = "{{\{\{}}<arg_0> %0|<arg_1> %1|<arg_2> %2|<arg_3> %3}|op2\l|{<res_4> %4 i32}}", shape = Mrecord, style = filled];
// DFG-NEXT: }
// DFG-NEXT: }
// DFG-NEXT: v15 [fillcolor = "0.166667 0.3 0.95", label = "{{\{\{}}<arg_0> %0|<arg_3> %3}|op2\l|{<res_2> %2 i32}}", shape = Mrecord, style = filled];
// DFG-NEXT: v16 [fillcolor = "0.500000 0.3 0.95", label = "{{\{\{}}<arg_0> %0}|op3\l|{<res_3> %3 i32}}", shape = Mrecord, style = filled];
// DFG-NEXT: }
// DFG-NEXT: }
// DFG-NEXT: }
// DFG-NEXT: }
// DFG-NEXT: v9:res_0:s -> v9:arg_0:n[style = solid];
// DFG-NEXT: v15:res_2:s -> v9:arg_2:n[style = solid];
// DFG-NEXT: v9:res_0:s -> v14:arg_0:n[style = solid];
// DFG-NEXT: v11 -> v14:arg_1:n[ltail = cluster_10, style = solid];
// DFG-NEXT: v15:res_2:s -> v14:arg_2:n[style = solid];
// DFG-NEXT: v16:res_3:s -> v14:arg_3:n[style = solid];
// DFG-NEXT: v9:res_0:s -> v15:arg_0:n[style = solid];
// DFG-NEXT: v16:res_3:s -> v15:arg_3:n[style = solid];
// DFG-NEXT: v9:res_0:s -> v16:arg_0:n[style = solid];
// DFG-NEXT: }
"test.graph_region"() ({ // A Graph region
%1 = "op1"(%1, %3) : (i32, i32) -> (i32) // OK: %1, %3 allowed here

View File

@ -6,49 +6,49 @@
// DFG: subgraph {{.*}}
// DFG: label = "func.func{{.*}}merge_blocks
// DFG: subgraph {{.*}} {
// DFG: v[[ARG0:.*]] [label = "arg0"
// DFG: v[[ARG0:.*]] [label = "<res_arg0> %arg0 i32"
// DFG: v[[CONST10:.*]] [{{.*}}label ={{.*}}10 : i32
// DFG: subgraph [[CLUSTER_MERGE_BLOCKS:.*]] {
// DFG: v[[ANCHOR:.*]] [label = " ", shape = plain]
// DFG: label = "test.merge_blocks
// DFG: subgraph {{.*}} {
// DFG: v[[TEST_BR:.*]] [{{.*}}label = "test.br
// DFG: v[[TEST_BR:.*]] [{{.*}}label = "{{.*}}test.br
// DFG: }
// DFG: subgraph {{.*}} {
// DFG: }
// DFG: }
// DFG: v[[TEST_RET:.*]] [{{.*}}label = "test.return
// DFG: v[[ARG0]] -> v[[TEST_BR]]
// DFG: v[[CONST10]] -> v[[TEST_BR]]
// DFG: v[[ANCHOR]] -> v[[TEST_RET]] [ltail = [[CLUSTER_MERGE_BLOCKS]], style = solid];
// DFG: v[[ANCHOR]] -> v[[TEST_RET]] [ltail = [[CLUSTER_MERGE_BLOCKS]], style = solid];
// DFG: v[[TEST_RET:.*]] [{{.*}}label = "{{.*}}test.return
// DFG: v[[ARG0]]:res_arg0:s -> v[[TEST_BR]]:arg_arg0:n
// DFG: v[[CONST10]]:res_c10_i32:s -> v[[TEST_BR]]
// DFG: v[[ANCHOR]] -> v[[TEST_RET]]:arg_1_0:n[ltail = [[CLUSTER_MERGE_BLOCKS]], style = solid];
// DFG: v[[ANCHOR]] -> v[[TEST_RET]]:arg_1_1:n[ltail = [[CLUSTER_MERGE_BLOCKS]], style = solid];
// CFG-LABEL: digraph G {
// CFG: subgraph {{.*}} {
// CFG: subgraph {{.*}}
// CFG: label = "func.func{{.*}}merge_blocks
// CFG: subgraph {{.*}} {
// CFG: v[[C1:.*]] [{{.*}}label = "arith.constant
// CFG: v[[C2:.*]] [{{.*}}label = "arith.constant
// CFG: v[[C3:.*]] [{{.*}}label = "arith.constant
// CFG: v[[C4:.*]] [{{.*}}label = "arith.constant
// CFG: v[[TEST_FUNC:.*]] [{{.*}}label = "test.func
// CFG: v[[C1:.*]] [{{.*}}label = "{arith.constant
// CFG: v[[C2:.*]] [{{.*}}label = "{arith.constant
// CFG: v[[C3:.*]] [{{.*}}label = "{arith.constant
// CFG: v[[C4:.*]] [{{.*}}label = "{arith.constant
// CFG: v[[TEST_FUNC:.*]] [{{.*}}label = "{test.func
// CFG: subgraph [[CLUSTER_MERGE_BLOCKS:.*]] {
// CFG: v[[ANCHOR:.*]] [label = " ", shape = plain]
// CFG: label = "test.merge_blocks
// CFG: subgraph {{.*}} {
// CFG: v[[TEST_BR:.*]] [{{.*}}label = "test.br
// CFG: v[[TEST_BR:.*]] [{{.*}}label = "{{.*}}test.br
// CFG: }
// CFG: subgraph {{.*}} {
// CFG: }
// CFG: }
// CFG: v[[TEST_RET:.*]] [{{.*}}label = "test.return
// CFG: v[[TEST_RET:.*]] [{{.*}}label = "{{.*}}test.return
// CFG: v[[C1]] -> v[[C2]]
// CFG: v[[C2]] -> v[[C3]]
// CFG: v[[C3]] -> v[[C4]]
// CFG: v[[C4]] -> v[[TEST_FUNC]]
// CFG: v[[TEST_FUNC]] -> v[[ANCHOR]] [lhead = [[CLUSTER_MERGE_BLOCKS]], style = dashed];
// CFG: v[[ANCHOR]] -> v[[TEST_RET]] [ltail = [[CLUSTER_MERGE_BLOCKS]], style = dashed];
// CFG: v[[TEST_FUNC]] -> v[[ANCHOR]][lhead = [[CLUSTER_MERGE_BLOCKS]], style = dashed];
// CFG: v[[ANCHOR]] -> v[[TEST_RET]][ltail = [[CLUSTER_MERGE_BLOCKS]], style = dashed];
func.func @merge_blocks(%arg0: i32, %arg1 : i32) -> () {
%0 = arith.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>