Skip to content

plaid.storage.common.preprocessor

plaid.storage.common.preprocessor

Common preprocessing utilities.

This module provides utilities for preprocessing PLAID samples into formats suitable for storage, including flattening CGNS trees, inferring data types, and handling parallel processing of sample shards.

plaid.storage.common.preprocessor.infer_dtype

infer_dtype(value)

Infer canonical dtype schema from a value.

Source code in plaid/storage/common/preprocessor.py
def infer_dtype(value: Any) -> dict[str, int | str]:
    """Infer canonical dtype schema from a value."""
    if value is None:  # pragma: no cover
        return {"dtype": "null", "ndim": 0}

    # Scalars
    if np.isscalar(value):  # pragma: no cover
        raise ValueError("CGNS should return arrays")

    # Arrays / lists
    elif isinstance(value, (list, tuple, np.ndarray)):
        arr = np.array(value)
        dtype = arr.dtype
        if np.issubdtype(dtype, np.floating):
            dt = "float32"
        elif np.issubdtype(dtype, np.int32):
            dt = "int32"
        elif np.issubdtype(dtype, np.int64):
            dt = "int64"
        elif np.issubdtype(dtype, np.str_):
            dt = "string"
        elif np.issubdtype(dtype, np.dtype("S1")):
            dt = "S1"
        else:  # pragma: no cover
            raise ValueError(f"Unrecognized scalar dtype: {dtype}")
        return {"dtype": dt, "ndim": arr.ndim}
        # arr = np.array(value)
        # return {"dtype": str(arr.dtype), "ndim": arr.ndim}

    raise TypeError(f"Unsupported type: {type(value)}")  # pragma: no cover

plaid.storage.common.preprocessor.build_sample_dict

build_sample_dict(sample)

Flatten a PLAID Sample's CGNS trees into Hugging Face–compatible arrays and metadata.

The function traverses every CGNS tree stored in sample.data (keyed by time), produces a flattened mapping path -> primitive value for each time, and then builds compact numpy arrays suitable for storage in a Hugging Face Dataset. Repeated value blocks that are identical across times are deduplicated and referenced by start/end indices; companion "_times" arrays describe, per time, the slice indices into the concatenated arrays.

Parameters:

  • sample (Sample) –

    A PLAID Sample whose features contain one or more CGNS trees (sample.data maps time -> CGNSTree).

Returns:

  • tuple ( tuple[dict[str, Any], set[str], dict[str, str]] ) –
    • sample_dict (dict[str, Any]): Mapping of flattened CGNS paths to either a numpy array (concatenation of per-time blocks) or None. For each path there is also an entry "_times" containing a flattened numpy array of triplets [time, start, end] (end == -1 indicates the block extends to the end of the array).
    • all_paths (list[str]): Sorted list of all considered variable feature paths (excluding Time-related nodes and CGNSLibraryVersion).
    • sample_cgns_types (dict[str, str]): Mapping from path to CGNS node type (metadata produced by flatten_cgns_tree).
Note
  • Byte-array encoded strings (dtype "|S1") are handled by reassembling and storing the string as a single-element numpy array; a sha256 hash is used for deduplication.
  • Deduplication reduces storage when identical blocks recur across times.
  • Paths containing "/Time" or "CGNSLibraryVersion" are ignored for variable features.
