# PlotXY.py -- Pamhyr # Copyright (C) 2023-2024 INRAE # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # -*- coding: utf-8 -*- from tools import timer, trace from View.Tools.PamhyrPlot import PamhyrPlot from matplotlib import collections import numpy as np from PyQt5.QtCore import ( QCoreApplication, Qt, QItemSelectionModel, QItemSelection, QItemSelectionRange, ) from PyQt5.QtWidgets import QApplication _translate = QCoreApplication.translate class PlotXY(PamhyrPlot): def __init__(self, canvas=None, trad=None, data=None, toolbar=None, table=None, parent=None): super(PlotXY, self).__init__( canvas=canvas, trad=trad, data=data, toolbar=toolbar, table=table, parent=parent ) self.line_xy = [] self.line_gl = [] self.label_x = self._trad["x"] self.label_y = self._trad["y"] self.before_plot_selected = None self.plot_selected = None self.after_plot_selected = None self.parent = parent self.line_xy_collection = None self._table = table self._colors = [] self._style = [] def onpick(self, event): if event.mouseevent.inaxes != self.canvas.axes: return if event.mouseevent.button.value != 1: return modifiers = QApplication.keyboardModifiers() if modifiers not in [Qt.ControlModifier, Qt.NoModifier, Qt.ShiftModifier]: return ind, point = self._closest_section(event) if self._table is None: return self._table.blockSignals(True) if modifiers == Qt.ControlModifier: rows = list( set( (i.row() for i in self.parent.tableView.selectedIndexes()) ) ) if ind in rows: rows.remove(ind) self._select_in_table(rows) else: self._select_in_table(rows + [ind]) elif modifiers == Qt.ShiftModifier: rows = list( set( (i.row() for i in self.parent.tableView.selectedIndexes()) ) ) if len(rows) > 0: i1 = min(rows[0], rows[-1], ind) i2 = max(rows[0], rows[-1], ind) else: i1 = ind i2 = ind self._select_range_in_table(i1, i2) else: self.parent.select_row_profile_slider(ind) self._table.blockSignals(False) return def _closest_section(self, event): axes = self.canvas.axes mx = event.mouseevent.xdata my = event.mouseevent.ydata bx, by = axes.get_xlim(), axes.get_ylim() ratio = (bx[0] - bx[1]) / (by[0] - by[1]) segments = event.artist.get_segments() ind = event.ind points = [] for i in ind: points = points + [[i, j] for j in segments[i]] def dist_mouse(point): x, y = point[1] d2 = (((mx - x) / ratio) ** 2) + ((my - y) ** 2) return d2 closest = min( points, key=dist_mouse ) return closest def _select_in_table(self, ind): if self._table is None: return self._table.setFocus() selection = self._table.selectionModel() index = QItemSelection() if len(ind) == 0: return for i in ind: index.append(QItemSelectionRange(self._table.model().index(i, 0))) selection.select( index, QItemSelectionModel.Rows | QItemSelectionModel.ClearAndSelect | QItemSelectionModel.Select ) if len(ind) > 0: self._table.scrollTo(self._table.model().index(ind[-1], 0)) def _select_range_in_table(self, ind1, ind2): if self._table is not None: self._table.setFocus() selection = self._table.selectionModel() index = QItemSelection(self._table.model().index(ind1, 0), self._table.model().index(ind2, 0)) selection.select( index, QItemSelectionModel.Rows | QItemSelectionModel.ClearAndSelect | QItemSelectionModel.Select ) self._table.scrollTo(self._table.model().index(ind2, 0)) @timer def draw(self): self.init_axes() if self.data is None: self.idle() return if self.data.number_profiles == 0: self._init = False self.idle() return self.draw_xy() self.draw_lr() self.draw_gl() self.idle() self._init = True def draw_xy(self): self.line_xy = [] for xy in zip(self.data.get_x(), self.data.get_y()): self.line_xy.append(np.column_stack(xy)) self._colors, self._style = self.color_hightlight() self.line_xy_collection = collections.LineCollection( self.line_xy, colors=self._colors, linestyle=self._style, picker=10 ) self.canvas.axes.add_collection(self.line_xy_collection) def color_hightlight(self): rows = sorted(list( set( (i.row() for i in self.parent.tableView.selectedIndexes()) ) )) colors = [self.color_plot for row in range(len(self._data))] style = ["-" for row in range(len(self._data))] if len(rows) > 0: for row in rows: colors[row] = self.color_plot_current if rows[0] > 0: colors[rows[0]-1] = self.color_plot_previous style[rows[0]-1] = "--" if rows[-1] < len(self._data)-1: colors[rows[-1]+1] = self.color_plot_next style[rows[-1]+1] = "--" return colors, style def draw_lr(self): lx = [] ly = [] rx = [] ry = [] self.line_lr = [] for x, y in zip(self.data.get_x(), self.data.get_y()): lx.append(x[0]) ly.append(y[0]) rx.append(x[-1]) ry.append(y[-1]) line = self.canvas.axes.plot( lx, ly, color=self.color_plot_river_bottom, linestyle="dotted", lw=1., ) self.line_lr.append(line) line = self.canvas.axes.plot( rx, ry, color=self.color_plot_river_bottom, linestyle="dotted", lw=1., ) self.line_lr.append(line) def draw_gl(self): x_complete = self.data.get_guidelines_x() y_complete = self.data.get_guidelines_y() ind = 0 self.line_gl = [] for x, y in zip(x_complete, y_complete): line = self.canvas.axes.plot( x, y, color=self.colors[ind % len(self.colors)], linestyle=self.linestyle[ind // len(self.colors)] ) self.line_gl.append(line) ind += 1 @timer def update(self): if not self._init: self.draw() return if self.data is None: return self.update_lr() self.update_gl() self.update_current() self.update_idle() def update_gl(self): self.data.compute_guidelines() x_complete = list(self.data.get_guidelines_x()) y_complete = list(self.data.get_guidelines_y()) # TODO comprendre à quoi sert ce bout de code # ========> # for i in range(self.data.number_profiles): # if i < len(self.line_xy): # self.line_xy[i][0].set_data( # self.data.profile(i).x(), # self.data.profile(i).y() # ) # else: # self.line_xy.append( # self.canvas.axes.plot( # self.data.profile(i).x(), # self.data.profile(i).y(), # color='r', # **self.plot_default_kargs # ) # ) # <======== for i in range(len(x_complete)): if i < len(self.line_gl): self.line_gl[i][0].set_data( x_complete[i], y_complete[i] ) else: self.line_gl.append( self.canvas.axes.plot( x_complete[i], y_complete[i] ) ) def update_current(self): if self._current_data_update: self._colors, self._style = self.color_hightlight() self.line_xy_collection.set_colors(self._colors) self.line_xy_collection.set_linestyle(self._style) def update_lr(self): for line in self.line_lr: line[0].remove() self.draw_lr()