Skip to content

plaid.storage.zarr.writer

plaid.storage.zarr.writer

Zarr dataset writer module.

This module provides functionality for writing and managing datasets in Zarr format for the PLAID library. It includes utilities for generating datasets from sample generators, saving them to disk with optimized chunking, uploading to Hugging Face Hub, and configuring dataset cards with metadata and usage examples.

Key features: - Parallel and sequential dataset generation from generators - Automatic chunking for efficient storage - Integration with Hugging Face Hub for dataset sharing - Dataset card generation with splits, features, and documentation

plaid.storage.zarr.writer.write_sample

write_sample(
    split_root, sample, var_features_keys, sample_counter
)

Write a single PLAID sample to a Zarr group on disk.

This function serializes one Sample instance into a dedicated Zarr group under the given split root. Each sample is written as:

sample_<zero-padded index>/

Only variable features listed in var_features_keys are written. Feature paths are flattened before being used as Zarr array names.

Behavior
  • A new Zarr group named sample_{sample_counter:09d} is created.
  • Each selected feature is written as a Zarr array if its value is not None.
  • NumPy arrays with Unicode dtype (dtype.kind == 'U') are converted to UTF-8 encoded byte arrays to ensure stable storage (notably for Zarr v3).
  • Chunk sizes are automatically determined using _auto_chunks with a target chunk size of approximately 5 million elements.

Parameters:

  • split_root (Any) –

    Open Zarr group corresponding to a dataset split (e.g. zarr.open_group(..., mode="a")).

  • sample (Sample) –

    PLAID Sample object to serialize.

  • var_features_keys (list[str]) –

    List of feature paths (as defined in the variable schema) to extract and write for this sample.

  • sample_counter (int) –

    Global index of the sample within the split, used to generate the group name and ensure deterministic ordering.

Notes
  • The function assumes split_root already exists and is writable.
  • No schema validation is performed at write time.
  • Missing features (None values) are silently skipped.
  • The function is side-effect only and returns None.

Raises:

  • ContainsGroupError

    If a sample group with the same name already exists.

  • OSError / IOError

    If an underlying filesystem or Zarr write error occurs.

Source code in plaid/storage/zarr/writer.py
def write_sample(
    split_root: Any,
    sample: Sample,
    var_features_keys: list[str],
    sample_counter: int,
) -> None:
    """Write a single PLAID sample to a Zarr group on disk.

    This function serializes one ``Sample`` instance into a dedicated Zarr group
    under the given split root. Each sample is written as:

        sample_<zero-padded index>/

    Only variable features listed in ``var_features_keys`` are written. Feature
    paths are flattened before being used as Zarr array names.

    Behavior:
      - A new Zarr group named ``sample_{sample_counter:09d}`` is created.
      - Each selected feature is written as a Zarr array if its value is not ``None``.
      - NumPy arrays with Unicode dtype (``dtype.kind == 'U'``) are converted to
        UTF-8 encoded byte arrays to ensure stable storage (notably for Zarr v3).
      - Chunk sizes are automatically determined using ``_auto_chunks`` with a
        target chunk size of approximately 5 million elements.

    Args:
        split_root:
            Open Zarr group corresponding to a dataset split
            (e.g. ``zarr.open_group(..., mode="a")``).

        sample (Sample):
            PLAID ``Sample`` object to serialize.

        var_features_keys (list[str]):
            List of feature paths (as defined in the variable schema) to extract
            and write for this sample.

        sample_counter (int):
            Global index of the sample within the split, used to generate the
            group name and ensure deterministic ordering.

    Notes:
        - The function assumes ``split_root`` already exists and is writable.
        - No schema validation is performed at write time.
        - Missing features (``None`` values) are silently skipped.
        - The function is side-effect only and returns ``None``.

    Raises:
        zarr.errors.ContainsGroupError:
            If a sample group with the same name already exists.
        OSError / IOError:
            If an underlying filesystem or Zarr write error occurs.
    """
    sample_dict, _, _ = build_sample_dict(sample)
    sample_data = {path: sample_dict.get(path, None) for path in var_features_keys}

    g = split_root.create_group(f"sample_{sample_counter:09d}")
    for key, value in sample_data.items():
        if value is not None:
            if isinstance(value, np.ndarray) and value.dtype.kind == "U":
                # Unicode → UTF-8 bytes (stable Zarr V3)
                s = "".join(value.ravel().tolist())
                value = np.frombuffer(s.encode("utf-8"), dtype=np.uint8)

            g.create_array(
                flatten_path(key),
                data=value,
                chunks=_auto_chunks(value.shape, 5_000_000),
            )

    del sample_dict, sample_data, value
    gc.collect()

