Source code for plaid.utils.stats

"""Utility functions for computing statistics on datasets."""

# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
#

# %% Imports

import copy
import logging
import sys
from typing import Union

if sys.version_info >= (3, 11):
    from typing import Self
else:  # pragma: no cover
    from typing import TypeVar

[docs] Self = TypeVar("Self")
import numpy as np from plaid import Dataset, Sample from plaid.constants import CGNS_FIELD_LOCATIONS logger = logging.getLogger(__name__) # %% Functions
[docs] def aggregate_stats( sizes: np.ndarray, means: np.ndarray, vars: np.ndarray ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Compute aggregated statistics of a batch of already computed statistics (without original samples information). This function calculates aggregated statistics, such as the total number of samples, mean, and variance, by taking into account the statistics computed for each batch of data. cf: [Variance from (cardinal,mean,variance) of several statistical series](https://fr.wikipedia.org/wiki/Variance_(math%C3%A9matiques)#Formules) Args: sizes (np.ndarray): An array containing the sizes (number of samples) of each batch. Expect shape (n_batches,1). means (np.ndarray): An array containing the means of each batch. Expect shape (n_batches, n_features). vars (np.ndarray): An array containing the variances of each batch. Expect shape (n_batches, n_features). Returns: tuple[np.ndarray,np.ndarray,np.ndarray]: A tuple containing the aggregated statistics in the following order: - Total number of samples in all batches. - Weighted mean calculated from the batch means. - Weighted variance calculated from the batch variances, considering the means. """ assert sizes.ndim == 1 assert means.ndim == 2 assert len(sizes) == len(means) assert means.shape == vars.shape sizes = sizes.reshape((-1, 1)) total_n_samples = np.sum(sizes) total_mean = np.sum(sizes * means, axis=0, keepdims=True) / total_n_samples total_var = ( np.sum(sizes * (vars + (total_mean - means) ** 2), axis=0, keepdims=True) / total_n_samples ) return total_n_samples, total_mean, total_var
# %% Classes
[docs] class OnlineStatistics(object): """OnlineStatistics is a class for computing online statistics of numpy arrays. This class computes running statistics (min, max, mean, variance, std) for streaming data without storing all samples in memory. Example: >>> stats = OnlineStatistics() >>> stats.add_samples(np.array([[1, 2], [3, 4]])) >>> stats.add_samples(np.array([[5, 6]])) >>> print(stats.get_stats()['mean']) [[3. 4.]] """ def __init__(self) -> None: """Initialize an empty OnlineStatistics object."""
[docs] self.n_samples: int = 0
[docs] self.n_features: int = None
[docs] self.n_points: int = None
[docs] self.min: np.ndarray = None
[docs] self.max: np.ndarray = None
[docs] self.mean: np.ndarray = None
[docs] self.var: np.ndarray = None
[docs] self.std: np.ndarray = None
[docs] def add_samples(self, x: np.ndarray, n_samples: int = None) -> None: """Add samples to compute statistics for. Args: x (np.ndarray): The input numpy array containing samples data. Expect 2D arrays with shape (n_samples, n_features). n_samples (int, optional): The number of samples in the input array. If not provided, it will be inferred from the shape of `x`. Use this argument when the input array has already been flattened because of shape inconsistencies. Raises: ValueError: Raised when input contains NaN or Inf values. """ # Validate input if not isinstance(x, np.ndarray): raise TypeError("Input must be a numpy array") if np.any(~np.isfinite(x)): raise ValueError("Input contains NaN or Inf values") # Handle 1D arrays if x.ndim == 1: if self.min is not None: if self.min.shape[1] == 1: x = x.reshape((-1, 1)) else: x = x.reshape((1, -1)) else: x = x.reshape((-1, 1)) # Default to column vector # Handle n-dimensional arrays elif x.ndim > 2: # if we have array of shape (n_samples, n_points, n_features) # it will be reshaped to (n_samples * n_points, n_features) x = x.reshape((-1, x.shape[-1])) if self.n_features is None: self.n_features = x.shape[1] if x.shape[1] != self.n_features: # it means that stats where previously on a per-point mode, # but it is no longer possible as the new added samples have a different shape # so we need to shift the stats to a per-sample mode, and then flatten the stats array self.flatten_array() n_samples = x.shape[0] x = x.reshape((-1, 1)) added_n_samples = len(x) if n_samples is None else n_samples added_n_points = x.size added_min = np.min(x, axis=0, keepdims=True) added_max = np.max(x, axis=0, keepdims=True) added_mean = np.mean(x, axis=0, keepdims=True) added_var = np.var(x, axis=0, keepdims=True) if ( (self.n_samples == 0) or (self.min is None) or (self.max is None) or (self.mean is None) or (self.var is None) ): self.n_samples = added_n_samples self.n_points = added_n_points self.min = added_min self.max = added_max self.mean = added_mean self.var = added_var else: self.min = np.min( np.concatenate((self.min, added_min), axis=0), axis=0, keepdims=True ) self.max = np.max( np.concatenate((self.max, added_max), axis=0), axis=0, keepdims=True ) if self.n_features > 1: # feature not flattened, we are on a per-sample mode self.n_points += added_n_points self.n_samples, self.mean, self.var = aggregate_stats( np.array([self.n_samples, added_n_samples]), np.concatenate([self.mean, added_mean]), np.concatenate([self.var, added_var]), ) else: # feature flattened, we are on a per-point mode self.n_samples += added_n_samples self.n_points, self.mean, self.var = aggregate_stats( np.array([self.n_points, added_n_points]), np.concatenate([self.mean, added_mean]), np.concatenate([self.var, added_var]), ) self.std = np.sqrt(self.var)
[docs] def merge_stats(self, other: Self) -> None: """Merge statistics from another instance. Args: other (Self): The other instance to merge statistics from. """ if not isinstance(other, self.__class__): raise TypeError("Can only merge with another instance of the same class") if self.n_features != other.n_features: # flatten both self.flatten_array() other = copy.deepcopy(other) other.flatten_array() assert self.min.shape == other.min.shape, ( "Shape mismatch in OnlineStatistics merging" ) self.min = np.min( np.concatenate((self.min, other.min), axis=0), axis=0, keepdims=True ) self.max = np.max( np.concatenate((self.max, other.max), axis=0), axis=0, keepdims=True ) self.n_points += other.n_points self.n_samples, self.mean, self.var = aggregate_stats( np.array([self.n_samples, other.n_samples]), np.concatenate([self.mean, other.mean]), np.concatenate([self.var, other.var]), ) self.std = np.sqrt(self.var)
[docs] def flatten_array(self) -> None: """When a shape incoherence is detected, you should call this function.""" self.min = np.min(self.min, keepdims=True).reshape(1, 1) self.max = np.max(self.max, keepdims=True).reshape(1, 1) self.n_points = self.n_samples * self.n_features assert self.mean.shape == self.var.shape self.n_points, self.mean, self.var = aggregate_stats( np.array([self.n_samples] * self.n_features), self.mean.reshape(-1, 1), self.var.reshape(-1, 1), ) self.std = np.sqrt(self.var) self.n_features = 1
[docs] def get_stats(self) -> dict[str, Union[int, np.ndarray]]: """Get computed statistics. Returns: dict[str, Union[int, np.ndarray]]: A dictionary containing computed statistics. The shapes of the arrays depend on the input data and may vary. """ return { "n_samples": self.n_samples, "n_points": self.n_points, "n_features": self.n_features, "min": self.min, "max": self.max, "mean": self.mean, "var": self.var, "std": self.std, }
[docs] class Stats: """Class for aggregating and computing statistics across datasets. The Stats class processes both scalar and field data from samples or datasets, computing running statistics like min, max, mean, variance and standard deviation. Attributes: _stats (dict[str, OnlineStatistics]): Dictionary mapping data identifiers to their statistics """ def __init__(self): """Initialize an empty Stats object."""
[docs] self._stats: dict[str, OnlineStatistics] = {}
self._feature_is_flattened: dict[str, bool] = {}
[docs] def add_dataset(self, dset: Dataset) -> None: """Add a dataset to compute statistics for. Args: dset (Dataset): The dataset to add. """ self.add_samples(dset)
[docs] def add_samples(self, samples: Union[list[Sample], Dataset]) -> None: """Add samples or a dataset to compute statistics for. Compute stats for each features present in the samples among scalars and fields. For fields, as long as the added samples have the same shape as the existing ones, the stats will be computed per-coordinates (n_features=x.shape[-1]). But as soon as the shapes differ, the stats and added fields will be flattened (n_features=1), then stats will be computed over all values of the field. Args: samples (Union[list[Sample], Dataset]): List of samples or dataset to process Raises: TypeError: If samples is not a list[Sample] or Dataset ValueError: If a sample contains invalid data """ # Input validation if not isinstance(samples, (list, Dataset)): raise TypeError("samples must be a list[Sample] or Dataset") # Process each sample new_data: dict[str, list] = {} for sample in samples: # Process scalars self._process_scalar_data(sample, new_data) # Process fields self._process_field_data(sample, new_data) # ---# SpatialSupport (Meshes) # TODO # ---# TemporalSupport # TODO # ---# Categorical # TODO # Update statistics self._update_statistics(new_data)
[docs] def get_stats( self, identifiers: list[str] = None ) -> dict[str, dict[str, np.ndarray]]: """Get computed statistics for specified data identifiers. Args: identifiers (list[str], optional): List of data identifiers to retrieve. If None, returns statistics for all identifiers. Returns: dict[str, dict[str, np.ndarray]]: Dictionary mapping identifiers to their statistics """ if identifiers is None: identifiers = self.get_available_statistics() stats = {} for identifier in identifiers: if identifier in self._stats: stats[identifier] = {} for stat_name, stat_value in ( self._stats[identifier].get_stats().items() ): stats[identifier][stat_name] = stat_value # stats[identifier][stat_name] = np.squeeze(stat_value) return stats
[docs] def get_available_statistics(self) -> list[str]: """Get list of data identifiers with computed statistics. Returns: list[str]: List of data identifiers """ return sorted(self._stats.keys())
[docs] def clear_statistics(self) -> None: """Clear all computed statistics.""" self._stats.clear()
[docs] def merge_stats(self, other: Self) -> None: """Merge statistics from another Stats object. Args: other (Stats): Stats object to merge with """ for name, stats in other._stats.items(): if name not in self._stats: self._stats[name] = copy.deepcopy(stats) else: self._stats[name].merge_stats(stats)
def _process_scalar_data(self, sample: Sample, data_dict: dict[str, list]) -> None: """Process scalar data from a sample. Args: sample (Sample): Sample containing scalar data data_dict (dict[str, list]): Dictionary to store processed data """ for name in sample.get_scalar_names(): if name not in data_dict: data_dict[name] = [] value = sample.get_scalar(name) if value is not None: data_dict[name].append(np.array(value).reshape((1, -1))) def _process_field_data(self, sample: Sample, data_dict: dict[str, list]) -> None: """Process field data from a sample. Args: sample (Sample): Sample containing field data data_dict (dict[str, list]): Dictionary to store processed data """ for time in sample.features.get_all_time_values(): for base_name in sample.features.get_base_names(time=time): for zone_name in sample.features.get_zone_names( base_name=base_name, time=time ): for location in CGNS_FIELD_LOCATIONS: for field_name in sample.get_field_names( location=location, zone_name=zone_name, base_name=base_name, time=time, ): stat_key = ( f"{base_name}/{zone_name}/{location}/{field_name}" ) if stat_key not in data_dict: data_dict[stat_key] = [] field = sample.get_field( field_name, location=location, zone_name=zone_name, base_name=base_name, time=time, ).reshape((1, -1)) if field is not None: # check if all previous arrays are the same shape as the new one that will be added to data_dict[stat_key] if len( data_dict[stat_key] ) > 0 and not self._feature_is_flattened.get( stat_key, False ): prev_shape = data_dict[stat_key][0].shape if field.shape != prev_shape: # set this stat as flattened self._feature_is_flattened[stat_key] = True # flatten corresponding stat if stat_key in self._stats: self._stats[stat_key].flatten_array() if self._feature_is_flattened.get(stat_key, False): field = field.reshape((-1, 1)) data_dict[stat_key].append(field) def _update_statistics(self, new_data: dict[str, list]) -> None: """Update running statistics with new data. Args: new_data (dict[str, list]): Dictionary containing new data to process """ for name, list_of_arrays in new_data.items(): if len(list_of_arrays) > 0: if name not in self._stats: self._stats[name] = OnlineStatistics() # internal check, should never happen if self._process_* functions work correctly for sample_id in range(len(list_of_arrays)): assert isinstance(list_of_arrays[sample_id], np.ndarray) assert list_of_arrays[sample_id].ndim == 2, ( f"for feature <{name}> -> {sample_id=}: {list_of_arrays[sample_id].ndim=} should be 2" ) if self._feature_is_flattened.get(name, False): # flatten all arrays in list_of_arrays n_samples = len(list_of_arrays) for i in range(len(list_of_arrays)): list_of_arrays[i] = list_of_arrays[i].reshape((-1, 1)) else: n_samples = None # Convert to numpy array and reshape if needed data = np.concatenate(list_of_arrays) assert data.ndim == 2 self._stats[name].add_samples(data, n_samples=n_samples)
# # old version of _update_statistics logic # for name in new_data: # # new_shapes = [value.shape for value in new_data[name] if value.shape!=new_data[name][0].shape] # # has_same_shape = (len(new_shapes)==0) # has_same_shape = True # if has_same_shape: # new_data[name] = np.array(new_data[name]) # else: # pragma: no cover ### remove "no cover" when "has_same_shape = True" is no longer used # if name in self._stats: # self._stats[name].flatten_array() # new_data[name] = np.concatenate( # [np.ravel(value) for value in new_data[name]] # ) # if new_data[name].ndim == 1: # new_data[name] = new_data[name].reshape((-1, 1)) # if name not in self._stats: # self._stats[name] = OnlineStatistics() # self._stats[name].add_samples(new_data[name]) # TODO : FAIRE DEUX FONCTIONS : # - compute_stats(samples) -> stats # - aggregate_stats(list[stats]) # TODO: reuse this ? more adapted to heterogenous data # def _compute_scalars_stats_(self) -> None: # nb_samples_with_scalars = 0 # scalars_have_timestamps = False # full_scalars = [] # full_scalars_timestamps = [] # for sample in self.samples: # if 'scalars' in sample._data: # nb_samples_with_scalars += 1 # if isinstance(sample._data['scalars'], dict): # scalars_have_timestamps = True # for k in sample._data['scalars']: # full_scalars_timestamps.append(k) # for val in sample._data['scalars'].values(): # full_scalars.append(val) # elif isinstance(sample._data['scalars'], tuple): # scalars_have_timestamps = True # full_scalars_timestamps.append(sample._data['scalars'][0]) # full_scalars.append(sample._data['scalars'][1]) # else: # full_scalars.append(sample._data['scalars']) # if nb_samples_with_scalars>0: # full_scalars = np.array(full_scalars) # logger.debug("full_scalars.shape: {}".format(full_scalars.shape)) # self._stats['scalars'] = { # 'min': np.min(full_scalars, axis=0), # 'max': np.max(full_scalars, axis=0), # 'mean': np.mean(full_scalars, axis=0), # 'std': np.std(full_scalars, axis=0), # 'var': np.var(full_scalars, axis=0), # } # if scalars_have_timestamps: # full_scalars_timestamps = np.array(full_scalars_timestamps) # logger.debug("full_scalars_timestamps.shape: {}".format(full_scalars_timestamps.shape)) # self._stats['scalars_timestamps'] = { # 'min': np.min(full_scalars_timestamps), # 'max': np.max(full_scalars_timestamps), # 'mean': np.mean(full_scalars_timestamps), # 'std': np.std(full_scalars_timestamps), # 'var': np.var(full_scalars_timestamps), # } # def _compute_fields_stats_(self) -> None: # nb_samples_with_fields = 0 # fields_have_timestamps = False # full_fields = [] # full_fields_timestamps = [] # for sample in self.samples: # if 'fields' in sample._data: # nb_samples_with_fields += 1 # if isinstance(sample._data['fields'], dict): # fields_have_timestamps = True # for k in sample._data['fields']: # full_fields_timestamps.append(k) # for val in sample._data['fields'].values(): # full_fields.append(val) # elif isinstance(sample._data['fields'], tuple): # fields_have_timestamps = True # full_fields_timestamps.append(sample._data['fields'][0]) # full_fields.append(sample._data['fields'][1]) # else: # full_fields.append(sample._data['fields']) # if nb_samples_with_fields>0: # full_fields = np.concatenate(full_fields, axis=0) # logger.debug("full_fields.shape: {}".format(full_fields.shape)) # self._stats['fields'] = { # 'min': np.min(full_fields, axis=0), # 'max': np.max(full_fields, axis=0), # 'mean': np.mean(full_fields, axis=0), # 'std': np.std(full_fields, axis=0), # 'var': np.var(full_fields, axis=0), # } # if fields_have_timestamps: # full_fields_timestamps = np.array(full_fields_timestamps) # logger.debug("full_fields_timestamps.shape: {}".format(full_fields_timestamps.shape)) # self._stats['fields_timestamps'] = { # 'min': np.min(full_fields_timestamps), # 'max': np.max(full_fields_timestamps), # 'mean': np.mean(full_fields_timestamps), # 'std': np.std(full_fields_timestamps), # 'var': np.var(full_fields_timestamps), # } # def _compute_mesh_stats_(self) -> None: # nb_samples_with_mesh = 0 # mesh_have_timestamps = False # full_mesh = [] # full_mesh_timestamps = [] # for sample in self.samples: # if 'mesh' in sample._data: # nb_samples_with_mesh += 1 # if isinstance(sample._data['mesh'], dict): # mesh_have_timestamps = True # for k in sample._data['mesh']: # full_mesh_timestamps.append(k) # for val in sample._data['mesh'].values(): # full_mesh.append(val) # elif isinstance(sample._data['mesh'], tuple): # mesh_have_timestamps = True # full_mesh_timestamps.append(sample._data['mesh'][0]) # full_mesh.append(sample._data['mesh'][1]) # else: # full_mesh.append(sample._data['mesh']) # if nb_samples_with_mesh>0: # full_mesh = np.array(full_mesh) # logger.debug("full_mesh.shape: {}".format(full_mesh.shape)) # self._stats['mesh'] = { # 'min': np.min(full_mesh, axis=0), # 'max': np.max(full_mesh, axis=0), # 'mean': np.mean(full_mesh, axis=0), # 'std': np.std(full_mesh, axis=0), # 'var': np.var(full_mesh, axis=0), # } # if mesh_have_timestamps: # full_mesh_timestamps = np.array(full_mesh_timestamps) # logger.debug("full_mesh_timestamps.shape: {}".format(full_mesh_timestamps.shape)) # self._stats['mesh_timestamps'] = { # 'min': np.min(full_mesh_timestamps), # 'max': np.max(full_mesh_timestamps), # 'mean': np.mean(full_mesh_timestamps), # 'std': np.std(full_mesh_timestamps), # 'var': np.var(full_mesh_timestamps), # }