Source code for rail.plotting.plotting_functions

import numpy as np
from matplotlib.figure import Figure

import matplotlib.pyplot as plt
from matplotlib import cm


[docs] def get_subplot_nrow_ncol(nfig: int) -> tuple[int, int]: """Get the number of rows and columns of sub-plots for a particular number of plots Parameters ---------- nfig: Number of figures Returns ------- Number of rows and columns as (nrow, ncol) """ shape_dict = { 1: (1, 1), 2: (1, 2), 3: (1, 3), 4: (2, 2), 5: (2, 3), 6: (2, 3), 7: (2, 4), 8: (2, 4), 9: (3, 3), 10: (3, 4), 11: (3, 4), 12: (3, 4), 13: (4, 4), 14: (4, 4), 15: (4, 4), 16: (4, 4), } try: return shape_dict[nfig] except KeyError: # pragma: no cover raise ValueError( f"Sorry, Phillipe. I'm not going to put {nfig} subplots in one figure" ) from None
[docs] def plot_feature_histograms( data: np.ndarray, labels: list[str] | None = None, bins: int | np.ndarray = 100, ) -> Figure: """Plot Histograms of the features being used to train a ML algorithm on a single, busy figure Parameters ---------- data: Input data labels: Labels for the various features bins: Bins for the histogram Returns ------- Figure with requested plots """ fig = plt.figure(figsize=(8, 8)) n_features = data.shape[-1] nrow, ncol = get_subplot_nrow_ncol(n_features) axs = fig.subplots(nrow, ncol) for ifeature in range(n_features): icol = int(ifeature / ncol) irow = ifeature % ncol axs[icol][irow].hist(data[:, ifeature], bins=bins) if labels is not None: axs[icol][irow].set_xlabel(labels[ifeature]) return fig
[docs] def plot_true_nz(targets: np.ndarray) -> Figure: """Plot the true NZ Parameters ---------- targets: Input data Returns ------- Figure with requested plot """ fig = plt.figure(figsize=(8, 8)) ax = fig.subplots(1, 1) ax.hist(targets, bins=100) return fig
[docs] def plot_feature_target_hist2d( data: np.ndarray, targets: np.ndarray, labels: list[str] | None = None, bins: tuple[int | np.ndarray, int | np.ndarray] = (100, 100), ) -> Figure: """Plot input data v. target redshift value as 2D histogram Parameters ---------- data: Input data [N_objects, N_features] targets: Target redshifts [N_objects] labels: Labels for the data columns [N_features] bins: Bins for the histogram Returns ------- Figure with requested plots Notes ----- This will create N_features sub-plots """ fig = plt.figure(figsize=(8, 8)) n_features = data.shape[-1] nrow, ncol = get_subplot_nrow_ncol(n_features) axs = fig.subplots(nrow, ncol) for ifeature in range(n_features): icol = int(ifeature / ncol) irow = ifeature % ncol axs[icol][irow].hist2d(targets, data[:, ifeature], bins=bins) if labels is not None: axs[icol][irow].set_xlabel(labels[ifeature]) return fig
[docs] def plot_colors_v_redshifts_with_templates( redshifts: np.ndarray, color_data: np.ndarray, zmax: float = 4.0, templates: dict | None = None, labels: list[str] | None = None, ) -> Figure: # pragma: no cover fig = plt.figure(figsize=(8, 8)) n_colors = color_data.shape[-1] nrow, ncol = get_subplot_nrow_ncol(n_colors) axs = fig.subplots(nrow, ncol) for icolor in range(n_colors): icol = int(icolor / ncol) irow = icolor % ncol axs[icol][irow].scatter(redshifts, color_data[:, icolor], color="black", s=1) axs[icol][irow].set_xlim(0, zmax) axs[icol][irow].set_ylim(-3.0, 3.0) if templates is not None: for key, val in templates.items(): mask = val[0] < zmax _ = axs[icol][irow].plot( val[0][mask], val[2][icolor][mask], label=key, c=cm.rainbow(1.0 - val[3] / len(templates)), ) # axs[icol][irow].legend() axs[icol][irow].set_xlabel("redshift") if labels is not None: axs[icol][irow].set_ylabel(labels[icolor]) return fig
[docs] def plot_colors_v_colors_with_templates( color_data: np.ndarray, zmax: float = 4.0, templates: dict | None = None, labels: list[str] | None = None, ) -> Figure: # pragma: no cover fig = plt.figure(figsize=(8, 8)) n_colors = color_data.shape[-1] nrow, ncol = n_colors - 1, n_colors - 1 axs = fig.subplots(nrow, ncol) for icol in range(n_colors - 1): for irow in range(n_colors - 1): axs[icol][irow].set_xlim(-3.0, 3.0) axs[icol][irow].set_ylim(-3.0, 3.0) if labels is not None: axs[icol][irow].set_ylabel(labels[icol]) axs[icol][irow].set_xlabel(labels[irow + 1]) if irow < icol: continue axs[icol][irow].scatter( color_data[:, icol], color_data[:, irow + 1], color="black", s=1 ) if templates is not None: for key, val in templates.items(): mask = val[0] < zmax _ = axs[icol][irow].plot( val[2][icol][mask], val[2][irow + 1][mask], label=key, c=cm.rainbow(1.0 - val[3] / len(templates)), ) # axs[icol][irow].legend() return fig