plaid.storage.zarr.writer.generate_datasetdict_to_disk

generate_datasetdict_to_disk(
    output_folder,
    generators,
    variable_schema,
    gen_kwargs=None,
    num_proc=1,
    verbose=False,
)

Generates and saves a dataset dictionary to disk in Zarr format.

This function processes sample generators for different dataset splits, converts samples to dictionaries, and writes them to Zarr arrays on disk. It supports both sequential and parallel processing modes. In parallel mode, gen_kwargs must be provided with batch information for each split.

Parameters:

  • output_folder (Union[str, Path]) –

    Base directory where the dataset will be saved. A 'data' subdirectory will be created inside this folder.

  • generators (dict[str, Callable[..., Generator[Sample, None, None]]]) –

    Dictionary mapping split names (e.g., "train", "test") to generator functions that yield Sample objects.

  • variable_schema (dict[str, dict]) –

    Schema describing the structure and types of variables/features in the samples.

  • gen_kwargs (Optional[dict[str, dict[str, IndexArrayType]]], default: None ) –

    Optional generator arguments for parallel processing. Must include "shards_ids" for each split when num_proc > 1. Required for parallel execution.

  • num_proc (int, default: 1 ) –

    Number of processes to use for parallel processing. Defaults to 1 (sequential). Must be > 1 only when gen_kwargs is provided.

  • verbose (bool, default: False ) –

    Whether to display progress bars during processing. Defaults to False.

Returns:

  • None ( None ) –

    This function does not return a value; it writes the dataset directly to disk.

