Skip to content

plaid.storage.cgns.writer

plaid.storage.cgns.writer

CGNS dataset writer module.

This module provides functionality for writing datasets in CGNS format for the PLAID library. It includes utilities for generating datasets from sample generators, saving to disk, uploading to Hugging Face Hub, and configuring dataset cards.

plaid.storage.cgns.writer.generate_datasetdict_to_disk

generate_datasetdict_to_disk(
    output_folder,
    generators,
    variable_schema=None,
    gen_kwargs=None,
    num_proc=1,
    verbose=False,
)

Generates and saves a dataset to disk in CGNS format.

Parameters:

  • output_folder (Union[str, Path]) –

    Base directory to save the dataset.

  • generators (dict[str, Callable[..., Generator[Sample, None, None]]]) –

    Dict of split generators.

  • variable_schema (Optional[dict[str, dict]], default: None ) –

    Unused variable schema.

  • gen_kwargs (Optional[dict[str, dict[str, Any]]], default: None ) –

    Optional generator kwargs for parallel processing.

  • num_proc (int, default: 1 ) –

    Number of processes.

  • verbose (bool, default: False ) –

    Whether to show progress.

Source code in plaid/storage/cgns/writer.py
def generate_datasetdict_to_disk(
    output_folder: Union[str, Path],
    generators: dict[str, Callable[..., Generator[Sample, None, None]]],
    variable_schema: Optional[dict[str, dict]] = None,  # noqa: ARG001
    gen_kwargs: Optional[dict[str, dict[str, Any]]] = None,
    num_proc: int = 1,
    verbose: bool = False,
) -> None:
    """Generates and saves a dataset to disk in CGNS format.

    Args:
        output_folder: Base directory to save the dataset.
        generators: Dict of split generators.
        variable_schema: Unused variable schema.
        gen_kwargs: Optional generator kwargs for parallel processing.
        num_proc: Number of processes.
        verbose: Whether to show progress.
    """
    output_folder = Path(output_folder)

    output_folder = output_folder / "data"
    output_folder.mkdir(exist_ok=True, parents=True)

    gen_kwargs_ = gen_kwargs or {sn: {} for sn in generators.keys()}

    for split_name, gen_func in generators.items():
        split_path = output_folder / split_name
        split_path.mkdir(exist_ok=True, parents=True)

        batch_ids_list = gen_kwargs_.get(split_name, {}).get("shards_ids", [])
        total_samples = (
            sum(len(batch) for batch in batch_ids_list) if batch_ids_list else None
        )

        if num_proc > 1:
            assert batch_ids_list, (
                f"Parallel mode requires gen_kwargs['{split_name}']['shards_ids'] "
                "to be provided and non-empty."
            )

            # deterministic start indices (prefix sums)
            start_indices = []
            s = 0
            for batch in batch_ids_list:
                start_indices.append(s)
                s += len(batch)

            jobs = [
                (str(split_path), gen_func, batch, start_idx)
                for batch, start_idx in zip(batch_ids_list, start_indices)
            ]

            # If your stack is sensitive to fork, switch to spawn:
            # ctx = mp.get_context("spawn")
            # with ctx.Pool(processes=num_proc) as pool:
            with mp.Pool(processes=num_proc) as pool:
                with tqdm(
                    total=total_samples,
                    desc=f"Writing {split_name} split",
                    disable=not verbose,
                ) as pbar:
                    for written in pool.imap_unordered(
                        _cgns_worker_batch_job, jobs, chunksize=1
                    ):
                        pbar.update(written)

        else:
            # Sequential execution
            sample_counter = 0
            with tqdm(
                total=total_samples,
                desc=f"Writing {split_name} split",
                disable=not verbose,
            ) as pbar:
                if batch_ids_list:
                    for sample in gen_func(batch_ids_list):
                        sample.save_to_dir(split_path / f"sample_{sample_counter:09d}")
                        sample_counter += 1
                        pbar.update(1)
                else:
                    for sample in gen_func():
                        sample.save_to_dir(split_path / f"sample_{sample_counter:09d}")
                        sample_counter += 1
                        pbar.update(1)

plaid.storage.cgns.writer.push_local_datasetdict_to_hub

push_local_datasetdict_to_hub(
    repo_id, local_dir, num_workers=1
)

Pushes a local dataset directory to Hugging Face Hub.

Parameters:

  • repo_id (str) –

    The repository ID.

  • local_dir (Union[str, Path]) –

    Local directory path.

  • num_workers (int, default: 1 ) –

    Number of upload workers.

