Skip to content

Commit

Permalink
Add entities with elements from superclass in entity graph
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelma committed Nov 23, 2023
1 parent 93d4789 commit 2de83a6
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 16 deletions.
9 changes: 9 additions & 0 deletions spinetoolbox/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1615,3 +1615,12 @@ def solve_connection_file(connection_file, connection_file_dict):
fp.close()
return connection_file
return connection_file


def remove_first(lst, items):
for x in items:
try:
lst.remove(x)
break
except ValueError:
pass
10 changes: 8 additions & 2 deletions spinetoolbox/spine_db_editor/graphics_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,12 @@ def db_maps(self):
def entity_class_id(self, db_map):
return self.db_mngr.get_item(db_map, "entity", self.entity_id(db_map)).get("class_id")

def entity_class_ids(self, db_map):
return {self.entity_class_id(db_map)} | {
x["superclass_id"]
for x in db_map.get_items("superclass_subclass", subclass_id=self.entity_class_id(db_map))
}

def entity_id(self, db_map):
return dict(self.db_map_ids).get(db_map)

Expand Down Expand Up @@ -592,7 +598,7 @@ def _refresh_db_map_entity_class_lists(self):
for db_map, ents in self.db_mngr.find_cascading_entities(db_map_entity_ids).items():
for ent in ents:
entity_ids_per_class.setdefault((db_map, ent["class_id"]), set()).add(ent["id"])
db_map_entity_class_ids = {db_map: {self.entity_class_id(db_map)} for db_map in self.db_maps}
db_map_entity_class_ids = {db_map: self.entity_class_ids(db_map) for db_map in self.db_maps}
for db_map, ent_clss in self.db_mngr.find_cascading_entity_classes(db_map_entity_class_ids).items():
for ent_cls in ent_clss:
ent_cls = ent_cls._extended()
Expand Down Expand Up @@ -630,7 +636,7 @@ def _populate_connect_entities_menu(self, menu):
if not isinstance(item, EntityItem):
continue
for db_map in item.db_maps:
entity_class_ids_in_graph.setdefault(db_map, set()).add(item.entity_class_id(db_map))
entity_class_ids_in_graph.setdefault(db_map, set()).update(item.entity_class_ids(db_map))
action_name_icon_enabled = []
for name, db_map_ent_clss in self._db_map_entity_class_lists.items():
for db_map, ent_cls in db_map_ent_clss:
Expand Down
9 changes: 5 additions & 4 deletions spinetoolbox/spine_db_editor/widgets/add_items_dialogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def populate_table_view(self):
item.setFlags(Qt.ItemIsEnabled)
item.setCheckState(Qt.CheckState.Checked)
self.table_view.setItem(row, 0, item)
for column, element_name in enumerate(entity):
item = QTableWidgetItem(element_name)
for column, element_byname in enumerate(entity):
item = QTableWidgetItem(DB_ITEM_SEPARATOR.join(element_byname))
item.setFlags(Qt.ItemIsEnabled)
self.table_view.setItem(row, column + 1, item)
self.table_view.resizeColumnsToContents()
Expand Down Expand Up @@ -134,8 +134,9 @@ def get_db_map_data(self):
for row in range(self.table_view.rowCount()):
if self.table_view.item(row, 0).checkState() != Qt.CheckState.Checked:
continue
element_name_list = tuple(self.entities[row])
data.append({"class_name": self.entity_class["name"], "element_name_list": element_name_list})
element_byname_list = tuple(self.entities[row])
byname = tuple(x for byname in element_byname_list for x in byname)
data.append({"class_name": self.entity_class["name"], "byname": byname})
return {db_map: data for db_map in self.db_maps}


