Source code for plaid.storage.zarr.writer

"""Zarr dataset writer module.

This module provides functionality for writing and managing datasets in Zarr format
for the PLAID library. It includes utilities for generating datasets from sample
generators, saving them to disk with optimized chunking, uploading to Hugging Face
Hub, and configuring dataset cards with metadata and usage examples.

Key features:
- Parallel and sequential dataset generation from generators
- Automatic chunking for efficient storage
- Integration with Hugging Face Hub for dataset sharing
- Dataset card generation with splits, features, and documentation
"""

# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
#

import gc
import multiprocessing as mp
from pathlib import Path
from typing import Callable, Generator, Optional, Union

import numpy as np
import yaml
import zarr
from huggingface_hub import DatasetCard, HfApi
from tqdm import tqdm

from plaid import Sample
from plaid.storage.common.bridge import flatten_path
from plaid.storage.common.preprocessor import build_sample_dict
from plaid.types import IndexType


def _auto_chunks(shape: tuple[int, ...], target_n: int) -> tuple[int, ...]:
    """Computes automatic chunk sizes for Zarr arrays based on shape and target size.

    Args:
        shape (tuple[int, ...]): The shape of the array.
        target_n (int): The target number of elements per chunk.

    Returns:
        tuple[int, ...]: The computed chunk sizes.
    """
    # ensure pure Python ints
    target_n = int(target_n)
    shape = tuple(int(s) for s in shape)

    # elements in one "row"
    elems_per_slice = int(np.prod(shape[1:]) or 1)

    rows = max(1, target_n // elems_per_slice)
    rows = min(rows, shape[0])  # cannot exceed the dimension size

    return (rows,) + shape[1:]


[docs] def write_sample(split_root, sample, var_features_keys, sample_counter): """Write a single PLAID sample to a Zarr group on disk. This function serializes one ``Sample`` instance into a dedicated Zarr group under the given split root. Each sample is written as: sample_<zero-padded index>/ Only variable features listed in ``var_features_keys`` are written. Feature paths are flattened before being used as Zarr array names. Behavior: - A new Zarr group named ``sample_{sample_counter:09d}`` is created. - Each selected feature is written as a Zarr array if its value is not ``None``. - NumPy arrays with Unicode dtype (``dtype.kind == 'U'``) are converted to UTF-8 encoded byte arrays to ensure stable storage (notably for Zarr v3). - Chunk sizes are automatically determined using ``_auto_chunks`` with a target chunk size of approximately 5 million elements. Args: split_root: Open Zarr group corresponding to a dataset split (e.g. ``zarr.open_group(..., mode="a")``). sample (Sample): PLAID ``Sample`` object to serialize. var_features_keys (list[str]): List of feature paths (as defined in the variable schema) to extract and write for this sample. sample_counter (int): Global index of the sample within the split, used to generate the group name and ensure deterministic ordering. Notes: - The function assumes ``split_root`` already exists and is writable. - No schema validation is performed at write time. - Missing features (``None`` values) are silently skipped. - The function is side-effect only and returns ``None``. Raises: zarr.errors.ContainsGroupError: If a sample group with the same name already exists. OSError / IOError: If an underlying filesystem or Zarr write error occurs. """ sample_dict, _, _ = build_sample_dict(sample) sample_data = {path: sample_dict.get(path, None) for path in var_features_keys} g = split_root.create_group(f"sample_{sample_counter:09d}") for key, value in sample_data.items(): if value is not None: if isinstance(value, np.ndarray) and value.dtype.kind == "U": # Unicode → UTF-8 bytes (stable Zarr V3) s = "".join(value.ravel().tolist()) value = np.frombuffer(s.encode("utf-8"), dtype=np.uint8) g.create_array( flatten_path(key), data=value, chunks=_auto_chunks(value.shape, 5_000_000), ) del sample_dict, sample_data, value gc.collect()
def _zarr_worker_batch_job(args) -> int: # pragma: no cover """Worker job for one shard/batch. args = (split_root_path, gen_func, var_features_keys, batch, start_index) Returns number of samples written. """ split_root_path, gen_func, var_features_keys, batch, start_index = args # split_root = zarr.open_group(split_root_path, mode="a") store = zarr.storage.LocalStore(split_root_path) split_root = zarr.group(store=store) sample_counter = start_index written = 0 # Keep original semantics: gen_func expects a list of batches for sample in gen_func([batch]): write_sample(split_root, sample, var_features_keys, sample_counter) sample_counter += 1 written += 1 return written
[docs] def generate_datasetdict_to_disk( output_folder: Union[str, Path], generators: dict[str, Callable[..., Generator[Sample, None, None]]], variable_schema: dict[str, dict], gen_kwargs: Optional[dict[str, dict[str, list[IndexType]]]] = None, num_proc: int = 1, verbose: bool = False, ) -> None: """Generates and saves a dataset dictionary to disk in Zarr format. This function processes sample generators for different dataset splits, converts samples to dictionaries, and writes them to Zarr arrays on disk. It supports both sequential and parallel processing modes. In parallel mode, gen_kwargs must be provided with batch information for each split. Args: output_folder (Union[str, Path]): Base directory where the dataset will be saved. A 'data' subdirectory will be created inside this folder. generators (dict[str, Callable[..., Generator[Sample, None, None]]]): Dictionary mapping split names (e.g., "train", "test") to generator functions that yield Sample objects. variable_schema (dict[str, dict]): Schema describing the structure and types of variables/features in the samples. gen_kwargs (Optional[dict[str, dict[str, list[IndexType]]]]): Optional generator arguments for parallel processing. Must include "shards_ids" for each split when num_proc > 1. Required for parallel execution. num_proc (int, optional): Number of processes to use for parallel processing. Defaults to 1 (sequential). Must be > 1 only when gen_kwargs is provided. verbose (bool, optional): Whether to display progress bars during processing. Defaults to False. Returns: None: This function does not return a value; it writes the dataset directly to disk. """ output_folder = Path(output_folder) / "data" output_folder.mkdir(exist_ok=True, parents=True) var_features_keys = list(variable_schema.keys()) gen_kwargs_ = gen_kwargs or {sn: {} for sn in generators.keys()} for split_name, gen_func in generators.items(): split_root_path = str(output_folder / split_name) _ = zarr.open_group(split_root_path, mode="w") # create/overwrite batch_ids_list = gen_kwargs_.get(split_name, {}).get("shards_ids", []) total_samples = ( sum(len(batch) for batch in batch_ids_list) if batch_ids_list else None ) if num_proc > 1: assert batch_ids_list, ( f"Parallel mode requires gen_kwargs['{split_name}']['shards_ids'] " "to be provided and non-empty." ) # deterministic start indices (prefix sums) start_indices = [] s = 0 for batch in batch_ids_list: start_indices.append(s) s += len(batch) jobs = [ (split_root_path, gen_func, var_features_keys, batch, start_idx) for batch, start_idx in zip(batch_ids_list, start_indices) ] # If your platform/library stack is sensitive to fork, use spawn: # ctx = mp.get_context("spawn") # with ctx.Pool(processes=num_proc) as pool: with mp.Pool(processes=num_proc) as pool: with tqdm( total=total_samples, desc=f"Writing {split_name} split", disable=not verbose, ) as pbar: for written in pool.imap_unordered( _zarr_worker_batch_job, jobs, chunksize=1 ): pbar.update(written) else: # Sequential execution sample_counter = 0 with tqdm( total=total_samples, desc=f"Writing {split_name} split", disable=not verbose, ) as pbar: if batch_ids_list: for sample in gen_func(batch_ids_list): write_sample(_, sample, var_features_keys, sample_counter) sample_counter += 1 pbar.update(1) else: for sample in gen_func(): write_sample(_, sample, var_features_keys, sample_counter) sample_counter += 1 pbar.update(1)
[docs] def push_local_datasetdict_to_hub( repo_id: str, local_dir: Union[str, Path], num_workers: int = 1 ) -> None: # pragma: no cover """Pushes a local dataset directory to Hugging Face Hub. This function uploads the contents of a local directory to a specified Hugging Face repository as a dataset. It uses the HfApi to handle large folder uploads with configurable parallelism. Args: repo_id (str): The Hugging Face repository ID where the dataset will be uploaded (e.g., "username/dataset_name"). local_dir (str or Path): Path to the local directory containing the dataset files to upload. num_workers (int, optional): Number of worker threads to use for uploading. Defaults to 1. Returns: None: This function does not return a value; it uploads the dataset directly to Hugging Face Hub. """ api = HfApi() api.upload_large_folder( folder_path=local_dir, repo_id=repo_id, repo_type="dataset", num_workers=num_workers, ignore_patterns=["*.tmp"], allow_patterns=["data/*"], )
[docs] def configure_dataset_card( repo_id: str, infos: dict[str, dict[str, str]], local_dir: Union[str, Path], viewer: Optional[bool] = None, # noqa: ARG001 pretty_name: Optional[str] = None, dataset_long_description: Optional[str] = None, illustration_urls: Optional[list[str]] = None, arxiv_paper_urls: Optional[list[str]] = None, ) -> None: # pragma: no cover """Configures and pushes a dataset card to Hugging Face Hub for a zarr backend dataset. This function generates a dataset card in YAML format with metadata, features, splits information, and usage examples. It automatically detects splits and sample counts from the local directory structure, then pushes the card to the specified Hugging Face repository. Args: repo_id (str): The Hugging Face repository ID where the dataset card will be pushed. infos (dict[str, dict[str, str]]): Dictionary containing dataset metadata, including legal information like license. local_dir (Union[str, Path]): Path to the local directory containing the dataset files, expected to have a 'data' subdirectory with split folders. variable_schema (Optional[dict]): Schema describing the variables/features in the dataset, used to generate the features section in the card. viewer (Optional[bool]): Unused parameter for viewer configuration. pretty_name (Optional[str]): A human-readable name for the dataset to display in the card. dataset_long_description (Optional[str]): A detailed description of the dataset to include in the card. illustration_urls (Optional[list[str]]): List of URLs to images that illustrate the dataset, displayed in the card. arxiv_paper_urls (Optional[list[str]]): List of arXiv URLs for papers related to the dataset, included as sources. Returns: None: This function does not return a value; it pushes the dataset card directly to Hugging Face Hub. """ dataset_card_str = """--- task_categories: - graph-ml tags: - physics learning - geometry learning --- """ local_folder = Path(local_dir) split_names = [p.name for p in (local_folder / "data").iterdir() if p.is_dir()] nbe_samples = {} num_bytes = {} size_bytes = 0 for sn in split_names: nbe_samples[sn] = sum( 1 for p in (local_folder / "data" / f"{sn}").iterdir() if p.is_dir() and p.name.startswith("sample_") ) num_bytes[sn] = sum( f.stat().st_size for f in (local_folder / "data" / f"{sn}").rglob("*") if f.is_file() ) size_bytes += num_bytes[sn] lines = dataset_card_str.splitlines() lines = [s for s in lines if not s.startswith("license")] indices = [i for i, line in enumerate(lines) if line.strip() == "---"] assert len(indices) >= 2, ( "Cannot find two instances of '---', you should try to update a correct dataset_card." ) lines = lines[: indices[1] + 1] count = 6 lines.insert(count, f"license: {infos['legal']['license']}") count += 1 lines.insert(count, "viewer: false") count += 1 if pretty_name: lines.insert(count, f"pretty_name: {pretty_name}") count += 1 lines.insert(count, "dataset_info:") count += 1 lines.insert(count, " splits:") count += 1 for sn in split_names: lines.insert(count, f" - name: {sn}") count += 1 lines.insert(count, f" num_bytes: {num_bytes[sn]}") count += 1 lines.insert(count, f" num_examples: {nbe_samples[sn]}") count += 1 lines.insert(count, f" download_size: {size_bytes}") count += 1 lines.insert(count, f" dataset_size: {size_bytes}") count += 1 lines.insert(count, "configs:") count += 1 lines.insert(count, "- config_name: default") count += 1 lines.insert(count, " data_files:") count += 1 for sn in split_names: lines.insert(count, f" - split: {sn}") count += 1 lines.insert(count, f" path: data/{sn}/*") count += 1 str__ = "\n".join(lines) + "\n" if illustration_urls: str__ += "<p align='center'>\n" for url in illustration_urls: str__ += f"<img src='{url}' alt='{url}' width='1000'/>\n" str__ += "</p>\n\n" str__ += f"```yaml\n{yaml.dump(infos, sort_keys=False, allow_unicode=True)}\n```" str__ += """ This dataset was generated with [`plaid`](https://plaid-lib.readthedocs.io/), we refer to this documentation for additional details on how to extract data from `plaid_sample` objects. The simplest way to use this dataset is to first download it: ```python from plaid.storage import download_from_hub repo_id = "channel/dataset" local_folder = "downloaded_dataset" download_from_hub(repo_id, local_folder) ``` Then, to iterate over the dataset and instantiate samples: ```python from plaid.storage import init_from_disk local_folder = "downloaded_dataset" split_name = "train" datasetdict, converterdict = init_from_disk(local_folder) dataset = datasetdict[split] converter = converterdict[split] for i in range(len(dataset)): plaid_sample = converter.to_plaid(dataset, i) ``` It is possible to stream the data directly: ```python from plaid.storage import init_streaming_from_hub repo_id = "channel/dataset" datasetdict, converterdict = init_streaming_from_hub(repo_id) dataset = datasetdict[split] converter = converterdict[split] for sample_raw in dataset: plaid_sample = converter.sample_to_plaid(sample_raw) ``` Plaid samples' features can be retrieved like the following: ```python from plaid.storage import load_problem_definitions_from_disk local_folder = "downloaded_dataset" pb_defs = load_problem_definitions_from_disk(local_folder) # or from plaid.storage import load_problem_definitions_from_hub repo_id = "channel/dataset" pb_defs = load_problem_definitions_from_hub(repo_id) pb_def = pb_defs[0] plaid_sample = ... # use a method from above to instantiate a plaid sample for t in plaid_sample.get_all_time_values(): for path in pb_def.get_in_features_identifiers(): plaid_sample.get_feature_by_path(path=path, time=t) for path in pb_def.get_out_features_identifiers(): plaid_sample.get_feature_by_path(path=path, time=t) ``` """ if dataset_long_description: str__ += f""" ### Dataset Description {dataset_long_description} """ if arxiv_paper_urls: str__ += """ ### Dataset Sources - **Papers:** """ for url in arxiv_paper_urls: str__ += f" - [arxiv]({url})\n" dataset_card = DatasetCard(str__) dataset_card.push_to_hub(repo_id)