Skip to content

Commit

Permalink
fix for non-labeled edges
Browse files Browse the repository at this point in the history
  • Loading branch information
hbmartin committed Jul 2, 2024
1 parent 052bba2 commit 39076d6
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 16 deletions.
21 changes: 15 additions & 6 deletions graphviz2drawio/models/SVG.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
def get_first(g, tag):
from xml.etree.ElementTree import Element


def get_first(g: Element, tag: str) -> Element:
return g.findall("./{http://www.w3.org/2000/svg}" + tag)[0]


def get_title(g):
def get_title(g: Element) -> str:
return get_first(g, "title").text

def get_text(g):
return get_first(g, "text").text

def is_tag(g, tag):
def get_text(g: Element) -> str | None:
try:
text_el = get_first(g, "text")
return text_el.text
except IndexError:
return None


def is_tag(g: Element, tag: str) -> bool:
return g.tag == "{http://www.w3.org/2000/svg}" + tag


def has(g, tag):
def has(g: Element, tag: str) -> bool:
return len(g.findall("./{http://www.w3.org/2000/svg}" + tag)) > 0
14 changes: 4 additions & 10 deletions graphviz2drawio/mx/EdgeFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,10 @@ def __init__(self, coords):
self.curve_factory = CurveFactory(coords)

def from_svg(self, g) -> Edge:
gid = SVG.get_title(g).replace("--", "->")
fr, to = gid.split("->")
gid_template = "{}->{}"
sp_fr = fr.split(":")
sp_to = to.split(":")
if len(sp_fr) == 2:
fr = sp_fr[0]
if len(sp_to) == 2:
to = sp_to[0]
gid = gid_template.format(fr, to)
fr, to = SVG.get_title(g).replace("--", "->").split("->")
fr = fr.split(":")[0]
to = to.split(":")[0]
gid = f"{fr}->{to}"
curve = None
label = SVG.get_text(g)
if SVG.has(g, "path"):
Expand Down
34 changes: 34 additions & 0 deletions test/directed/convnet.gv.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
digraph {
"$$x_t$$" [color="#187DF9" fill="#187DF9" fillcolor="#187DF9" fontcolor=white shape=box style=filled]
Conv2d [color=plum fill=plum fillcolor=plum fontcolor=white shape=ellipse style=filled]
"$$x_t$$" -> Conv2d [label="$$x_t$$" value=x_t]
Conv2d2 [color=orange fill=orange fillcolor=orange fontcolor=white shape=ellipse style=filled]
"$$x_t$$" -> Conv2d2 [label="$$x_t$$" value=x_t]
"|" [color=black fill=black fillcolor=black fontcolor=black shape=point style=filled]
"$$x_t$$" -> "|" [label="$$x_t$$" value=x_t]
Conv2d [color=thistle fill=thistle fillcolor=thistle fontcolor=white shape=box style=filled]
MaxPooling [color=steelblue1 fill=steelblue1 fillcolor=steelblue1 fontcolor=white shape=ellipse style=filled]
Conv2d -> MaxPooling [label="$$g_t$$" value=g_t]
Conv2d2 [color=steelblue1 fill=steelblue1 fillcolor=steelblue1 fontcolor=white shape=box style=filled]
"|" [color=black fill=black fillcolor=black fontcolor=black shape=point style=filled]
Conv2d2 -> "|" [label="$$c_t$$" value=c_t]
MaxPooling [color=steelblue1 fill=steelblue1 fillcolor=steelblue1 fontcolor=white shape=box style=filled]
Dropout [color=plum fill=plum fillcolor=plum fontcolor=white shape=ellipse style=filled]
MaxPooling -> Dropout [label="$$s_t$$" value=s_t]
Dropout [color=plum fill=plum fillcolor=plum fontcolor=white shape=box style=filled]
"." [color=skyblue fill=skyblue fillcolor=skyblue fontcolor=skyblue height=0.01 shape=doublecircle style=filled width=0.01]
Dropout -> "." [label="$$s_{t+1}$$" value="s_{t+1}"]
"|" [color=black fill=black fillcolor=black fontcolor=black shape=point style=filled]
"$$W_x$$" [color=plum fill=plum fillcolor=plum fontcolor=white shape=ellipse style=filled]
"|" -> "$$W_x$$" [label="$$[x_t, c_t]$$" value="[x_t, c_t]"]
"$$W_x$$" [color=plum fill=plum fillcolor=plum fontcolor=white shape=box style=filled]
"." [color=skyblue fill=skyblue fillcolor=skyblue fontcolor=skyblue height=0.01 shape=doublecircle style=filled width=0.01]
"$$W_x$$" -> "." [label="$$x_{t-1}$$" value="x_{t-1}"]
"." [color=skyblue fill=skyblue fillcolor=skyblue fontcolor=skyblue height=0.01 shape=doublecircle style=filled width=0.01]
"$$W_q$$" [color=powderblue fill=powderblue fillcolor=powderblue fontcolor=white shape=ellipse style=filled]
"." -> "$$W_q$$" [label="$$q_t$$" value=q_t]
"$$W_q$$" [color=powderblue fill=powderblue fillcolor=powderblue fontcolor=white shape=box style=filled]
"$$l_t$$" [color=thistle fill=thistle fillcolor=thistle fontcolor=white shape=ellipse style=filled]
"$$W_q$$" -> "$$l_t$$" [label="$$l_t$$" value=l_t]
"$$x_t$$" [color=thistle fill=thistle fillcolor=thistle fontcolor=white shape=ellipse style=filled]
}
8 changes: 8 additions & 0 deletions test/directed/multilabel.gv.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
digraph {
node [shape=rectangle];
rankdir=LR;
splines=false;
"X" -> "Y"[ label="c" ];
"X" -> "Y"[ label="b" ];
"X" -> "Y"[ label="a" ];
}

0 comments on commit 39076d6

Please sign in to comment.