Source code for rail.plotting.plot_group

"""Class to get track of groups of plots to make"""

from __future__ import annotations

import os
from typing import Any

from ceci.config import StageParameter
from jinja2 import Environment, FileSystemLoader, Template

from rail.projects.configurable import Configurable

from .dataset import RailDataset
from .dataset_factory import RailDatasetFactory
from .dataset_holder import RailDatasetHolder
from .plot_holder import RailPlotDict, RailPlotHolder
from .plotter import RailPlotter
from .plotter_factory import RailPlotterFactory

HTML_TEMPLATE_DIR = os.path.join(
    os.path.abspath(os.path.dirname(__file__)), "html_templates"
)


[docs] class RailPlotGroup(Configurable): """Class defining of a group on plots to make with a particular list of coherent datasets """ jinja_env: Environment | None = None jinja_template: Template | None = None jinja_index_template: Template | None = None config_options: dict[str, StageParameter] = dict( name=StageParameter( str, None, fmt="%s", required=True, msg="PlotGroupName name" ), plotter_list_name=StageParameter( str, None, fmt="%s", required=True, msg="PlotterList name" ), dataset_list_name=StageParameter( str, None, fmt="%s", required=True, msg="DatasetList name" ), outdir=StageParameter(str, ".", fmt="%s", msg="Output directory"), figtype=StageParameter(str, "png", fmt="%s", msg="Plot type"), ) yaml_tag = "PlotGroup" @classmethod def _load_jinja(cls) -> None: if cls.jinja_template is not None: # pragma: no cover return cls.jinja_env = Environment(loader=FileSystemLoader(HTML_TEMPLATE_DIR)) cls.jinja_template = cls.jinja_env.get_template("plot_group_table.html") cls.jinja_index_template = cls.jinja_env.get_template("plot_group_index.html") def __init__(self, **kwargs: Any) -> None: Configurable.__init__(self, **kwargs) self._plots: dict[str, RailPlotDict] = {} self._plotter_list: list[RailPlotter] = [] self._dataset_list: list[RailDatasetHolder] = [] def __repr__(self) -> str: return f"PlotGroup: {self.config.plotter_list_name}, DatasetList: {self.config.dataset_list_name}" @property def plotter_list(self) -> list[RailPlotter]: return self._plotter_list @property def dataset_list(self) -> list[RailDatasetHolder]: return self._dataset_list
[docs] def find_plot(self, dataset_name: str, plotter_name: str) -> RailPlotHolder: """Find a particular plot""" try: sub_dict = self._plots[dataset_name] except KeyError as msg: raise KeyError( f"Failed to find {dataset_name} in {list(self._plots.keys())}" ) from msg assert sub_dict.plots try: return sub_dict.plots[plotter_name] except KeyError as msg: raise KeyError( f"Failed to find {plotter_name} in {list(sub_dict.plots.keys())}" ) from msg
[docs] def find_plot_path(self, dataset_name: str, plotter_name: str) -> str | None: """Find a particular plot and get the path to the associated file""" return self.find_plot(dataset_name, plotter_name).path
[docs] def resolve(self) -> None: """Get the PlotterList and DatasetList""" plotter_factory = RailPlotterFactory.instance() dataset_factory = RailDatasetFactory.instance() plotter_list_wrapper = plotter_factory.get_plotter_list( self.config.plotter_list_name ) dataset_list_wrapper = dataset_factory.get_dataset_list( self.config.dataset_list_name ) plotter_data_holder_class = RailDatasetHolder.load_sub_class( plotter_list_wrapper.config.dataset_holder_class ) plotter_input_class = plotter_data_holder_class.output_type dataset_output_class = RailDataset.load_sub_class( dataset_list_wrapper.config.dataset_class ) if not issubclass( dataset_output_class, plotter_input_class ): # pragma: no cover raise TypeError( f"RailPlotGroup {self.config.name} dataset output class {dataset_output_class} " f"is not a subclass of plotter input class {plotter_input_class}" ) self._plotter_list = plotter_list_wrapper.resolve(plotter_factory) self._dataset_list = dataset_list_wrapper.resolve(dataset_factory)
[docs] def make_plots( self, ) -> dict[str, RailPlotDict]: """Make a set of plots Returns ------- out_dict: dict[str, RailPlotDict] Dictionary of the newly created figures """ self.resolve() self._plots.update( **RailPlotter.iterate(self._plotter_list, self._dataset_list) ) return self._plots
[docs] def find_plots( self, outdir: str, ) -> dict[str, RailPlotDict]: """Find a set of plots Parameters ---------- outdir: str Prepend this to the groups output dir Returns ------- out_dict: dict[str, Figure] Dictionary of the newly created figures """ self.resolve() self._plots.update( **RailPlotter.iterate( self._plotter_list, self._dataset_list, find_only=True, outdir=outdir, figtype=self.config.figtype, ) ) return self._plots
[docs] @classmethod def make_html_index( cls, outfile: str, output_pages: list[str], ) -> None: """Make the html index page for a list of other pages Parameters ---------- outfile: str Html file to write output_pages: list[str] Set of pages to include in the index """ cls._load_jinja() assert cls.jinja_index_template is not None # Render template data and save to HTML file output = cls.jinja_index_template.render(output_pages=output_pages, os=os) with open(outfile, "w", encoding="utf-8") as file: file.write(output)
[docs] def make_html( self, outfile: str, ) -> None: """Make the html page to display the plots made by this plot group Parameters ---------- outfile: str Html file to write """ self._load_jinja() assert self.jinja_template is not None # Render template data and save to HTML file output = self.jinja_template.render(plot_group=self, os=os) with open(outfile, "w", encoding="utf-8") as file: file.write(output)
[docs] def run( self, save_plots: bool = True, purge_plots: bool = True, find_only: bool = False, outdir: str | None = None, make_html: bool = False, output_html: str | None = None, ) -> dict[str, RailPlotDict]: """Make all the plots given the data Parameters ---------- save_plots: bool If true, save the plots to disk purge_plots: bool If true, delete the plots after saving find_only: bool If true, only look for existing plots make_html: bool If true, make an html table to browse plots outdir: str | None If set, prepend this to the groups output dir output_html: str | None Path for output html file Returns ------- out_dict: dict[str, Figure] Dictionary of the newly created figures """ if outdir is not None: output_dir = os.path.join(outdir, self.config.outdir) else: # pragma: no cover output_dir = self.config.outdir if find_only: self.find_plots( outdir=output_dir, ) else: self.make_plots() if save_plots: RailPlotter.write_plots( self._plots, output_dir, self.config.figtype, purge=purge_plots ) if make_html: if output_html is None: assert outdir output_html = os.path.join(outdir, f"plots_{self.config.name}.html") self.make_html(output_html) return self._plots