Skip to content

Commit

Permalink
Prevent dragging and dropping alternatives to the Scenario tree from …
Browse files Browse the repository at this point in the history
…corrupting the database (#2419)
  • Loading branch information
soininen authored Nov 16, 2023
2 parents db6522a + 7329256 commit ee5a65f
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 62 deletions.
11 changes: 6 additions & 5 deletions spinetoolbox/spine_db_editor/mvcmodels/alternative_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
"""Contains alternative tree model."""
import json
import pickle
from collections import defaultdict

from PySide6.QtCore import QMimeData, QByteArray
from .tree_model_base import TreeModelBase
Expand Down Expand Up @@ -38,17 +39,17 @@ def mimeData(self, indexes):
Returns:
QMimeData: MIME data
"""
d = {}
d = defaultdict(list)
# We have two columns and consequently usually twice the same item per row.
# Make items unique without losing order using a dictionary trick.
items = list(dict.fromkeys(self.item_from_index(ind) for ind in indexes))
for item in items:
db_item = item.parent_item
db_key = self.db_mngr.db_map_key(db_item.db_map)
d.setdefault(db_key, []).append(item.id)
d[db_key].append(item.name)
mime = QMimeData()
mime.setText(two_column_as_csv(indexes))
mime.setData(mime_types.ALTERNATIVE_DATA, QByteArray(json.dumps(d)))
mime.setData(mime_types.ALTERNATIVE_DATA, QByteArray(pickle.dumps(d)))
return mime

def paste_alternative_mime_data(self, mime_data, database_item):
Expand All @@ -58,7 +59,7 @@ def paste_alternative_mime_data(self, mime_data, database_item):
mime_data (QMimeData): mime data
database_item (alternative_item.DBItem): target database item
"""
alternative_data = json.loads(mime_data.data(mime_types.ALTERNATIVE_DATA).data())
alternative_data = pickle.loads(mime_data.data(mime_types.ALTERNATIVE_DATA).data())
names_to_descriptions = {}
for db_key in alternative_data:
db_map = self.db_mngr.db_map_from_key(db_key)
Expand Down
44 changes: 19 additions & 25 deletions spinetoolbox/spine_db_editor/mvcmodels/scenario_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
"""Contains scenario tree model."""
import json
import pickle

from PySide6.QtCore import QMimeData, Qt, QByteArray
from spinetoolbox.helpers import unique_name
Expand Down Expand Up @@ -61,7 +61,7 @@ def mimeData(self, indexes):
db_item = item.parent_item
db_key = self.db_mngr.db_map_key(db_item.db_map)
scenario_data.setdefault(db_key, []).append(item.id)
mime.setData(mime_types.SCENARIO_DATA, QByteArray(json.dumps(scenario_data)))
mime.setData(mime_types.SCENARIO_DATA, QByteArray(pickle.dumps(scenario_data)))
mime.setText(two_column_as_csv(scenario_indexes))
return mime
alternative_indexes = []
Expand All @@ -79,7 +79,7 @@ def mimeData(self, indexes):
db_item = item.parent_item.parent_item
db_key = self.db_mngr.db_map_key(db_item.db_map)
alternative_data.setdefault(db_key, []).append(item.alternative_id)
mime.setData(mime_types.ALTERNATIVE_DATA, QByteArray(json.dumps(alternative_data)))
mime.setData(mime_types.ALTERNATIVE_DATA, QByteArray(pickle.dumps(alternative_data)))
mime.setText(two_column_as_csv(alternative_indexes))
return mime
return None
Expand All @@ -90,8 +90,8 @@ def canDropMimeData(self, mime_data, drop_action, row, column, parent):
if not mime_data.hasFormat(mime_types.ALTERNATIVE_DATA):
return False
try:
payload = json.loads(mime_data.data(mime_types.ALTERNATIVE_DATA).data())
except json.JSONDecodeError:
payload = pickle.loads(mime_data.data(mime_types.ALTERNATIVE_DATA).data())
except pickle.UnpicklingError:
return False
if not isinstance(payload, dict):
return False
Expand Down Expand Up @@ -123,15 +123,7 @@ def dropMimeData(self, mime_data, drop_action, row, column, parent):
# on a wrong tree item (bug in Qt or canDropMimeData()?).
# In those cases the type of scen_item is StandardTreeItem or ScenarioRootItem.
return False
old_alternative_id_list = list(scenario_item.alternative_id_list)
if row == -1:
row = len(old_alternative_id_list)
_db_map_key, alternative_ids = json.loads(mime_data.data(mime_types.ALTERNATIVE_DATA).data()).popitem()
alternative_id_list = [id_ for id_ in old_alternative_id_list[:row] if id_ not in alternative_ids]
alternative_id_list += alternative_ids
alternative_id_list += [id_ for id_ in old_alternative_id_list[row:] if id_ not in alternative_ids]
db_item = {"id": scenario_item.id, "alternative_id_list": alternative_id_list}
self.db_mngr.set_scenario_alternatives({scenario_item.db_map: [db_item]})
self.paste_alternative_mime_data(mime_data, row, scenario_item)
return True

def paste_alternative_mime_data(self, mime_data, row, scenario_item):
Expand All @@ -145,16 +137,18 @@ def paste_alternative_mime_data(self, mime_data, row, scenario_item):
old_alternative_id_list = list(scenario_item.alternative_id_list)
if row == -1:
row = len(old_alternative_id_list)
data_to_add = {}
for db_map_key, alternative_ids in json.loads(mime_data.data(mime_types.ALTERNATIVE_DATA).data()).items():
new_alternative_ids = []
for db_map_key, alternative_names in pickle.loads(mime_data.data(mime_types.ALTERNATIVE_DATA).data()).items():
target_db_map = self.db_mngr.db_map_from_key(db_map_key)
if target_db_map != scenario_item.db_map:
continue
alternative_id_list = [id_ for id_ in old_alternative_id_list[:row] if id_ not in alternative_ids]
alternative_id_list += alternative_ids
alternative_id_list += [id_ for id_ in old_alternative_id_list[row:] if id_ not in alternative_ids]
data_to_add[target_db_map] = [{"id": scenario_item.id, "alternative_id_list": alternative_id_list}]
self.db_mngr.set_scenario_alternatives(data_to_add)
for name in alternative_names:
new_alternative_ids.append(scenario_item.db_map.get_alternative_item(name=name)["id"])
alternative_id_list = [id_ for id_ in old_alternative_id_list[:row] if id_ not in new_alternative_ids]
alternative_id_list += new_alternative_ids
alternative_id_list += [id_ for id_ in old_alternative_id_list[row:] if id_ not in new_alternative_ids]
db_item = {"id": scenario_item.id, "alternative_id_list": alternative_id_list}
self.db_mngr.set_scenario_alternatives({scenario_item.db_map: [db_item]})

def paste_scenario_mime_data(self, mime_data, db_item):
"""Adds scenarios and their alternatives from MIME data to the model.
Expand All @@ -168,15 +162,15 @@ def paste_scenario_mime_data(self, mime_data, db_item):
alternative_names_by_scenario = {}
existing_scenarios = {i["name"] for i in self.db_mngr.get_items(db_item.db_map, "scenario")}
existing_alternatives = {i["name"] for i in self.db_mngr.get_items(db_item.db_map, "alternative")}
for db_map_key, scenario_ids in json.loads(mime_data.data(mime_types.SCENARIO_DATA).data()).items():
for db_map_key, scenario_names in pickle.loads(mime_data.data(mime_types.SCENARIO_DATA).data()).items():
db_map = self.db_mngr.db_map_from_key(db_map_key)
if db_map is db_item.db_map:
continue
for id_ in scenario_ids:
scenario_data = self.db_mngr.get_item(db_map, "scenario", id_)
for name in scenario_names:
scenario_data = db_map.get_scenario_item(name=name)
if scenario_data["name"] in existing_scenarios:
continue
alternative_id_list = self.db_mngr.get_scenario_alternative_id_list(db_map, id_)
alternative_id_list = self.db_mngr.get_scenario_alternative_id_list(db_map, scenario_data["id"])
for alternative_id in alternative_id_list:
alternative_db_item = self.db_mngr.get_item(db_map, "alternative", alternative_id)
alternative_names_by_scenario.setdefault(scenario_data["name"], []).append(
Expand Down
6 changes: 3 additions & 3 deletions tests/spine_db_editor/mvcmodels/test_alternative_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
"""Unit tests for :class:`AlternativeModel`."""
import json
from pathlib import Path
import pickle
from tempfile import TemporaryDirectory
import unittest
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -120,8 +120,8 @@ def test_mimeData(self):
self.assertTrue(mime_data.hasText())
self.assertEqual(mime_data.text(), "Base\tBase alternative\r\n")
self.assertTrue(mime_data.hasFormat(mime_types.ALTERNATIVE_DATA))
alternative_data = json.loads(mime_data.data(mime_types.ALTERNATIVE_DATA).data())
self.assertEqual(alternative_data, {self._db_mngr.db_map_key(self._db_map): [1]})
alternative_data = pickle.loads(mime_data.data(mime_types.ALTERNATIVE_DATA).data())
self.assertEqual(alternative_data, {self._db_mngr.db_map_key(self._db_map): ["Base"]})


class TestAlternativeModelWithTwoDatabases(unittest.TestCase):
Expand Down
59 changes: 30 additions & 29 deletions tests/spine_db_editor/mvcmodels/test_scenario_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################
"""Unit tests for ``scenario_model`` module."""
import json
from pathlib import Path
import pickle
from tempfile import TemporaryDirectory
import unittest
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -142,7 +142,7 @@ def test_mimeData(self):
self.assertTrue(mime_data.hasText())
self.assertEqual(mime_data.text(), "Base\tBase alternative\r\n")
self.assertTrue(mime_data.hasFormat(mime_types.ALTERNATIVE_DATA))
data = json.loads(mime_data.data(mime_types.ALTERNATIVE_DATA).data())
data = pickle.loads(mime_data.data(mime_types.ALTERNATIVE_DATA).data())
self.assertEqual(data, {self._db_mngr.db_map_key(self._db_map): [1]})

def test_canDropMimeData_returns_true_when_dropping_alternative_to_empty_scenario(self):
Expand All @@ -155,8 +155,8 @@ def test_canDropMimeData_returns_true_when_dropping_alternative_to_empty_scenari
scenario_index = model.index(0, 0, root_index)
self.assertEqual(scenario_index.data(), "my_scenario")
mime_data = QMimeData()
data = {self._db_mngr.db_map_key(self._db_map): [1]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(json.dumps(data)))
data = {self._db_mngr.db_map_key(self._db_map): ["Base"]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(pickle.dumps(data)))
self.assertTrue(model.canDropMimeData(mime_data, Qt.DropAction.CopyAction, -1, -1, scenario_index))

def test_dropMimeData_adds_alternative_to_model(self):
Expand All @@ -169,8 +169,8 @@ def test_dropMimeData_adds_alternative_to_model(self):
scenario_index = model.index(0, 0, root_index)
self.assertEqual(scenario_index.data(), "my_scenario")
mime_data = QMimeData()
data = {self._db_mngr.db_map_key(self._db_map): [1]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(json.dumps(data)))
data = {self._db_mngr.db_map_key(self._db_map): ["Base"]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(pickle.dumps(data)))
self.assertTrue(model.dropMimeData(mime_data, Qt.DropAction.CopyAction, -1, -1, scenario_index))
self._fetch_recursively(model)
model_data = model_data_to_dict(model)
Expand All @@ -196,7 +196,7 @@ def test_dropMimeData_adds_alternative_to_model(self):
self.assertEqual(model_data, expected)

def test_dropMimeData_reorders_alternatives(self):
self._db_mngr.add_alternatives({self._db_map: [{"name": "alternative_1", "id": 2}]})
self._db_mngr.add_alternatives({self._db_map: [{"name": "alternative_1"}]})
model = ScenarioModel(self._db_editor, self._db_mngr, self._db_map)
model.build_tree()
self._fetch_recursively(model)
Expand All @@ -206,8 +206,8 @@ def test_dropMimeData_reorders_alternatives(self):
scenario_index = model.index(0, 0, root_index)
self.assertEqual(scenario_index.data(), "my_scenario")
mime_data = QMimeData()
data = {self._db_mngr.db_map_key(self._db_map): [1]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(json.dumps(data)))
data = {self._db_mngr.db_map_key(self._db_map): ["Base"]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(pickle.dumps(data)))
self.assertTrue(model.dropMimeData(mime_data, Qt.DropAction.CopyAction, -1, -1, scenario_index))
self._fetch_recursively(model)
model_data = model_data_to_dict(model)
Expand All @@ -232,8 +232,8 @@ def test_dropMimeData_reorders_alternatives(self):
]
self.assertEqual(model_data, expected)
mime_data = QMimeData()
data = {self._db_mngr.db_map_key(self._db_map): [2]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(json.dumps(data)))
data = {self._db_mngr.db_map_key(self._db_map): ["alternative_1"]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(pickle.dumps(data)))
self.assertTrue(model.dropMimeData(mime_data, Qt.DropAction.CopyAction, 0, 0, scenario_index))
self._fetch_recursively(model)
model_data = model_data_to_dict(model)
Expand All @@ -259,8 +259,8 @@ def test_dropMimeData_reorders_alternatives(self):
]
self.assertEqual(model_data, expected)
mime_data = QMimeData()
data = {self._db_mngr.db_map_key(self._db_map): [1]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(json.dumps(data)))
data = {self._db_mngr.db_map_key(self._db_map): ["Base"]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(pickle.dumps(data)))
self.assertTrue(model.dropMimeData(mime_data, Qt.DropAction.CopyAction, 0, 0, scenario_index))
self._fetch_recursively(model)
model_data = model_data_to_dict(model)
Expand All @@ -287,7 +287,7 @@ def test_dropMimeData_reorders_alternatives(self):
self.assertEqual(model_data, expected)

def test_paste_alternative_mime_data(self):
self._db_mngr.add_alternatives({self._db_map: [{"name": "alternative_1", "id": 2}]})
self._db_mngr.add_alternatives({self._db_map: [{"name": "alternative_1"}]})
model = ScenarioModel(self._db_editor, self._db_mngr, self._db_map)
model.build_tree()
self._fetch_recursively(model)
Expand All @@ -298,8 +298,8 @@ def test_paste_alternative_mime_data(self):
scenario_index = model.index(0, 0, root_index)
self.assertEqual(scenario_index.data(), "my_scenario")
mime_data = QMimeData()
data = {self._db_mngr.db_map_key(self._db_map): [2]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(json.dumps(data)))
data = {self._db_mngr.db_map_key(self._db_map): ["alternative_1"]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(pickle.dumps(data)))
scenario_item = model.item_from_index(scenario_index)
model.paste_alternative_mime_data(mime_data, -1, scenario_item)
self._fetch_recursively(model)
Expand All @@ -318,7 +318,7 @@ def test_paste_alternative_mime_data(self):
self.assertEqual(model_data, expected)

def test_paste_alternative_mime_data_ranks_alternatives(self):
self._db_mngr.add_alternatives({self._db_map: [{"name": "alternative_1", "id": 2}]})
self._db_mngr.add_alternatives({self._db_map: [{"name": "alternative_1"}]})
model = ScenarioModel(self._db_editor, self._db_mngr, self._db_map)
model.build_tree()
self._fetch_recursively(model)
Expand All @@ -329,8 +329,8 @@ def test_paste_alternative_mime_data_ranks_alternatives(self):
scenario_index = model.index(0, 0, root_index)
self.assertEqual(scenario_index.data(), "my_scenario")
mime_data = QMimeData()
data = {self._db_mngr.db_map_key(self._db_map): [1]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(json.dumps(data)))
data = {self._db_mngr.db_map_key(self._db_map): ["Base"]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(pickle.dumps(data)))
scenario_item = model.item_from_index(scenario_index)
model.paste_alternative_mime_data(mime_data, -1, scenario_item)
self._fetch_recursively(model)
Expand All @@ -355,8 +355,8 @@ def test_paste_alternative_mime_data_ranks_alternatives(self):
]
]
self.assertEqual(model_data, expected)
data = {self._db_mngr.db_map_key(self._db_map): [2]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(json.dumps(data)))
data = {self._db_mngr.db_map_key(self._db_map): ["alternative_1"]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(pickle.dumps(data)))
scenario_item = model.item_from_index(scenario_index)
model.paste_alternative_mime_data(mime_data, 0, scenario_item)
self._fetch_recursively(model)
Expand Down Expand Up @@ -456,7 +456,7 @@ def tearDown(self):
self._temp_dir.cleanup()

def test_paste_alternative_mime_data_doesnt_paste_across_databases(self):
self._db_mngr.add_alternatives({self._db_map1: [{"name": "alternative_1", "id": 2}]})
self._db_mngr.add_alternatives({self._db_map1: [{"name": "alternative_1"}]})
model = ScenarioModel(self._db_editor, self._db_mngr, self._db_map1, self._db_map2)
model.build_tree()
self._fetch_recursively(model)
Expand All @@ -467,8 +467,8 @@ def test_paste_alternative_mime_data_doesnt_paste_across_databases(self):
scenario_index = model.index(0, 0, root_index)
self.assertEqual(scenario_index.data(), "my_scenario")
mime_data = QMimeData()
data = {self._db_mngr.db_map_key(self._db_map1): [2]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(json.dumps(data)))
data = {self._db_mngr.db_map_key(self._db_map1): ["alternative_1"]}
mime_data.setData(mime_types.ALTERNATIVE_DATA, QByteArray(pickle.dumps(data)))
scenario_item = model.item_from_index(scenario_index)
model.paste_alternative_mime_data(mime_data, -1, scenario_item)
self._fetch_recursively(model)
Expand All @@ -488,15 +488,16 @@ def test_paste_alternative_mime_data_doesnt_paste_across_databases(self):
self.assertEqual(model_data, expected)

def test_paste_scenario_mime_data(self):
self._db_mngr.add_scenarios({self._db_map1: [{"name": "my_scenario", "id": 1}]})
self._db_mngr.add_alternatives({self._db_map1: [{"name": "alternative_1", "id": 2}]})
self._db_mngr.set_scenario_alternatives({self._db_map1: [{"id": 1, "alternative_id_list": [2, 1]}]})
self._db_mngr.add_scenarios({self._db_map1: [{"name": "my_scenario"}]})
self._db_mngr.add_alternatives({self._db_map1: [{"name": "alternative_1"}]})
scenario_id = self._db_map1.get_scenario_item(name="my_scenario")["id"]
self._db_mngr.set_scenario_alternatives({self._db_map1: [{"id": scenario_id, "alternative_name_list": ["alternative_1", "Base"]}]})
model = ScenarioModel(self._db_editor, self._db_mngr, self._db_map1, self._db_map2)
model.build_tree()
self._fetch_recursively(model)
mime_data = QMimeData()
data = {self._db_mngr.db_map_key(self._db_map1): [1]}
mime_data.setData(mime_types.SCENARIO_DATA, QByteArray(json.dumps(data)))
data = {self._db_mngr.db_map_key(self._db_map1): ["my_scenario"]}
mime_data.setData(mime_types.SCENARIO_DATA, QByteArray(pickle.dumps(data)))
root_index = model.index(1, 0)
self.assertEqual(root_index.data(), "test_db_2")
db_item = model.item_from_index(root_index)
Expand Down

0 comments on commit ee5a65f

Please sign in to comment.