From d791fa8f763e9e55d91365a376594f657236f05b Mon Sep 17 00:00:00 2001 From: Harold Martin Date: Mon, 1 Jul 2024 18:11:32 -0700 Subject: [PATCH] type fixes, replace SvgParser class with function call --- graphviz2drawio/graphviz2drawio.py | 7 +++-- graphviz2drawio/models/SvgParser.py | 43 ++++++++++++++--------------- graphviz2drawio/models/__init__.py | 2 +- graphviz2drawio/mx/EdgeFactory.py | 2 +- graphviz2drawio/mx/NodeFactory.py | 2 +- test/test_graphs.py | 5 ++-- 6 files changed, 31 insertions(+), 30 deletions(-) diff --git a/graphviz2drawio/graphviz2drawio.py b/graphviz2drawio/graphviz2drawio.py index a65ccc6..2af730e 100755 --- a/graphviz2drawio/graphviz2drawio.py +++ b/graphviz2drawio/graphviz2drawio.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 from pygraphviz import AGraph -from .models import SvgParser + +from .models import parse_nodes_edges_clusters from .mx.MxGraph import MxGraph @@ -23,14 +24,14 @@ def convert(graph_to_convert, layout_prog="dot"): graph_nodes = {n: list37(n.attr.iteritems()) for n in list37(graph.nodes_iter())} svg_graph = graph.draw(prog=layout_prog, format="svg") - nodes, edges, clusters = SvgParser(svg_graph).get_elements() + nodes, edges, clusters = parse_nodes_edges_clusters(svg_graph) [e.enrich_from_graph(graph_edges[e.gid]) for e in edges] [n.enrich_from_graph(graph_nodes[n.gid]) for n in nodes.values()] # Put clusters first, so that nodes are drawn in front nodes_and_clusters = clusters nodes_and_clusters.update(nodes) - + mx_graph = MxGraph(nodes_and_clusters, edges) return mx_graph.value() diff --git a/graphviz2drawio/models/SvgParser.py b/graphviz2drawio/models/SvgParser.py index 8cca3e0..e906525 100644 --- a/graphviz2drawio/models/SvgParser.py +++ b/graphviz2drawio/models/SvgParser.py @@ -4,31 +4,30 @@ from graphviz2drawio.mx.NodeFactory import NodeFactory from . import SVG from .CoordsTranslate import CoordsTranslate +from ..mx import Node, Edge -class SvgParser: - def __init__(self, svg_data): - self.svg_data = svg_data +def parse_nodes_edges_clusters( + svg_data: bytes, +) -> tuple[dict[str, Node], list[Edge], dict[str, Node]]: + root = ElementTree.fromstring(svg_data)[0] - def get_elements(self): - root = ElementTree.fromstring(self.svg_data)[0] + coords = CoordsTranslate.from_svg_transform(root.attrib["transform"]) + node_factory = NodeFactory(coords) + edge_factory = EdgeFactory(coords) - coords = CoordsTranslate.from_svg_transform(root.attrib["transform"]) - node_factory = NodeFactory(coords) - edge_factory = EdgeFactory(coords) + nodes = OrderedDict() + edges = [] + clusters = OrderedDict() - nodes = OrderedDict() - edges = [] - clusters = OrderedDict() + for g in root: + if SVG.is_tag(g, "g"): + title = SVG.get_title(g) + if g.attrib["class"] == "node": + nodes[title] = node_factory.from_svg(g) + elif g.attrib["class"] == "edge": + edges.append(edge_factory.from_svg(g)) + elif g.attrib["class"] == "cluster": + clusters[title] = node_factory.from_svg(g) - for g in root: - if SVG.is_tag(g, "g"): - title = SVG.get_title(g) - if g.attrib["class"] == "node": - nodes[title] = node_factory.from_svg(g) - elif g.attrib["class"] == "edge": - edges.append(edge_factory.from_svg(g)) - elif g.attrib["class"] == "cluster": - clusters[title] = node_factory.from_svg(g) - - return nodes, edges, clusters + return nodes, edges, clusters diff --git a/graphviz2drawio/models/__init__.py b/graphviz2drawio/models/__init__.py index 050578d..d3fe148 100644 --- a/graphviz2drawio/models/__init__.py +++ b/graphviz2drawio/models/__init__.py @@ -1,4 +1,4 @@ # flake8: noqa: F401 from .Arguments import Arguments from .Rect import Rect -from .SvgParser import SvgParser +from .SvgParser import parse_nodes_edges_clusters diff --git a/graphviz2drawio/mx/EdgeFactory.py b/graphviz2drawio/mx/EdgeFactory.py index 9f9132d..dd019ab 100644 --- a/graphviz2drawio/mx/EdgeFactory.py +++ b/graphviz2drawio/mx/EdgeFactory.py @@ -8,7 +8,7 @@ def __init__(self, coords): super(EdgeFactory, self).__init__() self.curve_factory = CurveFactory(coords) - def from_svg(self, g): + def from_svg(self, g) -> Edge: gid = SVG.get_title(g).replace("--", "->") fr, to = gid.split("->") curve = None diff --git a/graphviz2drawio/mx/NodeFactory.py b/graphviz2drawio/mx/NodeFactory.py index b07edd1..b8115ca 100644 --- a/graphviz2drawio/mx/NodeFactory.py +++ b/graphviz2drawio/mx/NodeFactory.py @@ -37,7 +37,7 @@ def rect_from_ellipse_svg(self, attrib): x, y = self.coords.translate(cx, cy) return Rect(x=x - rx, y=y - ry, width=rx * 2, height=ry * 2) - def from_svg(self, g): + def from_svg(self, g) -> Node: texts = [] current_text = None for t in g: diff --git a/test/test_graphs.py b/test/test_graphs.py index 15fd135..60fedd5 100644 --- a/test/test_graphs.py +++ b/test/test_graphs.py @@ -70,12 +70,13 @@ def test_polylines(): root = ET.fromstring(xml) check_xml_top(root) + def test_cluster(): - file = "./directed/cluster.gv.txt" + file = "test/directed/cluster.gv.txt" xml = graphviz2drawio.convert(file) print(xml) - root = ET.fromstring(xml) + root = ET.fromstring(xml) elements = check_xml_top(root) contains_cluster = False