Source code in plaid/storage/zarr/writer.py
def generate_datasetdict_to_disk(
    output_folder: Union[str, Path],
    generators: dict[str, Callable[..., Generator[Sample, None, None]]],
    variable_schema: dict[str, dict],
    gen_kwargs: Optional[dict[str, dict[str, Any]]] = None,
    num_proc: int = 1,
    verbose: bool = False,
) -> None:
    """Generates and saves a dataset dictionary to disk in Zarr format.

    This function processes sample generators for different dataset splits,
    converts samples to dictionaries, and writes them to Zarr arrays on disk.
    It supports both sequential and parallel processing modes. In parallel mode,
    gen_kwargs must be provided with batch information for each split.

    Args:
        output_folder (Union[str, Path]): Base directory where the dataset will be saved.
            A 'data' subdirectory will be created inside this folder.
        generators (dict[str, Callable[..., Generator[Sample, None, None]]]):
            Dictionary mapping split names (e.g., "train", "test") to generator
            functions that yield Sample objects.
        variable_schema (dict[str, dict]): Schema describing the structure and types
            of variables/features in the samples.
        gen_kwargs (Optional[dict[str, dict[str, IndexArrayType]]]): Optional
            generator arguments for parallel processing. Must include "shards_ids"
            for each split when num_proc > 1. Required for parallel execution.
        num_proc (int, optional): Number of processes to use for parallel processing.
            Defaults to 1 (sequential). Must be > 1 only when gen_kwargs is provided.
        verbose (bool, optional): Whether to display progress bars during processing.
            Defaults to False.

    Returns:
        None: This function does not return a value; it writes the dataset directly
            to disk.
    """
    output_folder = Path(output_folder) / "data"
    output_folder.mkdir(exist_ok=True, parents=True)

    var_features_keys = list(variable_schema.keys())
    gen_kwargs_ = gen_kwargs or {sn: {} for sn in generators.keys()}

    for split_name, gen_func in generators.items():
        split_root_path = str(output_folder / split_name)
        _ = zarr.open_group(split_root_path, mode="w")  # create/overwrite

        batch_ids_list = gen_kwargs_.get(split_name, {}).get("shards_ids", [])
        total_samples = (
            sum(len(batch) for batch in batch_ids_list) if batch_ids_list else None
        )

        if num_proc > 1:
            assert batch_ids_list, (
                f"Parallel mode requires gen_kwargs['{split_name}']['shards_ids'] "
                "to be provided and non-empty."
            )

            # deterministic start indices (prefix sums)
            start_indices = []
            s = 0
            for batch in batch_ids_list:
                start_indices.append(s)
                s += len(batch)

            jobs = [
                (split_root_path, gen_func, var_features_keys, batch, start_idx)
                for batch, start_idx in zip(batch_ids_list, start_indices)
            ]

            # If your platform/library stack is sensitive to fork, use spawn:
            # ctx = mp.get_context("spawn")
            # with ctx.Pool(processes=num_proc) as pool:
            with mp.Pool(processes=num_proc) as pool:
                with tqdm(
                    total=total_samples,
                    desc=f"Writing {split_name} split",
                    disable=not verbose,
                ) as pbar:
                    for written in pool.imap_unordered(
                        _zarr_worker_batch_job, jobs, chunksize=1
                    ):
                        pbar.update(written)

        else:
            # Sequential execution
            sample_counter = 0
            with tqdm(
                total=total_samples,
                desc=f"Writing {split_name} split",
                disable=not verbose,
            ) as pbar:
                if batch_ids_list:
                    for sample in gen_func(batch_ids_list):
                        write_sample(_, sample, var_features_keys, sample_counter)
                        sample_counter += 1
                        pbar.update(1)
                else:
                    for sample in gen_func():
                        write_sample(_, sample, var_features_keys, sample_counter)
                        sample_counter += 1
                        pbar.update(1)

plaid.storage.zarr.writer.push_local_datasetdict_to_hub

push_local_datasetdict_to_hub(
    repo_id, local_dir, num_workers=1
)

Pushes a local dataset directory to Hugging Face Hub.

This function uploads the contents of a local directory to a specified Hugging Face repository as a dataset. It uses the HfApi to handle large folder uploads with configurable parallelism.

Parameters:

  • repo_id (str) –

    The Hugging Face repository ID where the dataset will be uploaded (e.g., "username/dataset_name").

  • local_dir (str or Path) –

    Path to the local directory containing the dataset files to upload.

  • num_workers (int, default: 1 ) –

    Number of worker threads to use for uploading. Defaults to 1.

Returns:

  • None ( None ) –

    This function does not return a value; it uploads the dataset directly to Hugging Face Hub.

Source code in plaid/storage/zarr/writer.py
def push_local_datasetdict_to_hub(
    repo_id: str, local_dir: Union[str, Path], num_workers: int = 1
) -> None:  # pragma: no cover
    """Pushes a local dataset directory to Hugging Face Hub.

    This function uploads the contents of a local directory to a specified
    Hugging Face repository as a dataset. It uses the HfApi to handle large
    folder uploads with configurable parallelism.

    Args:
        repo_id (str): The Hugging Face repository ID where the dataset will be uploaded
            (e.g., "username/dataset_name").
        local_dir (str or Path): Path to the local directory containing the dataset files
            to upload.
        num_workers (int, optional): Number of worker threads to use for uploading.
            Defaults to 1.

    Returns:
        None: This function does not return a value; it uploads the dataset directly
            to Hugging Face Hub.
    """
    api = HfApi()
    api.upload_large_folder(
        folder_path=local_dir,
        repo_id=repo_id,
        repo_type="dataset",
        num_workers=num_workers,
        ignore_patterns=["*.tmp"],
        allow_patterns=["data/*"],
    )

plaid.storage.zarr.writer.configure_dataset_card

