Source code for pyswapp.utils.plotting

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import *
from matplotlib.cm import ScalarMappable,get_cmap
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle

# %% helper functions for plotting
[docs] def calculate_new_limit(fixed, dependent, limit): """Calculates the min/max of the dependent axis given a fixed axis with limits""" if len(fixed) > 2: mask = (fixed > limit[0]) & (fixed < limit[1]) window = dependent[mask] if len(window) > 0: low, high = window.min(), window.max() else: low = dependent[0] high = dependent[-1] else: low = dependent[0] high = dependent[-1] if low == 0.0 and high == 1.0: # This is a axhline in the autoscale direction low = np.inf high = -np.inf return low, high
[docs] def get_xy(artist): """get the xy coordinates of a given artist""" if "Collection" in str(artist): x, y = artist.get_offsets().T elif "Line" in str(artist): x, y = artist.get_xdata(), artist.get_ydata() return x, y
[docs] def discrete_cmap(N, base_cmap=None): """Create an N-bin discrete colormap from the specified input map""" base = plt.cm.get_cmap(base_cmap) color_list = base(np.linspace(0, 1, N)) cmap_name = base.name + str(N) return LinearSegmentedColormap.from_list(cmap_name, color_list, N)
[docs] def draw1DColumn(ax, x, val, thk=None, depth = None, width=1, vmin=1, vmax=1000, cmap=None): """draw a 1D column based on thicknesses or depths""" if depth is None and thk is not None: depth = np.hstack((0., np.cumsum(thk), np.sum(thk) * 1.5)) recs = [] for i in range(len(val)): recs.append(Rectangle((x - width / 2., depth[i]), width, depth[i + 1] - depth[i])) pp = PatchCollection(recs) col = ax.add_collection(pp) pp.set_edgecolor(None) pp.set_linewidths(0.0) if cmap is not None: pp.set_cmap(cmap) pp.set_norm(Normalize(vmin, vmax)) pp.set_array(np.asarray(val, dtype=float)) pp.set_clim(vmin, vmax) return col
[docs] def plot_vphase(ax, val, f, lam=None, xmid=0, vmin=100, vmax=1000, y_value = 'lam',width = 1,**kwargs): """draw dispersion curve as 1D column""" cmap = kwargs.setdefault('cmap','viridis') if y_value == 'lam': if lam is not None: wavelength = lam else: wavelength = val / f wavelength = np.array(wavelength) sort_idx = np.argsort(wavelength) wavelength = wavelength[sort_idx] values = np.array(val)[sort_idx] draw1DColumn(ax, xmid, values, depth=np.hstack((wavelength, wavelength[-1])), cmap=cmap, vmin=vmin, vmax=vmax,width = width) ax.set_ylabel(r'wavelength (m)') else: sort_idx = np.argsort(f) f = f[sort_idx] values = np.array(val)[sort_idx] draw1DColumn(ax, xmid, values, depth=np.hstack((f, f[-1])), cmap=cmap, vmin=vmin, vmax=vmax, width=width) ax.set_ylabel(r'frequency (Hz)') ax.set_xlabel('x (m)') ax.grid(True, linestyle=':') return ax, cmap
[docs] def plot_colorBar(ax,vmin,vmax, orientation='vertical', size=0.2, pad=None,**kwargs): """create and plot a colorbar""" from mpl_toolkits.axes_grid1 import make_axes_locatable divider = make_axes_locatable(ax) if orientation == 'horizontal': if pad is None: pad = 0.5 cax = divider.append_axes("bottom", size=size, pad=pad) else: if pad is None: pad = 0.1 cax = divider.append_axes("right", size=size, pad=pad) cmap = kwargs.setdefault('cmap', 'viridis') label = kwargs.setdefault('label','phase velocity (m/s)') norm = plt.Normalize(vmin, vmax) sm = ScalarMappable(norm=norm, cmap=mpl.colormaps[cmap]) fig = ax.figure cbar = fig.colorbar(sm, cax=cax, orientation = orientation, **kwargs) cbar.set_label(label) return cbar
[docs] def plot_tomo2D(dphi, phi_model, recs_plot, phi_vel, f, axes = None, outfile = None): """plot tomographic like approach""" if axes is None: fig, ax = plt.subplots(1, 2, figsize=(6, 2)) else: ax = axes fig = ax.figure ax[0].plot(dphi, color='k', marker='o', markersize=5) ax[0].plot(phi_model, color='r') ax[0].set_xlabel("x (m)") ax[0].set_ylabel(f"phase differences (rad)") ax[0].grid() ax[1].scatter(recs_plot, phi_vel, s=15, c='darkgrey', marker='o', edgecolor='k', linewidth=0.2, zorder=-2, label=f'f = {round(f)} Hz') ax[1].set_xlim([np.min(recs_plot), np.max(recs_plot)]) ax[1].set_ylabel(f"phase velocity (m/s)") ax[1].set_xlabel("offset (m)") ax[1].legend(loc='lower right', frameon=True) ax[1].grid() plt.tight_layout() if outfile: fig.savefig(outfile) plt.close() else: plt.show()