plaid.utils.split

Utility function for splitting a Dataset.

Functions

split_dataset(→ dict[str, int])

Splits a Dataset in several sub Datasets.

mmd_subsample_fn(→ numpy.typing.NDArray[numpy.int64])

Selects samples in the input table by greedily minimizing the maximum mena discrepancy (MMD).

Module Contents

split_dataset(dset: plaid.Dataset, options: dict[str, Any]) dict[str, int][source]

Splits a Dataset in several sub Datasets.

Parameters:
  • dset (Dataset) – dataset to be splited.

  • options ([str,Any]) – may have keys ‘shuffle’, ‘split_sizes’, ‘split_ratios’ or ‘split_ids’: - ‘split_sizes’ is supposed to be a dict[str,int]: split name -> size of splited dataset - ‘split_ratios’ is supposed to be a dict[str,float]: split name -> size ratios of splited dataset - ‘split_ids’ is supposed to be a dict[str,np.ndarray(int)]: split name -> ids of samples in splited dataset - if ‘shuffle’ is not set, it is supposed to be False - if ‘split_ids’ is present, other keys will be ignored

Returns:

the dataset with splits.

Return type:

Dataset

Raises:
  • ValueError – If a split is named ‘other’ (not authorized).

  • ValueError – If there are some ids out of bounds.

  • ValueError – If some split names are in ‘split_ratios’ and ‘split_sizes’.

Example

# Given a dataset of 2 samples
print(dataset)
>>> Dataset(2 samples, 2 scalars, 2 fields)

options = {
    'shuffle':False,
    'split_sizes': {
        'train':1,
        'val':1,
        },
}
split = split_dataset(dataset, options)
print(split)
>>> {'train': [0], 'val': [1]}
mmd_subsample_fn(X: numpy.typing.NDArray[numpy.float64], size: int, initial_ids: list[int] | None = None, memory_safe: bool = False) numpy.typing.NDArray[numpy.int64][source]

Selects samples in the input table by greedily minimizing the maximum mena discrepancy (MMD).

Parameters:
  • X (np.ndarray) – input table of shape n_samples x n_features

  • size (int) – number of samples to select

  • initial_ids (list[int]) – a list of ids of points to initialize the gready algorithm. Defaults to None.

  • memory_safe (bool) – if True, avoids a memory expensive computation. Useful for large tables. Defaults to False.

Returns:

array of selected samples

Return type:

np.ndarray

Example

# Let X be drawn from a standard 10-dimensional Gaussian distribution
np.random.seed(0)
X = np.random.randn(1000,10)
# Select 100 particles
idx = mmd_subsample_fn(X, size=100)
print(idx)
>>> [765 113 171 727 796 855 715 207 458 603  23 384 860   3 459 708 794 138
     221 639   8 816 619 806 398 236  36 404 167  87 201 676 961 624 556 840
     485 975 283 150 554 409  69 769 332 357 388 216 900 134  15 730  80 694
     251 714  11 817 525 382 328  67 356 514 597 668 959 260 968  26 209 789
     305 122 989 571 801 322  14 160 908  12   1 980 582 440  42 452 666 526
     290 231 712  21 606 575 656 950 879 948]
# In this simple Gaussian example, the means and standard deviations of the
# selected subsample should be close to the ones of the original sample
print(np.abs(np.mean(x[idx], axis=0) - np.mean(x, axis=0)))
>>> [0.00280955 0.00220179 0.01359079 0.00461107 0.0011997  0.01106616
0.01157571 0.0061314  0.00813494 0.0026543]
print(np.abs(np.std(x[idx], axis=0) - np.std(x, axis=0)))
>>> [0.0067711  0.00316008 0.00860733 0.07130127 0.02858514 0.0173707
0.00739646 0.03526784 0.0054039  0.00351996]