diff --git a/src/Model/Scenarios.py b/src/Model/Scenarios.py index 6b39c59e..254cdc44 100644 --- a/src/Model/Scenarios.py +++ b/src/Model/Scenarios.py @@ -56,6 +56,6 @@ class Scenarios(PamhyrModelDict): return None def new(self, parent): - new = Scenario(parent=parent, status=self._status) + new = Scenario(parent=parent) self.set(new._id, new) return new diff --git a/src/Model/Status.py b/src/Model/Status.py index 8675ac8d..db632db1 100644 --- a/src/Model/Status.py +++ b/src/Model/Status.py @@ -40,6 +40,7 @@ class StudyStatus(object): @scenario.setter def scenario(self, scenario): + logger.debug(f"Set scenario to {scenario}") self._scenario = scenario @property @@ -73,9 +74,4 @@ class StudyStatus(object): if self._saved: self.version += 1 - logger.debug( - "STATUS: Model status set as modified " + - f"at version {self.version}" - ) - self._saved = False diff --git a/src/Model/Study.py b/src/Model/Study.py index 49e0273d..a93d3943 100644 --- a/src/Model/Study.py +++ b/src/Model/Study.py @@ -62,14 +62,14 @@ class Study(SQLModel): if init_new: # Study data - self.scenarios = Scenarios(status=self.status) - self.scenarios.set( - 0, - Scenario( - id=0, name='default', - description='Default scenario', - ) + s0 = Scenario( + id=0, name='default', + description='Default scenario', ) + + self.scenarios = Scenarios(status=self.status) + self.scenarios[0] = s0 + self.status.scenario = s0 self._river = River(status=self.status) else: self._init_db_file(filename, is_new=False) @@ -225,7 +225,7 @@ class Study(SQLModel): "CREATE TABLE info(key TEXT NOT NULL UNIQUE, value TEXT NOT NULL)" ) self.execute( - "INSERT INTO info VALUES ('study_release', '0')" + "INSERT INTO info VALUES ('current_scenario', '0')" ) self.execute( "INSERT INTO info VALUES ('version', " + @@ -261,7 +261,7 @@ class Study(SQLModel): if version[0] == self._version: return True - logger.debug("Update database") + logger.info(f"Update database from {version[0]} to {self._version}") major, minor, release = version[0].split('.') if major == "0" and minor == "0" and int(release) < 10: @@ -269,6 +269,9 @@ class Study(SQLModel): "INSERT INTO info VALUES ('study_release', '0')" ) + if major == "0" and int(minor) <= 1: + self._add_into_info_if_not_exists('current_scenario', '0') + if major == "0" and int(minor) < 1: # Need to temporary disable the sqlite foreign keys # checking to update db dans change the table id fk to @@ -291,17 +294,26 @@ class Study(SQLModel): ) return True - logger.info("TODO: update failed") + logger.info("Update failed!") raise NotImplementedMethodeError(self, self._update) + def _add_into_info_if_not_exists(self, key, value): + rows = self.execute(f"SELECT value FROM info WHERE key='{key}'") + + if rows is None or len(rows) == 0: + self.execute( + f"INSERT INTO info VALUES ('{key}', '{value}')" + ) + @classmethod def _load(cls, filename): new = cls(init_new=False, filename=filename) + data = {"status": new.status} - version = new.execute( - "SELECT value FROM info WHERE key='study_release'" - ) - new.status.version = int(version[0]) + def sql_exec(sql): + return new.execute( + sql, fetch_one=False, commit=True + ) # TODO: Load metadata new.name = new.execute("SELECT value FROM info WHERE key='name'")[0] @@ -328,18 +340,20 @@ class Study(SQLModel): ) ) - data = {"status": new.status} - - def sql_exec(sql): - return new.execute( - sql, fetch_one=False, commit=True - ) - + # Scenarios new.scenarios = Scenarios._db_load( sql_exec, data=data ) + scenario_id = new.execute( + "SELECT value FROM info WHERE key='current_scenario'" + ) + logger.debug(f"Load with scenario {scenario_id[0]}") + + scenario = new.scenarios[int(scenario_id[0])] + new.status.scenario = scenario + # Load river data new._river = River._db_load( sql_exec, @@ -390,6 +404,13 @@ class Study(SQLModel): ) progress() + self.execute( + f"UPDATE info SET " + + f"value='{self.status.scenario_id}' " + + "WHERE key='current_scenario'" + ) + progress() + self._save_submodel( [self.scenarios, self._river], data=progress @@ -413,3 +434,27 @@ class Study(SQLModel): Nothing. """ self._close() + + def new_scenario_from_current(self, switch=True): + new = self.scenarios.new(self.status.scenario) + + if switch: + self.status.scenario = new + + return new + + def reload_from_scenario(self, scenario): + data = {"status": new.status} + + def sql_exec(sql): + return new.execute( + sql, fetch_one=False, commit=True + ) + + self.status.scenario = scenario + + # Reload river data + self._river = River._db_load( + sql_exec, + data=data + ) diff --git a/src/Model/Tools/PamhyrDict.py b/src/Model/Tools/PamhyrDict.py index 2885d364..1e81dc2b 100644 --- a/src/Model/Tools/PamhyrDict.py +++ b/src/Model/Tools/PamhyrDict.py @@ -59,6 +59,12 @@ class PamhyrModelDict(SQLSubModel): def __contains__(self, key): return key in self._dict + def __getitem__(self, key): + return self.get(key) + + def __setitem__(self, key, value): + self.set(key, value) + def set(self, key, new): self._dict[key] = new self._status.modified() diff --git a/src/Model/test_Model.py b/src/Model/test_Model.py index 8568a00a..5c6c118c 100644 --- a/src/Model/test_Model.py +++ b/src/Model/test_Model.py @@ -59,6 +59,55 @@ class StudyTestCase(unittest.TestCase): self.assertNotEqual(study.river, None) +class StudyScenarioTestCase(unittest.TestCase): + def test_create_study(self): + study = Study.new("foo", "bar") + self.assertEqual(study.name, "foo") + self.assertEqual(study.description, "bar") + self.assertEqual(study.status.scenario_id, 0) + + def test_create_new_scenario_study(self): + study = Study.new("foo", "bar") + old = study.status.scenario + new = study.new_scenario_from_current() + + self.assertEqual(study.name, "foo") + self.assertEqual(study.description, "bar") + + self.assertNotEqual(study.status.scenario_id, 0) + self.assertEqual(study.status.scenario_id, new.id) + + self.assertNotEqual(old, new) + + def test_open_study(self): + study = Study.open("../tests_cases/Enlargement/Enlargement.pamhyr") + self.assertNotEqual(study, None) + self.assertEqual(study.name, "Enlargement") + self.assertEqual(study.status.scenario_id, 0) + + def test_save_open_study(self): + study = Study.new("foo", "bar") + new = study.new_scenario_from_current() + nid = new.id + + dir = tempfile.mkdtemp() + f = os.path.join(dir, "foo.pamhyr") + + # Save study + study.filename = f + study.save() + study.close() + + # Reopen study + study = Study.open(f) + + # Check + self.assertNotEqual(study, None) + self.assertEqual(study.name, "foo") + self.assertEqual(study.description, "bar") + self.assertEqual(study.status.scenario_id, nid) + + class RiverTestCase(unittest.TestCase): def test_create_river(self): status = StudyStatus()