from __future__ import annotations
import os
from typing import TYPE_CHECKING, Any
from ceci.config import StageParameter
from rail.projects.configurable import Configurable
from rail.projects.dynamic_class import DynamicClass
from .dataset import RailDataset
from .dataset_holder import RailDatasetHolder
from .plot_holder import RailPlotDict
if TYPE_CHECKING:
from .plot_holder import RailPlotHolder
from .plotter_factory import RailPlotterFactory
[docs]
class RailPlotter(Configurable, DynamicClass):
"""Base class for making matplotlib plot
The main function in this class is:
.. highlight:: python
.. code-block:: python
run(prefix: str, kwargs**: Any) -> dict[str, RailPlotHolder]
This function will make a set of plots and return them in a dict.
prefix is string that gets prepended to plot names.
The data to be plotted is passed in via the kwargs.
Sub-classes should implement
.. highlight:: python
.. code-block:: python
config_options: dict[str, ceci.StageParameter]
that will be used to configure things like the axes binning, selection functions,
and other plot-specfic options
.. highlight:: python
.. code-block:: python
input_type: RailPZPointEstimateDataset
that specifics the inputs
that the sub-classes expect, this is used the check the kwargs
that are passed to the `run` function.
A function:
.. highlight:: python
.. code-block:: python
_make_plots(self, prefix: str, **kwargs: Any) -> dict[str, RailPlotHolder]:
That actually makes the plots. It does not need to do the checking
that the correct kwargs have been given.
"""
input_type: type[RailDataset] = RailDataset
sub_classes: dict[str, type[DynamicClass]] = {}
yaml_tag = "Plotter"
[docs]
@staticmethod
def iterate_plotters(
name: str,
plotters: list[RailPlotter],
prefix: str,
dataset: RailDatasetHolder,
**kwargs: Any,
) -> RailPlotDict:
"""Utility function to run several plotters on the same data
Parameters
----------
name: str
Name to give to the RailPlotDict
plotters: list[RailPlotter]
Plotters to run
prefix: str
Prefix to append to plot names, e.g., the p(z) algorithm or
analysis 'flavor'
kwargs: dict[str, Any]
Used to pass the data to make the plots
Returns
-------
out_dict: RailPlotDict
Dictionary of the newly created figures
"""
out_dict: dict[str, RailPlotHolder] = {}
extra_args: dict[str, Any] = dict(dataset_holder=dataset)
for plotter_ in plotters:
out_dict.update(
plotter_.run(prefix, **dataset.resolve(), **kwargs, **extra_args)
)
return RailPlotDict(name=name, plots=out_dict)
[docs]
@staticmethod
def iterate(
plotters: list[RailPlotter],
datasets: list[RailDatasetHolder],
**kwargs: Any,
) -> dict[str, RailPlotDict]:
"""Utility function to several plotters of several data sets
Parameters
----------
plotters: list[RailPlotter]
Plotters to run
datasets: list[RailDatasetHolder]
Prefixes and datasets to iterate over
Returns
-------
out_dict: dict[str, RailPlotDict]
Dictionary of the newly created figures
"""
out_dict: dict[str, RailPlotDict] = {}
for val in datasets:
out_dict[val.config.name] = RailPlotter.iterate_plotters(
val.config.name, plotters, "", val, **kwargs
)
return out_dict
[docs]
@staticmethod
def write_plots(
fig_dict: dict[str, RailPlotDict],
outdir: str = ".",
figtype: str = "png",
purge: bool = False,
) -> None:
"""Utility function to write several plots do disk
Parameters
----------
fig_dict: dict[str, RailPlotDict]
Dictionary of figures to write
outdir: str
Directory to write figures in
figtype: str
Type of figures to write, e.g., png, pdf...
purge: bool
Delete figure after saving
"""
for key, val in fig_dict.items():
try:
os.makedirs(outdir)
except Exception:
pass
out_path = os.path.join(outdir, key)
val.savefigs(out_path, figtype=figtype, purge=purge)
def __init__(self, **kwargs: Any):
"""C'tor
Parameters
----------
kwargs: Any
Configuration parameters for this plotter, must match
class.config_options data members
"""
DynamicClass.__init__(self)
Configurable.__init__(self, **kwargs)
def __repr__(self) -> str:
return f"{type(self)}"
[docs]
def run(
self,
prefix: str,
**kwargs: dict[str, Any],
) -> dict[str, RailPlotHolder]:
"""Make all the plots given the data
Parameters
----------
prefix: str
Prefix to append to plot names, e.g., the p(z) algorithm or
analysis 'flavor'
kwargs: dict[str, Any]
Used to pass the data to make the plots
Returns
-------
out_dict: dict[str, RailPlotHolder]
Dictionary of the newly created figures
"""
self._validate_inputs(**kwargs)
return self._make_plots(prefix, **kwargs)
def _make_full_plot_name(self, prefix: str, plot_name: str) -> str:
"""Create the make for a specific plot
Parameters
----------
prefix: str
Prefix to append to plot names, e.g., the p(z) algorithm or
analysis 'flavor'
plot_name: str
Specific name for a particular plot
Returns
-------
plot_name: str
Plot name, following the pattern f"{prefix}{self._name}{plot_name}"
"""
return f"{prefix}{self.config.name}{plot_name}"
[docs]
def to_yaml_dict(self) -> dict[str, dict[str, Any]]:
"""Create a yaml-convertable dict for this object"""
yaml_dict = Configurable.to_yaml_dict(self)
yaml_dict[self.yaml_tag].update(class_name=f"{self.full_class_name()}")
return yaml_dict
@classmethod
def _validate_inputs(cls, **kwargs: Any) -> None:
cls.input_type.validate_inputs(**kwargs)
def _make_plots(
self,
prefix: str,
**kwargs: Any,
) -> dict[str, RailPlotHolder]:
raise NotImplementedError()
[docs]
class RailPlotterList(Configurable):
"""The class collects a set of plotter that can all run on the same data.
E.g., plotters that can all run on a dict that looks like
`{truth:np.ndarray, pointEstimates: np.ndarray}` could be put into a PlotterList.
This make it easier to collect similar types of plots.
"""
config_options: dict[str, StageParameter] = dict(
name=StageParameter(str, None, fmt="%s", required=True, msg="PlotterList name"),
dataset_holder_class=StageParameter(
str,
None,
fmt="%s",
required=True,
msg="Dataset holder that provides datset types expected by plotter on the list",
),
plotters=StageParameter(
list,
[],
fmt="%s",
msg="List of plotter to include",
),
)
yaml_tag = "PlotterList"
def __init__(self, **kwargs: Any):
"""C'tor
Parameters
----------
kwargs: Any
Configuration parameters for this RailPlotterListHolder, must match
class.config_options data members
"""
Configurable.__init__(self, **kwargs)
def __repr__(self) -> str:
return f"{self.config.plotters}"
[docs]
def resolve(self, plotter_factory: RailPlotterFactory) -> list[RailPlotter]:
"""Extract the plotters
Paramters
---------
plotter_factory:
Factory used to make the plotters.
Returns
-------
list[RailPlotter]
Requested plotters.
Notes
-----
This will enforce that each plotter expects the compatible dataset_types
"""
the_list: list[RailPlotter] = []
dataset_holder_class = RailDatasetHolder.load_sub_class(
self.config.dataset_holder_class
)
dataset_class = dataset_holder_class.output_type
for name_ in self.config.plotters:
a_plotter = plotter_factory.get_plotter(name_)
if not issubclass(dataset_class, a_plotter.input_type): # pragma: no cover
raise TypeError(
f"PlotterList dataset_class {dataset_class} is "
f"not a subclass of Plotter.input_type {a_plotter.input_type}."
)
the_list.append(a_plotter)
return the_list