Skip to content

Commit

Permalink
Fix entity duplication (#2425)
Browse files Browse the repository at this point in the history
  • Loading branch information
soininen authored Nov 20, 2023
2 parents 1dff6f2 + 7f2c364 commit f476c3c
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 26 deletions.
6 changes: 2 additions & 4 deletions spinetoolbox/spine_db_editor/widgets/spine_db_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,12 +597,10 @@ def duplicate_entity(self, entity_item):
entity_item (EntityTreeItem of EntityItem)
"""
orig_name = entity_item.name
class_name = entity_item.parent_item.name
existing_names = {ent.name for ent in entity_item.parent_item.children}
dup_name = unique_name(orig_name, existing_names)
parcel = SpineDBParcel(self.db_mngr)
db_map_ent_ids = {db_map: {entity_item.db_map_id(db_map)} for db_map in entity_item.db_maps}
parcel.inner_push_entity_ids(db_map_ent_ids)
self.db_mngr.duplicate_entity(parcel.data, orig_name, dup_name, entity_item.db_maps)
self.db_mngr.duplicate_entity(orig_name, dup_name, class_name, entity_item.db_maps)

def duplicate_scenario(self, db_map, scen_id):
"""
Expand Down
59 changes: 42 additions & 17 deletions spinetoolbox/spine_db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
TimeSeriesVariableResolution,
to_database,
)
from spinedb_api.helpers import name_from_elements
from spinedb_api.parameter_value import load_db_value
from spinedb_api.parameter_value import join_value_and_type, split_value_and_type
from spinedb_api.spine_io.exporters.excel import export_spine_database_to_xlsx
Expand Down Expand Up @@ -1461,24 +1462,48 @@ def duplicate_scenario(self, scen_data, dup_name, db_map):
}
self.import_data({db_map: data}, command_text="Duplicate scenario")

def duplicate_entity(self, entity_data, orig_name, dup_name, db_maps):
def _replace_name(ent_name):
if ent_name == orig_name:
return dup_name
return tuple(name if name != orig_name else dup_name for name in ent_name)
def duplicate_entity(self, orig_name, dup_name, class_name, db_maps):
"""Duplicates entity, its parameter values and related multidimensional entities.
data = self._get_data_for_export(entity_data)
data = {
"entities": [
(cls_name, _replace_name(ent_name), el_name_list, description)
for (cls_name, ent_name, el_name_list, description) in data.get("entities", [])
],
"parameter_values": [
(cls_name, _replace_name(ent_name), param_name, val, alt)
for (cls_name, ent_name, param_name, val, alt) in data.get("parameter_values", [])
],
}
self.import_data({db_map: data for db_map in db_maps}, command_text="Duplicate entity")
Args:
orig_name (str): original entity's name
dup_name (str): duplicate's name
class_name (str): entity class name
db_maps (Iterable of DatabaseMapping): database mappings where duplication should take place
"""
dup_import_data = {}
for db_map in db_maps:
entity = db_map.get_entity_item(class_name=class_name, name=orig_name)
element_name_list = entity["element_name_list"]
if element_name_list:
first_import_entry = (class_name, dup_name, element_name_list, entity["description"])
else:
first_import_entry = (class_name, dup_name, entity["description"])
dup_entity_import_data = [first_import_entry]
for item in db_map.get_entity_items():
element_name_list = item["element_name_list"]
item_class_name = item["class_name"]
if orig_name in element_name_list and item_class_name != class_name:
index = item["dimension_name_list"].index(class_name)
name_list = element_name_list
dup_name_list = name_list[:index] + (dup_name,) + name_list[index + 1 :]
dup_entity_import_data.append(
(item_class_name, name_from_elements(dup_name_list), dup_name_list, item["description"])
)
dup_import_data[db_map] = {"entities": dup_entity_import_data}
dup_value_import_data = []
for item in db_map.get_parameter_value_items(entity_class_name=class_name, entity_name=orig_name):
dup_value_import_data.append(
(
class_name,
dup_name,
item["parameter_definition_name"],
item["parsed_value"],
item["alternative_name"],
)
)
dup_import_data[db_map].update(parameter_values=dup_value_import_data)
self.import_data(dup_import_data, command_text="Duplicate entity")

def _get_data_for_export(self, db_map_item_ids):
data = {}
Expand Down
35 changes: 30 additions & 5 deletions tests/spine_db_editor/widgets/test_SpineDBEditorWithDBMapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@
# Public License for more details. You should have received a copy of the GNU Lesser General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################

"""
Unit tests for SpineDBEditor classes.
"""
""" Unit tests for SpineDBEditor classes. """
import os.path
from tempfile import TemporaryDirectory
import unittest
from unittest import mock
import logging
import sys

from PySide6.QtCore import QItemSelectionModel
from PySide6.QtWidgets import QApplication

from spinetoolbox.spine_db_editor.widgets.spine_db_editor import SpineDBEditor
from tests.mock_helpers import TestSpineDBManager

Expand Down Expand Up @@ -87,9 +87,34 @@ def test_duplicate_object_in_object_tree_model(self):
root_item = self.spine_db_editor.entity_tree_model.root_item
fish_item = next(iter(item for item in root_item.children if item.display_data == "fish"))
nemo_item = fish_item.child(0)
self.spine_db_editor.duplicate_entity(nemo_item)
with mock.patch.object(self.db_mngr, "error_msg") as error_msg_signal:
self.spine_db_editor.duplicate_entity(nemo_item)
error_msg_signal.emit.assert_not_called()
self.assertEqual(fish_item.row_count(), 2)
nemo_dupe = fish_item.child(1)
self.assertEqual(nemo_dupe.display_data, "nemo (1)")
fish_dog_item = next(iter(item for item in root_item.children if item.display_data == "fish__dog"))
fish_dog_item.fetch_more()
self.assertEqual(fish_dog_item.row_count(), 2)
nemo_pluto_dupe = fish_dog_item.child(1)
self.assertEqual(nemo_pluto_dupe.display_data, "nemo (1)__pluto[nemo (1) ǀ pluto]")
root_index = self.spine_db_editor.entity_tree_model.index_from_item(root_item)
self.spine_db_editor.ui.treeView_entity.selectionModel().setCurrentIndex(
root_index, QItemSelectionModel.SelectionFlags.ClearAndSelect
)
while self.spine_db_editor.parameter_value_model.rowCount() != 3:
QApplication.processEvents()
expected = [
["fish", "nemo", "color", "Base", "orange", "db"],
["fish", "nemo (1)", "color", "Base", "orange", "db"],
[None, None, None, None, None, "db"],
]
for row in range(3):
for column in range(self.spine_db_editor.parameter_value_model.columnCount()):
with self.subTest(row=row, column=column):
self.assertEqual(
self.spine_db_editor.parameter_value_model.index(row, column).data(), expected[row][column]
)


if __name__ == '__main__':
Expand Down

0 comments on commit f476c3c

Please sign in to comment.