from __future__ import annotations
from typing import Any
from ceci.config import StageParameter
from rail.projects import RailProject, path_funcs
from .dataset import RailDataset
from .dataset_factory import RailDatasetFactory
from .data_extraction_funcs import get_ztrue_and_magntidues
from .dataset_holder import (
DatasetSplitMode,
RailDatasetHolder,
RailDatasetListHolder,
RailProjectHolder,
)
from .cat_plotters import RailCatTruthAndMagnitudesDataset
[docs]
class RailCatTruthAndMagntiduesDataHolder(RailDatasetHolder):
"""Class to extract true redshifts and observed magntidues
for a catalog from a RailProject
This will return a dict:
truth: np.ndarray
True redshifts
magntidues: np.ndarray
Magnitudes in the various filters
"""
config_options: dict[str, StageParameter] = dict(
name=StageParameter(str, None, fmt="%s", required=True, msg="Dataset name"),
project=StageParameter(
str, None, fmt="%s", required=True, msg="RailProject name"
),
selection=StageParameter(
str, None, fmt="%s", required=True, msg="RailProject data selection"
),
flavor=StageParameter(
str, None, fmt="%s", required=True, msg="RailProject analysis flavor"
),
tag=StageParameter(
str, None, fmt="%s", required=True, msg="RailProject file tag"
),
)
extractor_inputs: dict = {
"project": RailProject,
"selection": str,
"flavor": str,
"tag": str,
}
output_type: type[RailDataset] = RailCatTruthAndMagnitudesDataset
def __init__(self, **kwargs: Any):
RailDatasetHolder.__init__(self, **kwargs)
self._project: RailProject | None = None
def __repr__(self) -> str:
ret_str = (
f"{self.__class__.__name__} "
"( "
f"{self.config.project}, "
f"{self.config.selection}_{self.config.flavor}_{self.config.tag}"
")"
)
return ret_str
def _get_data(self, **kwargs: Any) -> dict[str, Any] | None:
return get_ztrue_and_magntidues(**kwargs)
[docs]
@classmethod
def generate_dataset_dict(
cls,
**kwargs: Any,
) -> tuple[
list[RailProjectHolder], list[RailDatasetHolder], list[RailDatasetListHolder]
]:
"""
Parameters
----------
**kwargs
Set Notes
Notes
-----
dataset_list_name: str
Name for the resulting DatasetList
project_file: str
Config file for project to inspect
selections: list[str]
Selections to use
flavors: list[str]
Flavors to use
tag: str
File tag
Returns
-------
list[RailProjectHolder]
Underlying RailProjects
list[RailDatasetHolder]
Extracted datasets
list[RailDatasetListHolder]
Extracted dataset lists
"""
dataset_list_name: str | None = kwargs.get("dataset_list_name")
project_file = kwargs["project_file"]
project = RailProject.load_config(project_file)
selections = kwargs.get("selections")
flavors = kwargs.get("flavors")
split_mode = kwargs.get("split_mode", DatasetSplitMode.by_algo)
flavor_dict = project.get_flavors()
if flavors is None or "all" in flavors:
flavors = list(flavor_dict.keys())
if selections is None or "all" in selections:
selections = list(project.get_selections().keys())
project_name = project.name
if not dataset_list_name:
dataset_list_name = f"{project_name}_catalog"
projects: list[RailProjectHolder] = []
datasets: list[RailDatasetHolder] = []
dataset_lists: list[RailDatasetListHolder] = []
projects.append(
RailProjectHolder(
name=project_name,
yaml_file=project_file,
)
)
dataset_list_dict: dict[str, list[str]] = {}
dataset_key = dataset_list_name
if split_mode == DatasetSplitMode.no_split: # pragma: no cover
dataset_list_dict[dataset_key] = []
for key in flavors:
for selection_ in selections:
dataset_key = f"{dataset_list_name}_{selection_}_{key}"
dataset_list_dict[dataset_key] = []
path = path_funcs.get_z_true_path(
project,
selection=selection_,
flavor=key,
tag=kwargs.get("tag", "test"),
)
if path is None: # pragma: no cover
continue
dataset_name = f"{selection_}_{key}"
dataset = cls(
name=dataset_name,
project=project_name,
flavor=key,
tag=kwargs.get("tag", "test"),
selection=selection_,
)
datasets.append(dataset)
dataset_list_dict[dataset_key].append(dataset_name)
for ds_name, ds_list in dataset_list_dict.items():
# Skip empty lists
if not ds_list: # pragma: no cover
continue
dataset_list = RailDatasetListHolder(
name=ds_name,
dataset_class=cls.output_type.full_class_name(),
datasets=ds_list,
)
dataset_lists.append(dataset_list)
return (projects, datasets, dataset_lists)