Source code for rail.plotting.plot_group_factory

from __future__ import annotations

import re
from typing import Any

import yaml

from rail.projects.factory_mixin import RailFactoryMixin
from rail.projects.configurable import Configurable

from .dataset_factory import RailDatasetFactory
from .dataset_holder import RailDatasetHolder, RailDatasetListHolder, RailProjectHolder
from .plot_group import RailPlotGroup
from .plotter import RailPlotterList
from .plotter_factory import RailPlotterFactory


[docs] class RailPlotGroupFactory(RailFactoryMixin): """Factory class to make plot_groups The yaml file should look something like this: .. highlight:: yaml .. code-block:: yaml Includes: - <path_to_yaml_file_defining_plotter_lists> - <path_to_yaml_file defining_dataset_lists> PlotGroups: - PlotGroup: name: some_name plotter_list_name: nice_plots dataset_dict_name: nice_data - PlotGroup: name: some_other_name plotter_list_name: janky_plots dataset_dict_name: janky_data """ yaml_tag: str = "PlotGroups" client_classes = [RailPlotGroup] _instance: RailPlotGroupFactory | None = None def __init__(self) -> None: """C'tor, build an empty RailDatasetFactory""" RailFactoryMixin.__init__(self) self._plot_groups = self.add_dict(RailPlotGroup)
[docs] @classmethod def make_plot_groups( cls, plotter_list: RailPlotterList, **kwargs: Any, ) -> dict: """Extract the datasets from a project and construct the approporiate PlotGroups Paramters --------- **kwargs: Passed to the generate_dataset_dict() functions of the holder classes Returns ------- dict: Projects: the extracted RailProjects Datasets: the extracted DatasetHolders DatasetLists: the extracted dataset lists PlotGroups: the constructed PlotGropus """ if cls._instance is None: cls._instance = RailPlotGroupFactory() return cls._instance.make_plot_groups_instance( plotter_list, **kwargs, )
[docs] @classmethod def make_yaml_for_project( cls, output_yaml: str, plotter_yaml_path: str, project_yaml_path: str, **kwargs: Any, ) -> None: """Construct a yaml file defining plot groups Parameters ---------- output_yaml: Path to the output file plotter_yaml_path: Path to the yaml file defining the plotter_lists project_yaml_path: Path to the underlying project file **kwargs: See notes """ if cls._instance is None: cls._instance = RailPlotGroupFactory() cls._instance.make_yaml_for_project_instance( output_yaml, plotter_yaml_path, project_yaml_path, **kwargs, )
[docs] @classmethod def make_yaml_for_dataset_list( cls, output_yaml: str, plotter_yaml_path: str, dataset_yaml_path: str, plotter_list_name: str, output_prefix: str = "", dataset_list_name: list[str] | None = None, ) -> None: """Construct a yaml file defining plot groups Parameters ---------- output_yaml: Path to the output file plotter_yaml_path: Path to the yaml file defining the plotter_lists dataset_yaml_path: Path to the yaml file defining the datasets plotter_list_name: Name of plotter list to use output_prefix: Prefix for PlotGroup names we construct dataset_list_names: Names of dataset lists to use """ if cls._instance is None: cls._instance = RailPlotGroupFactory() cls._instance.make_yaml_for_dataset_list_instance( output_yaml=output_yaml, plotter_yaml_path=plotter_yaml_path, dataset_yaml_path=dataset_yaml_path, plotter_list_name=plotter_list_name, output_prefix=output_prefix, dataset_list_name=dataset_list_name, )
[docs] @classmethod def get_plot_groups(cls) -> dict[str, RailPlotGroup]: """Return the dict of all the RailPlotGroup""" return cls.instance().plot_groups
[docs] @classmethod def get_plot_group_names(cls) -> list[str]: """Return the names of all the projectsRailPlotGroup""" return list(cls.instance().plot_groups.keys())
[docs] @classmethod def add_plot_group(cls, plot_group: RailPlotGroup) -> None: """Add a particular RailPlotGroup to the factory""" cls.instance().add_to_dict(plot_group)
[docs] @classmethod def get_plot_group(cls, key: str) -> RailPlotGroup: """Return a project by name""" return cls.instance().plot_groups[key]
@property def plot_groups(self) -> dict[str, RailPlotGroup]: """Return the dictionary of RailProjects""" return self._plot_groups
[docs] def make_yaml_for_dataset_list_instance( self, output_yaml: str, plotter_yaml_path: str, dataset_yaml_path: str, plotter_list_name: str, output_prefix: str = "", dataset_list_name: list[str] | None = None, ) -> None: """Construct a yaml file defining plot groups Parameters ---------- output_yaml: str Path to the output file plotter_yaml_path: str Path to the yaml file defining the plotter_lists dataset_yaml_path: str Path to the yaml file defining the datasets plotter_list_name: str Name of plotter list to use output_prefix: str="" Prefix for PlotGroup names we construct dataset_list_name: list[str] Names of dataset lists to use """ RailPlotterFactory.clear() RailPlotterFactory.load_yaml(plotter_yaml_path) RailDatasetFactory.clear() RailDatasetFactory.load_yaml(dataset_yaml_path) plotter_list = RailPlotterFactory.get_plotter_list(plotter_list_name) assert plotter_list if not dataset_list_name: # pragma: no cover dataset_list_name = RailDatasetFactory.get_dataset_list_names() plotter_path = re.sub( ".*rail_project_config", "${RAIL_PROJECT_CONFIG_DIR}", plotter_yaml_path ) dataset_path = re.sub( ".*rail_project_config", "${RAIL_PROJECT_CONFIG_DIR}", dataset_yaml_path ) plot_groups: list[RailPlotGroup] = [] for ds_name in dataset_list_name: group_name = f"{output_prefix}{ds_name}_{plotter_list_name}" plot_groups.append( RailPlotGroup( name=group_name, plotter_list_name=plotter_list_name, dataset_list_name=ds_name, ) ) output: dict[str, Any] = dict( Includes=[plotter_path, dataset_path], PlotGroups=[plot_group_.to_yaml_dict() for plot_group_ in plot_groups], ) with open(output_yaml, "w", encoding="utf-8") as fout: yaml.dump(output, fout)
[docs] def make_yaml_for_project_instance( self, output_yaml: str, plotter_yaml_path: str, project_yaml_path: str, **kwargs: Any, ) -> None: """Construct a yaml file defining plot groups Parameters ---------- output_yaml: Path to the output file plotter_yaml_path: Path to the yaml file defining the plotter_lists project_yaml_path: Path to the underlying project file **kwargs: See notes """ outdir = kwargs.get("outdir", ".") figtype = kwargs.get("outdir", "png") RailPlotterFactory.load_yaml(plotter_yaml_path) plotter_lists = RailPlotterFactory.get_plotter_lists() projects: list[list[RailProjectHolder]] = [] datasets: list[list[RailDatasetHolder]] = [] dataset_lists: list[list[RailDatasetListHolder]] = [] plot_groups: list[list[RailPlotGroup]] = [] for _key, val in plotter_lists.items(): plot_group_stuff = self.make_plot_groups_instance( val, project_file=project_yaml_path, **kwargs, ) projects.append(plot_group_stuff["Projects"]) datasets.append(plot_group_stuff["Datasets"]) dataset_lists_ = plot_group_stuff["DatasetLists"] dataset_lists.append(dataset_lists_) plot_groups_: list[RailPlotGroup] = [] for dataset_list_ in dataset_lists_: plot_groups_.append( RailPlotGroup( name=f"{val.config.name}_{dataset_list_.config.name}", plotter_list_name=val.config.name, dataset_list_name=dataset_list_.config.name, outdir=outdir, figtype=figtype, ) ) plot_groups.append(plot_groups_) merged_projects = Configurable.merge_named_lists(projects) merged_datasets = Configurable.merge_named_lists(datasets) merged_dataset_lists = Configurable.merge_named_lists(dataset_lists) merged_plot_groups = Configurable.merge_named_lists(plot_groups) data_yaml_list: list[dict[str, Any]] = [] for project_ in merged_projects: data_yaml_list.append(project_.to_yaml_dict()) for dataset_ in merged_datasets: data_yaml_list.append(dataset_.to_yaml_dict()) for dataset_list_ in merged_dataset_lists: data_yaml_list.append(dataset_list_.to_yaml_dict()) plot_group_yaml_list: list[dict[str, Any]] = [] for plot_group_ in merged_plot_groups: plot_group_yaml_list.append(plot_group_.to_yaml_dict()) plotter_path = re.sub( ".*rail_project_config", "${RAIL_PROJECT_CONFIG_DIR}", plotter_yaml_path ) output_yaml_dict: dict[str, list] = dict( Includes=[plotter_path], Data=data_yaml_list, PlotGroups=plot_group_yaml_list, ) with open(output_yaml, "w", encoding="utf-8") as fout: yaml.dump(output_yaml_dict, fout)
[docs] def make_plot_groups_instance( self, plotter_list: RailPlotterList, **kwargs: Any, ) -> dict: """Extract the datasets from a project and construct the approporiate PlotGroups Paramters --------- **kwargs: Passed to the generate_dataset_dict() functions of the holder classes Returns ------- dict: Projects: the extracted RailProjects Datasets: the extracted DatasetHolders DatasetLists: the extracted dataset lists PlotGroups: the constructed PlotGropus """ dataset_holder_class = RailDatasetHolder.load_sub_class( plotter_list.config.dataset_holder_class ) try: ( projects, datasets, dataset_lists, ) = dataset_holder_class.generate_dataset_dict(**kwargs) except NotImplementedError: return dict( Projects=[], Datasets=[], DatasetLists=[], PlotGroups=[], ) output_data = dict( Projects=projects, Datasets=datasets, DatasetLists=dataset_lists, ) plot_groups_list: list[RailPlotGroup] = [] outdir = kwargs.get("outdir", ".") figtype = kwargs.get("figtype", "png") for dslist_ in dataset_lists: dataset_list_name = dslist_.config.name plot_group = RailPlotGroup( name=f"{plotter_list.config.name}_{dataset_list_name}", plotter_list_name=plotter_list.config.name, dataset_list_name=dataset_list_name, outdir=outdir, figtype=figtype, ) try: self.add_plot_group(plot_group) except KeyError as msg: print(msg) plot_groups_list.append(plot_group) output_data.update( PlotGroups=plot_groups_list, ) return output_data