diff --git a/spinetoolbox/helpers.py b/spinetoolbox/helpers.py index 15d131b36..c7a01996a 100644 --- a/spinetoolbox/helpers.py +++ b/spinetoolbox/helpers.py @@ -1615,3 +1615,12 @@ def solve_connection_file(connection_file, connection_file_dict): fp.close() return connection_file return connection_file + + +def remove_first(lst, items): + for x in items: + try: + lst.remove(x) + break + except ValueError: + pass diff --git a/spinetoolbox/spine_db_editor/graphics_items.py b/spinetoolbox/spine_db_editor/graphics_items.py index b8cab3f6f..ff6aa6330 100644 --- a/spinetoolbox/spine_db_editor/graphics_items.py +++ b/spinetoolbox/spine_db_editor/graphics_items.py @@ -157,6 +157,12 @@ def db_maps(self): def entity_class_id(self, db_map): return self.db_mngr.get_item(db_map, "entity", self.entity_id(db_map)).get("class_id") + def entity_class_ids(self, db_map): + return {self.entity_class_id(db_map)} | { + x["superclass_id"] + for x in db_map.get_items("superclass_subclass", subclass_id=self.entity_class_id(db_map)) + } + def entity_id(self, db_map): return dict(self.db_map_ids).get(db_map) @@ -592,7 +598,7 @@ def _refresh_db_map_entity_class_lists(self): for db_map, ents in self.db_mngr.find_cascading_entities(db_map_entity_ids).items(): for ent in ents: entity_ids_per_class.setdefault((db_map, ent["class_id"]), set()).add(ent["id"]) - db_map_entity_class_ids = {db_map: {self.entity_class_id(db_map)} for db_map in self.db_maps} + db_map_entity_class_ids = {db_map: self.entity_class_ids(db_map) for db_map in self.db_maps} for db_map, ent_clss in self.db_mngr.find_cascading_entity_classes(db_map_entity_class_ids).items(): for ent_cls in ent_clss: ent_cls = ent_cls._extended() @@ -630,7 +636,7 @@ def _populate_connect_entities_menu(self, menu): if not isinstance(item, EntityItem): continue for db_map in item.db_maps: - entity_class_ids_in_graph.setdefault(db_map, set()).add(item.entity_class_id(db_map)) + entity_class_ids_in_graph.setdefault(db_map, set()).update(item.entity_class_ids(db_map)) action_name_icon_enabled = [] for name, db_map_ent_clss in self._db_map_entity_class_lists.items(): for db_map, ent_cls in db_map_ent_clss: diff --git a/spinetoolbox/spine_db_editor/widgets/add_items_dialogs.py b/spinetoolbox/spine_db_editor/widgets/add_items_dialogs.py index 592a14747..c7fc8213f 100644 --- a/spinetoolbox/spine_db_editor/widgets/add_items_dialogs.py +++ b/spinetoolbox/spine_db_editor/widgets/add_items_dialogs.py @@ -95,8 +95,8 @@ def populate_table_view(self): item.setFlags(Qt.ItemIsEnabled) item.setCheckState(Qt.CheckState.Checked) self.table_view.setItem(row, 0, item) - for column, element_name in enumerate(entity): - item = QTableWidgetItem(element_name) + for column, element_byname in enumerate(entity): + item = QTableWidgetItem(DB_ITEM_SEPARATOR.join(element_byname)) item.setFlags(Qt.ItemIsEnabled) self.table_view.setItem(row, column + 1, item) self.table_view.resizeColumnsToContents() @@ -134,8 +134,9 @@ def get_db_map_data(self): for row in range(self.table_view.rowCount()): if self.table_view.item(row, 0).checkState() != Qt.CheckState.Checked: continue - element_name_list = tuple(self.entities[row]) - data.append({"class_name": self.entity_class["name"], "element_name_list": element_name_list}) + element_byname_list = tuple(self.entities[row]) + byname = tuple(x for byname in element_byname_list for x in byname) + data.append({"class_name": self.entity_class["name"], "byname": byname}) return {db_map: data for db_map in self.db_maps} diff --git a/spinetoolbox/spine_db_editor/widgets/custom_qgraphicsviews.py b/spinetoolbox/spine_db_editor/widgets/custom_qgraphicsviews.py index c89d2a88e..7c254bb9a 100644 --- a/spinetoolbox/spine_db_editor/widgets/custom_qgraphicsviews.py +++ b/spinetoolbox/spine_db_editor/widgets/custom_qgraphicsviews.py @@ -19,11 +19,11 @@ from contextlib import contextmanager import numpy as np from PySide6.QtCore import Qt, QTimeLine, Signal, Slot, QRectF, QRunnable, QThreadPool -from PySide6.QtWidgets import QMenu, QGraphicsView, QInputDialog, QColorDialog, QMessageBox, QLineEdit, QGraphicsScene +from PySide6.QtWidgets import QMenu, QInputDialog, QColorDialog, QMessageBox, QLineEdit, QGraphicsScene from PySide6.QtGui import QCursor, QPainter, QIcon, QAction, QPageSize, QPixmap from PySide6.QtPrintSupport import QPrinter from PySide6.QtSvg import QSvgGenerator -from ...helpers import CharIconEngine +from ...helpers import CharIconEngine, remove_first from ...widgets.custom_qgraphicsviews import CustomQGraphicsView from ...widgets.custom_qwidgets import ToolBarWidgetAction, HorizontalSpinBox from ..graphics_items import EntityItem, CrossHairsArcItem, BgItem, ArcItem @@ -805,7 +805,9 @@ def clear_cross_hairs_items(self): def _cross_hairs_has_valid_target(self): db_map = self.entity_class["db_map"] - return self._hovered_ent_item.entity_class_id(db_map) in self.entity_class["dimension_ids_to_go"] + return any( + id_ in self.entity_class["dimension_ids_to_go"] for id_ in self._hovered_ent_item.entity_class_ids(db_map) + ) def mousePressEvent(self, event): """Handles relationship creation if one it's in process.""" @@ -817,7 +819,7 @@ def mousePressEvent(self, event): return if self._cross_hairs_has_valid_target(): db_map = self.entity_class["db_map"] - self.entity_class["dimension_ids_to_go"].remove(self._hovered_ent_item.entity_class_id(db_map)) + remove_first(self.entity_class["dimension_ids_to_go"], self._hovered_ent_item.entity_class_ids(db_map)) if self.entity_class["dimension_ids_to_go"]: # Add hovered as member and keep going, we're not done yet ch_ent_item = self.cross_hairs_items[1] diff --git a/spinetoolbox/spine_db_editor/widgets/graph_view_mixin.py b/spinetoolbox/spine_db_editor/widgets/graph_view_mixin.py index 42691ac7b..502843a13 100644 --- a/spinetoolbox/spine_db_editor/widgets/graph_view_mixin.py +++ b/spinetoolbox/spine_db_editor/widgets/graph_view_mixin.py @@ -20,7 +20,7 @@ from spinedb_api import from_database from spinedb_api.parameter_value import IndexedValue, TimeSeries from ...widgets.custom_qgraphicsscene import CustomGraphicsScene -from ...helpers import get_save_file_name_in_last_dir, get_open_file_name_in_last_dir, busy_effect +from ...helpers import get_save_file_name_in_last_dir, get_open_file_name_in_last_dir, busy_effect, remove_first from ...fetch_parent import FlexibleFetchParent from ..graphics_items import EntityItem, ArcItem, CrossHairsItem, CrossHairsEntityItem, CrossHairsArcItem from .graph_layout_generator import GraphLayoutGenerator, GraphLayoutGeneratorRunnable @@ -760,7 +760,7 @@ def start_connecting_entities(self, db_map, entity_class, ent_item): ent_item (..graphics_items.EntityItem) """ dimension_ids_to_go = entity_class["dimension_id_list"].copy() - dimension_ids_to_go.remove(ent_item.entity_class_id(db_map)) + remove_first(dimension_ids_to_go, ent_item.entity_class_ids(db_map)) entity_class["dimension_ids_to_go"] = dimension_ids_to_go entity_class["db_map"] = db_map db_map_ids = ((db_map, None),) @@ -784,12 +784,13 @@ def finalize_connecting_entities(self, entity_class, *entity_items): """ db_map = entity_class["db_map"] entities = set() - dimension_id_list = entity_class["dimension_id_list"] for item_permutation in itertools.permutations(entity_items): - if [item.entity_class_id(db_map) for item in item_permutation] == dimension_id_list: + dimension_id_lists = list(itertools.product(*[item.entity_class_ids(db_map) for item in item_permutation])) + if tuple(entity_class["dimension_id_list"]) in dimension_id_lists: element_name_list = tuple(item.entity_name for item in item_permutation) - if not db_map.get_item("entity", class_name=entity_class["name"], byname=element_name_list): - entities.add(element_name_list) + if not db_map.get_item("entity", class_name=entity_class["name"], element_name_list=element_name_list): + element_byname_list = tuple(item.byname for item in item_permutation) + entities.add(element_byname_list) if not entities: return dialog = AddReadyEntitiesDialog(self, entity_class, list(entities), self.db_mngr, db_map, commit_data=False)