configure_dataset_card(
    repo_id,
    infos,
    local_dir,
    viewer=None,
    pretty_name=None,
    dataset_long_description=None,
    illustration_urls=None,
    arxiv_paper_urls=None,
)

Configures and pushes a dataset card to Hugging Face Hub for a zarr backend dataset.

This function generates a dataset card in YAML format with metadata, features, splits information, and usage examples. It automatically detects splits and sample counts from the local directory structure, then pushes the card to the specified Hugging Face repository.

Parameters:

  • repo_id (str) –

    The Hugging Face repository ID where the dataset card will be pushed.

  • infos (Infos) –

    Dataset metadata, including legal information like license.

  • local_dir (Union[str, Path]) –

    Path to the local directory containing the dataset files, expected to have a 'data' subdirectory with split folders.

  • viewer (Optional[bool], default: None ) –

    Unused parameter for viewer configuration.

  • pretty_name (Optional[str], default: None ) –

    A human-readable name for the dataset to display in the card.

  • dataset_long_description (Optional[str], default: None ) –

    A detailed description of the dataset to include in the card.

  • illustration_urls (Optional[list[str]], default: None ) –

    List of URLs to images that illustrate the dataset, displayed in the card.

  • arxiv_paper_urls (Optional[list[str]], default: None ) –

    List of arXiv URLs for papers related to the dataset, included as sources.

Returns:

  • None ( None ) –

    This function does not return a value; it pushes the dataset card directly to Hugging Face Hub.

