Pamhyr2/src/View/Results/PlotSedReach.py

284 lines
7.8 KiB
Python

# PlotSedReach.py -- Pamhyr
# Copyright (C) 2024 INRAE
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
# -*- coding: utf-8 -*-
import logging
from functools import reduce
from tools import timer
from View.Tools.PamhyrPlot import PamhyrPlot
from PyQt5.QtCore import (
QCoreApplication
)
_translate = QCoreApplication.translate
logger = logging.getLogger()
class PlotSedReach(PamhyrPlot):
def __init__(self, canvas=None, trad=None, toolbar=None,
results=None, reach_id=0, profile_id=0,
parent=None):
super(PlotSedReach, self).__init__(
canvas=canvas,
trad=trad,
data=results,
toolbar=toolbar,
parent=parent
)
self._results = results
self._current_timestamp = max(results.get("timestamps"))
self._current_reach_id = reach_id
self._current_profile_id = profile_id
@property
def results(self):
return self.data
@results.setter
def results(self, results):
self.data = results
self._results = results
self._current_timestamp = max(results.get("timestamps"))
# DEPRECATED version of sediment layser display
# def _get_zsl(self, reach):
# rk = reach.geometry.get_rk()
# z_min = reach.geometry.get_z_min()
# z_max = reach.geometry.get_z_max()
# profiles_sl = list(
# map(
# # Get SL list for profile p
# lambda p: p.get_ts_key(self._current_timestamp, "sl"),
# reach.profiles
# )
# )
# max_sl_num = reduce(
# lambda acc, sl: max(acc, len(sl)),
# profiles_sl,
# 0
# )
# sl = []
# for i in range(max_sl_num):
# cur = []
# for profile_sl in profiles_sl:
# if i < len(profile_sl):
# cur.append(profile_sl[i][0])
# else:
# cur.append(0)
# sl.append(cur)
# self.canvas.axes.set_xlim(
# left = min(rk) - 10, right = max(rk) + 10
# )
# # Dummy layer with height = 0
# f = list(map(lambda p: 0, reach.profiles))
# # We compute Z sediment layer in reverse order, from last layer to
# # fake river bottom
# r_sl = list(reversed(sl))
# z_sl = reduce(
# lambda acc, v: acc + [
# list(
# map(lambda x, y: y + x, v, acc[-1])
# )
# ],
# r_sl,
# [f]
# )
# # We normalize Z coordinate to 0 (the maximum must be 0)
# f_z_max = max(z_sl[-1])
# z_sl = list(
# map(
# lambda p: list(map(lambda z: z - f_z_max, p)),
# z_sl
# )
# )
# # We apply the river geometry bottom height at each layers to
# # fond the new river geometry
# z_sl = list(
# map(
# lambda sl: list(
# map(lambda z, m: z + m, sl, z_min)
# ),
# z_sl
# )
# )
# return z_sl
def get_zsl(self, reach):
rk = reach.geometry.get_rk()
z_min = reach.geometry.get_z_min()
z_max = reach.geometry.get_z_max()
profiles_sl_0 = list(
map(
# Get SL list for profile p at time 0 (initial data)
lambda p: p.get_ts_key(0.0, "sl")[0],
reach.profiles
)
)
profiles_sl = list(
map(
# Get SL list for profile p at current time
lambda p: p.get_ts_key(self._current_timestamp, "sl")[0],
reach.profiles
)
)
max_sl_num = reduce(
lambda acc, sl: max(acc, len(sl)),
profiles_sl,
0
)
f = list(map(lambda p: 0, reach.profiles))
sl = []
sl_0 = []
for i in range(max_sl_num):
cur = []
cur_0 = []
for profile_sl, profile_sl_0 in zip(profiles_sl, profiles_sl_0):
if i < len(profile_sl_0):
cur.append(profile_sl[i][0])
cur_0.append(profile_sl_0[i][0])
else:
cur.append(0)
cur_0.append(0)
sl.append(cur)
sl_0.append(cur_0)
# Compute sediment layer from initial data in function to
# profile z_min
z_sl = reduce(
lambda acc, v: acc + [
list(
map(
lambda x, y: y - x,
v, acc[-1]
)
)
],
sl_0,
[z_min]
)
# Diff between initial data and data att current timestamp
d_sl = list(
map(
lambda ln0, lni: list(
map(
lambda z0, zi: z0 - zi,
ln0, lni
)
),
sl_0, sl
)
)
# Apply diff for t0 for each layer Z
z_sl = list(
map(
lambda z, d: list(
map(
lambda zn, dn: zn - dn,
z, d
)
),
z_sl,
d_sl + [f] # HACK: Add dummy data for last layer
)
)
return list(reversed(z_sl))
@timer
def draw(self):
self.canvas.axes.cla()
self.canvas.axes.grid(color='grey', linestyle='--', linewidth=0.5)
if self.data is None:
return
reach = self._results.river.reach(self._current_reach_id)
if reach.geometry.number_profiles == 0:
return
self.canvas.axes.set_xlabel(
_translate("MainWindow_reach", "Position (m)"),
color='black', fontsize=10
)
self.canvas.axes.set_ylabel(
_translate("MainWindow_reach", "Height (m)"),
color='black', fontsize=10
)
rk = reach.geometry.get_rk()
z_min = reach.geometry.get_z_min()
z_max = reach.geometry.get_z_max()
z_sl = self.get_zsl(reach)
# Draw
self.line_rk_sl = []
for i, z in enumerate(z_sl):
self.line_rk_sl.append(None)
self.line_rk_sl[i], = self.canvas.axes.plot(
rk, z,
linestyle="solid" if i == len(z_sl) - 1 else "--",
lw=1.8,
color='grey' if i == len(z_sl) - 1 else None
)
self.canvas.figure.tight_layout()
self.canvas.figure.canvas.draw_idle()
if self.toolbar is not None:
self.toolbar.update()
self._init = False
@timer
def update(self, ind=None):
if not self._init:
self.draw()
return
def set_reach(self, reach_id):
self._current_reach_id = reach_id
self._current_profile_id = 0
self.draw()
def set_profile(self, profile_id):
self._current_profile_id = profile_id
self.draw()
def set_timestamp(self, timestamp):
self._current_timestamp = timestamp
self.draw()