Source code for plaid.utils.init_with_tabular
"""Utility functions to initialize a Dataset with tabular data."""
# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
#
# %% Imports
import logging
import numpy as np
from plaid import Dataset, Sample
# from plaid.quantity import QuantityValueType
logger = logging.getLogger(__name__)
# %% Functions
[docs]
def initialize_dataset_with_tabular_data(
tabular_data: dict[str, np.ndarray],
) -> Dataset:
"""Initialize a Dataset with tabular data.
This function takes a dictionary of tabular data where keys represent scalar names,
and values are numpy arrays of the same length. It creates a Dataset and adds samples
to it based on the provided tabular data.
Args:
tabular_data (dict[str,np.ndarray]): A dictionary of scalar names and corresponding numpy arrays.
Returns:
Dataset: A Dataset initialized with the tabular data.
Raises:
AssertionError: If the lengths of the numpy arrays in tabular data are not identical.
Example:
.. code-block:: python
import numpy as np
from plaid.utils.init import initialize_dataset_with_tabular_data
tabular_data = {'feature1': np.array([1, 2, 3]), 'feature2': np.array([4, 5, 6])}
dataset = initialize_dataset_with_tabular_data(tabular_data)
"""
lengths = [len(value) for value in tabular_data.values()]
assert len(list(set(lengths))) == 1, "sizes not identical in tabular data"
dataset = Dataset()
nb_samples = lengths[0]
for i in range(nb_samples):
sample = Sample()
for scalar_name, value in tabular_data.items():
sample.add_scalar(scalar_name, value[i])
dataset.add_sample(sample)
# TODO:
# logger.info("Pour l'instant on boucle sur les samples, il y a probablement mieux à faire, mais l'API est simple")
return dataset
# def initialize_quantity_dataset_with_tabular_data(tabular_data:dict[str,Union[list[QuantityValueType],np.ndarray]]) -> Dataset:
# """_summary_
# Args:
# tabular_data (dict[str,Union[list[QuantityValueType],np.ndarray]]):
# `feature_name` -> tabular values
# Returns:
# Dataset
# """
# lengths = [len(value) for value in tabular_data.values()]
# assert len(list(set(lengths))) == 1, "sizes not identical in tabular data"
# #---# Adds data to collection
# data_collection = DataCollection()
# for name in tabular_data:
# storage = data_collection.add_storage('quantity', name)
# storage.add_values(tabular_data[name])
# #---# Link samples to data in collection
# dataset = Dataset()
# nb_samples = lengths[0]
# for i_samp in range(nb_samples):
# sample = Sample(data_collection = data_collection)
# for feature_name in tabular_data:
# sample.link_to_value("quantity", feature_name, i_samp)
# dataset.add_sample(sample)
# return dataset
# %% Classes