Source code in plaid/storage/common/preprocessor.py
def build_sample_dict(
    sample: Sample,
) -> tuple[dict[str, Any], set[str], dict[str, str]]:
    """Flatten a PLAID Sample's CGNS trees into Hugging Face–compatible arrays and metadata.

    The function traverses every CGNS tree stored in sample.data (keyed by time),
    produces a flattened mapping path -> primitive value for each time, and then builds
    compact numpy arrays suitable for storage in a Hugging Face Dataset. Repeated value
    blocks that are identical across times are deduplicated and referenced by start/end
    indices; companion "<path>_times" arrays describe, per time, the slice indices into
    the concatenated arrays.

    Args:
        sample (Sample): A PLAID Sample whose features contain one or more CGNS trees
            (sample.data maps time -> CGNSTree).

    Returns:
        tuple:
            - sample_dict (dict[str, Any]): Mapping of flattened CGNS paths to either a
              numpy array (concatenation of per-time blocks) or None. For each path
              there is also an entry "<path>_times" containing a flattened numpy array
              of triplets [time, start, end] (end == -1 indicates the block extends to
              the end of the array).
            - all_paths (list[str]): Sorted list of all considered variable feature paths
              (excluding Time-related nodes and CGNSLibraryVersion).
            - sample_cgns_types (dict[str, str]): Mapping from path to CGNS node type
              (metadata produced by flatten_cgns_tree).

    Note:
        - Byte-array encoded strings (dtype ``"|S1"``) are handled by reassembling and
          storing the string as a single-element numpy array; a sha256 hash is used
          for deduplication.
        - Deduplication reduces storage when identical blocks recur across times.
        - Paths containing "/Time" or "CGNSLibraryVersion" are ignored for variable features.
    """
    sample_flat_trees = {}
    sample_cgns_types = {}
    all_paths = set()

    # --- Flatten CGNS trees ---
    for time, tree in sample.data.items():
        flat, cgns_types = flatten_cgns_tree(tree)
        sample_flat_trees[time] = flat

        all_paths.update(
            k for k in flat.keys() if "/Time" not in k and "CGNSLibraryVersion" not in k
        )

        sample_cgns_types.update(cgns_types)

    sample_dict = {}

    for path in all_paths:
        sample_dict[path] = None
        sample_dict[path + "_times"] = None

        known_values = {}
        values_acc, times_acc = [], []
        current_length = 0

        for time, flat in sample_flat_trees.items():
            if path not in flat:
                continue  # pragma: no cover
            value = flat[path]

            # Handle byte-array encoded strings
            if (
                isinstance(value, np.ndarray)
                and value.dtype == np.dtype("|S1")
                and value.ndim == 1
            ):
                value_str = b"".join(value).decode("ascii")
                value_np = np.array([value_str])
                key = hashlib.sha256(value_str.encode("ascii")).hexdigest()
                size = 1
            elif value is not None:
                value_np = value
                key = hashlib.sha256(value.tobytes()).hexdigest()
                size = (
                    value.shape[-1]
                    if isinstance(value, np.ndarray) and value.ndim >= 1
                    else 1
                )
            else:
                continue

            # Deduplicate identical arrays
            if key in known_values:
                start, end = known_values[key]  # pragma: no cover
            else:
                start, end = current_length, current_length + size
                known_values[key] = (start, end)
                values_acc.append(value_np)
                current_length = end

            times_acc.append([time, start, end])

        # Build arrays
        if values_acc:
            try:
                sample_dict[path] = np.hstack(values_acc)
            except Exception:  # pragma: no cover
                sample_dict[path] = np.concatenate(
                    [np.atleast_1d(x) for x in values_acc]
                )

            if len(known_values) == 1:
                for t in times_acc:
                    t[-1] = -1
            sample_dict[path + "_times"] = np.array(times_acc).flatten()
        else:
            sample_dict[path] = None
            sample_dict[path + "_times"] = None

    # Convert lists to numpy arrays
    for k, v in sample_dict.items():
        if isinstance(v, list):
            sample_dict[k] = np.array(v)  # pragma: no cover

    return sample_dict, all_paths, sample_cgns_types

plaid.storage.common.preprocessor.process_shard

process_shard(
    generator_fn, progress, n_proc, shard_ids=None
)

Process a single shard of sample ids and collect per-shard metadata.

This function drives a shard-level pass over samples produced by generator_fn. For each sample it: - flattens the sample into Hugging Face friendly arrays (build_sample_dict), - collects observed flattened paths, - aggregates CGNS type metadata, - infers Hugging Face feature types for each path, - detects per-path constants using a content hash, - updates progress (either a multiprocessing.Queue or a tqdm progress bar).

Parameters:

  • shard_ids (IndexArrayType, default: None ) –

    Sequence of sample ids (a single shard) to process.

  • generator_fn (Callable) –

    Generator function accepting a list of shard id sequences and yielding Sample objects for those ids.

  • progress (Any) –

    Progress reporter; either a multiprocessing.Queue (for parallel execution) or a tqdm progress bar object (for sequential execution).

  • n_proc (int) –

    Number of worker processes used by the caller (used to decide how to report progress).

