diff --git a/graphviz2drawio/models/SvgParser.py b/graphviz2drawio/models/SvgParser.py
index 73404d6..9def2e5 100644
--- a/graphviz2drawio/models/SvgParser.py
+++ b/graphviz2drawio/models/SvgParser.py
@@ -18,7 +18,7 @@ def parse_nodes_edges_clusters(
edge_factory = EdgeFactory(coords)
nodes: OrderedDict[str, Node] = OrderedDict()
- edges: list[Edge] = []
+ edges: OrderedDict[str, Edge] = OrderedDict()
clusters: OrderedDict[str, Node] = OrderedDict()
for g in root:
@@ -27,8 +27,12 @@ def parse_nodes_edges_clusters(
if g.attrib["class"] == "node":
nodes[title] = node_factory.from_svg(g)
elif g.attrib["class"] == "edge":
- edges.append(edge_factory.from_svg(g))
+ edge = edge_factory.from_svg(g)
+ if existing_edge := edges.get(edge.key_for_label):
+ existing_edge.label += f"
{edge.label}
"
+ else:
+ edges[edge.key_for_label] = edge
elif g.attrib["class"] == "cluster":
clusters[title] = node_factory.from_svg(g)
- return nodes, edges, clusters
+ return nodes, list(edges.values()), clusters
diff --git a/graphviz2drawio/mx/CurveFactory.py b/graphviz2drawio/mx/CurveFactory.py
index 58ed37b..2afbcdf 100644
--- a/graphviz2drawio/mx/CurveFactory.py
+++ b/graphviz2drawio/mx/CurveFactory.py
@@ -8,7 +8,7 @@ def __init__(self, coords):
super(CurveFactory, self).__init__()
self.coords = coords
- def from_svg(self, svg_path):
+ def from_svg(self, svg_path) -> Curve:
path = parse_path(svg_path)
start = self.coords.complex_translate(path[0].start)
end = self.coords.complex_translate(path[len(path) - 1].end)
diff --git a/graphviz2drawio/mx/Edge.py b/graphviz2drawio/mx/Edge.py
index 8e33bc1..27a7877 100644
--- a/graphviz2drawio/mx/Edge.py
+++ b/graphviz2drawio/mx/Edge.py
@@ -1,9 +1,10 @@
from graphviz2drawio.models import DotAttr
+from .Curve import Curve
from .GraphObj import GraphObj
class Edge(GraphObj):
- def __init__(self, sid, gid, fr, to, curve, label):
+ def __init__(self, sid: str, gid: str, fr: str, to: str, curve: Curve, label: str):
super(Edge, self).__init__(sid, gid)
self.fr = fr
self.to = to
@@ -18,3 +19,10 @@ def curve_start_end(self):
return self.curve.end, self.curve.start
else:
return self.curve.start, self.curve.end
+
+ @property
+ def key_for_label(self) -> str:
+ return f"{self.gid}-{self.curve}"
+
+ def __repr__(self):
+ return f"{self.fr}->{self.to}: {self.label} {self.style} {self.dir} {self.arrowtail}"
diff --git a/graphviz2drawio/mx/GraphObj.py b/graphviz2drawio/mx/GraphObj.py
index 00c5b0b..b5d81fb 100644
--- a/graphviz2drawio/mx/GraphObj.py
+++ b/graphviz2drawio/mx/GraphObj.py
@@ -4,5 +4,7 @@ def __init__(self, sid, gid):
self.gid = gid
def enrich_from_graph(self, attrs):
- for e in attrs:
- self.__setattr__(e[0], e[1])
+ for k, v in attrs:
+ if k in self.__dict__ and self.__dict__[k] is not None:
+ continue
+ self.__setattr__(k, v)
diff --git a/graphviz2drawio/mx/MxGraph.py b/graphviz2drawio/mx/MxGraph.py
index 8147984..3de3fbc 100644
--- a/graphviz2drawio/mx/MxGraph.py
+++ b/graphviz2drawio/mx/MxGraph.py
@@ -41,12 +41,14 @@ def add_edge(self, edge):
edge_label_element = ET.SubElement(
self.root,
MxConst.CELL,
- id=uuid.uuid4().hex,
- style="edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];",
- parent=edge.sid,
- value=edge.label,
- vertex="1",
- connectable="0",
+ attrib={
+ "id": uuid.uuid4().hex,
+ "style": "edgeLabel;html=1;align=center;verticalAlign=middle;resizable=0;points=[];",
+ "parent": edge.sid,
+ "value": edge.label,
+ "vertex": "1",
+ "connectable": "0",
+ },
)
self.add_mx_geo(edge_label_element)
diff --git a/test/test_graphs.py b/test/test_graphs.py
index a3d8fb7..2984ee8 100644
--- a/test/test_graphs.py
+++ b/test/test_graphs.py
@@ -94,6 +94,27 @@ def test_cluster():
assert contains_cluster
+def test_convnet():
+ file = "test/directed/convnet.gv.txt"
+ xml = graphviz2drawio.convert(file)
+ print(xml)
+
+ root = ET.fromstring(xml)
+ elements = check_xml_top(root)
+
+ assert elements[-1].attrib["value"] == "$$l_t$$"
+
+def test_multilabel():
+ file = "test/directed/multilabel.gv.txt"
+ xml = graphviz2drawio.convert(file)
+ print(xml)
+
+ root = ET.fromstring(xml)
+ elements = check_xml_top(root)
+
+ assert elements[-1].attrib["value"] == "cb
a
"
+
+
# def test_runAll():
# for f in os.listdir('undirected'):
# xml = graphviz2drawio.convert(f)