forked from tjwei/tf-play
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtfdot.py
76 lines (67 loc) · 2.54 KB
/
tfdot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import tensorflow as tf
from graphviz import Digraph
from random import randint
from collections import defaultdict
color_table = {
"Const": "yellow",
"MatMul": "#bbffbb",
"Variable": "#ffbbbb",
"Assign": "#bbbbff"
}
def split_name(n):
ns = n.split('/')
return "/".join(ns[:-1]), ns[-1]
def common_name_space(n1, n2):
ns1 = n1.split('/')[:-1]
ns2 = n2.split('/')[:-1]
l = min(len(ns1), len(ns2))
rtn = []
for i in range(l):
if ns1[i] != ns2[i]:
break
rtn.append(ns1[i])
return "/".join(rtn)
def tfdot(graph=None):
def get_dot_data(name_space):
if name_space !='':
parent, _ = split_name(name_space)
if name_space not in dot_data_dict[parent]['subgraphs']:
get_dot_data(parent)['subgraphs'].add(name_space)
return dot_data_dict[name_space]
def update_dot(name_space=''):
name = "cluster_"+name_space if name_space else 'root'
dot = Digraph(comment="subgraph: "+name_space, name=name)
dot.body.append('label="%s"'%name_space)
dot_data = dot_data_dict[name_space]
for s in dot_data['subgraphs']:
#print(name_space, s)
dot.subgraph(update_dot(s))
for node in dot_data['nodes']:
#print(name_space, "node", node)
dot.node(**node)
for edge in dot_data['edges']:
attr = extra_attr.get(edge, {})
dot.edge(*edge, **attr)
return dot
dot_data_dict = defaultdict(lambda :{"subgraphs":set(), "edges":set(), "nodes": []})
extra_attr = {}
if graph is None:
graph = tf.ops.get_default_graph()
for op in graph.get_operations():
if op.type not in color_table:
new_color = "#%02x%02x%02x"%tuple(randint(0,100)+155 for i in range(3))
color_table[op.type] = new_color
color = color_table.get(op.type, "white")
name_space, name = split_name(op.name)
dot_data = get_dot_data(name_space)
dot_data['nodes'].append(dict(name=op.name, label=name, style="filled", fillcolor=color))
for op in graph.get_operations():
for i, ip in enumerate(op.inputs):
name_space = common_name_space(ip.op.name, op.name)
dot_data = get_dot_data(name_space)
if op.type == 'Assign' and i ==0:
dot_data['edges'].add((op.name, ip.op.name))
extra_attr[(op.name, ip.op.name)]={'color': 'red'}
else:
dot_data['edges'].add((ip.op.name, op.name))
return update_dot()