Returns:

  • tuple ( tuple[set[str], dict[str, str], dict[str, Any], dict[str, dict[str, Union[str, bool, int, set]]], int] ) –
    • split_all_paths (set[str]): Set of all flattened feature paths observed in the shard.
    • shard_global_cgns_types (dict[str, str]): Mapping path -> CGNS node type observed in the shard.
    • shard_global_feature_types (dict[str, Union[Value, Sequence]]): Inferred feature types per path.
    • split_constant_leaves (dict[str, dict]): Per-path metadata for constant detection. Each entry is a dict with keys "hash" (str), "constant" (bool) and "count" (int).
    • n_samples_processed (int): Number of samples processed in this shard.

Raises:

  • ValueError

    If inconsistent feature types are detected for the same path within the shard.

Source code in plaid/storage/common/preprocessor.py
def process_shard(
    generator_fn: Callable[..., Any],
    progress: Any,
    n_proc: int,
    shard_ids: Optional[Sequence[Any]] = None,
) -> tuple[
    set[str],
    dict[str, str],
    dict[str, Any],
    dict[str, dict[str, Union[str, bool, int, set]]],
    int,
]:
    """Process a single shard of sample ids and collect per-shard metadata.

    This function drives a shard-level pass over samples produced by `generator_fn`.
    For each sample it:
    - flattens the sample into Hugging Face friendly arrays (build_sample_dict),
    - collects observed flattened paths,
    - aggregates CGNS type metadata,
    - infers Hugging Face feature types for each path,
    - detects per-path constants using a content hash,
    - updates progress (either a multiprocessing.Queue or a tqdm progress bar).

    Args:
        shard_ids (IndexArrayType): Sequence of sample ids (a single shard) to process.
        generator_fn (Callable): Generator function accepting a list of shard id sequences
            and yielding Sample objects for those ids.
        progress (Any): Progress reporter; either a multiprocessing.Queue (for parallel
            execution) or a tqdm progress bar object (for sequential execution).
        n_proc (int): Number of worker processes used by the caller (used to decide
            how to report progress).

    Returns:
        tuple:
            - split_all_paths (set[str]): Set of all flattened feature paths observed in the shard.
            - shard_global_cgns_types (dict[str, str]): Mapping path -> CGNS node type observed in the shard.
            - shard_global_feature_types (dict[str, Union[Value, Sequence]]): Inferred feature types per path.
            - split_constant_leaves (dict[str, dict]): Per-path metadata for constant detection. Each entry
              is a dict with keys "hash" (str), "constant" (bool) and "count" (int).
            - n_samples_processed (int): Number of samples processed in this shard.

    Raises:
        ValueError: If inconsistent feature types are detected for the same path within the shard.
    """
    split_constant_leaves = {}
    split_all_paths = set()
    shard_global_cgns_types = {}
    shard_global_feature_types = {}

    if shard_ids is not None:
        generator = generator_fn([shard_ids])  # pragma: no cover
    else:
        generator = generator_fn()

    n_samples = 0
    for sample in generator:
        sample_dict, all_paths, sample_cgns_types = build_sample_dict(sample)

        split_all_paths.update(sample_dict.keys())
        shard_global_cgns_types.update(sample_cgns_types)

        # Feature type inference
        for path in all_paths:
            value = sample_dict[path]
            if value is None:
                continue
            inferred_dtype = infer_dtype(value)
            if path not in shard_global_feature_types:
                shard_global_feature_types[path] = inferred_dtype
            elif shard_global_feature_types[path] != inferred_dtype:
                raise ValueError(
                    f"Feature type mismatch for {path} in shard"
                )  # pragma: no cover

        # Constant detection using **hash only**
        for path, value in sample_dict.items():
            h = _hash_value(value)
            if path not in split_constant_leaves:
                split_constant_leaves[path] = {"hashes": {h}, "count": 1}
            else:
                entry = split_constant_leaves[path]
                entry["hashes"].add(h)
                entry["count"] += 1

        # Progress
        if n_proc > 1:
            progress.put(1)  # pragma: no cover
        else:
            progress.update(1)

        n_samples += 1

    return (
        split_all_paths,
        shard_global_cgns_types,
        shard_global_feature_types,
        split_constant_leaves,
        n_samples,
    )

