Source code for rail.projects.subsampler

from __future__ import annotations

import os
from typing import Any

import numpy as np
import pyarrow.dataset as ds
import pyarrow.parquet as pq
from ceci.config import StageParameter

from .configurable import Configurable
from .dynamic_class import DynamicClass


[docs] class RailSubsampler(Configurable, DynamicClass): """Base class for subsampling ata The main function in this class is: run(...) This function will take the input files and make a single output file """ config_options: dict[str, StageParameter] = {} sub_classes: dict[str, type[DynamicClass]] = {} def __init__(self, **kwargs: Any) -> None: """C'tor Parameters ---------- **kwargs: Configuration parameters for this plotter, must match class.config_options data members """ DynamicClass.__init__(self) Configurable.__init__(self, **kwargs)
[docs] def run( self, input_files: list[str], output: str, ) -> None: """Subsample the data Parameters ---------- input_files: list[str] Input files to subsamle output: str Path to the output file """ raise NotImplementedError()
[docs] class RandomSubsampler(RailSubsampler): """Pick a random subsample of the data""" config_options: dict[str, StageParameter] = dict( name=StageParameter(str, None, fmt="%s", required=True, msg="Subsampler Name"), seed=StageParameter(int, 1234, fmt="%i", msg="Random number seed"), num_objects=StageParameter( int, None, fmt="%i", required=True, msg="Number of output objects" ), )
[docs] def run( self, input_files: list[str], output: str, ) -> None: dataset = ds.dataset(input_files) num_rows = dataset.count_rows() print("num rows", num_rows) rng = np.random.default_rng(self.config.seed) print("sampling", self.config.num_objects) size = min(self.config.num_objects, num_rows) indices = rng.choice(num_rows, size=size, replace=False) subset = dataset.take(indices) print("writing", output) output_dir = os.path.dirname(output) os.makedirs(output_dir, exist_ok=True) pq.write_table( subset, output, ) print("done")