Source code for plaid.storage.common.bridge

# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
#

"""Common bridge utilities.

This module provides bridge functions for converting between PLAID samples and
storage formats, including flattening/unflattening and sample reconstruction.
"""

from typing import Any, Optional

import numpy as np

from plaid import Sample
from plaid.containers.features import SampleFeatures
from plaid.storage.common.preprocessor import build_sample_dict
from plaid.storage.common.tree_handling import unflatten_cgns_tree


def _split_dict(d: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
    """Split a dictionary into values and times based on key suffixes.

    Args:
        d: Dictionary with keys that may end with '_times'.

    Returns:
        tuple: (vals, times) where vals has non-times keys, times has times keys.
    """
    vals = {}
    times = {}
    for k, v in d.items():
        if k.endswith("_times"):
            times[k[:-6]] = v
        else:
            vals[k] = v
    return vals, times


def _split_dict_feat(
    d: dict[str, Any], features_set: set[str]
) -> tuple[dict[str, Any], dict[str, Any]]:  # pragma: no cover
    """Split a dictionary into values and times, filtering by features set.

    Args:
        d: Dictionary with keys.
        features_set: Set of feature names to include.

    Returns:
        tuple: (vals, times) filtered by features_set.
    """
    vals = {}
    times = {}
    for k, v in d.items():
        if k.endswith("_times") and k[:-6] in features_set:
            times[k[:-6]] = v
        elif k in features_set:
            vals[k] = v
    return vals, times


[docs] def to_sample_dict( var_sample_dict: dict[str, Any], flat_cst: dict[str, Any], cgns_types: dict[str, str], features: Optional[list[str]] = None, ) -> dict[float, dict[str, Any]]: """Convert variable sample dict to time-based sample dict. Args: var_sample_dict: Variable features dictionary. flat_cst: Constant features dictionary. cgns_types: CGNS types dictionary. features: Optional list of features to include. Returns: dict: Time-based sample dictionary. """ assert not isinstance(flat_cst[next(iter(flat_cst))], dict), ( "did you provide the complete `flat_cst` instead of the one for the considered split?" ) if features is None: flat_cst_val, flat_cst_tim = _split_dict(flat_cst) row_val, row_tim = _split_dict(var_sample_dict) else: # pragma: no cover features_set = set(features) flat_cst_val, flat_cst_tim = _split_dict_feat(flat_cst, features_set) row_val, row_tim = _split_dict_feat(var_sample_dict, features_set) row_val.update(flat_cst_val) row_tim.update(flat_cst_tim) row_val = {p: row_val[p] for p in sorted(row_val)} row_tim = {p: row_tim[p] for p in sorted(row_tim)} sample_flat_trees = {} paths_none = {} for (path_t, times_struc), (path_v, val) in zip(row_tim.items(), row_val.items()): assert path_t == path_v, "did you forget to specify the features arg?" if val is None: assert times_struc is None if path_v not in paths_none and cgns_types[path_v] not in [ "DataArray_t", "IndexArray_t", ]: paths_none[path_v] = None else: times_struc = np.array(times_struc, dtype=np.float64).reshape((-1, 3)) for i, time in enumerate(times_struc[:, 0]): start = int(times_struc[i, 1]) end = int(times_struc[i, 2]) if end == -1: end = None if val.ndim > 1: values = val[:, start:end] else: values = val[start:end] if isinstance(values[0], str): values = np.frombuffer( values[0].encode("ascii", "strict"), dtype="|S1" ) if time in sample_flat_trees: sample_flat_trees[time][path_v] = values else: sample_flat_trees[time] = {path_v: values} for time, tree in sample_flat_trees.items(): bases = list(set([k.split("/")[0] for k in tree.keys()])) for base in bases: tree[f"{base}/Time"] = np.array([1], dtype=np.int32) tree[f"{base}/Time/IterationValues"] = np.array([1], dtype=np.int32) tree[f"{base}/Time/TimeValues"] = np.array([time], dtype=np.float64) tree["CGNSLibraryVersion"] = np.array([4.0], dtype=np.float32) tree.update(paths_none) return sample_flat_trees
[docs] def to_plaid_sample( sample_dict: dict[float, dict[str, Any]], cgns_types: dict[str, str], ) -> Sample: """Convert sample dict to PLAID Sample. Args: sample_dict: Time-based sample dictionary. cgns_types: CGNS types dictionary. Returns: Sample: The reconstructed PLAID Sample. """ sample_data = {} for time, flat_tree in sample_dict.items(): sample_data[time] = unflatten_cgns_tree(flat_tree, cgns_types) return Sample(path=None, features=SampleFeatures(sample_data))
[docs] def plaid_to_sample_dict( sample: Sample, variable_schema: dict[str, Any], constant_schema: dict[str, Any] ) -> dict[str, Any]: """Convert PLAID Sample to sample dict. Args: sample: The PLAID Sample. variable_schema: Variable schema dictionary. constant_schema: Constant schema dictionary. Returns: dict[str, Any]: sample_dict """ var_features = list(variable_schema.keys()) cst_features = list(constant_schema.keys()) hf_sample, _, _ = build_sample_dict(sample) var_sample_dict = {path: hf_sample.get(path, None) for path in var_features} cst_sample_dict = {path: hf_sample.get(path, None) for path in cst_features} return cst_sample_dict | var_sample_dict