Skip to content

Commit

Permalink
support multiple labels, prevent graph obj override of existing prope…
Browse files Browse the repository at this point in the history
…rty on extension
  • Loading branch information
hbmartin committed Jul 2, 2024
1 parent 39076d6 commit f996cc9
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 13 deletions.
10 changes: 7 additions & 3 deletions graphviz2drawio/models/SvgParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def parse_nodes_edges_clusters(
edge_factory = EdgeFactory(coords)

nodes: OrderedDict[str, Node] = OrderedDict()
edges: list[Edge] = []
edges: OrderedDict[str, Edge] = OrderedDict()
clusters: OrderedDict[str, Node] = OrderedDict()

for g in root:
Expand All @@ -27,8 +27,12 @@ def parse_nodes_edges_clusters(
if g.attrib["class"] == "node":
nodes[title] = node_factory.from_svg(g)
elif g.attrib["class"] == "edge":
edges.append(edge_factory.from_svg(g))
edge = edge_factory.from_svg(g)
if existing_edge := edges.get(edge.key_for_label):
existing_edge.label += f"<div>{edge.label}</div>"
else:
edges[edge.key_for_label] = edge
elif g.attrib["class"] == "cluster":
clusters[title] = node_factory.from_svg(g)

return nodes, edges, clusters
return nodes, list(edges.values()), clusters
2 changes: 1 addition & 1 deletion graphviz2drawio/mx/CurveFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self, coords):
super(CurveFactory, self).__init__()
self.coords = coords

def from_svg(self, svg_path):
def from_svg(self, svg_path) -> Curve:
path = parse_path(svg_path)
start = self.coords.complex_translate(path[0].start)
end = self.coords.complex_translate(path[len(path) - 1].end)
Expand Down
10 changes: 9 additions & 1 deletion graphviz2drawio/mx/Edge.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from graphviz2drawio.models import DotAttr
from .Curve import Curve
from .GraphObj import GraphObj


class Edge(GraphObj):
def __init__(self, sid, gid, fr, to, curve, label):
def __init__(self, sid: str, gid: str, fr: str, to: str, curve: Curve, label: str):
super(Edge, self).__init__(sid, gid)
self.fr = fr
self.to = to
Expand All @@ -18,3 +19,10 @@ def curve_start_end(self):
return self.curve.end, self.curve.start
else:
return self.curve.start, self.curve.end

@property
def key_for_label(self) -> str:
return f"{self.gid}-{self.curve}"

def __repr__(self):
return f"{self.fr}->{self.to}: {self.label} {self.style} {self.dir} {self.arrowtail}"
6 changes: 4 additions & 2 deletions graphviz2drawio/mx/GraphObj.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ def __init__(self, sid, gid):
self.gid = gid

def enrich_from_graph(self, attrs):
for e in attrs:
self.__setattr__(e[0], e[1])
for k, v in attrs:
if k in self.__dict__ and self.__dict__[k] is not None:
continue
self.__setattr__(k, v)
14 changes: 8 additions & 6 deletions graphviz2drawio/mx/MxGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ def add_edge(self, edge):
edge_label_element = ET.SubElement(
self.root,
MxConst.CELL,
id=uuid.uuid4().hex,
style="edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];",
parent=edge.sid,
value=edge.label,
vertex="1",
connectable="0",
attrib={
"id": uuid.uuid4().hex,
"style": "edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];",
"parent": edge.sid,
"value": edge.label,
"vertex": "1",
"connectable": "0",
},
)
self.add_mx_geo(edge_label_element)

Expand Down
21 changes: 21 additions & 0 deletions test/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,27 @@ def test_cluster():
assert contains_cluster


def test_convnet():
file = "test/directed/convnet.gv.txt"
xml = graphviz2drawio.convert(file)
print(xml)

root = ET.fromstring(xml)
elements = check_xml_top(root)

assert elements[-1].attrib["value"] == "$$l_t$$"

def test_multilabel():
file = "test/directed/multilabel.gv.txt"
xml = graphviz2drawio.convert(file)
print(xml)

root = ET.fromstring(xml)
elements = check_xml_top(root)

assert elements[-1].attrib["value"] == "c<div>b</div><div>a</div>"


# def test_runAll():
# for f in os.listdir('undirected'):
# xml = graphviz2drawio.convert(f)
Expand Down

0 comments on commit f996cc9

Please sign in to comment.