Source code in plaid/storage/cgns/writer.py
def push_local_datasetdict_to_hub(
    repo_id: str, local_dir: Union[str, Path], num_workers: int = 1
) -> None:  # pragma: no cover
    """Pushes a local dataset directory to Hugging Face Hub.

    Args:
        repo_id: The repository ID.
        local_dir: Local directory path.
        num_workers: Number of upload workers.
    """
    api = HfApi()
    api.upload_large_folder(
        folder_path=local_dir,
        repo_id=repo_id,
        repo_type="dataset",
        num_workers=num_workers,
        ignore_patterns=["*.tmp"],
        allow_patterns=["data/*"],
    )

plaid.storage.cgns.writer.configure_dataset_card

configure_dataset_card(
    repo_id,
    infos,
    local_dir,
    viewer=None,
    pretty_name=None,
    dataset_long_description=None,
    illustration_urls=None,
    arxiv_paper_urls=None,
)

Configures and pushes a dataset card to Hugging Face Hub for a CGNS backend dataset.

This function generates a dataset card in YAML format with metadata, features, splits information, and usage examples. It automatically detects splits and sample counts from the local directory structure, then pushes the card to the specified Hugging Face repository.

Parameters:

  • repo_id (str) –

    The Hugging Face repository ID where the dataset card will be pushed.

  • infos (Infos) –

    Dataset metadata, including legal information like license.

  • local_dir (Union[str, Path]) –

    Path to the local directory containing the dataset files, expected to have a 'data' subdirectory with split folders.

  • viewer (Optional[bool], default: None ) –

    Unused parameter for viewer configuration.

  • pretty_name (Optional[str], default: None ) –

    A human-readable name for the dataset to display in the card.

  • dataset_long_description (Optional[str], default: None ) –

    A detailed description of the dataset to include in the card.

  • illustration_urls (Optional[list[str]], default: None ) –

    List of URLs to images that illustrate the dataset, displayed in the card.

  • arxiv_paper_urls (Optional[list[str]], default: None ) –

    List of arXiv URLs for papers related to the dataset, included as sources.

Returns:

  • None ( None ) –

    This function does not return a value; it pushes the dataset card directly to Hugging Face Hub.

