plaid.utils.split ================= .. py:module:: plaid.utils.split .. autoapi-nested-parse:: Utility function for splitting a Dataset. Functions --------- .. autoapisummary:: plaid.utils.split.split_dataset plaid.utils.split.mmd_subsample_fn Module Contents --------------- .. py:function:: split_dataset(dset: plaid.Dataset, options: dict[str, Any]) -> dict[str, int] Splits a Dataset in several sub Datasets. :param dset: dataset to be splited. :type dset: Dataset :param options: may have keys 'shuffle', 'split_sizes', 'split_ratios' or 'split_ids': - 'split_sizes' is supposed to be a dict[str,int]: split name -> size of splited dataset - 'split_ratios' is supposed to be a dict[str,float]: split name -> size ratios of splited dataset - 'split_ids' is supposed to be a dict[str,np.ndarray(int)]: split name -> ids of samples in splited dataset - if 'shuffle' is not set, it is supposed to be False - if 'split_ids' is present, other keys will be ignored :type options: [str,Any] :returns: the dataset with splits. :rtype: Dataset :raises ValueError: If a split is named 'other' (not authorized). :raises ValueError: If there are some ids out of bounds. :raises ValueError: If some split names are in 'split_ratios' and 'split_sizes'. .. rubric:: Example .. code-block:: python # Given a dataset of 2 samples print(dataset) >>> Dataset(2 samples, 2 scalars, 2 fields) options = { 'shuffle':False, 'split_sizes': { 'train':1, 'val':1, }, } split = split_dataset(dataset, options) print(split) >>> {'train': [0], 'val': [1]} .. py:function:: mmd_subsample_fn(X: numpy.typing.NDArray[numpy.float64], size: int, initial_ids: Optional[list[int]] = None, memory_safe: bool = False) -> numpy.typing.NDArray[numpy.int64] Selects samples in the input table by greedily minimizing the maximum mena discrepancy (MMD). :param X: input table of shape n_samples x n_features :type X: np.ndarray :param size: number of samples to select :type size: int :param initial_ids: a list of ids of points to initialize the gready algorithm. Defaults to None. :type initial_ids: list[int] :param memory_safe: if True, avoids a memory expensive computation. Useful for large tables. Defaults to False. :type memory_safe: bool :returns: array of selected samples :rtype: np.ndarray .. rubric:: Example .. code-block:: python # Let X be drawn from a standard 10-dimensional Gaussian distribution np.random.seed(0) X = np.random.randn(1000,10) # Select 100 particles idx = mmd_subsample_fn(X, size=100) print(idx) >>> [765 113 171 727 796 855 715 207 458 603 23 384 860 3 459 708 794 138 221 639 8 816 619 806 398 236 36 404 167 87 201 676 961 624 556 840 485 975 283 150 554 409 69 769 332 357 388 216 900 134 15 730 80 694 251 714 11 817 525 382 328 67 356 514 597 668 959 260 968 26 209 789 305 122 989 571 801 322 14 160 908 12 1 980 582 440 42 452 666 526 290 231 712 21 606 575 656 950 879 948] # In this simple Gaussian example, the means and standard deviations of the # selected subsample should be close to the ones of the original sample print(np.abs(np.mean(x[idx], axis=0) - np.mean(x, axis=0))) >>> [0.00280955 0.00220179 0.01359079 0.00461107 0.0011997 0.01106616 0.01157571 0.0061314 0.00813494 0.0026543] print(np.abs(np.std(x[idx], axis=0) - np.std(x, axis=0))) >>> [0.0067711 0.00316008 0.00860733 0.07130127 0.02858514 0.0173707 0.00739646 0.03526784 0.0054039 0.00351996]