plaid.storage.common.preprocessor.preprocess_splits

preprocess_splits(
    generators, gen_kwargs=None, num_proc=1, verbose=True
)

Pre-process dataset splits: inspect samples to infer features, constants and CGNS metadata.

This function iterates over the provided split generators (optionally in parallel), flattens each PLAID sample into Hugging Face friendly arrays, detects constant CGNS leaves (features identical across all samples in a split), infers global Hugging Face feature types, and aggregates CGNS type metadata.

The work is sharded per-split and each shard is processed by process_shard. In parallel mode, progress is updated via a multiprocessing.Queue; otherwise a tqdm progress bar is used.

Parameters:

  • generators (dict[str, Callable]) –

    Mapping from split name to a generator function. Each generator must accept a single argument (a sequence of shard ids) and yield PLAID samples.

  • gen_kwargs (dict[str, dict[str, list[IndexArrayType]]], default: None ) –

    Per-split kwargs used to drive generator invocation (e.g. {"train": {"shards_ids": [...]}}).

  • num_proc (int, default: 1 ) –

    Number of worker processes to use for shard-level parallelism. Defaults to 1.

  • verbose (bool, default: True ) –

    If True, displays progress bars. Defaults to True.

Returns:

  • tuple ( tuple[dict[str, set[str]], dict[str, dict[str, Any]], dict[str, set[str]], dict[str, str], dict[str, Any], dict[str, int]] ) –
    • split_all_paths (dict[str, set[str]]): For each split, the set of all observed flattened feature paths (including "_times" keys).
    • split_flat_cst (dict[str, dict[str, Any]]): For each split, a mapping of constant feature path -> value (constant parts of the tree).
    • split_var_path (dict[str, set[str]]): For each split, the set of variable feature paths (non-constant).
    • global_cgns_types (dict[str, str]): Aggregated mapping from flattened path -> CGNS node type.
    • global_feature_types (dict[str, Union[Value, Sequence]]): Aggregated inferred Hugging Face feature types for each variable path.
    • split_n_samples (dict[str, int]): For each split, the total number of samples processed.

Raises:

  • ValueError

    If inconsistent feature types or CGNS types are detected across shards/splits.

