Source code for plaid.storage.hf_datasets.bridge

"""HF Datasets bridge utilities.

This module provides bridge functions for converting between PLAID datasets/samples
and Hugging Face Datasets format. It includes utilities for feature type conversion,
dataset generation from PLAID objects, and sample reconstruction.
"""

# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
#

from functools import partial
from typing import Any, Callable, Generator, Optional

import datasets
import numpy as np
import pyarrow as pa
from datasets import Features, Sequence, Value

from plaid import Sample
from plaid.storage.common.preprocessor import build_sample_dict
from plaid.types import IndexType


[docs] def convert_dtype_to_hf_feature(feature_type: dict[str, Any]): """Convert a PLAID feature type dict to Hugging Face Feature. Args: feature_type (dict): Dictionary with 'dtype' and 'ndim' keys. Returns: Features or Sequence: The corresponding HF feature type. """ base_dtype = feature_type["dtype"] ndim = feature_type["ndim"] feature = Value(base_dtype) for _ in range(ndim): feature = Sequence(feature) return feature
[docs] def convert_to_hf_feature(variable_schema: dict[str, dict]): """Convert a PLAID variable schema to Hugging Face Features. Args: variable_schema (dict[str, dict]): Mapping of variable names to type dicts. Returns: Features: The HF Features object. """ return Features( {k: convert_dtype_to_hf_feature(v) for k, v in variable_schema.items()} )
[docs] def generator_to_datasetdict( generators: dict[str, Callable[..., Generator[Sample, None, None]]], variable_schema: dict, cache_dir: str, gen_kwargs: Optional[dict[str, dict[str, list[IndexType]]]] = None, processes_number: int = 1, writer_batch_size: int = 1, ) -> datasets.DatasetDict: """Convert PLAID dataset generators into a Hugging Face `datasets.DatasetDict`. This function takes generator functions that yield PLAID samples and converts them into a Hugging Face DatasetDict. Each generator corresponds to a split (e.g., "train", "test") and the function processes samples by flattening their structure and converting them to the Hugging Face format based on the provided variable schema. Args: generators (dict[str, Callable[..., Generator[Sample, None, None]]]): Mapping from split names (e.g., "train", "test") to generator functions. Each generator function must yield PLAID Sample objects that will be converted to the Hugging Face format. variable_schema (dict[str, dict]): Dictionary defining the schema of variables/features in the dataset. Maps feature names to their type information (dtype and ndim). cache_dir (str): Directory path used as cache directory for the Hugging Face dataset generation process. gen_kwargs (dict[str, dict[str, list[IndexType]]], optional): Optional mapping from split names to dictionaries of keyword arguments to be passed to each generator function. Useful for passing split-specific parameters like sample indices. Default is None, which creates empty kwargs for each split. processes_number (int, optional): Number of parallel processes to use when materializing the dataset from the generators. Default is 1 (no parallelization). writer_batch_size (int, optional): Batch size used when writing samples to disk in Hugging Face format. Default is 1. Returns: datasets.DatasetDict: A Hugging Face DatasetDict containing one Dataset per split, where each dataset contains the samples generated by the corresponding generator. Example: >>> def train_generator(): ... for sample in train_samples: ... yield sample >>> def test_generator(): ... for sample in test_samples: ... yield sample >>> variable_schema = { ... "velocity_x": {"dtype": "float32", "ndim": 2}, ... "velocity_y": {"dtype": "float32", "ndim": 2} ... } >>> ds_dict = generator_to_datasetdict( ... generators={"train": train_generator, "test": test_generator}, ... variable_schema=variable_schema, ... cache_dir="/tmp/hf_cache", ... processes_number=4, ... writer_batch_size=10 ... ) >>> print(ds_dict) DatasetDict({ train: Dataset({ features: ['velocity_x', 'velocity_y'], num_rows: ... }), test: Dataset({ features: ['velocity_x', 'velocity_y'], num_rows: ... }) }) """ hf_features = convert_to_hf_feature(variable_schema) all_features_keys = list(variable_schema.keys()) def generator_fn(gen_func, all_features_keys, **kwargs): for sample in gen_func(**kwargs): hf_sample, _, _ = build_sample_dict(sample) yield {path: hf_sample.get(path, None) for path in all_features_keys} _dict = {} for split_name, gen_func in generators.items(): gen = partial(generator_fn, all_features_keys=all_features_keys) gen_kwargs_ = gen_kwargs or {split_name: {} for split_name in generators.keys()} _dict[split_name] = datasets.Dataset.from_generator( generator=gen, gen_kwargs={"gen_func": gen_func, **gen_kwargs_[split_name]}, features=hf_features, cache_dir=cache_dir, num_proc=processes_number, writer_batch_size=writer_batch_size, split=datasets.splits.NamedSplit(split_name), ) return datasets.DatasetDict(_dict)
[docs] def to_var_sample_dict( ds: datasets.Dataset, i: int, features: Optional[list[str]] = None, enforce_shapes: bool = True, ) -> dict[str, Optional[np.ndarray]]: """Convert a Hugging Face dataset row to a variable sample dict containing the features that vary in the dataset. Args: ds (datasets.Dataset): The Hugging Face dataset. i (int): The row index. features: Iterable of feature names (keys) to extract from the dataset. enforce_shapes (bool): Whether to enforce consistent shapes. Returns: dict[str, Optional[np.ndarray]]: The variable sample dictionary. """ table = ds.data if features is None: features = table.column_names else: missing = set(features) - set(table.column_names) if missing: # pragma: no cover raise KeyError(f"Missing features in hf_dataset: {sorted(missing)}") var_sample_dict = {} if not enforce_shapes: for name in features: value = table[name][i].values if value is None: var_sample_dict[name] = None # pragma: no cover else: var_sample_dict[name] = value.to_numpy(zero_copy_only=False) else: for name in features: if isinstance(table[name][i], pa.NullScalar): var_sample_dict[name] = None # pragma: no cover else: value = table[name][i].values if value is None: var_sample_dict[name] = None # pragma: no cover else: if isinstance(value, pa.ListArray): var_sample_dict[name] = np.stack( value.to_numpy(zero_copy_only=False) ) elif isinstance(value, pa.StringArray): # pragma: no cover var_sample_dict[name] = value.to_numpy(zero_copy_only=False) else: var_sample_dict[name] = value.to_numpy(zero_copy_only=True) return var_sample_dict
[docs] def sample_to_var_sample_dict( hf_sample: dict[str, Any], ) -> dict[str, Any]: """Convert a Hugging Face sample dict to variable sample dict. Args: hf_sample (dict): The HF sample dictionary. Returns: dict: The processed variable sample dictionary. """ var_sample_dict = {} for name, value in hf_sample.items(): if value is None: var_sample_dict[name] = None else: var_sample_dict[name] = np.array(value) return var_sample_dict