Expand Down
10 changes: 6 additions & 4 deletions spinetoolbox/spine_db_editor/widgets/custom_qgraphicsviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from contextlib import contextmanager
import numpy as np
from PySide6.QtCore import Qt, QTimeLine, Signal, Slot, QRectF, QRunnable, QThreadPool
from PySide6.QtWidgets import QMenu, QGraphicsView, QInputDialog, QColorDialog, QMessageBox, QLineEdit, QGraphicsScene
from PySide6.QtWidgets import QMenu, QInputDialog, QColorDialog, QMessageBox, QLineEdit, QGraphicsScene
from PySide6.QtGui import QCursor, QPainter, QIcon, QAction, QPageSize, QPixmap
from PySide6.QtPrintSupport import QPrinter
from PySide6.QtSvg import QSvgGenerator
from ...helpers import CharIconEngine
from ...helpers import CharIconEngine, remove_first
from ...widgets.custom_qgraphicsviews import CustomQGraphicsView
from ...widgets.custom_qwidgets import ToolBarWidgetAction, HorizontalSpinBox
from ..graphics_items import EntityItem, CrossHairsArcItem, BgItem, ArcItem
Expand Down Expand Up @@ -805,7 +805,9 @@ def clear_cross_hairs_items(self):

def _cross_hairs_has_valid_target(self):
db_map = self.entity_class["db_map"]
return self._hovered_ent_item.entity_class_id(db_map) in self.entity_class["dimension_ids_to_go"]
return any(
id_ in self.entity_class["dimension_ids_to_go"] for id_ in self._hovered_ent_item.entity_class_ids(db_map)
)

def mousePressEvent(self, event):
"""Handles relationship creation if one it's in process."""
Expand All @@ -817,7 +819,7 @@ def mousePressEvent(self, event):
return
if self._cross_hairs_has_valid_target():
db_map = self.entity_class["db_map"]
self.entity_class["dimension_ids_to_go"].remove(self._hovered_ent_item.entity_class_id(db_map))
remove_first(self.entity_class["dimension_ids_to_go"], self._hovered_ent_item.entity_class_ids(db_map))
if self.entity_class["dimension_ids_to_go"]:
# Add hovered as member and keep going, we're not done yet
ch_ent_item = self.cross_hairs_items[1]
Expand Down
13 changes: 7 additions & 6 deletions spinetoolbox/spine_db_editor/widgets/graph_view_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from spinedb_api import from_database
from spinedb_api.parameter_value import IndexedValue, TimeSeries
from ...widgets.custom_qgraphicsscene import CustomGraphicsScene
from ...helpers import get_save_file_name_in_last_dir, get_open_file_name_in_last_dir, busy_effect
from ...helpers import get_save_file_name_in_last_dir, get_open_file_name_in_last_dir, busy_effect, remove_first
from ...fetch_parent import FlexibleFetchParent
from ..graphics_items import EntityItem, ArcItem, CrossHairsItem, CrossHairsEntityItem, CrossHairsArcItem
from .graph_layout_generator import GraphLayoutGenerator, GraphLayoutGeneratorRunnable
Expand Down Expand Up @@ -760,7 +760,7 @@ def start_connecting_entities(self, db_map, entity_class, ent_item):
ent_item (..graphics_items.EntityItem)
"""
dimension_ids_to_go = entity_class["dimension_id_list"].copy()
dimension_ids_to_go.remove(ent_item.entity_class_id(db_map))
remove_first(dimension_ids_to_go, ent_item.entity_class_ids(db_map))
entity_class["dimension_ids_to_go"] = dimension_ids_to_go
entity_class["db_map"] = db_map
db_map_ids = ((db_map, None),)
Expand All @@ -784,12 +784,13 @@ def finalize_connecting_entities(self, entity_class, *entity_items):
"""
db_map = entity_class["db_map"]
entities = set()
dimension_id_list = entity_class["dimension_id_list"]
for item_permutation in itertools.permutations(entity_items):
if [item.entity_class_id(db_map) for item in item_permutation] == dimension_id_list:
dimension_id_lists = list(itertools.product(*[item.entity_class_ids(db_map) for item in item_permutation]))
if tuple(entity_class["dimension_id_list"]) in dimension_id_lists:
element_name_list = tuple(item.entity_name for item in item_permutation)
if not db_map.get_item("entity", class_name=entity_class["name"], byname=element_name_list):
entities.add(element_name_list)
if not db_map.get_item("entity", class_name=entity_class["name"], element_name_list=element_name_list):
element_byname_list = tuple(item.byname for item in item_permutation)
entities.add(element_byname_list)
if not entities:
return
dialog = AddReadyEntitiesDialog(self, entity_class, list(entities), self.db_mngr, db_map, commit_data=False)
Expand Down

0 comments on commit 2de83a6

Please sign in to comment.