Source code for rail.plotting.cat_plotters

from __future__ import annotations

import os
from typing import Any

import numpy as np
from ceci.config import StageParameter

from .dataset import RailDataset
from .dataset_holder import RailDatasetHolder
from .plot_holder import RailPlotHolder
from .plotter import RailPlotter
from . import plotting_functions
from . import utility_functions


[docs] class RailCatTruthDataset(RailDataset): """Dataset to hold any array of true redshifts""" data_types = dict(truth=np.ndarray)
[docs] class RailCatMagnitudesDataset(RailDataset): """Dataset to hold any array of magntidues and list of band names""" data_types = dict(magnitudes=np.ndarray, bands=list)
[docs] class RailCatTruthAndMagnitudesDataset(RailCatTruthDataset, RailCatMagnitudesDataset): """Dataset to hold any array of magntidues and list of band names and true (or spec) redshifts """ data_types = RailCatMagnitudesDataset.data_types.copy() data_types.update(**RailCatTruthDataset.data_types)
[docs] class CatPlotterTruth(RailPlotter): """Class to make a histogram magnitudes in each band""" config_options: dict[str, StageParameter] = RailPlotter.config_options.copy() config_options.update( z_min=StageParameter(float, 0.0, fmt="%0.2f", msg="Minimum Redshift"), z_max=StageParameter(float, 3.0, fmt="%0.2f", msg="Maximum Redshift"), n_zbins=StageParameter(int, 151, fmt="%i", msg="Number of Redshift bins"), ) input_type = RailCatTruthDataset def _make_hist_plot( self, prefix: str, truth: np.ndarray, dataset_holder: RailDatasetHolder | None = None, ) -> RailPlotHolder: figure = plotting_functions.plot_true_nz(truth) plot_name = self._make_full_plot_name(prefix, "") return RailPlotHolder( name=plot_name, figure=figure, plotter=self, dataset_holder=dataset_holder ) def _make_plots(self, prefix: str, **kwargs: Any) -> dict[str, RailPlotHolder]: find_only = kwargs.get("find_only", False) figtype = kwargs.get("figtype", "png") dataset_holder = kwargs.get("dataset_holder") out_dict: dict[str, RailPlotHolder] = {} truth: np.ndarray = kwargs["truth"] if find_only: plot_name = self._make_full_plot_name(prefix, "") assert dataset_holder plot = RailPlotHolder( name=plot_name, path=os.path.join(dataset_holder.config.name, f"{plot_name}.{figtype}"), plotter=self, dataset_holder=dataset_holder, ) else: plot = self._make_hist_plot( prefix=prefix, truth=truth, dataset_holder=dataset_holder, ) out_dict[plot.name] = plot return out_dict
[docs] class CatPlotterMagntidues(RailPlotter): """Class to make a histogram magnitudes in each band""" config_options: dict[str, StageParameter] = RailPlotter.config_options.copy() config_options.update( mag_min=StageParameter(float, 18.0, fmt="%0.2f", msg="Minimum Magnitude"), mag_max=StageParameter(float, 28.0, fmt="%0.2f", msg="Maximum Magnitude"), n_magbins=StageParameter(int, 101, fmt="%i", msg="Number of magnitude bins"), ) input_type = RailCatMagnitudesDataset def _make_hist_plots( self, prefix: str, magnitudes: np.ndarray, bands: list[str], dataset_holder: RailDatasetHolder | None = None, ) -> RailPlotHolder: xbins = np.linspace( self.config.mag_min, self.config.mag_max, self.config.n_magbins ) figure = plotting_functions.plot_feature_histograms( magnitudes, bands, bins=xbins ) plot_name = self._make_full_plot_name(prefix, "") return RailPlotHolder( name=plot_name, figure=figure, plotter=self, dataset_holder=dataset_holder ) def _make_plots(self, prefix: str, **kwargs: Any) -> dict[str, RailPlotHolder]: find_only = kwargs.get("find_only", False) figtype = kwargs.get("figtype", "png") dataset_holder = kwargs.get("dataset_holder") out_dict: dict[str, RailPlotHolder] = {} magnitudes: np.ndarray = kwargs["magnitudes"] bands: list[str] = kwargs["bands"] if find_only: plot_name = self._make_full_plot_name(prefix, "") assert dataset_holder plot = RailPlotHolder( name=plot_name, path=os.path.join(dataset_holder.config.name, f"{plot_name}.{figtype}"), plotter=self, dataset_holder=dataset_holder, ) else: plot = self._make_hist_plots( prefix=prefix, magnitudes=magnitudes, bands=bands, dataset_holder=dataset_holder, ) out_dict[plot.name] = plot return out_dict
[docs] class CatPlotterMagntiduesVsTruth(RailPlotter): """Class to make 2D histograms of magntidue versus true redshift in each band """ config_options: dict[str, StageParameter] = RailPlotter.config_options.copy() config_options.update( z_min=StageParameter(float, 0.0, fmt="%0.2f", msg="Minimum Redshift"), z_max=StageParameter(float, 3.0, fmt="%0.2f", msg="Maximum Redshift"), n_zbins=StageParameter(int, 151, fmt="%i", msg="Number of Redshift bins"), mag_min=StageParameter(float, 18.0, fmt="%0.2f", msg="Minimum Magnitude"), mag_max=StageParameter(float, 25.0, fmt="%0.2f", msg="Maximum Magnitude"), n_magbins=StageParameter(int, 141, fmt="%i", msg="Number of magnitude bins"), ) input_type = RailCatTruthAndMagnitudesDataset def _make_2d_hist_plots( self, prefix: str, truth: np.ndarray, magnitudes: np.ndarray, bands: list[str], dataset_holder: RailDatasetHolder | None = None, ) -> RailPlotHolder: xbins = np.linspace(self.config.z_min, self.config.z_min, self.config.n_zbins) ybins = np.linspace( self.config.mag_min, self.config.mag_max, self.config.n_magbins ) figure = plotting_functions.plot_feature_target_hist2d( magnitudes, truth, bands, bins=(xbins, ybins) ) plot_name = self._make_full_plot_name(prefix, "") return RailPlotHolder( name=plot_name, figure=figure, plotter=self, dataset_holder=dataset_holder ) def _make_plots(self, prefix: str, **kwargs: Any) -> dict[str, RailPlotHolder]: find_only = kwargs.get("find_only", False) figtype = kwargs.get("figtype", "png") dataset_holder = kwargs.get("dataset_holder") out_dict: dict[str, RailPlotHolder] = {} truth: np.ndarray = kwargs["truth"] magnitudes: np.ndarray = kwargs["magnitudes"] bands: list[str] = kwargs["bands"] if find_only: plot_name = self._make_full_plot_name(prefix, "") assert dataset_holder plot = RailPlotHolder( name=plot_name, path=os.path.join(dataset_holder.config.name, f"{plot_name}.{figtype}"), plotter=self, dataset_holder=dataset_holder, ) else: plot = self._make_2d_hist_plots( prefix=prefix, truth=truth, magnitudes=magnitudes, bands=bands, dataset_holder=dataset_holder, ) out_dict[plot.name] = plot return out_dict
[docs] class CatPlotterColorsVsTruth(RailPlotter): """Class to make 2D histograms of colors versus true redshift in each band """ config_options: dict[str, StageParameter] = RailPlotter.config_options.copy() config_options.update( z_min=StageParameter(float, 0.0, fmt="%0.2f", msg="Minimum Redshift"), z_max=StageParameter(float, 3.0, fmt="%0.2f", msg="Maximum Redshift"), n_zbins=StageParameter(int, 151, fmt="%i", msg="Number of Redshift bins"), color_min=StageParameter(float, -2.0, fmt="%0.2f", msg="Minimum color"), color_max=StageParameter(float, 2.0, fmt="%0.2f", msg="Maximum color"), n_colorbins=StageParameter(int, 161, fmt="%i", msg="Number of color bins"), ) input_type = RailCatTruthAndMagnitudesDataset def _make_2d_hist_plots( self, prefix: str, truth: np.ndarray, magnitudes: np.ndarray, bands: list[str], dataset_holder: RailDatasetHolder | None = None, ) -> RailPlotHolder: colors = utility_functions.adjacent_band_colors(magnitudes) color_names = [] for i in range(len(bands) - 1): color_names.append(f"{bands[i+1]} - {bands[i]}") xbins = np.linspace(self.config.z_min, self.config.z_min, self.config.n_zbins) ybins = np.linspace( self.config.color_min, self.config.color_max, self.config.n_colorbins ) figure = plotting_functions.plot_feature_target_hist2d( colors, truth, color_names, bins=(xbins, ybins) ) plot_name = self._make_full_plot_name(prefix, "") return RailPlotHolder( name=plot_name, figure=figure, plotter=self, dataset_holder=dataset_holder ) def _make_plots(self, prefix: str, **kwargs: Any) -> dict[str, RailPlotHolder]: find_only = kwargs.get("find_only", False) figtype = kwargs.get("figtype", "png") dataset_holder = kwargs.get("dataset_holder") out_dict: dict[str, RailPlotHolder] = {} truth: np.ndarray = kwargs["truth"] magnitudes: np.ndarray = kwargs["magnitudes"] bands: list[str] = kwargs["bands"] if find_only: plot_name = self._make_full_plot_name(prefix, "") assert dataset_holder plot = RailPlotHolder( name=plot_name, path=os.path.join(dataset_holder.config.name, f"{plot_name}.{figtype}"), plotter=self, dataset_holder=dataset_holder, ) else: plot = self._make_2d_hist_plots( prefix=prefix, truth=truth, magnitudes=magnitudes, bands=bands, dataset_holder=dataset_holder, ) out_dict[plot.name] = plot return out_dict