Source code in plaid/storage/cgns/writer.py
def configure_dataset_card(
    repo_id: str,
    infos: Infos,
    local_dir: Union[str, Path],
    viewer: Optional[bool] = None,  # noqa: ARG001
    pretty_name: Optional[str] = None,
    dataset_long_description: Optional[str] = None,
    illustration_urls: Optional[list[str]] = None,
    arxiv_paper_urls: Optional[list[str]] = None,
) -> None:  # pragma: no cover
    """Configures and pushes a dataset card to Hugging Face Hub for a CGNS backend dataset.

    This function generates a dataset card in YAML format with metadata, features,
    splits information, and usage examples. It automatically detects splits and
    sample counts from the local directory structure, then pushes the card to
    the specified Hugging Face repository.

    Args:
        repo_id (str): The Hugging Face repository ID where the dataset card will be pushed.
        infos (Infos): Dataset metadata,
            including legal information like license.
        local_dir (Union[str, Path]): Path to the local directory containing the
            dataset files, expected to have a 'data' subdirectory with split folders.
        viewer (Optional[bool]): Unused parameter for viewer configuration.
        pretty_name (Optional[str]): A human-readable name for the dataset to
            display in the card.
        dataset_long_description (Optional[str]): A detailed description of the
            dataset to include in the card.
        illustration_urls (Optional[list[str]]): List of URLs to images that
            illustrate the dataset, displayed in the card.
        arxiv_paper_urls (Optional[list[str]]): List of arXiv URLs for papers
            related to the dataset, included as sources.

    Returns:
        None: This function does not return a value; it pushes the dataset card
            directly to Hugging Face Hub.
    """
    infos_dict = infos.model_dump(exclude_none=True)
    dataset_card_str = """---
task_categories:
- graph-ml
tags:
- physics learning
- geometry learning
---
"""
    local_folder = Path(local_dir)
    split_names = [p.name for p in (local_folder / "data").iterdir() if p.is_dir()]

    nbe_samples = {}
    num_bytes = {}
    size_bytes = 0
    for sn in split_names:
        nbe_samples[sn] = sum(
            1
            for p in (local_folder / "data" / f"{sn}").iterdir()
            if p.is_dir() and p.name.startswith("sample_")
        )
        num_bytes[sn] = sum(
            f.stat().st_size
            for f in (local_folder / "data" / f"{sn}").rglob("*")
            if f.is_file()
        )
        size_bytes += num_bytes[sn]

    lines = dataset_card_str.splitlines()
    lines = [s for s in lines if not s.startswith("license")]

    indices = [i for i, line in enumerate(lines) if line.strip() == "---"]

    assert len(indices) >= 2, (
        "Cannot find two instances of '---', you should try to update a correct dataset_card."
    )
    lines = lines[: indices[1] + 1]

    count = 6
    lines.insert(count, f"license: {infos.license}")
    count += 1
    lines.insert(count, "viewer: false")
    count += 1
    if pretty_name:
        lines.insert(count, f"pretty_name: {pretty_name}")
        count += 1

    lines.insert(count, "dataset_info:")
    count += 1
    lines.insert(count, "  splits:")
    count += 1
    for sn in split_names:
        lines.insert(count, f"    - name: {sn}")
        count += 1
        lines.insert(count, f"      num_bytes: {num_bytes[sn]}")
        count += 1
        lines.insert(count, f"      num_examples: {nbe_samples[sn]}")
        count += 1
    lines.insert(count, f"  download_size: {size_bytes}")
    count += 1
    lines.insert(count, f"  dataset_size: {size_bytes}")
    count += 1
    lines.insert(count, "configs:")
    count += 1
    lines.insert(count, "- config_name: default")
    count += 1
    lines.insert(count, "  data_files:")
    count += 1
    for sn in split_names:
        lines.insert(count, f"  - split: {sn}")
        count += 1
        lines.insert(count, f"    path: data/{sn}/*")
        count += 1

    str__ = "\n".join(lines) + "\n"

    if illustration_urls:
        str__ += "<p align='center'>\n"
        for url in illustration_urls:
            str__ += f"<img src='{url}' alt='{url}' width='1000'/>\n"
        str__ += "</p>\n\n"

    str__ += (
        f"```yaml\n{yaml.dump(infos_dict, sort_keys=False, allow_unicode=True)}\n```"
    )

    str__ += """
This dataset was generated with [`plaid`](https://plaid-lib.readthedocs.io/), we refer to this documentation for additional details on how to extract data from `plaid_sample` objects.

The simplest way to use this dataset is to first download it:
```python
from plaid.storage import download_from_hub

repo_id = "channel/dataset"
local_folder = "downloaded_dataset"

download_from_hub(repo_id, local_folder)
```

Then, to iterate over the dataset and instantiate samples:
```python
from plaid.storage import init_from_disk

local_folder = "downloaded_dataset"
split_name = "train"

datasetdict, converterdict = init_from_disk(local_folder)

dataset = datasetdict[split]
converter = converterdict[split]

for i in range(len(dataset)):
    plaid_sample = converter.to_plaid(dataset, i)
```

It is possible to stream the data directly:
```python
from plaid.storage import init_streaming_from_hub

repo_id = "channel/dataset"

datasetdict, converterdict = init_streaming_from_hub(repo_id)

dataset = datasetdict[split]
converter = converterdict[split]

for sample_raw in dataset:
    plaid_sample = converter.sample_to_plaid(sample_raw)
```

Sample features can then be retrieved as follows:
```python
from plaid.storage import load_problem_definitions_from_disk
local_folder = "downloaded_dataset"
pb_defs = load_problem_definitions_from_disk(local_folder)

# or
from plaid.storage import load_problem_definitions_from_hub
repo_id = "channel/dataset"
pb_defs = load_problem_definitions_from_hub(repo_id)


pb_def = next(iter(pb_defs.values()))

plaid_sample = ... # use a method from above to instantiate a plaid sample

for t in plaid_sample.get_all_time_values():
    for path in pb_def.input_features:
        feature = plaid_sample.get_feature_by_path(path=path, time=t)
        ...
    for path in pb_def.output_features:
        feature = plaid_sample.get_feature_by_path(path=path, time=t)
        ...
```
"""
    str__ += "This dataset was generated in [PLAID](https://plaid-lib.readthedocs.io/), we refer to this documentation for additional details on how to extract data from `sample` objects.\n"

    if dataset_long_description:
        str__ += f"""
### Dataset Description
{dataset_long_description}
"""

    if arxiv_paper_urls:
        str__ += """
### Dataset Sources

- **Papers:**
"""
        for url in arxiv_paper_urls:
            str__ += f"   - [arxiv]({url})\n"

    dataset_card = DatasetCard(str__)
    dataset_card.push_to_hub(repo_id)