Dataset Splitting Examples

This Jupyter Notebook demonstrates the usage of the split module using the PLAID library. It includes examples of:

  1. Initializing a Dataset

  2. Splitting a Dataset with ratios

  3. Splitting a Dataset with fixed sizes

  4. Splitting a Dataset with ratio and fixed Sizes

  5. Splitting a Dataset with custom split IDs

This example demonstrates the usage of dataset splitting functions to divide a dataset into training, validation, and test sets. It provides examples of splitting the dataset using different methods and configurations.

Each section is documented and explained.

# Import required libraries
import numpy as np
# Import necessary libraries and functions
from plaid.utils.init_with_tabular import initialize_dataset_with_tabular_data
from plaid.utils.split import split_dataset
# Print dict util
def dprint(name: str, dictio: dict):
    print(name, "{")
    for key, value in dictio.items():
        print("    ", key, ":", value)

    print("}")

Section 1: Initialize Dataset

In this section, we create a dataset with random tabular data for testing purposes. The dataset will be used for subsequent splitting.

# Create a dataset with random tabular data for testing purposes
nb_scalars = 7
nb_samples = 70
tabular_data = {f"scalar_{j}": np.random.randn(nb_samples) for j in range(nb_scalars)}
dataset = initialize_dataset_with_tabular_data(tabular_data)

print(f"{dataset = }")
dataset = Dataset(70 samples, 7 scalars, 0 fields)

Section 2: Splitting a Dataset with Ratios

In this section, we split the dataset into training, validation, and test sets using specified ratios. We also have the option to shuffle the dataset during the split process.

print("# First split")
options = {
    "shuffle": True,
    "split_ratios": {
        "train": 0.8,
        "val": 0.1,
    },
}

split = split_dataset(dataset, options)
dprint("split =", split)
# First split
split = {
     train : [ 0 17 14 57 43 16 49 27 53 51 25 62 60 48 21 41 20 59 22 63 52 26 19 47
 50 13 10 23  1 56 64 31  6 30 32  7 12 42 66 67 15 61 65 58  8  5 34 28
 45 46 55  2 24 35 29 44]
     val : [69 37 39 40 38 11  9]
     other : [36 33 68  4  3 54 18]
}

Section 3: Splitting a Dataset with Fixed Sizes

In this section, we split the dataset into training, validation, and test sets with fixed sample counts for each set. We can also choose to shuffle the dataset during the split.

print("# Second split")
options = {
    "shuffle": True,
    "split_sizes": {
        "train": 14,
        "val": 8,
        "test": 5,
    },
}

split = split_dataset(dataset, options)
dprint("split =", split)
# Second split
split = {
     train : [ 5 34 23 33 16 60 18 68  2 31 19 47 10 17]
     val : [ 0  8 32 54 55 46 49 37]
     test : [44  4  6 30 65]
     other : [21 61 11 25 58  3 12 67 20 45 57 22 15 66  1 53 69 56 27 26 42 59 40 62
 41  9 39 51  7 38 63 14 35 52 24 29 48 36 28 50 64 43 13]
}

Section 4: Splitting a Dataset with Ratios and Fixed Sizes

In this section, we split the dataset into training, validation, and test sets with fixed sample counts and sample ratios for each set. We can also choose to shuffle the dataset during the split.

print("# Third split")
options = {
    "shuffle": True,
    "split_ratios": {
        "train": 0.7,
        "test": 0.1,
    },
    "split_sizes": {"val": 7},
}

split = split_dataset(dataset, options)
dprint("split =", split)
# Third split
split = {
     train : [58  7 46 31 57 23 41 40 33 56 52 11 43 54 42 19 22  5 48  0 10 34  3  9
  1 28 55 21 30 47 50 27  2 63 13 35 18 25 38 24 67  4  6 14 26 65 17  8
 39]
     test : [32 29 36 51 45 59 61]
     val : [12 49 53 15 60 68 62]
     other : [44 16 66 69 20 64 37]
}

Section 5: Splitting a Dataset with Custom Split IDs

In this section, we split the dataset based on custom sample IDs for each set. We can specify the sample IDs for training, validation, and prediction sets.

print("# Fourth split")
options = {
    "split_ids": {
        "train": np.arange(20),
        "val": np.arange(30, 60),
        "predict": np.arange(25, 35),
    },
}

split = split_dataset(dataset, options)
dprint("split =", split)
[2026-03-25 17:16:16,649:WARNING:split.py:split_dataset(114)]:there are some ids present in several splits
# Fourth split
split = {
     train : [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
     val : [30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
 54 55 56 57 58 59]
     predict : [25 26 27 28 29 30 31 32 33 34]
     other : [20 21 22 23 24 60 61 62 63 64 65 66 67 68 69]
}