Pamhyr2/src/View/Results/CustomPlot/Plot.py

408 lines
11 KiB
Python

# Plot.py -- Pamhyr
# Copyright (C) 2023-2024 INRAE
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
# -*- coding: utf-8 -*-
import logging
from functools import reduce
from datetime import datetime
from tools import timer
from View.Tools.PamhyrPlot import PamhyrPlot
from View.Results.CustomPlot.Translate import CustomPlotTranslate
logger = logging.getLogger()
unit = {
"elevation": "0-meter",
"water_elevation": "0-meter",
"discharge": "1-m3s",
"velocity": "2-ms",
}
class CustomPlot(PamhyrPlot):
def __init__(self, x, y, reach, profile, timestamp,
data=None, canvas=None, trad=None,
toolbar=None, parent=None):
super(CustomPlot, self).__init__(
canvas=canvas,
trad=CustomPlotTranslate(),
data=data,
toolbar=toolbar,
parent=parent
)
self._x = x
self._y = y
self._reach = reach
self._profile = profile
self._timestamp = timestamp
logger.debug(
"Create custom plot for: " +
f"{x} -> {','.join(y)}: " +
f"reach={reach}, profile={profile}, " +
f"timestamp={timestamp}"
)
self._y_axes = sorted(
set(
map(
lambda y: unit[y],
self._y
)
)
)
self._axes = {}
def _draw_rk(self):
results = self.data
reach = results.river.reach(self._reach)
rk = reach.geometry.get_rk()
z_min = reach.geometry.get_z_min()
q = list(
map(
lambda p: p.get_ts_key(self._timestamp, "Q"),
reach.profiles
)
)
z = list(
map(
lambda p: p.get_ts_key(self._timestamp, "Z"),
reach.profiles
)
)
# self.canvas.axes.set_xlim(
# left=min(rk), right=max(rk)
# )
meter_axes = self.canvas.axes
#m3s_axes = meter_axes.twinx()
#ms_axes = meter_axes.twinx()
shift = 0
if "0-meter" in self._y_axes and "1-m3s" in self._y_axes:
m3s_axes = self._axes["1-m3s"]
m3s_axes.spines['right'].set_position(('outward', shift))
shift += 60
if "0-meter" in self._y_axes and "2-ms" in self._y_axes:
ms_axes = self._axes["2-ms"]
ms_axes.spines['right'].set_position(('outward', shift))
shift += 60
lines = {}
if "elevation" in self._y:
# meter_axes.set_ylim(
# bottom=min(0, min(z_min)),
# top=max(z_min) + 1
# )
line = meter_axes.plot(
rk, z_min,
color='grey', lw=1.,
)
lines["elevation"] = line
if "water_elevation" in self._y:
# meter_axes.set_ylim(
# bottom=min(0, min(z_min)),
# top=max(z) + 1
# )
line = meter_axes.plot(
rk, z, lw=1.,
color='blue',
)
lines["water_elevation"] = line
if "elevation" in self._y:
meter_axes.fill_between(
rk, z_min, z,
color='blue', alpha=0.5, interpolate=True
)
if "discharge" in self._y:
# m3s_axes.set_ylim(
# bottom=min(0, min(q)),
# top=max(q) + 1
# )
line = m3s_axes.plot(
rk, q, lw=1.,
color='r',
)
lines["discharge"] = line
if "velocity" in self._y:
v = list(
map(
lambda p: p.geometry.speed(
p.get_ts_key(self._timestamp, "Q"),
p.get_ts_key(self._timestamp, "Z")),
reach.profiles
)
)
# m3s_axes.set_ylim(
# bottom=min(0, min(q)),
# top=max(q) + 1
# )
line = ms_axes.plot(
rk, v, lw=1.,
color='g',
)
lines["velocity"] = line
# Legend
lns = reduce(
lambda acc, line: acc + line,
map(lambda line: lines[line], lines),
[]
)
labs = list(map(lambda line: self._trad[line], lines))
self.canvas.axes.legend(lns, labs, loc="best")
def _customize_x_axes_time(self, ts, mode="time"):
# Custom time display
nb = len(ts)
mod = int(nb / 5)
mod = mod if mod > 0 else nb
fx = list(
map(
lambda x: x[1],
filter(
lambda x: x[0] % mod == 0,
enumerate(ts)
)
)
)
if mode == "time":
t0 = datetime.fromtimestamp(0)
xt = list(
map(
lambda v: (
str(
datetime.fromtimestamp(v) - t0
).split(",")[0]
.replace("days", self._trad["days"])
.replace("day", self._trad["day"])
),
fx
)
)
else:
xt = list(
map(
lambda v: str(datetime.fromtimestamp(v).date()),
fx
)
)
self.canvas.axes.set_xticks(ticks=fx, labels=xt, rotation=45)
def _draw_time(self):
results = self.data
reach = results.river.reach(self._reach)
profile = reach.profile(self._profile)
meter_axes = self.canvas.axes
#m3s_axes = meter_axes.twinx()
#ms_axes = meter_axes.twinx()
shift = 0
if "0-meter" in self._y_axes and "1-m3s" in self._y_axes:
m3s_axes = self._axes["1-m3s"]
m3s_axes.spines['right'].set_position(('outward', shift))
shift += 60
if "0-meter" in self._y_axes and "2-ms" in self._y_axes:
ms_axes = self._axes["2-ms"]
ms_axes.spines['right'].set_position(('outward', shift))
shift += 60
ts = list(results.get("timestamps"))
ts.sort()
q = profile.get_key("Q")
z = profile.get_key("Z")
# self.canvas.axes.set_xlim(
# left=min(ts), right=max(ts)
# )
x = ts
lines = {}
if "elevation" in self._y:
# Z min is constant in time
z_min = profile.geometry.z_min()
ts_z_min = list(
map(
lambda ts: z_min,
ts
)
)
line = meter_axes.plot(
ts, ts_z_min,
color='grey', lw=1.
)
lines["elevation"] = line
if "water_elevation" in self._y:
# meter_axes.set_ylim(
# bottom=min(0, min(z)),
# top=max(z) + 1
# )
line = meter_axes.plot(
ts, z, lw=1.,
color='b',
)
lines["water_elevation"] = line
if "elevation" in self._y:
z_min = profile.geometry.z_min()
ts_z_min = list(
map(
lambda ts: z_min,
ts
)
)
meter_axes.fill_between(
ts, ts_z_min, z,
color='blue', alpha=0.5, interpolate=True
)
if "discharge" in self._y:
# m3s_axes.set_ylim(
# bottom=min(0, min(q)),
# top=max(q) + 1
# )
line = m3s_axes.plot(
ts, q, lw=1.,
color='r',
)
lines["discharge"] = line
if "velocity" in self._y:
v = list(
map(
lambda q, z: profile.geometry.speed(q, z),
q, z
)
)
# ms_axes.set_ylim(
# bottom=min(0, min(q)),
# top=max(q) + 1
# )
line = ms_axes.plot(
ts, v, lw=1.,
color='g',
)
lines["velocity"] = line
self._customize_x_axes_time(ts)
# Legend
lns = reduce(
lambda acc, line: acc + line,
map(lambda line: lines[line], lines),
[]
)
labs = list(map(lambda line: self._trad[line], lines))
self.canvas.axes.legend(lns, labs, loc="best")
@timer
def draw(self):
self.canvas.axes.cla()
self.canvas.axes.grid(color='grey', linestyle='--', linewidth=0.5)
if self.data is None:
return
self.canvas.axes.set_xlabel(
self._trad[self._x],
color='black', fontsize=10
)
self.canvas.axes.set_ylabel(
self._trad[self._y_axes[0]],
color='black', fontsize=10
)
for axes in self._y_axes[1:]:
if axes in self._axes:
self._axes[axes].clear()
continue
ax_new = self.canvas.axes.twinx()
ax_new.set_ylabel(
self._trad[axes],
color='black', fontsize=10
)
self._axes[axes] = ax_new
if self._x == "rk":
self._draw_rk()
elif self._x == "time":
self._draw_time()
#self.canvas.figure.tight_layout()
self.canvas.figure.canvas.draw_idle()
if self.toolbar is not None:
self.toolbar.update()
@timer
def update(self):
if not self._init:
self.draw()
return
def set_reach(self, reach_id):
self._reach = reach_id
self._profile = 0
self.update()
def set_profile(self, profile_id):
self._profile = profile_id
if self._x != "rk":
self.update()
def set_timestamp(self, timestamp):
self._timestamp = timestamp
if self._x != "time":
self.update()