Skip to content

plaid.storage.zarr.reader

plaid.storage.zarr.reader

Zarr dataset reader module.

This module provides functionality for reading and streaming datasets stored in Zarr format for the PLAID library. It includes utilities for loading datasets from local disk or streaming directly from Hugging Face Hub, with support for selective loading of splits and features.

Key features: - Local dataset loading from disk - Streaming datasets from Hugging Face Hub - Selective loading of splits and features - ZarrDataset class for convenient data access

plaid.storage.zarr.reader.ZarrDataset

ZarrDataset(zarr_group, path)

A dataset class for accessing Zarr-stored data.

This class provides a convenient interface for accessing samples stored in Zarr format. It wraps a Zarr group and provides dictionary-like access to samples, along with additional metadata fields.

Initialize a :class:ZarrDataset.

Parameters:

  • zarr_group (Group) –

    The underlying Zarr group containing the data.

  • path (Union[str, Path]) –

    Path to the dataset root (local directory or remote identifier). Stored on the instance as self.path.

Source code in plaid/storage/zarr/reader.py
def __init__(self, zarr_group: zarr.Group, path: Union[str, Path]) -> None:
    """Initialize a :class:`ZarrDataset`.

    Args:
        zarr_group (zarr.Group): The underlying Zarr group containing the data.
        path (Union[str, Path]): Path to the dataset root (local directory or remote
            identifier). Stored on the instance as ``self.path``.
    """
    self.zarr_group = zarr_group
    self.path = path

    self.ids = sorted(int(name[7:]) for name, _ in zarr_group.groups())

plaid.storage.zarr.reader.ZarrDataset.__iter__

__iter__()

Iterate over all samples in the dataset.

Yields:

  • dict[str, Any]

    dict[str, Any]: Dictionary containing sample data.

Source code in plaid/storage/zarr/reader.py
def __iter__(self) -> Iterator[dict[str, Any]]:
    """Iterate over all samples in the dataset.

    Yields:
        dict[str, Any]: Dictionary containing sample data.
    """
    for idx in self.ids:
        yield self[idx]

plaid.storage.zarr.reader.ZarrDataset.__getitem__

__getitem__(idx)

Get a sample by index.

Parameters:

  • idx (int) –

    Sample index.

Returns:

  • dict[str, Any]

    dict[str, Any]: Dictionary containing sample data.

Source code in plaid/storage/zarr/reader.py
def __getitem__(self, idx: int) -> dict[str, Any]:
    """Get a sample by index.

    Args:
        idx: Sample index.

    Returns:
        dict[str, Any]: Dictionary containing sample data.
    """
    zarr_sample = self.zarr_group[f"sample_{idx:09d}"]
    return {
        unflatten_path(path): zarr_sample[path] for path in zarr_sample.array_keys()
    }

plaid.storage.zarr.reader.ZarrDataset.__len__

__len__()

Get the number of samples in the dataset.

Returns:

  • int ( int ) –

    Number of samples.

Source code in plaid/storage/zarr/reader.py
def __len__(self) -> int:
    """Get the number of samples in the dataset.

    Returns:
        int: Number of samples.
    """
    return len(self.ids)

plaid.storage.zarr.reader.ZarrDataset.__repr__

__repr__()

String representation of the dataset.

Returns:

  • str ( str ) –

    String representation.

Source code in plaid/storage/zarr/reader.py
def __repr__(self) -> str:
    """String representation of the dataset.

    Returns:
        str: String representation.
    """
    return f"<ZarrDataset {repr(self.zarr_group)} | ids={self.ids}>"

plaid.storage.zarr.reader.sample_generator

sample_generator(repo_id, split, ids, selected_features)

Generates samples from a Zarr dataset on Hugging Face Hub.

Parameters:

  • repo_id (str) –

    The Hugging Face repository ID.

  • split (str) –

    The dataset split name.

  • ids (Iterable[int]) –

    Iterable of sample IDs to generate.

  • selected_features (list[str]) –

    List of features to include.

Yields:

  • dict ( dict[str, Any] ) –

    Dictionary mapping feature names to Zarr arrays.

