Skip to content

Commit

Permalink
type fixes, replace SvgParser class with function call
Browse files Browse the repository at this point in the history
  • Loading branch information
hbmartin committed Jul 2, 2024
1 parent e9db9a1 commit d791fa8
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 30 deletions.
7 changes: 4 additions & 3 deletions graphviz2drawio/graphviz2drawio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#!/usr/bin/env python3

from pygraphviz import AGraph
from .models import SvgParser

from .models import parse_nodes_edges_clusters
from .mx.MxGraph import MxGraph


Expand All @@ -23,14 +24,14 @@ def convert(graph_to_convert, layout_prog="dot"):
graph_nodes = {n: list37(n.attr.iteritems()) for n in list37(graph.nodes_iter())}

svg_graph = graph.draw(prog=layout_prog, format="svg")
nodes, edges, clusters = SvgParser(svg_graph).get_elements()
nodes, edges, clusters = parse_nodes_edges_clusters(svg_graph)
[e.enrich_from_graph(graph_edges[e.gid]) for e in edges]
[n.enrich_from_graph(graph_nodes[n.gid]) for n in nodes.values()]

# Put clusters first, so that nodes are drawn in front
nodes_and_clusters = clusters
nodes_and_clusters.update(nodes)

mx_graph = MxGraph(nodes_and_clusters, edges)
return mx_graph.value()

Expand Down
43 changes: 21 additions & 22 deletions graphviz2drawio/models/SvgParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,30 @@
from graphviz2drawio.mx.NodeFactory import NodeFactory
from . import SVG
from .CoordsTranslate import CoordsTranslate
from ..mx import Node, Edge


class SvgParser:
def __init__(self, svg_data):
self.svg_data = svg_data
def parse_nodes_edges_clusters(
svg_data: bytes,
) -> tuple[dict[str, Node], list[Edge], dict[str, Node]]:
root = ElementTree.fromstring(svg_data)[0]

def get_elements(self):
root = ElementTree.fromstring(self.svg_data)[0]
coords = CoordsTranslate.from_svg_transform(root.attrib["transform"])
node_factory = NodeFactory(coords)
edge_factory = EdgeFactory(coords)

coords = CoordsTranslate.from_svg_transform(root.attrib["transform"])
node_factory = NodeFactory(coords)
edge_factory = EdgeFactory(coords)
nodes = OrderedDict()
edges = []
clusters = OrderedDict()

nodes = OrderedDict()
edges = []
clusters = OrderedDict()
for g in root:
if SVG.is_tag(g, "g"):
title = SVG.get_title(g)
if g.attrib["class"] == "node":
nodes[title] = node_factory.from_svg(g)
elif g.attrib["class"] == "edge":
edges.append(edge_factory.from_svg(g))
elif g.attrib["class"] == "cluster":
clusters[title] = node_factory.from_svg(g)

for g in root:
if SVG.is_tag(g, "g"):
title = SVG.get_title(g)
if g.attrib["class"] == "node":
nodes[title] = node_factory.from_svg(g)
elif g.attrib["class"] == "edge":
edges.append(edge_factory.from_svg(g))
elif g.attrib["class"] == "cluster":
clusters[title] = node_factory.from_svg(g)

return nodes, edges, clusters
return nodes, edges, clusters
2 changes: 1 addition & 1 deletion graphviz2drawio/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# flake8: noqa: F401
from .Arguments import Arguments
from .Rect import Rect
from .SvgParser import SvgParser
from .SvgParser import parse_nodes_edges_clusters
2 changes: 1 addition & 1 deletion graphviz2drawio/mx/EdgeFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self, coords):
super(EdgeFactory, self).__init__()
self.curve_factory = CurveFactory(coords)

def from_svg(self, g):
def from_svg(self, g) -> Edge:
gid = SVG.get_title(g).replace("--", "->")
fr, to = gid.split("->")
curve = None
Expand Down
2 changes: 1 addition & 1 deletion graphviz2drawio/mx/NodeFactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def rect_from_ellipse_svg(self, attrib):
x, y = self.coords.translate(cx, cy)
return Rect(x=x - rx, y=y - ry, width=rx * 2, height=ry * 2)

def from_svg(self, g):
def from_svg(self, g) -> Node:
texts = []
current_text = None
for t in g:
Expand Down
5 changes: 3 additions & 2 deletions test/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,13 @@ def test_polylines():
root = ET.fromstring(xml)
check_xml_top(root)


def test_cluster():
file = "./directed/cluster.gv.txt"
file = "test/directed/cluster.gv.txt"
xml = graphviz2drawio.convert(file)
print(xml)

root = ET.fromstring(xml)
root = ET.fromstring(xml)

elements = check_xml_top(root)
contains_cluster = False
Expand Down

0 comments on commit d791fa8

Please sign in to comment.