Source code for rail.plotting.dataset_factory

from __future__ import annotations

from typing import TYPE_CHECKING, Any, TypeVar

from rail.projects import RailProject
from rail.projects.factory_mixin import RailFactoryMixin

from .dataset_holder import RailDatasetHolder, RailDatasetListHolder, RailProjectHolder

if TYPE_CHECKING:
    from rail.projects.configurable import Configurable

    C = TypeVar("C", bound="Configurable")


[docs] class RailDatasetFactory(RailFactoryMixin): """Factory class to make datasets Expected usage is that user will define a yaml file with the various datasets that they wish to use with the following example syntax: .. highlight:: yaml .. code-block:: yaml Data: - Project: name: some_project yaml_file: /path/to/rail_project_file - Dataset: name: gold_baseline_test class: rail.plotting.project_dataset_holder.RailProjectDatasetHolder extractor: rail.plotting.pz_data_extractor.PZPointEstimateDataExtractor project: some_project selection: gold flavor: baseline tag: test algos: ['all'] - Dataset: name: blend_baseline_test class: rail.plotting.project_dataset_holder.RailProjectDatasetHolder exctractor: rail.plottings.pz_data_extractor.PZPointEstimateDataExtractor project: some_project selection: blend flavor: baseline tag: test algos: ['all'] And group them into lists of dataset that can be run over particular types of data, using the following example syntax: .. highlight:: yaml .. code-block:: yaml Data: - DatasetList: name: baseline_test datasets: - gold_baseline_test - blend_baseline_test """ yaml_tag: str = "Data" client_classes = [RailProjectHolder, RailDatasetHolder, RailDatasetListHolder] _instance: RailDatasetFactory | None = None def __init__(self) -> None: """C'tor, build an empty RailDatasetFactory""" RailFactoryMixin.__init__(self) self._projects = self.add_dict(RailProjectHolder) self._datasets = self.add_dict(RailDatasetHolder) self._dataset_lists = self.add_dict(RailDatasetListHolder)
[docs] @classmethod def get_projects(cls) -> dict[str, RailProject]: """Return the dict of all the projects""" return cls.instance().projects
[docs] @classmethod def get_project_names(cls) -> list[str]: """Return the names of all the projects""" return list(cls.instance().projects.keys())
[docs] @classmethod def get_project(cls, key: str) -> RailProject: """Return a project by name""" return cls.instance().projects[key]
[docs] @classmethod def get_datasets(cls) -> dict[str, RailDatasetHolder]: """Return the dict of all the datasets""" return cls.instance().datasets
[docs] @classmethod def get_dataset_names(cls) -> list[str]: """Return the names of the datasets""" return list(cls.instance().datasets.keys())
[docs] @classmethod def get_dataset_lists(cls) -> dict[str, RailDatasetListHolder]: """Return the dict of all the datasets""" return cls.instance().dataset_lists
[docs] @classmethod def get_dataset_list_names(cls) -> list[str]: """Return the names of the dataset lists""" return list(cls.instance().dataset_lists.keys())
[docs] @classmethod def get_dataset(cls, name: str) -> RailDatasetHolder: """Get dataset by it's assigned name Parameters ---------- name: str Name of the dataset to return Returns ------- dataset: dict Dataset in question """ try: return cls.instance().datasets[name] except KeyError as msg: raise KeyError( f"Dataset named {name} not found in RailDatasetFactory " f"{list(cls.instance().datasets.keys())}" ) from msg
[docs] @classmethod def get_dataset_list(cls, name: str) -> RailDatasetListHolder: """Get a list of datasets their assigned name Parameters ---------- name: str Name of the dataset list to return Returns ------- datasets: list[dict] Datasets in question """ try: return cls.instance().dataset_lists[name] except KeyError as msg: raise KeyError( f"DatasetList named {name} not found in RailDatasetFactory " f"{list(cls.instance().dataset_lists.keys())}" ) from msg
[docs] @classmethod def add_project(cls, project_holder: RailProjectHolder) -> None: """Add a particular RailProjectHolder to the factory""" cls.instance().add_to_dict(project_holder)
[docs] @classmethod def add_dataset(cls, dataset_holder: RailDatasetHolder) -> None: """Add a particular RailDatasetHolder to the factory""" cls.instance().add_to_dict(dataset_holder)
[docs] @classmethod def add_dataset_list(cls, dataset_list: RailDatasetListHolder) -> None: """Add a particular RailDatasetListHolder to the factory""" cls.instance().add_to_dict(dataset_list)
@property def projects(self) -> dict[str, RailProjectHolder]: """Return the dictionary of RailProjects""" return self._projects @property def datasets(self) -> dict[str, RailDatasetHolder]: """Return the dictionary of individual datasets""" return self._datasets @property def dataset_lists(self) -> dict[str, RailDatasetListHolder]: """Return the dictionary of lists of datasets""" return self._dataset_lists
[docs] def load_object_from_yaml_tag( self, configurable_class: type[C], yaml_tag: dict[str, Any] ) -> None: if configurable_class == RailDatasetHolder: the_object = RailDatasetHolder.create_from_dict(yaml_tag) self.add_to_dict(the_object) the_object.resolve() return RailFactoryMixin.load_object_from_yaml_tag(self, configurable_class, yaml_tag)