Source code in plaid/storage/zarr/reader.py
def sample_generator(
    repo_id: str, split: str, ids: Iterable[int], selected_features: list[str]
) -> Iterator[dict[str, Any]]:  # pragma: no cover
    """Generates samples from a Zarr dataset on Hugging Face Hub.

    Args:
        repo_id (str): The Hugging Face repository ID.
        split (str): The dataset split name.
        ids (Iterable[int]): Iterable of sample IDs to generate.
        selected_features (list[str]): List of features to include.

    Yields:
        dict: Dictionary mapping feature names to Zarr arrays.
    """
    base_url = (
        f"https://huggingface.co/datasets/{repo_id}/resolve/main/data/{split}/sample_"
    )
    for idx in ids:
        sample = {}
        for feat in selected_features:
            flatten_feat = flatten_path(feat)
            mapper = fsspec.get_mapper(base_url + f"{idx:09d}/{flatten_feat}")
            sample[feat] = zarr.open(mapper, mode="r")

        yield sample

plaid.storage.zarr.reader.create_zarr_iterable_dataset

create_zarr_iterable_dataset(
    repo_id, split, ids, selected_features
)

Creates an IterableDataset from Zarr data on Hugging Face Hub.

Parameters:

  • repo_id (str) –

    The Hugging Face repository ID.

  • split (str) –

    The dataset split name.

  • ids (Iterable[int]) –

    Iterable of sample IDs.

  • selected_features (list[str]) –

    List of features to include.

Returns:

  • IterableDataset ( IterableDataset ) –

    An iterable dataset for streaming access.

Source code in plaid/storage/zarr/reader.py
def create_zarr_iterable_dataset(
    repo_id: str, split: str, ids: Iterable[int], selected_features: list[str]
) -> IterableDataset:  # pragma: no cover
    """Creates an IterableDataset from Zarr data on Hugging Face Hub.

    Args:
        repo_id (str): The Hugging Face repository ID.
        split (str): The dataset split name.
        ids (Iterable[int]): Iterable of sample IDs.
        selected_features (list[str]): List of features to include.

    Returns:
        IterableDataset: An iterable dataset for streaming access.
    """

    def wrapped_gen():
        yield from sample_generator(repo_id, split, ids, selected_features)

    return IterableDataset.from_generator(
        wrapped_gen,
        split=NamedSplit(split),
        features=None,
    )

plaid.storage.zarr.reader.init_datasetdict_from_disk

init_datasetdict_from_disk(path)

Initializes dataset dictionaries from local Zarr files.

Parameters:

  • path (Union[str, Path]) –

    Path to the local directory containing the dataset.

Returns:

  • ZarrDatasetDict ( ZarrDatasetDict ) –

    Dictionary mapping split names to ZarrDataset objects.

Source code in plaid/storage/zarr/reader.py
def init_datasetdict_from_disk(
    path: Union[str, Path],
) -> ZarrDatasetDict:
    """Initializes dataset dictionaries from local Zarr files.

    Args:
        path (Union[str, Path]): Path to the local directory containing the dataset.

    Returns:
        ZarrDatasetDict: Dictionary mapping split names to ZarrDataset objects.
    """
    local_path = Path(path) / "data"
    split_names = [p.name for p in local_path.iterdir() if p.is_dir()]
    return {
        sn: ZarrDataset(
            zarr.open(zarr.storage.LocalStore(local_path / sn), mode="r"), path
        )
        for sn in split_names
    }

plaid.storage.zarr.reader.download_datasetdict_from_hub

download_datasetdict_from_hub(
    repo_id,
    local_dir,
    split_ids=None,
    features=None,
    overwrite=False,
)

Downloads dataset from Hugging Face Hub to local directory.

Parameters:

  • repo_id (str) –

    The Hugging Face repository ID.

  • local_dir (Union[str, Path]) –

    Local directory to download to.

  • split_ids (Optional[dict[str, Iterable[int]]], default: None ) –

    Optional split IDs for selective download.

  • features (Optional[list[str]], default: None ) –

    Optional features for selective download.

  • overwrite (bool, default: False ) –

    Whether to overwrite existing directory.

Returns:

  • str ( str ) –

    Path to the local directory where the dataset has been downloaded.

