diff --git a/src/Meshing/test_Meshing.py b/src/Meshing/test_Meshing.py index 1b880f84..0abc41f8 100644 --- a/src/Meshing/test_Meshing.py +++ b/src/Meshing/test_Meshing.py @@ -21,7 +21,6 @@ import unittest import warnings import tempfile -from Model.Saved import SavedStatus from Model.Study import Study from Model.River import River diff --git a/src/Model/Network/test_network.py b/src/Model/Network/test_network.py index 23cbdcce..fb68b6d5 100644 --- a/src/Model/Network/test_network.py +++ b/src/Model/Network/test_network.py @@ -20,7 +20,7 @@ import os import unittest import tempfile -from Model.Saved import SavedStatus +from Model.Status import StudyStatus from Model.Network.Graph import Graph from Model.Network.Edge import Edge @@ -28,7 +28,7 @@ from Model.Network.Node import Node def new_graph(): - status = SavedStatus() + status = StudyStatus() g = Graph(status=status) return g diff --git a/src/Model/Scenario.py b/src/Model/Scenario.py index ed89a6e7..3c9097f2 100644 --- a/src/Model/Scenario.py +++ b/src/Model/Scenario.py @@ -30,8 +30,7 @@ class Scenario(SQLSubModel): name: str = "", description: str = "", revision: int = 0, - parent=None, - status=None): + parent=None): super(Scenario, self).__init__() self._set_id(id) @@ -40,7 +39,6 @@ class Scenario(SQLSubModel): self._description = description self._revision = revision self._parent = parent - self._status = status def _set_id(self, id): if id == -1: @@ -52,6 +50,10 @@ class Scenario(SQLSubModel): self._id + 1, Scenario._id_cnt + 1 ) + @property + def id(self): + return self._id + @classmethod def _db_create(cls, execute): execute(""" @@ -133,8 +135,7 @@ class Scenario(SQLSubModel): new = cls( id=id, name=name, description=desc, - revision=revi, parent=parent, - status=data["status"] + revision=revi, parent=parent ) scenarios[id] = new @@ -177,6 +178,10 @@ class Scenario(SQLSubModel): def revision(self): return self._revision + @revision.setter + def revision(self, revision): + self._revision = revision + @property def parent(self): return self._parent diff --git a/src/Model/Saved.py b/src/Model/Status.py similarity index 57% rename from src/Model/Saved.py rename to src/Model/Status.py index 6ee80df1..8675ac8d 100644 --- a/src/Model/Saved.py +++ b/src/Model/Status.py @@ -1,4 +1,4 @@ -# Saved.py -- Pamhyr model status class +# Status.py -- Pamhyr model status class # Copyright (C) 2023-2024 INRAE # # This program is free software: you can redistribute it and/or modify @@ -21,19 +21,46 @@ import logging logger = logging.getLogger() -class SavedStatus(object): - def __init__(self, version=0): - super(SavedStatus, self).__init__() - self._version = version +class StudyStatus(object): + def __init__(self, scenario=None): + super(StudyStatus, self).__init__() + self._scenario = scenario self._saved = True + @property + def scenario_id(self): + if self._scenario is None: + return -1 + + return self._scenario.id + + @property + def scenario(self): + return self._scenario + + @scenario.setter + def scenario(self, scenario): + self._scenario = scenario + @property def version(self): - return self._version + if self._scenario is None: + return 0 + + return self._scenario.revision @version.setter def version(self, version): - self._version = version + if self._scenario is None: + return + + self._scenario.revision = version + + def str_display(self): + if self._scenario is None: + return "" + + return f"{self._scenario.name}" def is_saved(self): return self._saved @@ -44,11 +71,11 @@ class SavedStatus(object): def modified(self): if self._saved: - self._version += 1 + self.version += 1 logger.debug( "STATUS: Model status set as modified " + - f"at version {self._version}" + f"at version {self.version}" ) self._saved = False diff --git a/src/Model/Stricklers/StricklersList.py b/src/Model/Stricklers/StricklersList.py index 37d82033..3cb53abe 100644 --- a/src/Model/Stricklers/StricklersList.py +++ b/src/Model/Stricklers/StricklersList.py @@ -18,7 +18,6 @@ from tools import trace, timer -from Model.Saved import SavedStatus from Model.Tools.PamhyrList import PamhyrModelList from Model.Stricklers.Stricklers import Stricklers diff --git a/src/Model/Study.py b/src/Model/Study.py index dda0a7f2..49e0273d 100644 --- a/src/Model/Study.py +++ b/src/Model/Study.py @@ -26,8 +26,7 @@ from tools import timer, timestamp from Model.Tools.PamhyrDB import SQLModel from Model.Scenarios import Scenarios from Model.Scenario import Scenario -from Model.Saved import SavedStatus -from Model.Serializable import Serializable +from Model.Status import StudyStatus from Model.Except import NotImplementedMethodeError from Model.River import River @@ -52,7 +51,7 @@ class Study(SQLModel): self._filename = filename super(Study, self).__init__(filename=filename) - self.status = SavedStatus() + self.status = StudyStatus() # Study general information self._name = "" @@ -69,7 +68,6 @@ class Study(SQLModel): Scenario( id=0, name='default', description='Default scenario', - status=self.status, ) ) self._river = River(status=self.status) diff --git a/src/Model/test_Model.py b/src/Model/test_Model.py index b1c90fb0..8568a00a 100644 --- a/src/Model/test_Model.py +++ b/src/Model/test_Model.py @@ -20,7 +20,7 @@ import os import unittest import tempfile -from Model.Saved import SavedStatus +from Model.Status import StudyStatus from Model.Study import Study from Model.River import River @@ -61,13 +61,13 @@ class StudyTestCase(unittest.TestCase): class RiverTestCase(unittest.TestCase): def test_create_river(self): - status = SavedStatus() + status = StudyStatus() river = River(status=status) self.assertNotEqual(river, None) def test_create_river_nodes(self): - status = SavedStatus() + status = StudyStatus() river = River(status=status) self.assertNotEqual(river, None) @@ -86,7 +86,7 @@ class RiverTestCase(unittest.TestCase): self.assertEqual(nodes[2], n2) def test_create_river_edges(self): - status = SavedStatus() + status = StudyStatus() river = River(status=status) self.assertNotEqual(river, None) diff --git a/src/Scripts/P3DST.py b/src/Scripts/P3DST.py index dccf3249..37a4afc9 100644 --- a/src/Scripts/P3DST.py +++ b/src/Scripts/P3DST.py @@ -24,7 +24,7 @@ from numpy import mean from Scripts.AScript import AScript -from Model.Saved import SavedStatus +from Model.Status import StudyStatus from Model.Geometry.Reach import Reach logger = logging.getLogger() @@ -81,7 +81,7 @@ class Script3DST(AScript): return 1 try: - status = SavedStatus() + status = StudyStatus() my_reach = Reach(status=status) my_reach.import_geometry(st_file)