Source code in plaid/storage/common/preprocessor.py
def preprocess_splits(
    generators: dict[str, Callable[..., Generator[Sample, None, None]]],
    gen_kwargs: Optional[dict[str, dict[str, Any]]] = None,
    num_proc: int = 1,
    verbose: bool = True,
) -> tuple[
    dict[str, set[str]],
    dict[str, dict[str, Any]],
    dict[str, set[str]],
    dict[str, str],
    dict[str, Any],
    dict[str, int],
]:
    """Pre-process dataset splits: inspect samples to infer features, constants and CGNS metadata.

    This function iterates over the provided split generators (optionally in parallel),
    flattens each PLAID sample into Hugging Face friendly arrays, detects constant
    CGNS leaves (features identical across all samples in a split), infers global
    Hugging Face feature types, and aggregates CGNS type metadata.

    The work is sharded per-split and each shard is processed by `process_shard`.
    In parallel mode, progress is updated via a multiprocessing.Queue; otherwise a
    tqdm progress bar is used.

    Args:
        generators (dict[str, Callable]):
            Mapping from split name to a generator function. Each generator must
            accept a single argument (a sequence of shard ids) and yield PLAID samples.
        gen_kwargs (dict[str, dict[str, list[IndexArrayType]]]):
            Per-split kwargs used to drive generator invocation (e.g. {"train": {"shards_ids": [...]}}).
        num_proc (int, optional):
            Number of worker processes to use for shard-level parallelism. Defaults to 1.
        verbose (bool, optional):
            If True, displays progress bars. Defaults to True.

    Returns:
        tuple:
            - split_all_paths (dict[str, set[str]]):
                For each split, the set of all observed flattened feature paths (including "_times" keys).
            - split_flat_cst (dict[str, dict[str, Any]]):
                For each split, a mapping of constant feature path -> value (constant parts of the tree).
            - split_var_path (dict[str, set[str]]):
                For each split, the set of variable feature paths (non-constant).
            - global_cgns_types (dict[str, str]):
                Aggregated mapping from flattened path -> CGNS node type.
            - global_feature_types (dict[str, Union[Value, Sequence]]):
                Aggregated inferred Hugging Face feature types for each variable path.
            - split_n_samples (dict[str, int]):
                For each split, the total number of samples processed.

    Raises:
        ValueError: If inconsistent feature types or CGNS types are detected across shards/splits.
    """
    global_cgns_types = {}
    global_feature_types = {}
    split_flat_cst = {}
    split_var_path = {}
    split_all_paths = {}
    split_n_samples = {}

    gen_kwargs_ = gen_kwargs or {split_name: {} for split_name in generators.keys()}

    for split_name, generator_fn in generators.items():
        shards_ids_list = gen_kwargs_[split_name].get("shards_ids", [None])
        n_proc = max(1, num_proc or len(shards_ids_list))

        shards_data = []

        if n_proc == 1:
            with tqdm(
                disable=not verbose,
                desc=f"Pre-process split {split_name}",
            ) as pbar:
                for shard_ids in shards_ids_list:
                    shards_data.append(
                        process_shard(generator_fn, pbar, n_proc=1, shard_ids=shard_ids)
                    )

        else:  # pragma: no cover
            # Parallel execution
            manager = mp.Manager()
            progress_queue = manager.Queue()
            shards_data = []

            try:
                with mp.Pool(n_proc) as pool:
                    results = [
                        pool.apply_async(
                            _process_shard_debug,
                            args=(generator_fn, progress_queue, n_proc, shard_ids),
                        )
                        for shard_ids in shards_ids_list
                    ]

                    total_samples = sum(len(shard) for shard in shards_ids_list)
                    completed = 0

                    with tqdm(
                        total=total_samples,
                        disable=not verbose,
                        desc=f"Pre-process split {split_name}",
                    ) as pbar:
                        while completed < total_samples:
                            try:
                                increment = progress_queue.get(timeout=0.5)
                                pbar.update(increment)
                                completed += increment
                            except Empty:
                                # Check for any crashed workers
                                for r in results:
                                    if r.ready():
                                        try:
                                            r.get(
                                                timeout=0.1
                                            )  # will raise worker exception if any
                                        except Exception as e:
                                            raise RuntimeError(f"Worker crashed: {e}")

                    # Collect all results
                    for r in results:
                        shards_data.append(r.get())

            finally:
                manager.shutdown()

        # Merge shard results
        split_all_paths[split_name] = set()
        split_constant_hashes = {}
        n_samples_total = 0

        for (
            all_paths,
            shard_cgns,
            shard_features,
            shard_constants,
            n_samples,
        ) in shards_data:
            split_all_paths[split_name].update(all_paths)
            global_cgns_types.update(shard_cgns)

            for path, inferred_dtype in shard_features.items():
                if path not in global_feature_types:
                    global_feature_types[path] = inferred_dtype
                elif global_feature_types[path] != inferred_dtype:
                    raise ValueError(  # pragma: no cover
                        f"Feature type mismatch for {path} in split {split_name}"
                    )

            for path, entry in shard_constants.items():
                if path not in split_constant_hashes:
                    split_constant_hashes[path] = entry
                else:  # pragma: no cover
                    existing = split_constant_hashes[path]
                    existing["hashes"].update(entry["hashes"])
                    existing["count"] += entry["count"]

            n_samples_total += n_samples

        split_n_samples[split_name] = n_samples_total

        # Retrieve **values** only for constant paths from first sample
        if gen_kwargs:
            first_sample = next(generator_fn([shards_ids_list[0]]))  # pragma: no cover
        else:
            first_sample = next(generator_fn())
        sample_dict, _, _ = build_sample_dict(first_sample)

        # Determine truly constant paths (same hash across all samples). A split
        # with a single sample has no cross-sample repetition to prove that a
        # value is constant. Keep only None-valued structural paths as constants
        # so sample-based backends still have typed per-sample data columns.
        # this make possible to work with dataset with only one sample
        if n_samples_total <= 1:
            constant_paths = [p for p, value in sample_dict.items() if value is None]
        else:
            constant_paths = [
                p
                for p, entry in split_constant_hashes.items()
                if len(entry["hashes"]) == 1 and entry["count"] == n_samples_total
            ]

        split_flat_cst[split_name] = {p: sample_dict[p] for p in sorted(constant_paths)}
        split_var_path[split_name] = {
            p
            for p in split_all_paths[split_name]
            if p not in split_flat_cst[split_name]
        }

    global_feature_types = {
        p: global_feature_types[p] for p in sorted(global_feature_types)
    }

    return (
        split_all_paths,
        split_flat_cst,
        split_var_path,
        global_cgns_types,
        global_feature_types,
        split_n_samples,
    )

