Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit f16109b

Browse files
[WIP]Debug print graph
1 parent 36ed5e0 commit f16109b

4 files changed

Lines changed: 97 additions & 0 deletions

File tree

src/common/exec_utils.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,25 @@ bool CheckForInputNameDuplicates(const nnvm::IndexedGraph& idx) {
7575
return true;
7676
}
7777

78+
void PrintGraph(const nnvm::IndexedGraph& idx, std::ostream& os) {
79+
auto node_str = [&idx](uint32_t nid) {
80+
return std::to_string(nid) + " " + idx[nid].source->attrs.name;
81+
};
82+
for (size_t i = 0; i < idx.num_nodes(); ++i) {
83+
const auto& attrs = idx[i].source->attrs;
84+
os << "node " << node_str(i) << " " << (attrs.op ? attrs.op->name : "(var)") << "\n";
85+
for (auto [k, v] : attrs.dict)
86+
os << "attr " << k << " " << v << "\n";
87+
for (const auto& inp : idx[i].inputs)
88+
os << "inp " << node_str(inp.node_id) << " " << inp.index << " " << inp.version << "\n";
89+
for (auto dep : idx[i].control_deps)
90+
os << "dep " << node_str(dep) << "\n";
91+
for (const auto& sub : attrs.subgraphs) {
92+
std::string name;
93+
os << "sub " << (sub->GetAttr("name", &name) ? name : "(noname)") << "\n";
94+
}
95+
}
96+
}
97+
7898
} // namespace common
7999
} // namespace mxnet

src/common/exec_utils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <nnvm/graph.h>
2828
#include <nnvm/pass_functions.h>
2929
#include <map>
30+
#include <ostream>
3031
#include <vector>
3132
#include <string>
3233
#include <utility>
@@ -570,6 +571,14 @@ void CopyGraph(nnvm::Graph* dst, const nnvm::Graph& src, bool copy_variables);
570571
*/
571572
bool CheckForInputNameDuplicates(const nnvm::IndexedGraph& idx);
572573

574+
/*!
575+
* \brief Prints graph to the specified stream.
576+
*
577+
* \param idx Indexed graph to print
578+
* \param os Output stream
579+
*/
580+
void PrintGraph(const nnvm::IndexedGraph& idx, std::ostream& os);
581+
573582
} // namespace common
574583
} // namespace mxnet
575584
#endif // MXNET_COMMON_EXEC_UTILS_H_

src/imperative/cached_op.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@
2121
#define MXNET_IMPERATIVE_CACHED_OP_H_
2222

2323
#include <mxnet/imperative.h>
24+
#include <fstream>
2425
#include <vector>
2526
#include <numeric>
2627
#include <atomic>
2728
#include <utility>
2829
#include <string>
2930
#include <unordered_map>
3031
#include <map>
32+
#include "../common/exec_utils.h"
3133
#include "../operator/operator_common.h"
3234
#include "../operator/subgraph/common.h"
3335
#include "./imperative_utils.h"
@@ -327,13 +329,34 @@ void SetRefCounts(nnvm::Graph* fwd_graph, const nnvm::Graph& full_graph) {
327329
std::make_shared<dmlc::any>(std::move(full_ref_count));
328330
}
329331

332+
void MaybePrintGraph(const nnvm::IndexedGraph& idx, const std::string& msg) {
333+
if (!dmlc::GetEnv("MXNET_DEBUG_PRINT_GRAPH", false))
334+
return;
335+
336+
std::ofstream f;
337+
std::ostream* dest = &std::cout;
338+
std::string dest_name = dmlc::GetEnv("MXNET_DEBUG_PRINT_GRAPH_PATH", std::string("stdout"));
339+
if (dest_name == "stderr") {
340+
dest = &std::cerr;
341+
} else if (dest_name != "stdout") {
342+
f.open(dest_name.c_str(), std::ios::app);
343+
CHECK(f.good());
344+
dest = &f;
345+
}
346+
347+
*dest << "[[[ " << msg << "\n";
348+
common::PrintGraph(idx, *dest);
349+
*dest << "]]] " << msg << "\n";
350+
}
351+
330352
void OptimizeGraph(nnvm::Graph* full_graph,
331353
nnvm::Graph* fwd_graph,
332354
nnvm::Graph* grad_graph,
333355
std::vector<size_t>* input_map,
334356
const Context& context,
335357
size_t num_forward_outputs,
336358
const bool inlining) {
359+
MaybePrintGraph(full_graph->indexed_graph(), "graph before optimization");
337360
input_map->resize(full_graph->indexed_graph().input_nodes().size());
338361
std::iota(input_map->begin(), input_map->end(), 0);
339362
#if MXNET_USE_CUDA && !defined(_WIN32)
@@ -383,6 +406,7 @@ void OptimizeGraph(nnvm::Graph* full_graph,
383406
grad_graph->outputs = std::vector<nnvm::NodeEntry>(
384407
full_graph->outputs.begin() + num_forward_outputs, full_graph->outputs.end());
385408
SetRefCounts(fwd_graph, *full_graph);
409+
MaybePrintGraph(full_graph->indexed_graph(), "graph after optimization");
386410
}
387411

388412
/* \brief Check if param indices and data indices are set, if not then set data indices */

tools/print_graph.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#!/usr/bin/env python3
2+
3+
import re
4+
import sys
5+
6+
7+
RE_NODE = re.compile(r'node\s(.+)\n')
8+
RE_ATTR = re.compile(r'attr\s(.+)\n')
9+
RE_INP = re.compile(r'inp\s(.+)\n')
10+
RE_DEP = re.compile(r'dep\s(.+)\n')
11+
RE_SUB = re.compile(r'node\s(.+)\n')
12+
13+
14+
def to_dot(f):
15+
print('digraph Net {')
16+
for line in f:
17+
m = RE_NODE.fullmatch(line)
18+
if m:
19+
nid, name, op = m.group(1).split()
20+
shape = 'ellipse' if op == '(var)' else 'rectangle'
21+
print(f' node_{nid} [shape={shape}, label={name}]')
22+
continue
23+
m = RE_ATTR.fullmatch(line)
24+
if m:
25+
continue
26+
m = RE_INP.fullmatch(line)
27+
if m:
28+
njd, _name, index, _version = m.group(1).split()
29+
print(f' node_{njd} -> node_{nid} [label={index}, style=solid]')
30+
continue
31+
m = RE_DEP.fullmatch(line)
32+
if m:
33+
njd, _name = m.group(1).split()
34+
print(f' node_{njd} -> node_{nid} [style=dashed]')
35+
continue
36+
m = RE_SUB.fullmatch(line)
37+
if m:
38+
continue
39+
break
40+
print('}')
41+
42+
43+
if __name__ == '__main__':
44+
to_dot(sys.stdin)

0 commit comments

Comments
 (0)