From 39076d6a468f122b5f91818810415a4dc042bed8 Mon Sep 17 00:00:00 2001 From: Harold Martin Date: Tue, 2 Jul 2024 14:02:53 -0700 Subject: [PATCH] fix for non-labeled edges --- graphviz2drawio/models/SVG.py | 21 +++++++++++++------ graphviz2drawio/mx/EdgeFactory.py | 14 ++++--------- test/directed/convnet.gv.txt | 34 +++++++++++++++++++++++++++++++ test/directed/multilabel.gv.txt | 8 ++++++++ 4 files changed, 61 insertions(+), 16 deletions(-) create mode 100644 test/directed/convnet.gv.txt create mode 100644 test/directed/multilabel.gv.txt diff --git a/graphviz2drawio/models/SVG.py b/graphviz2drawio/models/SVG.py index 1ea128d..f409ab8 100644 --- a/graphviz2drawio/models/SVG.py +++ b/graphviz2drawio/models/SVG.py @@ -1,16 +1,25 @@ -def get_first(g, tag): +from xml.etree.ElementTree import Element + + +def get_first(g: Element, tag: str) -> Element: return g.findall("./{http://www.w3.org/2000/svg}" + tag)[0] -def get_title(g): +def get_title(g: Element) -> str: return get_first(g, "title").text -def get_text(g): - return get_first(g, "text").text -def is_tag(g, tag): +def get_text(g: Element) -> str | None: + try: + text_el = get_first(g, "text") + return text_el.text + except IndexError: + return None + + +def is_tag(g: Element, tag: str) -> bool: return g.tag == "{http://www.w3.org/2000/svg}" + tag -def has(g, tag): +def has(g: Element, tag: str) -> bool: return len(g.findall("./{http://www.w3.org/2000/svg}" + tag)) > 0 diff --git a/graphviz2drawio/mx/EdgeFactory.py b/graphviz2drawio/mx/EdgeFactory.py index 149ab3d..94412c7 100644 --- a/graphviz2drawio/mx/EdgeFactory.py +++ b/graphviz2drawio/mx/EdgeFactory.py @@ -9,16 +9,10 @@ def __init__(self, coords): self.curve_factory = CurveFactory(coords) def from_svg(self, g) -> Edge: - gid = SVG.get_title(g).replace("--", "->") - fr, to = gid.split("->") - gid_template = "{}->{}" - sp_fr = fr.split(":") - sp_to = to.split(":") - if len(sp_fr) == 2: - fr = sp_fr[0] - if len(sp_to) == 2: - to = sp_to[0] - gid = gid_template.format(fr, to) + fr, to = SVG.get_title(g).replace("--", "->").split("->") + fr = fr.split(":")[0] + to = to.split(":")[0] + gid = f"{fr}->{to}" curve = None label = SVG.get_text(g) if SVG.has(g, "path"): diff --git a/test/directed/convnet.gv.txt b/test/directed/convnet.gv.txt new file mode 100644 index 0000000..9ba5b30 --- /dev/null +++ b/test/directed/convnet.gv.txt @@ -0,0 +1,34 @@ +digraph { + "$$x_t$$" [color="#187DF9" fill="#187DF9" fillcolor="#187DF9" fontcolor=white shape=box style=filled] + Conv2d [color=plum fill=plum fillcolor=plum fontcolor=white shape=ellipse style=filled] + "$$x_t$$" -> Conv2d [label="$$x_t$$" value=x_t] + Conv2d2 [color=orange fill=orange fillcolor=orange fontcolor=white shape=ellipse style=filled] + "$$x_t$$" -> Conv2d2 [label="$$x_t$$" value=x_t] + "|" [color=black fill=black fillcolor=black fontcolor=black shape=point style=filled] + "$$x_t$$" -> "|" [label="$$x_t$$" value=x_t] + Conv2d [color=thistle fill=thistle fillcolor=thistle fontcolor=white shape=box style=filled] + MaxPooling [color=steelblue1 fill=steelblue1 fillcolor=steelblue1 fontcolor=white shape=ellipse style=filled] + Conv2d -> MaxPooling [label="$$g_t$$" value=g_t] + Conv2d2 [color=steelblue1 fill=steelblue1 fillcolor=steelblue1 fontcolor=white shape=box style=filled] + "|" [color=black fill=black fillcolor=black fontcolor=black shape=point style=filled] + Conv2d2 -> "|" [label="$$c_t$$" value=c_t] + MaxPooling [color=steelblue1 fill=steelblue1 fillcolor=steelblue1 fontcolor=white shape=box style=filled] + Dropout [color=plum fill=plum fillcolor=plum fontcolor=white shape=ellipse style=filled] + MaxPooling -> Dropout [label="$$s_t$$" value=s_t] + Dropout [color=plum fill=plum fillcolor=plum fontcolor=white shape=box style=filled] + "." [color=skyblue fill=skyblue fillcolor=skyblue fontcolor=skyblue height=0.01 shape=doublecircle style=filled width=0.01] + Dropout -> "." [label="$$s_{t+1}$$" value="s_{t+1}"] + "|" [color=black fill=black fillcolor=black fontcolor=black shape=point style=filled] + "$$W_x$$" [color=plum fill=plum fillcolor=plum fontcolor=white shape=ellipse style=filled] + "|" -> "$$W_x$$" [label="$$[x_t, c_t]$$" value="[x_t, c_t]"] + "$$W_x$$" [color=plum fill=plum fillcolor=plum fontcolor=white shape=box style=filled] + "." [color=skyblue fill=skyblue fillcolor=skyblue fontcolor=skyblue height=0.01 shape=doublecircle style=filled width=0.01] + "$$W_x$$" -> "." [label="$$x_{t-1}$$" value="x_{t-1}"] + "." [color=skyblue fill=skyblue fillcolor=skyblue fontcolor=skyblue height=0.01 shape=doublecircle style=filled width=0.01] + "$$W_q$$" [color=powderblue fill=powderblue fillcolor=powderblue fontcolor=white shape=ellipse style=filled] + "." -> "$$W_q$$" [label="$$q_t$$" value=q_t] + "$$W_q$$" [color=powderblue fill=powderblue fillcolor=powderblue fontcolor=white shape=box style=filled] + "$$l_t$$" [color=thistle fill=thistle fillcolor=thistle fontcolor=white shape=ellipse style=filled] + "$$W_q$$" -> "$$l_t$$" [label="$$l_t$$" value=l_t] + "$$x_t$$" [color=thistle fill=thistle fillcolor=thistle fontcolor=white shape=ellipse style=filled] +} \ No newline at end of file diff --git a/test/directed/multilabel.gv.txt b/test/directed/multilabel.gv.txt new file mode 100644 index 0000000..1cf0a13 --- /dev/null +++ b/test/directed/multilabel.gv.txt @@ -0,0 +1,8 @@ +digraph { + node [shape=rectangle]; + rankdir=LR; + splines=false; + "X" -> "Y"[ label="c" ]; + "X" -> "Y"[ label="b" ]; + "X" -> "Y"[ label="a" ]; +} \ No newline at end of file