plaid.storage.common.preprocessor.preprocess

preprocess(
    generators, gen_kwargs=None, num_proc=1, verbose=True
)

Preprocess generators to extract schemas and metadata.

Parameters:

  • generators (dict[str, Callable[..., Generator[Sample, None, None]]]) –

    Dict of split generators.

  • gen_kwargs (Optional[dict[str, dict[str, Any]]], default: None ) –

    Optional generator kwargs for parallel processing.

  • num_proc (int, default: 1 ) –

    Number of processes.

  • verbose (bool, default: True ) –

    Whether to show progress.

Returns:

  • tuple ( tuple[dict[str, dict[str, Any]], dict[str, Any], dict[str, dict[str, Any]], dict[str, int], dict[str, str]] ) –

    A 5-tuple (split_flat_cst, variable_schema, constant_schema, split_n_samples, global_cgns_types).

Source code in plaid/storage/common/preprocessor.py
def preprocess(
    generators: dict[str, Callable[..., Generator[Sample, None, None]]],
    gen_kwargs: Optional[dict[str, dict[str, Any]]] = None,
    num_proc: int = 1,
    verbose: bool = True,
) -> tuple[
    dict[str, dict[str, Any]],
    dict[str, Any],
    dict[str, dict[str, Any]],
    dict[str, int],
    dict[str, str],
]:
    """Preprocess generators to extract schemas and metadata.

    Args:
        generators: Dict of split generators.
        gen_kwargs: Optional generator kwargs for parallel processing.
        num_proc: Number of processes.
        verbose: Whether to show progress.

    Returns:
        tuple: A 5-tuple ``(split_flat_cst, variable_schema, constant_schema,
            split_n_samples, global_cgns_types)``.
    """
    (
        split_all_paths,
        split_flat_cst,
        split_var_path,
        global_cgns_types,
        global_feature_types,
        split_n_samples,
    ) = preprocess_splits(generators, gen_kwargs, num_proc, verbose)

    # --- build features ---
    var_features = sorted(list(set().union(*split_var_path.values())))
    # if len(var_features) == 0:  # pragma: no cover
    #    raise ValueError(
    #        "no variable feature found, is your dataset variable through samples?"
    #    )

    for split_name in split_flat_cst.keys():
        for path in var_features:
            if not path.endswith("_times") and path not in split_all_paths[split_name]:
                split_flat_cst[split_name][path + "_times"] = None  # pragma: no cover
            if path in split_flat_cst[split_name]:
                split_flat_cst[split_name].pop(path)  # pragma: no cover

    cst_features_by_split = {
        split_name: sorted(list(cst.keys()))
        for split_name, cst in split_flat_cst.items()
    }
    # var_features = [path for path in var_features if not path.endswith("_times")]
    # cst_features_by_split = {
    #     split: [path for path in cst_features if not path.endswith("_times")]
    #     for split, cst_features in cst_features_by_split.items()
    # }

    def _build_schema(feature_list: list[str], split_flat_cst=None) -> dict:
        schema = {}
        for path in feature_list:
            if path.endswith("_times"):
                if split_flat_cst and split_flat_cst[path] is None:
                    schema[path] = {"dtype": None}
                else:
                    schema[path] = {"dtype": "float64", "ndim": 1}
            elif path in global_feature_types:
                schema[path] = global_feature_types[path]
            else:
                schema[path] = {"dtype": None}
        return schema

    variable_schema = _build_schema(var_features)
    constant_schema = {
        split: _build_schema(cst_features_by_split[split], flat_cst)
        for split, flat_cst in split_flat_cst.items()
    }

    return (
        split_flat_cst,
        variable_schema,
        constant_schema,
        split_n_samples,
        global_cgns_types,
    )