Source code for rail.plotting.pz_plotters

from __future__ import annotations

import os
from typing import Any

import numpy as np
from astropy.stats import biweight_location, biweight_scale
from ceci.config import StageParameter
from matplotlib import colors
from matplotlib import pyplot as plt
from scipy.stats import sigmaclip

from .dataset import RailDataset
from .dataset_holder import RailDatasetHolder
from .plot_holder import RailPlotHolder
from .plotter import RailPlotter


[docs] class RailPZPointEstimateDataset(RailDataset): """Dataet to hold a vector p(z) point estimates and corresponding true redshifts """ data_types = dict(truth=np.ndarray, pointEstimate=np.ndarray, magnitude=np.ndarray)
[docs] class RailPZMultiPointEstimateDataset(RailDataset): """Dataet to hold a set of vectors of p(z) point estimates and corresponding true redshifts """ data_types = dict( truth=np.ndarray, pointEstimates=dict[str, np.ndarray], )
[docs] class PZPlotterPointEstimateVsTrueHist2D(RailPlotter): """Class to make a 2D histogram of p(z) point estimates versus true redshift """ config_options: dict[str, StageParameter] = RailPlotter.config_options.copy() config_options.update( z_min=StageParameter(float, 0.0, fmt="%0.2f", msg="Minimum Redshift"), z_max=StageParameter(float, 3.0, fmt="%0.2f", msg="Maximum Redshift"), n_zbins=StageParameter(int, 150, fmt="%i", msg="Number of z bins"), n_clip=StageParameter( int, 3, fmt="%i", msg="Number of sigma cliping for outliers" ), abs_out_thresh=StageParameter( float, 0.2, fmt="%0.2f", msg="Threshold for the absolute outlier rate" ), ) input_type = RailPZPointEstimateDataset def _make_2d_hist_plot( self, prefix: str, truth: np.ndarray, pointEstimate: np.ndarray, dataset_holder: RailDatasetHolder | None = None, ) -> RailPlotHolder: figure, axes = plt.subplots(figsize=(7, 6)) bin_edges = np.linspace( self.config.z_min, self.config.z_max, self.config.n_zbins + 1 ) dz = (pointEstimate - truth) / (1 + truth) mean, _mean_err, std, outlier_rate, abs_outlier_rate = ( self.get_biweight_mean_sigma_outlier(dz, nclip=self.config.n_clip) ) mean, std, outlier_rate, abs_outlier_rate = ( round(mean, 4), round(std, 4), round(outlier_rate, 4), round(abs_outlier_rate, 4), ) h = axes.hist2d( truth, pointEstimate, bins=(bin_edges, bin_edges), norm=colors.LogNorm(), cmap="gray", ) axes.plot( [self.config.z_min - 10, self.config.z_max + 10], [self.config.z_min - 10, self.config.z_max + 10], "--", color="red", ) axes.plot( [self.config.z_min - 10, self.config.z_max + 10], [self.config.z_min - 10 - 3 * std, self.config.z_max + 10 - 3 * std], "--", color="red", ) axes.plot( [self.config.z_min - 10, self.config.z_max + 10], [self.config.z_min - 10 + 3 * std, self.config.z_max + 10 + 3 * std], "--", color="red", ) axes.plot( [], [], ".", alpha=0.0, label=rf"$\Delta z = {mean} $" + "\n" + rf"$\sigma z = {std} $" + "\n" + rf"outlier rate (>3$\sigma$) = {outlier_rate}" + "\n" + f"outlier rate (>{self.config.abs_out_thresh}) = {abs_outlier_rate}", ) plt.xlabel("True Redshift") plt.ylabel("Estimated Redshift") cb = figure.colorbar(h[3], ax=axes) cb.set_label("Density") plt.legend() plot_name = self._make_full_plot_name(prefix, "") return RailPlotHolder( name=plot_name, figure=figure, plotter=self, dataset_holder=dataset_holder ) def _make_plots(self, prefix: str, **kwargs: Any) -> dict[str, RailPlotHolder]: find_only = kwargs.get("find_only", False) figtype = kwargs.get("figtype", "png") dataset_holder = kwargs.get("dataset_holder") out_dict: dict[str, RailPlotHolder] = {} truth: np.ndarray = kwargs["truth"] pointEstimate: np.ndarray = kwargs["pointEstimate"] if find_only: plot_name = self._make_full_plot_name(prefix, "") assert dataset_holder plot = RailPlotHolder( name=plot_name, path=os.path.join(dataset_holder.config.name, f"{plot_name}.{figtype}"), plotter=self, dataset_holder=dataset_holder, ) else: plot = self._make_2d_hist_plot( prefix=prefix, truth=truth, pointEstimate=pointEstimate, dataset_holder=dataset_holder, ) out_dict[plot.name] = plot return out_dict
[docs] def get_biweight_mean_sigma_outlier( self, subset: np.ndarray, nclip: int = 3 ) -> tuple[float, float, float, float, float]: subset_clip, _, _ = sigmaclip(subset, low=3, high=3) for _j in range(nclip): subset_clip, _, _ = sigmaclip(subset_clip, low=3, high=3) mean = biweight_location(subset_clip) std = biweight_scale(subset_clip) outlier_rate = np.sum(np.abs(subset) > 3 * biweight_scale(subset_clip)) / len( subset ) abs_outlier_rate = np.sum(np.abs(subset) > self.config.abs_out_thresh) / len( subset ) return ( mean, std / np.sqrt(len(subset_clip)), std, outlier_rate, abs_outlier_rate, )
[docs] class PZPlotterPointEstimateVsTrueProfile(RailPlotter): """Class to make a profile plot of p(z) point estimates versus true redshift """ config_options: dict[str, StageParameter] = RailPlotter.config_options.copy() config_options.update( z_min=StageParameter(float, 0.0, fmt="%0.2f", msg="Minimum Redshift"), z_max=StageParameter(float, 3.0, fmt="%0.2f", msg="Maximum Redshift"), n_zbins=StageParameter(int, 150, fmt="%i", msg="Number of z bins"), ) input_type = RailPZPointEstimateDataset def _make_2d_profile_plot( self, prefix: str, truth: np.ndarray, pointEstimate: np.ndarray, dataset_holder: RailDatasetHolder | None = None, ) -> RailPlotHolder: figure, axes = plt.subplots() bin_edges = np.linspace( self.config.z_min, self.config.z_max, self.config.n_zbins + 1 ) bin_centers = 0.5 * (bin_edges[0:-1] + bin_edges[1:]) z_true_bin = np.searchsorted(bin_edges, truth) means = np.zeros((self.config.n_zbins)) stds = np.zeros((self.config.n_zbins)) for i in range(self.config.n_zbins): mask = z_true_bin == i data = pointEstimate[mask] if len(data) == 0: continue means[i] = np.mean(data) - bin_centers[i] stds[i] = np.std(data) axes.errorbar( bin_centers, means, stds, ) plt.xlabel("True Redshift") plt.ylabel("Estimated Redshift") plot_name = self._make_full_plot_name(prefix, "") return RailPlotHolder( name=plot_name, figure=figure, plotter=self, dataset_holder=dataset_holder ) def _make_plots(self, prefix: str, **kwargs: Any) -> dict[str, RailPlotHolder]: find_only = kwargs.get("find_only", False) figtype = kwargs.get("figtype", "png") dataset_holder = kwargs.get("dataset_holder") out_dict: dict[str, RailPlotHolder] = {} truth: np.ndarray = kwargs["truth"] pointEstimate: np.ndarray = kwargs["pointEstimate"] if find_only: assert dataset_holder plot_name = self._make_full_plot_name(prefix, "") plot = RailPlotHolder( name=plot_name, path=os.path.join(dataset_holder.config.name, f"{plot_name}.{figtype}"), plotter=self, dataset_holder=dataset_holder, ) else: plot = self._make_2d_profile_plot( prefix=prefix, truth=truth, pointEstimate=pointEstimate, dataset_holder=dataset_holder, ) out_dict[plot.name] = plot return out_dict
[docs] class PZPlotterAccuraciesVsTrue(RailPlotter): """Class to make a plot of the accuracy of several algorithms versus true redshift """ config_options: dict[str, StageParameter] = RailPlotter.config_options.copy() config_options.update( z_min=StageParameter(float, 0.0, fmt="%0.2f", msg="Minimum Redshift"), z_max=StageParameter(float, 3.0, fmt="%0.2f", msg="Maximum Redshift"), n_zbins=StageParameter(int, 150, fmt="%i", msg="Number of z bins"), delta_cutoff=StageParameter( float, 0.1, fmt="%0.2f", msg="Delta-Z Cutoff for accurary" ), ) input_type = RailPZMultiPointEstimateDataset def _make_accuracy_plot( self, prefix: str, truth: np.ndarray, pointEstimates: dict[str, np.ndarray], dataset_holder: RailDatasetHolder | None = None, ) -> RailPlotHolder: figure, axes = plt.subplots() bin_edges = np.linspace( self.config.z_min, self.config.z_max, self.config.n_zbins + 1 ) bin_centers = 0.5 * (bin_edges[0:-1] + bin_edges[1:]) z_true_bin = np.searchsorted(bin_edges, truth) for key, val in pointEstimates.items(): deltas = val - truth accuracy = np.ones((self.config.n_zbins)) * np.nan for i in range(self.config.n_zbins): mask = z_true_bin == i data = deltas[mask] if len(data) == 0: continue accuracy[i] = (np.abs(data) <= self.config.delta_cutoff).sum() / float( len(data) ) axes.plot( bin_centers, accuracy, label=key, ) plt.xlabel("True Redshift") plt.ylabel("Estimated Redshift") plot_name = self._make_full_plot_name(prefix, "") return RailPlotHolder( name=plot_name, figure=figure, plotter=self, dataset_holder=dataset_holder ) def _make_plots(self, prefix: str, **kwargs: Any) -> dict[str, RailPlotHolder]: find_only = kwargs.get("find_only", False) figtype = kwargs.get("figtype", "png") dataset_holder = kwargs.get("dataset_holder") out_dict: dict[str, RailPlotHolder] = {} if find_only: plot_name = self._make_full_plot_name(prefix, "") assert dataset_holder plot = RailPlotHolder( name=plot_name, path=os.path.join(dataset_holder.config.name, f"{plot_name}.{figtype}"), plotter=self, dataset_holder=dataset_holder, ) else: plot = self._make_accuracy_plot(prefix=prefix, **kwargs) out_dict[plot.name] = plot return out_dict
[docs] class PZPlotterBiweightStatsVsRedshift(RailPlotter): """Class to make a 2D histogram of p(z) point estimates versus true redshift """ config_options: dict[str, StageParameter] = RailPlotter.config_options.copy() config_options.update( z_min=StageParameter(float, 0.0, fmt="%0.2f", msg="Minimum Redshift"), z_max=StageParameter(float, 3.0, fmt="%0.2f", msg="Maximum Redshift"), n_zbins=StageParameter(int, 15, fmt="%i", msg="Number of z bins"), n_clip=StageParameter( int, 3, fmt="%i", msg="Number of sigma cliping for outliers" ), zbin_type=StageParameter( str, "spec", fmt="%s", msg="Type of redshift binned by, 'spec' or 'phot'. " ), ) input_type = RailPZPointEstimateDataset def _make_biweight_stats_plot( self, prefix: str, truth: np.ndarray, pointEstimate: np.ndarray, dataset_holder: RailDatasetHolder | None = None, ) -> RailPlotHolder: dz = (pointEstimate - truth) / (1 + truth) if self.config.zbin_type == "spec": x_label = r"$z_{spec}$" z_x = truth elif self.config.zbin_type == "phot": # pragma: no cover x_label = r"$z_{phot}$" z_x = pointEstimate else: # pragma: no cover raise ValueError("`zbin_type` must be either 'spec' or 'phot'") results = self.process_data( pointEstimate, truth, nbin=self.config.n_zbins, low=self.config.z_min, high=self.config.z_max, nclip=self.config.n_clip, ) figure, axes = plt.subplots(2, 1, figsize=(8, 6)) plt.subplots_adjust(wspace=0.1, hspace=0.0) axes[0].errorbar( results["z_mean"], results["biweight_mean"], results["biweight_std"], label="Bias", ) axes[0].plot(results["z_mean"], results["biweight_sigma"], label=r"$\sigma_z$") axes[0].plot( results["z_mean"], results["biweight_outlier"], label=r"Outlier rate" ) axes[0].set_title( f"Bias, Sigma, and Outlier rates w/ {self.config.n_clip} sigma clipping" ) axes[0].set_ylabel("Statistics") axes[0].legend() axes[0].tick_params( axis="x", which="both", bottom=False, top=False, labelbottom=False ) axes[0].set_xlim(self.config.z_min, self.config.z_max) bin_edges_z = np.linspace(self.config.z_min, self.config.z_max, 100 + 1) bin_edges_dz = np.linspace(np.min(dz), np.max(dz), 100 + 1) axes[1].hist2d( z_x, dz, bins=(bin_edges_z, bin_edges_dz), norm=colors.LogNorm(), cmap="gray", ) axes[1].set_xlim(self.config.z_min, self.config.z_max) for qt in ["qt_95_low", "qt_68_low", "median", "qt_68_high", "qt_95_high"]: axes[1].plot( results["z_mean"], results[qt], "--", color="blue", linewidth=2.0 ) axes[1].set_xlabel(x_label) axes[1].set_ylabel(r"$(z_{phot} - z_{spec})/(1+z_{spec})$") plot_name = self._make_full_plot_name(prefix, "") return RailPlotHolder( name=plot_name, figure=figure, plotter=self, dataset_holder=dataset_holder ) def _make_plots(self, prefix: str, **kwargs: Any) -> dict[str, RailPlotHolder]: find_only = kwargs.get("find_only", False) figtype = kwargs.get("figtype", "png") dataset_holder = kwargs.get("dataset_holder") out_dict: dict[str, RailPlotHolder] = {} truth: np.ndarray = kwargs["truth"] pointEstimate: np.ndarray = kwargs["pointEstimate"] if find_only: plot_name = self._make_full_plot_name(prefix, "") assert dataset_holder plot = RailPlotHolder( name=plot_name, path=os.path.join(dataset_holder.config.name, f"{plot_name}.{figtype}"), plotter=self, dataset_holder=dataset_holder, ) else: plot = self._make_biweight_stats_plot( prefix=prefix, truth=truth, pointEstimate=pointEstimate, dataset_holder=dataset_holder, ) out_dict[plot.name] = plot return out_dict
[docs] def process_data( self, zphot: np.ndarray, specz: np.ndarray, low: float = 0.01, high: float = 2.0, nclip: int = 3, nbin: int = 101, ) -> dict[str, list[float]]: dz = (zphot - specz) / (1 + specz) z_bins = np.linspace(low, high, nbin) # Bin the data if self.config.zbin_type == "spec": zx = specz else: # pragma: no cover zx = zphot bin_indices = np.digitize(zx, bins=z_bins) - 1 # Assign each point to a bin biweight_mean: list[float] = [] biweight_std: list[float] = [] biweight_sigma: list[float] = [] biweight_outlier: list[float] = [] z_mean: list[float] = [] qt_95_low: list[float] = [] qt_68_low: list[float] = [] median: list[float] = [] qt_68_high: list[float] = [] qt_95_high: list[float] = [] for i in range(len(z_bins) - 1): subset = dz[bin_indices == i] if len(subset) < 1: # pragma: no cover continue subset_clip, _, _ = sigmaclip(subset, low=3, high=3) for _j in range(nclip): subset_clip, _, _ = sigmaclip(subset_clip, low=3, high=3) biweight_mean.append(biweight_location(subset_clip)) biweight_std.append(biweight_scale(subset_clip) / np.sqrt(len(subset_clip))) biweight_sigma.append(biweight_scale(subset_clip)) outlier_rate = np.sum( np.abs(subset) > 3 * biweight_scale(subset_clip) ) / len(subset) biweight_outlier.append(outlier_rate) qt_95_low.append(np.percentile(subset, 2.5)) qt_68_low.append(np.percentile(subset, 16)) median.append(np.percentile(subset, 50)) qt_68_high.append(np.percentile(subset, 84)) qt_95_high.append(np.percentile(subset, 97.5)) z_mean.append(np.mean(zx[bin_indices == i])) return { "z_mean": z_mean, "biweight_mean": biweight_mean, "biweight_std": biweight_std, "biweight_sigma": biweight_sigma, "biweight_outlier": biweight_outlier, "qt_95_low": qt_95_low, "qt_68_low": qt_68_low, "median": median, "qt_68_high": qt_68_high, "qt_95_high": qt_95_high, }
[docs] class PZPlotterBiweightStatsVsMag(RailPlotter): """Class to make a 2D histogram of p(z) point estimates versus true redshift """ config_options: dict[str, StageParameter] = RailPlotter.config_options.copy() config_options.update( mag_min=StageParameter(float, 18, fmt="%0.2f", msg="Minimum Magnitude"), mag_max=StageParameter(float, 25, fmt="%0.2f", msg="Maximum Magnitude"), n_magbins=StageParameter(int, 10, fmt="%i", msg="Number of magnitude bins"), n_clip=StageParameter( int, 3, fmt="%i", msg="Number of sigma cliping for outliers" ), ) input_type = RailPZPointEstimateDataset def _make_biweight_stats_plot( self, prefix: str, truth: np.ndarray, pointEstimate: np.ndarray, magnitude: np.ndarray, dataset_holder: RailDatasetHolder | None = None, ) -> RailPlotHolder: dz = (pointEstimate - truth) / (1 + truth) results = self.process_data( pointEstimate, truth, magnitude, nbin=self.config.n_magbins, low=self.config.mag_min, high=self.config.mag_max, nclip=self.config.n_clip, ) figure, axes = plt.subplots(2, 1, figsize=(8, 6)) plt.subplots_adjust(wspace=0.1, hspace=0.0) axes[0].errorbar( results["mag_mean"], results["biweight_mean"], results["biweight_std"], label="Bias", ) axes[0].plot( results["mag_mean"], results["biweight_sigma"], label=r"$\sigma_z$" ) axes[0].plot( results["mag_mean"], results["biweight_outlier"], label=r"Outlier rate" ) axes[0].set_title( f"Bias, Sigma, and Outlier rates w/ {self.config.n_clip} sigma clipping" ) axes[0].set_ylabel("Statistics") axes[0].legend() axes[0].tick_params( axis="x", which="both", bottom=False, top=False, labelbottom=False ) axes[0].set_xlim(self.config.mag_min, self.config.mag_max) bin_edges_mag = np.linspace(self.config.mag_min, self.config.mag_max, 100 + 1) bin_edges_dz = np.linspace(np.min(dz), np.max(dz), 100 + 1) axes[1].hist2d( magnitude, dz, bins=(bin_edges_mag, bin_edges_dz), norm=colors.LogNorm(), cmap="gray", ) axes[1].set_xlim(self.config.mag_min, self.config.mag_max) for qt in ["qt_95_low", "qt_68_low", "median", "qt_68_high", "qt_95_high"]: axes[1].plot( results["mag_mean"], results[qt], "--", color="blue", linewidth=2.0 ) axes[1].set_xlabel("Magnitude") axes[1].set_ylabel(r"$(z_{phot} - z_{spec})/(1+z_{spec})$") plot_name = self._make_full_plot_name(prefix, "") return RailPlotHolder( name=plot_name, figure=figure, plotter=self, dataset_holder=dataset_holder ) def _make_plots(self, prefix: str, **kwargs: Any) -> dict[str, RailPlotHolder]: find_only = kwargs.get("find_only", False) figtype = kwargs.get("figtype", "png") dataset_holder = kwargs.get("dataset_holder") out_dict: dict[str, RailPlotHolder] = {} truth: np.ndarray = kwargs["truth"] pointEstimate: np.ndarray = kwargs["pointEstimate"] magnitude: np.ndarray = kwargs["magnitude"] if find_only: plot_name = self._make_full_plot_name(prefix, "") assert dataset_holder plot = RailPlotHolder( name=plot_name, path=os.path.join(dataset_holder.config.name, f"{plot_name}.{figtype}"), plotter=self, dataset_holder=dataset_holder, ) else: plot = self._make_biweight_stats_plot( prefix=prefix, truth=truth, pointEstimate=pointEstimate, magnitude=magnitude, dataset_holder=dataset_holder, ) out_dict[plot.name] = plot return out_dict
[docs] def process_data( self, zphot: np.ndarray, specz: np.ndarray, mag: np.ndarray, low: float = 0.01, high: float = 2.0, nclip: int = 3, nbin: int = 101, ) -> dict[str, list[float] | np.ndarray]: dz = (zphot - specz) / (1 + specz) mag_bins = np.linspace(low, high, nbin) # Bin the data bin_indices = np.digitize(mag, bins=mag_bins) - 1 # Assign each point to a bin biweight_mean: list[float] = [] biweight_std: list[float] = [] biweight_sigma: list[float] = [] biweight_outlier: list[float] = [] qt_95_low: list[float] = [] qt_68_low: list[float] = [] median: list[float] = [] qt_68_high: list[float] = [] qt_95_high: list[float] = [] for i in range(len(mag_bins) - 1): subset = dz[bin_indices == i] subset_clip, _, _ = sigmaclip(subset, low=3, high=3) for _j in range(nclip): subset_clip, _, _ = sigmaclip(subset_clip, low=3, high=3) if len(subset_clip) == 0: # pragma: no cover biweight_mean.append(np.nan) biweight_std.append(np.nan) biweight_sigma.append(np.nan) biweight_outlier.append(np.nan) qt_95_low.append(np.nan) qt_68_low.append(np.nan) median.append(np.nan) qt_68_high.append(np.nan) qt_95_high.append(np.nan) continue biweight_mean.append(biweight_location(subset_clip)) biweight_std.append(biweight_scale(subset_clip) / np.sqrt(len(subset_clip))) biweight_sigma.append(biweight_scale(subset_clip)) outlier_rate = np.sum( np.abs(subset) > 3 * biweight_scale(subset_clip) ) / len(subset) biweight_outlier.append(outlier_rate) qt_95_low.append(np.percentile(subset, 2.5)) qt_68_low.append(np.percentile(subset, 16)) median.append(np.percentile(subset, 50)) qt_68_high.append(np.percentile(subset, 84)) qt_95_high.append(np.percentile(subset, 97.5)) mag_mean = (mag_bins[:-1] + mag_bins[1:]) / 2 return { "mag_mean": mag_mean, "biweight_mean": biweight_mean, "biweight_std": biweight_std, "biweight_sigma": biweight_sigma, "biweight_outlier": biweight_outlier, "qt_95_low": qt_95_low, "qt_68_low": qt_68_low, "median": median, "qt_68_high": qt_68_high, "qt_95_high": qt_95_high, }