# Friction.py -- Pamhyr # Copyright (C) 2023-2025 INRAE # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # -*- coding: utf-8 -*- import logging from tools import trace, timer from Model.Tools.PamhyrDB import SQLSubModel from Model.Scenario import Scenario from numpy import interp logger = logging.getLogger() class Friction(SQLSubModel): def __init__(self, id: int = -1, reach=None, status=None, owner_scenario=-1): super(Friction, self).__init__( id=id, status=status, owner_scenario=owner_scenario ) self._reach = reach self._begin_rk = 0.0 self._end_rk = 0.0 self._begin_strickler = None self._end_strickler = None @classmethod def _db_create(cls, execute, ext=""): execute(f""" CREATE TABLE friction{ext} ( {cls.create_db_add_pamhyr_id()}, deleted BOOLEAN NOT NULL DEFAULT FALSE, ind INTEGER NOT NULL, begin_rk REAL NOT NULL, end_rk REAL NOT NULL, reach INTEGER NOT NULL, begin_strickler INTEGER NOT NULL, end_strickler INTEGER NOT NULL, {Scenario.create_db_add_scenario()}, {Scenario.create_db_add_scenario_fk()}, FOREIGN KEY(reach) REFERENCES river_reach(pamhyr_id), FOREIGN KEY(begin_strickler) REFERENCES stricklers(pamhyr_id), FOREIGN KEY(end_strickler) REFERENCES stricklers(pamhyr_id), PRIMARY KEY(pamhyr_id, scenario) ) """) if ext != "": return True return cls._create_submodel(execute) @classmethod def _db_update(cls, execute, version, data=None): major, minor, release = version.strip().split(".") if major == minor == "0": if int(release) < 11: execute("ALTER TABLE friction " + "RENAME COLUMN begin_kp TO begin_rk") execute("ALTER TABLE friction RENAME COLUMN end_kp TO end_rk") cls._db_update_to_0_2_0(execute, data) if major == "0" and minor == "1": if int(release) < 2: execute( "ALTER TABLE friction " + "ADD COLUMN deleted BOOLEAN NOT NULL DEFAULT FALSE" ) return cls._update_submodel(execute, version, data) @classmethod def _db_update_to_0_2_0(cls, execute, data): table = "friction" reachs = data['id2pid']['river_reach'] stricklers = data['id2pid']['stricklers'] cls.update_db_add_pamhyr_id(execute, table, data) Scenario.update_db_add_scenario(execute, table) cls._db_create(execute, ext="_tmp") execute( f"INSERT INTO {table}_tmp " + "(pamhyr_id, ind, begin_rk, end_rk, reach, " + "begin_strickler, end_strickler, scenario) " + "SELECT pamhyr_id, ind, begin_rk, end_rk, reach, " + "begin_strickler, end_strickler, scenario " + f"FROM {table}" ) execute(f"DROP TABLE {table}") execute(f"ALTER TABLE {table}_tmp RENAME TO {table}") cls._db_update_to_0_2_0_set_reach_pid(execute, table, reachs) cls._db_update_to_0_2_0_set_stricklers_pid(execute, table, stricklers) @classmethod def _db_update_to_0_2_0_set_stricklers_pid(cls, execute, table, stricklers): els = execute( f"SELECT pamhyr_id, begin_strickler, end_strickler FROM {table}" ) for row in els: it = iter(row) pid = next(it) b_s_id = next(it) e_s_id = next(it) b_s = stricklers[b_s_id] if b_s_id != -1 else -1 e_s = stricklers[e_s_id] if e_s_id != -1 else -1 execute( f"UPDATE {table} " + f"SET begin_strickler = {b_s}, " + f"end_strickler = {e_s} " + f"WHERE pamhyr_id = {pid}" ) @classmethod def _db_load(cls, execute, data=None): new = [] scenario = data["scenario"] loaded = data['loaded_pid'] if scenario is None: return new reach = data["reach"] stricklers = data["stricklers"].stricklers table = execute( "SELECT pamhyr_id, deleted, begin_rk, end_rk, " + "begin_strickler, end_strickler, scenario " + "FROM friction " + f"WHERE reach = {reach.pamhyr_id} " + f"AND scenario = {scenario.id} " + f"AND pamhyr_id NOT IN ({', '.join(map(str, loaded))}) " + "ORDER BY ind ASC" ) for row in table: it = iter(row) pid = next(it) deleted = (next(it) == 1) begin_rk = next(it) end_rk = next(it) begin_strickler_pid = int(next(it)) end_strickler_pid = int(next(it)) owner_scenario = next(it) # Get stricklers begin_strickler = None if begin_strickler_pid != -1: begin_strickler = next( filter( lambda s: s.id == begin_strickler_pid, stricklers ) ) end_strickler = None if end_strickler_pid != -1: end_strickler = next( filter( lambda s: s.id == begin_strickler_pid, stricklers ) ) new_friction = cls( id=pid, status=data['status'], owner_scenario=owner_scenario ) if deleted: f.set_as_deleted() new_friction.reach = reach new_friction.begin_rk = begin_rk new_friction.end_rk = end_rk new_friction.begin_strickler = begin_strickler new_friction.end_strickler = begin_strickler loaded.add(pid) new.append(new_friction) data["scenario"] = scenario.parent new += cls._db_load(execute, data) data["scenario"] = scenario return new def must_be_saved(self): ssi = self._status.scenario_id return ( self._begin_strickler._owner_scenario == ssi # or self._end_strickler._owner_scenario == ssi or super(Friction, self).must_be_saved() ) def _db_save(self, execute, data=None): if not self.must_be_saved(): return True ind = data["ind"] b_s_id = -1 e_s_id = -1 if self._begin_strickler is not None: b_s_id = self._begin_strickler.id if self._end_strickler is not None: e_s_id = self._end_strickler.id execute( "INSERT INTO " + "friction(pamhyr_id, deleted, ind, begin_rk, end_rk, " + "reach, begin_strickler, end_strickler, scenario) " + "VALUES (" + f"{self.id}, {self._db_format(self.is_deleted())}, " + f"{ind}, {self._begin_rk}, {self._end_rk}, " + f"{self._reach.id}, {b_s_id}, {e_s_id}, " + f"{self._status.scenario_id}" + ")" ) return True @property def reach(self): return self._reach @reach.setter def reach(self, reach): self._reach = reach if (reach is not None and self._begin_rk == 0.0 and self._end_rk == 0.0): self._begin_rk = self._reach.reach.get_rk_min() self._end_rk = self._reach.reach.get_rk_max() self._status.modified() def has_reach(self): return self._reach is not None def has_coefficient(self): return ( self._begin_strickler is not None ) def is_full_defined(self): return self.has_reach() and self.has_coefficient() @property def begin_rk(self): return self._begin_rk @begin_rk.setter def begin_rk(self, begin_rk): if self._reach is None: self._begin_rk = begin_rk else: _min = self._reach.reach.get_rk_min() _max = self._reach.reach.get_rk_max() if _min <= begin_rk <= _max: self._begin_rk = begin_rk self._status.modified() @property def end_rk(self): return self._end_rk @end_rk.setter def end_rk(self, end_rk): if self._reach is None: self._end_rk = end_rk else: _min = self._reach.reach.get_rk_min() _max = self._reach.reach.get_rk_max() if _min <= end_rk <= _max: self._end_rk = end_rk self._status.modified() def __contains__(self, rk): return self.contains_rk(rk) def contains_rk(self, rk): return ( self._begin_rk <= rk <= self._end_rk ) @property def begin_strickler(self): return self._begin_strickler @begin_strickler.setter def begin_strickler(self, strickler): self._begin_strickler = strickler self._status.modified() @property def end_strickler(self): # return self._end_strickler return self._begin_strickler @end_strickler.setter def end_strickler(self, strickler): self._end_strickler = strickler self._status.modified() def get_friction(self, rk): if not self.contains_rk(rk): return None minor = self.begin_strickler.minor medium = self.begin_strickler.medium # minor = interp(rk, # [self.begin_rk, self.end_rk], # [self.begin_strickler.minor, # self.end_strickler.minor]) # medium = interp(rk, # [self.begin_rk, self.end_rk], # [self.begin_strickler.medium, # self.end_strickler.medium]) return minor, medium