Source code in plaid/storage/zarr/writer.py
def configure_dataset_card(
    repo_id: str,
    infos: Infos,
    local_dir: Union[str, Path],
    viewer: Optional[bool] = None,  # noqa: ARG001
    pretty_name: Optional[str] = None,
    dataset_long_description: Optional[str] = None,
    illustration_urls: Optional[list[str]] = None,
    arxiv_paper_urls: Optional[list[str]] = None,
) -> None:  # pragma: no cover
    """Configures and pushes a dataset card to Hugging Face Hub for a zarr backend dataset.

    This function generates a dataset card in YAML format with metadata, features,
    splits information, and usage examples. It automatically detects splits and
    sample counts from the local directory structure, then pushes the card to
    the specified Hugging Face repository.

    Args:
        repo_id (str): The Hugging Face repository ID where the dataset card will be pushed.
        infos (Infos): Dataset metadata,
            including legal information like license.
        local_dir (Union[str, Path]): Path to the local directory containing the
            dataset files, expected to have a 'data' subdirectory with split folders.
        viewer (Optional[bool]): Unused parameter for viewer configuration.
        pretty_name (Optional[str]): A human-readable name for the dataset to
            display in the card.
        dataset_long_description (Optional[str]): A detailed description of the
            dataset to include in the card.
        illustration_urls (Optional[list[str]]): List of URLs to images that
            illustrate the dataset, displayed in the card.
        arxiv_paper_urls (Optional[list[str]]): List of arXiv URLs for papers
            related to the dataset, included as sources.

    Returns:
        None: This function does not return a value; it pushes the dataset card
            directly to Hugging Face Hub.
    """
    infos_dict = infos.model_dump(exclude_none=True)
    dataset_card_str = """---
task_categories:
- graph-ml
tags:
- physics learning
- geometry learning
---
"""
    local_folder = Path(local_dir)
    split_names = [p.name for p in (local_folder / "data").iterdir() if p.is_dir()]

    nbe_samples = {}
    num_bytes = {}
    size_bytes = 0
    for sn in split_names:
        nbe_samples[sn] = sum(
            1
            for p in (local_folder / "data" / f"{sn}").iterdir()
            if p.is_dir() and p.name.startswith("sample_")
        )
        num_bytes[sn] = sum(
            f.stat().st_size
            for f in (local_folder / "data" / f"{sn}").rglob("*")
            if f.is_file()
        )
        size_bytes += num_bytes[sn]

    lines = dataset_card_str.splitlines()
    lines = [s for s in lines if not s.startswith("license")]

    indices = [i for i, line in enumerate(lines) if line.strip() == "---"]

    assert len(indices) >= 2, (
        "Cannot find two instances of '---', you should try to update a correct dataset_card."
    )
    lines = lines[: indices[1] + 1]

    count = 6
    lines.insert(count, f"license: {infos.license}")
    count += 1
    lines.insert(count, "viewer: false")
    count += 1
    if pretty_name:
        lines.insert(count, f"pretty_name: {pretty_name}")
        count += 1

    lines.insert(count, "dataset_info:")
    count += 1
    lines.insert(count, "  splits:")
    count += 1
    for sn in split_names:
        lines.insert(count, f"    - name: {sn}")
        count += 1
        lines.insert(count, f"      num_bytes: {num_bytes[sn]}")
        count += 1
        lines.insert(count, f"      num_examples: {nbe_samples[sn]}")
        count += 1
    lines.insert(count, f"  download_size: {size_bytes}")
    count += 1
    lines.insert(count, f"  dataset_size: {size_bytes}")
    count += 1
    lines.insert(count, "configs:")
    count += 1
    lines.insert(count, "- config_name: default")
    count += 1
    lines.insert(count, "  data_files:")
    count += 1
    for sn in split_names:
        lines.insert(count, f"  - split: {sn}")
        count += 1
        lines.insert(count, f"    path: data/{sn}/*")
        count += 1

    str__ = "\n".join(lines) + "\n"

    if illustration_urls:
        str__ += "<p align='center'>\n"
        for url in illustration_urls:
            str__ += f"<img src='{url}' alt='{url}' width='1000'/>\n"
        str__ += "</p>\n\n"

    str__ += (
        f"```yaml\n{yaml.dump(infos_dict, sort_keys=False, allow_unicode=True)}\n```"
    )

    str__ += """
This dataset was generated with [`plaid`](https://plaid-lib.readthedocs.io/), we refer to this documentation for additional details on how to extract data from `plaid_sample` objects.

The simplest way to use this dataset is to first download it:
```python
from plaid.storage import download_from_hub

repo_id = "channel/dataset"
local_folder = "downloaded_dataset"

download_from_hub(repo_id, local_folder)
```

Then, to iterate over the dataset and instantiate samples:
```python
from plaid.storage import init_from_disk

local_folder = "downloaded_dataset"
split_name = "train"

datasetdict, converterdict = init_from_disk(local_folder)

dataset = datasetdict[split]
converter = converterdict[split]

for i in range(len(dataset)):
    plaid_sample = converter.to_plaid(dataset, i)
```

It is possible to stream the data directly:
```python
from plaid.storage import init_streaming_from_hub

repo_id = "channel/dataset"

datasetdict, converterdict = init_streaming_from_hub(repo_id)

dataset = datasetdict[split]
converter = converterdict[split]

for sample_raw in dataset:
    plaid_sample = converter.sample_to_plaid(sample_raw)
```

Sample features can then be retrieved as follows:
```python
from plaid.storage import load_problem_definitions_from_disk
local_folder = "downloaded_dataset"
pb_defs = load_problem_definitions_from_disk(local_folder)

# or
from plaid.storage import load_problem_definitions_from_hub
repo_id = "channel/dataset"
pb_defs = load_problem_definitions_from_hub(repo_id)


pb_def = next(iter(pb_defs.values()))

plaid_sample = ... # use a method from above to instantiate a plaid sample

for t in plaid_sample.get_all_time_values():
    for path in pb_def.input_features:
        feature = plaid_sample.get_feature_by_path(path=path, time=t)
        ...
    for path in pb_def.output_features:
        feature = plaid_sample.get_feature_by_path(path=path, time=t)
        ...
```
"""

    if dataset_long_description:
        str__ += f"""
### Dataset Description
{dataset_long_description}
"""

    if arxiv_paper_urls:
        str__ += """
### Dataset Sources

- **Papers:**
"""
        for url in arxiv_paper_urls:
            str__ += f"   - [arxiv]({url})\n"

    dataset_card = DatasetCard(str__)
    dataset_card.push_to_hub(repo_id)