From f996cc9447905fe7159c823608d800d89842b7db Mon Sep 17 00:00:00 2001 From: Harold Martin Date: Tue, 2 Jul 2024 14:55:37 -0700 Subject: [PATCH] support multiple labels, prevent graph obj override of existing property on extension --- graphviz2drawio/models/SvgParser.py | 10 +++++++--- graphviz2drawio/mx/CurveFactory.py | 2 +- graphviz2drawio/mx/Edge.py | 10 +++++++++- graphviz2drawio/mx/GraphObj.py | 6 ++++-- graphviz2drawio/mx/MxGraph.py | 14 ++++++++------ test/test_graphs.py | 21 +++++++++++++++++++++ 6 files changed, 50 insertions(+), 13 deletions(-) diff --git a/graphviz2drawio/models/SvgParser.py b/graphviz2drawio/models/SvgParser.py index 73404d6..9def2e5 100644 --- a/graphviz2drawio/models/SvgParser.py +++ b/graphviz2drawio/models/SvgParser.py @@ -18,7 +18,7 @@ def parse_nodes_edges_clusters( edge_factory = EdgeFactory(coords) nodes: OrderedDict[str, Node] = OrderedDict() - edges: list[Edge] = [] + edges: OrderedDict[str, Edge] = OrderedDict() clusters: OrderedDict[str, Node] = OrderedDict() for g in root: @@ -27,8 +27,12 @@ def parse_nodes_edges_clusters( if g.attrib["class"] == "node": nodes[title] = node_factory.from_svg(g) elif g.attrib["class"] == "edge": - edges.append(edge_factory.from_svg(g)) + edge = edge_factory.from_svg(g) + if existing_edge := edges.get(edge.key_for_label): + existing_edge.label += f"
{edge.label}
" + else: + edges[edge.key_for_label] = edge elif g.attrib["class"] == "cluster": clusters[title] = node_factory.from_svg(g) - return nodes, edges, clusters + return nodes, list(edges.values()), clusters diff --git a/graphviz2drawio/mx/CurveFactory.py b/graphviz2drawio/mx/CurveFactory.py index 58ed37b..2afbcdf 100644 --- a/graphviz2drawio/mx/CurveFactory.py +++ b/graphviz2drawio/mx/CurveFactory.py @@ -8,7 +8,7 @@ def __init__(self, coords): super(CurveFactory, self).__init__() self.coords = coords - def from_svg(self, svg_path): + def from_svg(self, svg_path) -> Curve: path = parse_path(svg_path) start = self.coords.complex_translate(path[0].start) end = self.coords.complex_translate(path[len(path) - 1].end) diff --git a/graphviz2drawio/mx/Edge.py b/graphviz2drawio/mx/Edge.py index 8e33bc1..27a7877 100644 --- a/graphviz2drawio/mx/Edge.py +++ b/graphviz2drawio/mx/Edge.py @@ -1,9 +1,10 @@ from graphviz2drawio.models import DotAttr +from .Curve import Curve from .GraphObj import GraphObj class Edge(GraphObj): - def __init__(self, sid, gid, fr, to, curve, label): + def __init__(self, sid: str, gid: str, fr: str, to: str, curve: Curve, label: str): super(Edge, self).__init__(sid, gid) self.fr = fr self.to = to @@ -18,3 +19,10 @@ def curve_start_end(self): return self.curve.end, self.curve.start else: return self.curve.start, self.curve.end + + @property + def key_for_label(self) -> str: + return f"{self.gid}-{self.curve}" + + def __repr__(self): + return f"{self.fr}->{self.to}: {self.label} {self.style} {self.dir} {self.arrowtail}" diff --git a/graphviz2drawio/mx/GraphObj.py b/graphviz2drawio/mx/GraphObj.py index 00c5b0b..b5d81fb 100644 --- a/graphviz2drawio/mx/GraphObj.py +++ b/graphviz2drawio/mx/GraphObj.py @@ -4,5 +4,7 @@ def __init__(self, sid, gid): self.gid = gid def enrich_from_graph(self, attrs): - for e in attrs: - self.__setattr__(e[0], e[1]) + for k, v in attrs: + if k in self.__dict__ and self.__dict__[k] is not None: + continue + self.__setattr__(k, v) diff --git a/graphviz2drawio/mx/MxGraph.py b/graphviz2drawio/mx/MxGraph.py index 8147984..3de3fbc 100644 --- a/graphviz2drawio/mx/MxGraph.py +++ b/graphviz2drawio/mx/MxGraph.py @@ -41,12 +41,14 @@ def add_edge(self, edge): edge_label_element = ET.SubElement( self.root, MxConst.CELL, - id=uuid.uuid4().hex, - style="edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];", - parent=edge.sid, - value=edge.label, - vertex="1", - connectable="0", + attrib={ + "id": uuid.uuid4().hex, + "style": "edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];", + "parent": edge.sid, + "value": edge.label, + "vertex": "1", + "connectable": "0", + }, ) self.add_mx_geo(edge_label_element) diff --git a/test/test_graphs.py b/test/test_graphs.py index a3d8fb7..2984ee8 100644 --- a/test/test_graphs.py +++ b/test/test_graphs.py @@ -94,6 +94,27 @@ def test_cluster(): assert contains_cluster +def test_convnet(): + file = "test/directed/convnet.gv.txt" + xml = graphviz2drawio.convert(file) + print(xml) + + root = ET.fromstring(xml) + elements = check_xml_top(root) + + assert elements[-1].attrib["value"] == "$$l_t$$" + +def test_multilabel(): + file = "test/directed/multilabel.gv.txt" + xml = graphviz2drawio.convert(file) + print(xml) + + root = ET.fromstring(xml) + elements = check_xml_top(root) + + assert elements[-1].attrib["value"] == "c
b
a
" + + # def test_runAll(): # for f in os.listdir('undirected'): # xml = graphviz2drawio.convert(f)