"""Zarr dataset writer module.
This module provides functionality for writing and managing datasets in Zarr format
for the PLAID library. It includes utilities for generating datasets from sample
generators, saving them to disk with optimized chunking, uploading to Hugging Face
Hub, and configuring dataset cards with metadata and usage examples.
Key features:
- Parallel and sequential dataset generation from generators
- Automatic chunking for efficient storage
- Integration with Hugging Face Hub for dataset sharing
- Dataset card generation with splits, features, and documentation
"""
# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
#
import gc
import multiprocessing as mp
from pathlib import Path
from typing import Callable, Generator, Optional, Union
import numpy as np
import yaml
import zarr
from huggingface_hub import DatasetCard, HfApi
from tqdm import tqdm
from plaid import Sample
from plaid.storage.common.bridge import flatten_path
from plaid.storage.common.preprocessor import build_sample_dict
from plaid.types import IndexType
def _auto_chunks(shape: tuple[int, ...], target_n: int) -> tuple[int, ...]:
"""Computes automatic chunk sizes for Zarr arrays based on shape and target size.
Args:
shape (tuple[int, ...]): The shape of the array.
target_n (int): The target number of elements per chunk.
Returns:
tuple[int, ...]: The computed chunk sizes.
"""
# ensure pure Python ints
target_n = int(target_n)
shape = tuple(int(s) for s in shape)
# elements in one "row"
elems_per_slice = int(np.prod(shape[1:]) or 1)
rows = max(1, target_n // elems_per_slice)
rows = min(rows, shape[0]) # cannot exceed the dimension size
return (rows,) + shape[1:]
[docs]
def write_sample(split_root, sample, var_features_keys, sample_counter):
"""Write a single PLAID sample to a Zarr group on disk.
This function serializes one ``Sample`` instance into a dedicated Zarr group
under the given split root. Each sample is written as:
sample_<zero-padded index>/
Only variable features listed in ``var_features_keys`` are written. Feature
paths are flattened before being used as Zarr array names.
Behavior:
- A new Zarr group named ``sample_{sample_counter:09d}`` is created.
- Each selected feature is written as a Zarr array if its value is not ``None``.
- NumPy arrays with Unicode dtype (``dtype.kind == 'U'``) are converted to
UTF-8 encoded byte arrays to ensure stable storage (notably for Zarr v3).
- Chunk sizes are automatically determined using ``_auto_chunks`` with a
target chunk size of approximately 5 million elements.
Args:
split_root:
Open Zarr group corresponding to a dataset split
(e.g. ``zarr.open_group(..., mode="a")``).
sample (Sample):
PLAID ``Sample`` object to serialize.
var_features_keys (list[str]):
List of feature paths (as defined in the variable schema) to extract
and write for this sample.
sample_counter (int):
Global index of the sample within the split, used to generate the
group name and ensure deterministic ordering.
Notes:
- The function assumes ``split_root`` already exists and is writable.
- No schema validation is performed at write time.
- Missing features (``None`` values) are silently skipped.
- The function is side-effect only and returns ``None``.
Raises:
zarr.errors.ContainsGroupError:
If a sample group with the same name already exists.
OSError / IOError:
If an underlying filesystem or Zarr write error occurs.
"""
sample_dict, _, _ = build_sample_dict(sample)
sample_data = {path: sample_dict.get(path, None) for path in var_features_keys}
g = split_root.create_group(f"sample_{sample_counter:09d}")
for key, value in sample_data.items():
if value is not None:
if isinstance(value, np.ndarray) and value.dtype.kind == "U":
# Unicode → UTF-8 bytes (stable Zarr V3)
s = "".join(value.ravel().tolist())
value = np.frombuffer(s.encode("utf-8"), dtype=np.uint8)
g.create_array(
flatten_path(key),
data=value,
chunks=_auto_chunks(value.shape, 5_000_000),
)
del sample_dict, sample_data, value
gc.collect()
def _zarr_worker_batch_job(args) -> int: # pragma: no cover
"""Worker job for one shard/batch.
args = (split_root_path, gen_func, var_features_keys, batch, start_index)
Returns number of samples written.
"""
split_root_path, gen_func, var_features_keys, batch, start_index = args
# split_root = zarr.open_group(split_root_path, mode="a")
store = zarr.storage.LocalStore(split_root_path)
split_root = zarr.group(store=store)
sample_counter = start_index
written = 0
# Keep original semantics: gen_func expects a list of batches
for sample in gen_func([batch]):
write_sample(split_root, sample, var_features_keys, sample_counter)
sample_counter += 1
written += 1
return written
[docs]
def generate_datasetdict_to_disk(
output_folder: Union[str, Path],
generators: dict[str, Callable[..., Generator[Sample, None, None]]],
variable_schema: dict[str, dict],
gen_kwargs: Optional[dict[str, dict[str, list[IndexType]]]] = None,
num_proc: int = 1,
verbose: bool = False,
) -> None:
"""Generates and saves a dataset dictionary to disk in Zarr format.
This function processes sample generators for different dataset splits,
converts samples to dictionaries, and writes them to Zarr arrays on disk.
It supports both sequential and parallel processing modes. In parallel mode,
gen_kwargs must be provided with batch information for each split.
Args:
output_folder (Union[str, Path]): Base directory where the dataset will be saved.
A 'data' subdirectory will be created inside this folder.
generators (dict[str, Callable[..., Generator[Sample, None, None]]]):
Dictionary mapping split names (e.g., "train", "test") to generator
functions that yield Sample objects.
variable_schema (dict[str, dict]): Schema describing the structure and types
of variables/features in the samples.
gen_kwargs (Optional[dict[str, dict[str, list[IndexType]]]]): Optional
generator arguments for parallel processing. Must include "shards_ids"
for each split when num_proc > 1. Required for parallel execution.
num_proc (int, optional): Number of processes to use for parallel processing.
Defaults to 1 (sequential). Must be > 1 only when gen_kwargs is provided.
verbose (bool, optional): Whether to display progress bars during processing.
Defaults to False.
Returns:
None: This function does not return a value; it writes the dataset directly
to disk.
"""
output_folder = Path(output_folder) / "data"
output_folder.mkdir(exist_ok=True, parents=True)
var_features_keys = list(variable_schema.keys())
gen_kwargs_ = gen_kwargs or {sn: {} for sn in generators.keys()}
for split_name, gen_func in generators.items():
split_root_path = str(output_folder / split_name)
_ = zarr.open_group(split_root_path, mode="w") # create/overwrite
batch_ids_list = gen_kwargs_.get(split_name, {}).get("shards_ids", [])
total_samples = (
sum(len(batch) for batch in batch_ids_list) if batch_ids_list else None
)
if num_proc > 1:
assert batch_ids_list, (
f"Parallel mode requires gen_kwargs['{split_name}']['shards_ids'] "
"to be provided and non-empty."
)
# deterministic start indices (prefix sums)
start_indices = []
s = 0
for batch in batch_ids_list:
start_indices.append(s)
s += len(batch)
jobs = [
(split_root_path, gen_func, var_features_keys, batch, start_idx)
for batch, start_idx in zip(batch_ids_list, start_indices)
]
# If your platform/library stack is sensitive to fork, use spawn:
# ctx = mp.get_context("spawn")
# with ctx.Pool(processes=num_proc) as pool:
with mp.Pool(processes=num_proc) as pool:
with tqdm(
total=total_samples,
desc=f"Writing {split_name} split",
disable=not verbose,
) as pbar:
for written in pool.imap_unordered(
_zarr_worker_batch_job, jobs, chunksize=1
):
pbar.update(written)
else:
# Sequential execution
sample_counter = 0
with tqdm(
total=total_samples,
desc=f"Writing {split_name} split",
disable=not verbose,
) as pbar:
if batch_ids_list:
for sample in gen_func(batch_ids_list):
write_sample(_, sample, var_features_keys, sample_counter)
sample_counter += 1
pbar.update(1)
else:
for sample in gen_func():
write_sample(_, sample, var_features_keys, sample_counter)
sample_counter += 1
pbar.update(1)
[docs]
def push_local_datasetdict_to_hub(
repo_id: str, local_dir: Union[str, Path], num_workers: int = 1
) -> None: # pragma: no cover
"""Pushes a local dataset directory to Hugging Face Hub.
This function uploads the contents of a local directory to a specified
Hugging Face repository as a dataset. It uses the HfApi to handle large
folder uploads with configurable parallelism.
Args:
repo_id (str): The Hugging Face repository ID where the dataset will be uploaded
(e.g., "username/dataset_name").
local_dir (str or Path): Path to the local directory containing the dataset files
to upload.
num_workers (int, optional): Number of worker threads to use for uploading.
Defaults to 1.
Returns:
None: This function does not return a value; it uploads the dataset directly
to Hugging Face Hub.
"""
api = HfApi()
api.upload_large_folder(
folder_path=local_dir,
repo_id=repo_id,
repo_type="dataset",
num_workers=num_workers,
ignore_patterns=["*.tmp"],
allow_patterns=["data/*"],
)