Source code in plaid/storage/zarr/reader.py
def download_datasetdict_from_hub(
    repo_id: str,
    local_dir: Union[str, Path],
    split_ids: Optional[dict[str, Iterable[int]]] = None,
    features: Optional[list[str]] = None,
    overwrite: bool = False,
) -> str:  # pragma: no cover
    """Downloads dataset from Hugging Face Hub to local directory.

    Args:
        repo_id (str): The Hugging Face repository ID.
        local_dir (Union[str, Path]): Local directory to download to.
        split_ids (Optional[dict[str, Iterable[int]]]): Optional split IDs for selective download.
        features (Optional[list[str]]): Optional features for selective download.
        overwrite (bool): Whether to overwrite existing directory.

    Returns:
        str: Path to the local directory where the dataset has been downloaded.
    """
    output_folder = Path(local_dir)

    if output_folder.is_dir():
        if overwrite:
            shutil.rmtree(local_dir)
            logger.warning(f"Existing {local_dir} directory has been reset.")
        elif any(local_dir.iterdir()):
            raise ValueError(
                f"directory {local_dir} already exists and is not empty. Set `overwrite` to True if needed."
            )

    allow_patterns, ignore_patterns = _zarr_patterns(repo_id, split_ids, features)

    return snapshot_download(
        repo_id=repo_id,
        repo_type="dataset",
        allow_patterns=allow_patterns,
        ignore_patterns=ignore_patterns,
        local_dir=local_dir,
    )

plaid.storage.zarr.reader.init_datasetdict_streaming_from_hub

init_datasetdict_streaming_from_hub(
    repo_id, split_ids=None, features=None
)

Initializes streaming dataset dictionaries from Hugging Face Hub.

This function creates iterable datasets that stream Zarr data directly from the Hugging Face Hub without downloading files locally. It supports selective loading of specific splits and features for memory-efficient data access. Note that streaming mode is not compatible with private Hugging Face mirrors.

Parameters:

  • repo_id (str) –

    The Hugging Face repository ID (e.g., "username/dataset_name").

  • split_ids (Optional[dict[str, Iterable[int]]], default: None ) –

    Optional dictionary mapping split names to lists of sample IDs to include. If None, all samples from all splits are included.

  • features (Optional[list[str]], default: None ) –

    Optional list of feature names to include. If None, all features from the variable schema are included.

Returns:

  • dict[str, IterableDataset]

    dict[str, IterableDataset]: Dictionary mapping split names to IterableDataset objects for streaming data access.

Source code in plaid/storage/zarr/reader.py
def init_datasetdict_streaming_from_hub(
    repo_id: str,
    split_ids: Optional[dict[str, Iterable[int]]] = None,
    features: Optional[list[str]] = None,
) -> dict[str, IterableDataset]:  # pragma: no cover
    """Initializes streaming dataset dictionaries from Hugging Face Hub.

    This function creates iterable datasets that stream Zarr data directly from
    the Hugging Face Hub without downloading files locally. It supports selective
    loading of specific splits and features for memory-efficient data access.
    Note that streaming mode is not compatible with private Hugging Face mirrors.

    Args:
        repo_id (str): The Hugging Face repository ID (e.g., "username/dataset_name").
        split_ids (Optional[dict[str, Iterable[int]]]): Optional dictionary mapping split names
            to lists of sample IDs to include. If None, all samples from all splits
            are included.
        features (Optional[list[str]]): Optional list of feature names to include.
            If None, all features from the variable schema are included.

    Returns:
        dict[str, IterableDataset]: Dictionary mapping split names to IterableDataset
            objects for streaming data access.
    """
    hf_endpoint = os.getenv("HF_ENDPOINT", "").strip()
    if hf_endpoint:
        raise RuntimeError("Streaming mode not compatible with private mirror.")

    if features is not None:
        selected_features = features
    else:
        yaml_path = hf_hub_download(
            repo_id=repo_id,
            filename="variable_schema.yaml",
            repo_type="dataset",
        )
        with open(yaml_path, "r", encoding="utf-8") as f:
            variable_schema = yaml.safe_load(f)
        selected_features = list(variable_schema.keys())

    if split_ids is not None:
        selected_ids = split_ids
    else:
        infos = load_infos_from_hub(repo_id=repo_id)
        selected_ids = {
            split: range(n_samples) for split, n_samples in infos.num_samples.items()
        }

    return {
        split: create_zarr_iterable_dataset(repo_id, split, ids, selected_features)
        for split, ids in selected_ids.items()
    }