Source code for pyswapp.utils.interactive

from contextlib import suppress

import warnings
import numbers

import matplotlib.pyplot as plt
import matplotlib.path as mpltPath
from matplotlib.backend_bases import MouseEvent

from .utils import *

from .physics import lorentzian_err, wavenumber, phase_velocity

import numpy as np
from collections.abc import Iterable

[docs] class DraggablePoints: def __init__(self, ax, points=None, **kwargs): if ax is None: raise ValueError("Argument 'ax' must be a valid Matplotlib Axes instance.") self.ax = ax self.canvas = self.ax.figure.canvas self.distance_threshold = 0.1 self._dragging_point = None self._points = {} # Dict: x -> y self._line = None self.logger = create_logging(name='Interactive') # Extract plotting kwargs color = kwargs.pop('color', 'r') ls = kwargs.pop('linestyle', '--') lw = kwargs.pop('linewidth', 1) ms = kwargs.pop('markersize', 7) marker = kwargs.pop('marker', 'x') self._plot_kwargs = { 'color': color, 'linestyle': ls, 'linewidth': lw, 'markersize': ms, 'marker': marker } self._kwargs = kwargs # Initialize points if provided if points: try: for x, y in points.items(): self._points[x] = y except Exception as e: self.logger.warning(f"Failed to load initial points: {e}") # remove last line try: if self.ax.lines: self.ax.lines[-1].remove() except Exception: pass self._update_plot() self._init_plot() def _init_plot(self): """Connect matplotlib events.""" self.canvas.mpl_connect('button_press_event', self._on_click) self.canvas.mpl_connect('button_release_event', self._on_release) self.canvas.mpl_connect('motion_notify_event', self._on_motion) def _update_plot(self, kwargs=None): """Update line data and redraw.""" if self._line and (self._line in self.ax.lines): self._line.remove() self._line = None try: if self._points: x, y = zip(*sorted(self._points.items())) self._line, = self.ax.plot(x, y, **self._plot_kwargs) except Exception as e: self.logger.warning(f"Failed to update plot: {e}") self.canvas.draw_idle() def _add_point(self, x, y=None): """add a new point""" self._points[x] = y return x, y def _remove_point(self, x, _): """Remove a point by its x coordinate.""" if x in self._points: self._points.pop(x) def _find_neighbor_point(self, event): """Find point around mouse position.""" try: if not self._points or event.xdata is None or event.ydata is None: return None points = [] for x, y in self._points.items(): try: if isinstance(y, Iterable): y = y[0] points.append((x, y)) except Exception: continue mouse_pos = np.array([event.xdata, event.ydata]) distances = np.sum((np.asarray(points) - mouse_pos) ** 2, axis=1) nearest_idx = np.argmin(distances) if distances[nearest_idx] < self.distance_threshold: return points[nearest_idx] except Exception: pass return None def _on_click(self, event): if event.inaxes != self.ax: return # Left click if event.button == 1: point = self._find_neighbor_point(event) if point: self._dragging_point = point else: if event.xdata is not None and event.ydata is not None: self._add_point(event.xdata, event.ydata) self._update_plot() # Right click elif event.button == 3: point = self._find_neighbor_point(event) if point: self._remove_point(*point) self._update_plot() def _on_release(self, event): if event.button == 1 and self._dragging_point: self._dragging_point = None self._update_plot() def _on_motion(self, event): if not self._dragging_point or event.inaxes != self.ax: return if event.xdata is None or event.ydata is None: return self._remove_point(*self._dragging_point) self._dragging_point = self._add_point(event.xdata, event.ydata) self._update_plot() @property def points(self): return self._points
[docs] class TXInteractive(DraggablePoints): def __init__(self, ax, receiver = None, points=None, data = None, picks = None, **kwargs): super().__init__(ax,points,**kwargs) self.data = data if receiver is None: receiver = self.data.receiver self.receiver = receiver self.x = np.arange(0,len(self.receiver),1) self._init_param() self.canvas.mpl_connect('key_press_event', self._on_key) def _init_param(self): self._key = 't' self._terminate = False self._reset = False self._reset_last = False self.slope = None self.tbox = None self.plot_slope = None def _add_point(self, x, y=None): # snap to data indx = np.searchsorted(self.x, [x])[0] try: x = self.x[indx] except IndexError: if indx > 0: x = self.x[indx-1] else: x = self.x[0] if len(self._points) == 2: id = sorted(self._points.keys()) lxo = id[0] rxo = id[1] if x <= lxo: self._points.pop(lxo) else: self._points.pop(rxo) self._points[x] = y return x, y def _estimate_velocity(self): """calculate the slope""" if (len(self._points) == 2) & (self.receiver is not None): idx, t = zip(*sorted(self._points.items())) x1 = self.receiver[int(idx[0])] x2 = self.receiver[int(idx[1])] t1 = t[0] t2 = t[1] self.slope = (t2-t1)/(x2-x1) props = dict(boxstyle='round', facecolor='white', alpha=0.5) self.tbox = self.ax.text(np.mean(idx), np.mean(t),f'v={round(1/self.slope,2)} m/s', bbox = props) def _on_key(self, event): if event.key == 'enter': plt.close() if event.key == 'e': #print(f'You pressed {event.key}. Process stopped.') self._terminate = True plt.close() if event.key == 'r': #print(f'You pressed {event.key}. Reset mute') if self.tbox is not None: self.tbox.remove() self._reset = True plt.close() if event.key == 'ctrl+z': #print(f'You pressed {event.key}. Reset last step') self._reset_last = True plt.close() if event.key == 't': #print(f'You pressed {event.key}.') if self.tbox is not None: with suppress(ValueError): self.tbox.remove() self._key = 't' self.ax.set_title('Top mute active', fontweight='bold') self.canvas.draw_idle() if event.key == 'b': #print(f'You pressed {event.key}.') if self.tbox is not None: with suppress(ValueError): self.tbox.remove() self._key = 'b' self.ax.set_title('Bottom mute active', fontweight='bold') self.canvas.draw_idle() if event.key == 'v': #print(f'You pressed {event.key}.') if self.tbox is not None: with suppress(ValueError): self.tbox.remove() self._estimate_velocity() self.ax.set_title('Velocity estimation', fontweight='bold') self.canvas.draw_idle()
[docs] def update_plot(self): if self._points: # reset plot self._points = {} self._update_plot()
# def filter(self): # # kwargs = self._kwargs # taper_type = kwargs.pop('taper_type', 'tukey') # taper = kwargs.pop('tapering', 'mild') # # if self._points: # self.data.linear_mute(self._points, key=self._key, taper = taper, taper_type = taper_type) # # # reset plot # self._points = {} # self._update_plot() # # return self.data @property def key(self): return self._key @property def terminate(self): return self._terminate @property def reset(self): return self._reset @property def reset_last(self): return self._reset_last @property def picks(self): return self._points
[docs] class FKFilterInteractive(DraggablePoints): def __init__(self, ax, points=None, data = None, picks = None, **kwargs): super().__init__(ax, points, **kwargs) self.data = data self._init_param() self.canvas.mpl_connect('key_press_event', self._on_key) def _init_param(self): """initialize some parameters""" self._key = 't' self._terminate = False
[docs] def update_plot(self): if self._points: # reset plot self._points = {} self._update_plot()
def _on_key(self, event): """keyboard events""" if event.key == 'e': self._terminate = True plt.close() if event.key == 't': self._key = 't' self.ax.set_title('Top filter active', fontweight='bold') # change to textbox self.canvas.draw_idle() if event.key == 'b': self._key = 'b' self.ax.set_title('Bottom filter active', fontweight='bold') # change to textbox self.canvas.draw_idle() @property def key(self): return self._key @property def terminate(self): return self._terminate @property def picks(self): return self._points
[docs] class DCPickingInteractive(DraggablePoints): """class for drawing boundaries for dispersion curve extraction""" def __init__(self, ax, points=None, data=None, freq=None, vel=None, power=None, offsets=None, picks=None, **kwargs): super().__init__(ax) self.domain = kwargs.pop('domain', 'FV') if self.domain == 'FV': self._err = 'lor' else: self._err = 5 # load data try: if data: if self.domain == 'FK': self._ydata = data.FK_data['freq'] self._xdata = data.FK_data['kw'] self._power = data.FK_data['FK_abs'] else: self._xdata = getattr(data, 'frequency', None) self._ydata = getattr(data, 'velocity', None) self._power = getattr(data, 'dispersive_energy', None) self._offsets = getattr(data, 'offset', None) else: self._xdata = freq self._ydata = vel self._power = power self._offsets = offsets except Exception as e: self.logger.warning(f"Failed to extract data attributes: {e}") self._xdata = self._ydata = self._power = self._offsets = None self._init_param() self._init_plot() # Load picks if provided if picks is not None: self._picks = picks self._update_plot(**kwargs) elif points: for x, y in points.items(): self._points[x] = y # Line removal try: for line in list(ax.lines): line.remove() except Exception: pass self._update_polygons() self._update_plot() def _init_plot(self): """event connections""" self.canvas.mpl_connect('button_release_event', self._on_release) self.canvas.mpl_connect('button_press_event', self._on_click) self.canvas.mpl_connect('motion_notify_event', self._on_motion) self.canvas.mpl_connect('key_press_event', self._on_key) self.canvas.mpl_connect('scroll_event', self._on_scroll) def _init_param(self): """initialize parameters""" self._reset = False self._mode = 0 self._polygons = {} self._picks = {} self._ppid = 0 self._picks_prior = {} self._boundary_strength = 0.3 self._minVelErr = 20 self._boundary_min = self._minVelErr self._points_prior = {} self._pick_lines = {} self._line = None self._upper_bound_line = None self._lower_bound_line = None def _update_plot(self, **kwargs): """update plot after event""" try: show_legend = kwargs.pop('show_legend', True) marker_size = kwargs.pop('markersize', 3) alpha = kwargs.pop('alpha', 1) except Exception: show_legend = True marker_size = 3 alpha = 1 try: # case 1: no boundary drawn → show picks if not self._points: cmap = getattr(plt.cm, 'Greys', plt.cm.viridis) color = cmap(np.linspace(0, 1, 10)) if self._mode in self._pick_lines: try: self._pick_lines[self._mode].remove() except Exception: pass self._pick_lines.pop(self._mode, None) try: if self._pick_lines: self.ax.legend(loc='upper right') else: lg = self.ax.get_legend() if lg: lg.remove() except Exception: pass if self._line: try: self._line.set_data([], []) except Exception: pass if self._upper_bound_line: try: self._upper_bound_line.set_data([], []) except Exception: pass if self._lower_bound_line: try: self._lower_bound_line.set_data([], []) except Exception: pass # draw picks for self._mode in self._picks: if self._mode not in self._pick_lines and not self._reset: try: picks = self._picks[self._mode] f = picks['frequency'] v = picks['velocity'] if self.domain == 'FK': x_data = wavenumber(f,v) y_data = f else: x_data = f y_data = v pick_line, = self.ax.plot( x_data, y_data, marker="s", markersize=marker_size, markeredgecolor='k', color=color[self._mode], linewidth=0, alpha=alpha, label=f'Mode {self._mode}' ) if show_legend: self.ax.legend(loc='upper right') self._pick_lines[self._mode] = pick_line except Exception as e: self.logger.warning(f"Failed plotting picks: {e}") # case 2: user is drawing boundaries else: try: x, raw_y = zip(*sorted(self._points.items())) except Exception: return try: y_data = [val[0] for val in raw_y] y_lower = [val[2] for val in raw_y] y_upper = [val[1] for val in raw_y] except Exception: return try: if self._line is None: self._line, = self.ax.plot( x, y_data, marker="o", markersize=7, markeredgecolor='k', color='white', linewidth=0 ) marker_style = dict(color="r", marker="o", markersize=3, linestyle='--', linewidth=0.8) self._upper_bound_line, = self.ax.plot(x, y_lower, **marker_style) self._lower_bound_line, = self.ax.plot(x, y_upper, **marker_style) else: self._line.set_data(x, y_data) self._upper_bound_line.set_data(x, y_lower) self._lower_bound_line.set_data(x, y_upper) except Exception as e: self.logger.warning(f"Failed updating boundary plot: {e}") self.canvas.draw_idle() except Exception as e: self.logger.error(f"Error in _update_plot: {e}") def _add_point(self, x, y=None): """add new point with upper/lower boundaries""" try: if isinstance(x, MouseEvent): x, y = float(x.xdata), float(x.ydata) except Exception: return None, None if (self._err == 'lor') and (self._offsets is not None): err = lorentzian_err(self._offsets, y, x) elif isinstance(self._err, numbers.Number): err = self._err else: err = 10 self._points[x] = [y, y + err, y - err] return x, y def _update_bounds(self): try: xs, ys = zip(*sorted(self._points.items())) except Exception: return for i in range(len(ys)): x = xs[i] y = ys[i][0] if (self._err == 'lor') and (self._offsets is not None): err = lorentzian_err( self._offsets, y, x, a=self._boundary_strength, minvelerr=self._boundary_min ) elif isinstance(self._err, numbers.Number): err = self._err * self._boundary_strength else: err = 10 * self._boundary_strength self._points[x] = [y, y + err, y - err] def _update_polygons(self): """update polygons after event""" x, ys = zip(*sorted(self._points.items())) y_lo = [p[2] for p in ys] y_up = [p[1] for p in ys] bound_lo = np.vstack((np.array(x), np.array(y_lo))).T bound_up = np.flipud(np.vstack((np.array(x), np.array(y_up))).T) self._polygons = np.vstack((bound_up, bound_lo)) def _extract_dccurve(self): """extract dispersion curve within boundary""" if self._xdata is None or self._ydata is None or self._power is None: self.logger.warning("Frequency, velocity, and power must be defined.") return fgrid, vgrid = np.meshgrid(self._xdata, self._ydata, indexing='xy') path = mpltPath.Path(self._polygons) # Test points inside polygon points = np.column_stack((fgrid.ravel(), vgrid.ravel())) flags = path.contains_points(points).reshape(self._power.shape) if not flags.any(): return masked_power = np.where(flags, self._power, -np.inf) peaks_idx = masked_power.argmax(axis=0) cols = np.flatnonzero(flags.any(axis=0)) x_pick = self._xdata[cols] y_pick = self._ydata[peaks_idx[cols]] if self.domain == 'FK': freq_pick, uniq_id = np.unique(y_pick, return_index=True) vel_pick = phase_velocity(y_pick,x_pick)[uniq_id] else: freq_pick = x_pick vel_pick = y_pick self._picks[self._mode] = {'frequency': freq_pick, 'velocity': vel_pick} self._picks_prior[self._ppid] = {'frequency': freq_pick, 'velocity': vel_pick} self._ppid = (self._ppid + 1) % 5
[docs] def interact(self): """extract dispersion curve within boundary""" if self._points: self._extract_dccurve() self._points = {} self._update_plot()
def _on_click(self, event): if event.button == 1 and event.inaxes in [self.ax]: point = self._find_neighbor_point(event) if point: self._dragging_point = point else: self._add_point(event) self._update_plot() if len(self._points) > 1: self._update_polygons() elif event.button == 3 and event.inaxes in [self.ax]: point = self._find_neighbor_point(event) if point: self._remove_point(*point) self._update_plot() def _on_scroll(self, event): if len(self._points) > 0: increment_strength = 0.05 increment_min = 0.5 if event.button == 'up': if self._boundary_strength <= 1: self._boundary_strength += increment_strength elif self._boundary_min > 1 + increment_min: self._boundary_min -= increment_min else: if self._boundary_strength > increment_strength: self._boundary_strength -= increment_strength if self._boundary_min < self._minVelErr: self._boundary_min += increment_min self._update_bounds() self._update_polygons() self._update_plot() def _on_key(self, event): if event.key.isnumeric(): mode = int(event.key) if self._mode != mode: self._mode = mode if event.key == 'p': if self._points: self._extract_dccurve() self._points = {} self._update_plot() if event.key == 'r': self._reset = True if self._mode in self._picks: self._picks.pop(self._mode) self._polygons = {} self._points = {} self._update_plot() self._reset = False if event.key == 'd': self._points = {} self._update_plot() if event.key == 'e': self._terminate = True plt.close() @property def polygons(self): return self._polygons @property def picks(self): return self._picks