Source code for plaid.utils.split

"""Utility function for splitting a Dataset."""

# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
#

# %% Imports

import logging
from typing import Any, Optional

import numpy as np
from numpy.typing import NDArray
from scipy.spatial.distance import cdist

from plaid import Dataset

logger = logging.getLogger(__name__)


# %% Functions


[docs] def split_dataset(dset: Dataset, options: dict[str, Any]) -> dict[str, int]: """Splits a Dataset in several sub Datasets. Args: dset(Dataset): dataset to be splited. options([str,Any]): may have keys 'shuffle', 'split_sizes', 'split_ratios' or 'split_ids': - 'split_sizes' is supposed to be a dict[str,int]: split name -> size of splited dataset - 'split_ratios' is supposed to be a dict[str,float]: split name -> size ratios of splited dataset - 'split_ids' is supposed to be a dict[str,np.ndarray(int)]: split name -> ids of samples in splited dataset - if 'shuffle' is not set, it is supposed to be False - if 'split_ids' is present, other keys will be ignored Returns: Dataset: the dataset with splits. Raises: ValueError: If a split is named 'other' (not authorized). ValueError: If there are some ids out of bounds. ValueError: If some split names are in 'split_ratios' and 'split_sizes'. Example: .. code-block:: python # Given a dataset of 2 samples print(dataset) >>> Dataset(2 samples, 2 scalars, 2 fields) options = { 'shuffle':False, 'split_sizes': { 'train':1, 'val':1, }, } split = split_dataset(dataset, options) print(split) >>> {'train': [0], 'val': [1]} """ _splits = {} all_ids = dset.get_sample_ids() total_size = len(dset) # Verify that split option validity def check_options_validity(split_option: dict): assert isinstance(split_option, dict), "split option must be a dictionary" if "other" in split_option: raise ValueError("name 'other' is not authorized for a split") # Check that the keys in options are among authorized keys authorized_task = ["split_ids", "split_ratios", "split_sizes", "shuffle"] for task in options: if task in authorized_task: continue logger.warning(f"option {task} is not authorized. {task} key will be ignored") f_case = len(set(["split_ids"]).intersection(set(options.keys()))) s_case = len(set(["split_ratios", "split_sizes"]).intersection(set(options.keys()))) assert f_case == 0 or s_case == 0, ( "split by id cannot exist with split by ratios or sizes" ) # First case if "split_ids" in options: check_options_validity(options["split_ids"]) if len(options) > 1: logger.warning( "options has key 'split_ids' and 'shuffle' -> 'shuffle' key will be ignored" ) # all_ids = np.arange(total_size) used_ids = np.unique( np.concatenate([ids for ids in options["split_ids"].values()]) ) if np.min(used_ids) < 0 or np.max(used_ids) >= total_size: raise ValueError( "there are some ids out of bounds -> min/max:{}/{} | dataset len:{}".format( np.min(used_ids), np.max(used_ids), total_size ) ) other_ids = np.setdiff1d(all_ids, used_ids) if len(other_ids) > 0: options["split_ids"]["other"] = other_ids if len(used_ids) < np.sum([len(ids) for ids in options["split_ids"].values()]): logger.warning("there are some ids present in several splits") for name in options["split_ids"]: _splits[name] = options["split_ids"][name] # split_samples = [] # for id in options['split_ids'][name]: # split_samples.append(dset[id]) # dset._splits[name] = Dataset() # dset._splits[name].add_samples(split_samples) return _splits if "shuffle" in options: shuffle = options["shuffle"] else: shuffle = False split_sizes = [0] split_names = [] # Second case if "split_ratios" in options: check_options_validity(options["split_ratios"]) for key, value in options["split_ratios"].items(): assert isinstance(value, float) split_names.append(key) split_sizes.append(int(total_size * value)) if "split_sizes" in options: check_options_validity(options["split_sizes"]) for key, value in options["split_sizes"].items(): assert "split_ratios" not in options or key not in options["split_ratios"] assert isinstance(value, int) split_names.append(key) split_sizes.append(value) assert np.sum(split_sizes) <= total_size if np.sum(split_sizes) < total_size: split_names.append("other") split_sizes.append(total_size - np.sum(split_sizes)) slices = np.cumsum(split_sizes) # all_ids = np.arange(total_size) if shuffle: all_ids = np.random.permutation(all_ids) for i_split in range(len(split_names)): _splits[split_names[i_split]] = all_ids[slices[i_split] : slices[i_split + 1]] # split_samples = [] # for id in all_ids[slices[i_split]:slices[i_split+1]]: # split_samples.append(dset[id]) # dset._splits[split_names[i_split]] = Dataset() # dset._splits[split_names[i_split]].add_samples(split_samples) return _splits
[docs] def mmd_subsample_fn( X: NDArray[np.float64], size: int, initial_ids: Optional[list[int]] = None, memory_safe: bool = False, ) -> NDArray[np.int64]: """Selects samples in the input table by greedily minimizing the maximum mena discrepancy (MMD). Args: X(np.ndarray): input table of shape n_samples x n_features size(int): number of samples to select initial_ids(list[int]): a list of ids of points to initialize the gready algorithm. Defaults to None. memory_safe(bool): if True, avoids a memory expensive computation. Useful for large tables. Defaults to False. Returns: np.ndarray: array of selected samples Example: .. code-block:: python # Let X be drawn from a standard 10-dimensional Gaussian distribution np.random.seed(0) X = np.random.randn(1000,10) # Select 100 particles idx = mmd_subsample_fn(X, size=100) print(idx) >>> [765 113 171 727 796 855 715 207 458 603 23 384 860 3 459 708 794 138 221 639 8 816 619 806 398 236 36 404 167 87 201 676 961 624 556 840 485 975 283 150 554 409 69 769 332 357 388 216 900 134 15 730 80 694 251 714 11 817 525 382 328 67 356 514 597 668 959 260 968 26 209 789 305 122 989 571 801 322 14 160 908 12 1 980 582 440 42 452 666 526 290 231 712 21 606 575 656 950 879 948] # In this simple Gaussian example, the means and standard deviations of the # selected subsample should be close to the ones of the original sample print(np.abs(np.mean(x[idx], axis=0) - np.mean(x, axis=0))) >>> [0.00280955 0.00220179 0.01359079 0.00461107 0.0011997 0.01106616 0.01157571 0.0061314 0.00813494 0.0026543] print(np.abs(np.std(x[idx], axis=0) - np.std(x, axis=0))) >>> [0.0067711 0.00316008 0.00860733 0.07130127 0.02858514 0.0173707 0.00739646 0.03526784 0.0054039 0.00351996] """ n = X.shape[0] assert size <= n # Precompute norms and distance matrix norms = np.linalg.norm(X, axis=1) if memory_safe: k0_mean = np.zeros(n) for i in range(n): kxy = norms[i : i + 1, None] + norms[None, :] - cdist(X[i : i + 1], X) k0_mean[i] = np.mean(kxy) else: dist_matrix = cdist(X, X) gram_matrix = norms[:, None] + norms[None, :] - dist_matrix k0_mean = np.mean(gram_matrix, axis=1) idx = np.zeros(size, dtype=np.int64) if initial_ids is None or len(initial_ids) == 0: k0 = np.zeros((n, size)) k0[:, 0] = 2.0 * norms idx[0] = np.argmin(k0[:, 0] - 2.0 * k0_mean) for i in range(1, size): x_ = X[idx[i - 1]] dist = np.linalg.norm(X - x_, axis=1) k0[:, i] = -dist + norms[idx[i - 1]] + norms idx[i] = np.argmin( k0[:, 0] + 2.0 * np.sum(k0[:, 1 : (i + 1)], axis=1) - 2.0 * (i + 1) * k0_mean ) else: assert len(initial_ids) < size idx[: len(initial_ids)] = initial_ids k0 = np.zeros((n, size)) k0[:, 0] = 2.0 * norms for i in range(1, size): x_ = X[idx[i - 1]] dist = np.linalg.norm(X - x_, axis=1) k0[:, i] = -dist + norms[idx[i - 1]] + norms if i >= len(initial_ids): idx[i] = np.argmin( k0[:, 0] + 2.0 * np.sum(k0[:, 1 : (i + 1)], axis=1) - 2.0 * (i + 1) * k0_mean ) return idx