diff --git a/graphviz2drawio/mx/Node.py b/graphviz2drawio/mx/Node.py index 9a9b6e6..f3bc5a7 100644 --- a/graphviz2drawio/mx/Node.py +++ b/graphviz2drawio/mx/Node.py @@ -2,14 +2,14 @@ class Node(GraphObj): - def __init__(self, sid, gid, rect, texts, fill, stroke): + def __init__(self, sid, gid, rect, texts, fill, stroke, shape): super(Node, self).__init__(sid, gid) self.rect = rect self.texts = texts self.fill = fill self.stroke = stroke self.label = None - self.shape = None + self.shape = shape def text_to_mx_value(self): value = "" diff --git a/graphviz2drawio/mx/NodeFactory.py b/graphviz2drawio/mx/NodeFactory.py index e776d7d..39f7129 100644 --- a/graphviz2drawio/mx/NodeFactory.py +++ b/graphviz2drawio/mx/NodeFactory.py @@ -2,6 +2,7 @@ from graphviz2drawio.models.Rect import Rect from .Node import Node from .Text import Text +from . import Shape class NodeFactory: @@ -56,8 +57,10 @@ def from_svg(self, g) -> Node: rect = self.rect_from_svg_points( SVG.get_first(g, "polygon").attrib["points"] ) + shape = Shape.RECT else: rect = self.rect_from_ellipse_svg(SVG.get_first(g, "ellipse").attrib) + shape = Shape.ELLIPSE stroke = None if SVG.has(g, "polygon"): @@ -78,4 +81,5 @@ def from_svg(self, g) -> Node: texts=texts, fill=fill, stroke=stroke, + shape=shape ) diff --git a/test/directed/hello_rect.gv.txt b/test/directed/hello_rect.gv.txt new file mode 100644 index 0000000..30497e9 --- /dev/null +++ b/test/directed/hello_rect.gv.txt @@ -0,0 +1,4 @@ +digraph G { + node [shape=rect] + Hello->World +} diff --git a/test/test_graphs.py b/test/test_graphs.py index 2984ee8..ee09984 100644 --- a/test/test_graphs.py +++ b/test/test_graphs.py @@ -61,6 +61,22 @@ def test_hello(): check_edge(edge, hello, world) check_edge_dir(edge, dx=0, dy=1) +def test_hello_rect(): + file = "./directed/hello_rect.gv.txt" + xml = graphviz2drawio.convert(file) + print(xml) + + root = ET.fromstring(xml) + elements = check_xml_top(root) + + hello = elements[3] + check_value(hello, "Hello") + assert "ellipse" not in hello.attrib["style"] + + world = elements[4] + check_value(world, "World") + assert "ellipse" not in world.attrib["style"] + def test_port(): file = "test/directed/port.gv.txt"