Dataset Examples

This Jupyter Notebook demonstrates various use cases for the Dataset class, including:

  1. Initializing an Empty Dataset and Adding Samples

  2. Retrieving and Manipulating Samples from a Dataset

  3. Performing Operations on the Dataset

  4. Saving and Loading Datasets from directories or files

This notebook provides detailed examples of using the Dataset class to manage data, Samples, and information within a PLAID Dataset. It is intended for documentation purposes and familiarization with the PLAID library.

Each section is documented and explained.

# Import required libraries
from pathlib import Path
import platform

import numpy as np
import copy
# Import necessary libraries and functions
import Muscat.MeshContainers.ElementsDescription as ElementsDescription
from Muscat.Bridges.CGNSBridge import MeshToCGNS
from Muscat.MeshTools import MeshCreationTools as MCT

import plaid
from plaid import Dataset
from plaid import Sample
Kokkos::OpenMP::initialize WARNING: OMP_PROC_BIND environment variable not set
  In general, for best performance with OpenMP 4.0 or better set OMP_PROC_BIND=spread and OMP_PLACES=threads
  For best performance with OpenMP 3.1 set OMP_PROC_BIND=true
  For unit testing set OMP_PROC_BIND=false
# Print dict util
def dprint(name: str, dictio: dict, end: str = "\n"):
    print(name, "{")
    for key, value in dictio.items():
        print("    ", key, ":", value)

    print("}", end=end)

Section 1: Initializing an Empty Dataset and Samples construction

This section demonstrates how to initialize an empty Dataset and handle Samples.

Initialize an empty Dataset

print("#---# Empty Dataset")
dataset = Dataset()
print(f"{dataset=}")
#---# Empty Dataset
dataset=Dataset(0 samples, 0 scalars, 0 fields)

Create Sample

# Create Sample
points = np.array(
    [
        [0.0, 0.0],
        [1.0, 0.0],
        [1.0, 1.0],
        [0.0, 1.0],
        [0.5, 1.5],
    ]
)

triangles = np.array(
    [
        [0, 1, 2],
        [0, 2, 3],
        [2, 4, 3],
    ]
)

bars = np.array([[0, 1], [0, 2]])

Mesh = MCT.CreateMeshOfTriangles(points, triangles)
elbars = Mesh.GetElementsOfType(ElementsDescription.Bar_2)
elbars.AddNewElements(bars, [1, 2])
cgns_mesh = MeshToCGNS(Mesh)

# Initialize an empty Sample
print("#---# Empty Sample")
sample_01 = Sample()
print(f"{sample_01 = }")
#---# Empty Sample
sample_01 = Sample(0 globals, 0 timestamps, 0 fields)
# Add a CGNS tree structure to the Sample
sample_01.features.add_tree(copy.deepcopy(cgns_mesh))
print(f"{sample_01 = }")
sample_01 = Sample(0 globals, 1 timestamp, 1 field)
# Add a scalar to the Sample
sample_01.add_scalar("rotation", np.random.randn())
print(f"{sample_01 = }")
sample_01 = Sample(1 global, 1 timestamp, 1 field)

Display Sample CGNS tree

# Initialize a third empty Sample
print("#---# Empty Sample")
sample_03 = Sample()
sample_03.add_scalar("speed", np.random.randn())
sample_03.add_scalar("rotation", sample_01.get_scalar("rotation"))
sample_03.features.add_tree(copy.deepcopy(cgns_mesh))

# Show Sample CGNS content
sample_03.show_tree()
#---# Empty Sample
 CGNSLibraryVersion : (1,) [4.] float32 CGNSLibraryVersion_t
 Global : (2,) [1 1] int32 CGNSBase_t
|_  Time : (1,) [1] int32 BaseIterativeData_t
   |_  IterationValues : (1,) [1] int32 DataArray_t
   |_  TimeValues : (1,) [0.] float64 DataArray_t
|_  speed : (1,) [0.0943413] float64 DataArray_t
|_  rotation : (1,) [-1.57178081] float64 DataArray_t
 Base_2_2 : (2,) [2 2] int32 CGNSBase_t
|_  2D : None Family_t
|_  Zone : (1, 3) [[5 3 0]] int64 Zone_t
   |_  ZoneType : (12,) Unstructured |S1 ZoneType_t
   |_  GridCoordinates : None GridCoordinates_t
      |_  CoordinateX : (5,) [0.  1.  1.  0.  0.5] float64 DataArray_t
      |_  CoordinateY : (5,) [0.  0.  1.  1.  1.5] float64 DataArray_t
   |_  Elements_BAR_2 : (2,) [3 0] int32 Elements_t
      |_  ElementRange : (2,) [1 2] int64 IndexRange_t
      |_  ElementConnectivity : (4,) [1 2 1 3] int64 DataArray_t
   |_  Elements_TRI_3 : (2,) [5 0] int32 Elements_t
      |_  ElementRange : (2,) [3 5] int64 IndexRange_t
      |_  ElementConnectivity : (9,) [1 ... 4] int64 DataArray_t
   |_  VertexFields : None FlowSolution_t
      |_  GridLocation : (6,) Vertex |S1 GridLocation_t
      |_  OriginalIds : (5,) [1 2 3 4 5] int64 DataArray_t
   |_  CellCenterFields : None FlowSolution_t
      |_  GridLocation : (10,) CellCenter |S1 GridLocation_t
      |_  OriginalIds : (3,) [1 2 3] int64 DataArray_t
   |_  FaceCenterFields : None FlowSolution_t
      |_  GridLocation : (10,) FaceCenter |S1 GridLocation_t
      |_  OriginalIds : (2,) [2 3] int64 DataArray_t
   |_  FamilyName : (2,) 2D |S1 FamilyName_t
|_  Time : (1,) [1] int32 BaseIterativeData_t
   |_  IterationValues : (1,) [1] int32 DataArray_t
   |_  TimeValues : (1,) [0.] float64 DataArray_t
# Add a field to the third empty Sample
sample_03.add_field("temperature", np.random.rand(5), zone_name="Zone", base_name="Base_2_2")
sample_03.show_tree()
 CGNSLibraryVersion : (1,) [4.] float32 CGNSLibraryVersion_t
 Global : (2,) [1 1] int32 CGNSBase_t
|_  Time : (1,) [1] int32 BaseIterativeData_t
   |_  IterationValues : (1,) [1] int32 DataArray_t
   |_  TimeValues : (1,) [0.] float64 DataArray_t
|_  speed : (1,) [0.0943413] float64 DataArray_t
|_  rotation : (1,) [-1.57178081] float64 DataArray_t
 Base_2_2 : (2,) [2 2] int32 CGNSBase_t
|_  2D : None Family_t
|_  Zone : (1, 3) [[5 3 0]] int64 Zone_t
   |_  ZoneType : (12,) Unstructured |S1 ZoneType_t
   |_  GridCoordinates : None GridCoordinates_t
      |_  CoordinateX : (5,) [0.  1.  1.  0.  0.5] float64 DataArray_t
      |_  CoordinateY : (5,) [0.  0.  1.  1.  1.5] float64 DataArray_t
   |_  Elements_BAR_2 : (2,) [3 0] int32 Elements_t
      |_  ElementRange : (2,) [1 2] int64 IndexRange_t
      |_  ElementConnectivity : (4,) [1 2 1 3] int64 DataArray_t
   |_  Elements_TRI_3 : (2,) [5 0] int32 Elements_t
      |_  ElementRange : (2,) [3 5] int64 IndexRange_t
      |_  ElementConnectivity : (9,) [1 ... 4] int64 DataArray_t
   |_  VertexFields : None FlowSolution_t
      |_  GridLocation : (6,) Vertex |S1 GridLocation_t
      |_  OriginalIds : (5,) [1 2 3 4 5] int64 DataArray_t
      |_  temperature : (5,) [0.27334053 0.31500634 0.5064709  0.76036868 0.69596234] float64 DataArray_t
   |_  CellCenterFields : None FlowSolution_t
      |_  GridLocation : (10,) CellCenter |S1 GridLocation_t
      |_  OriginalIds : (3,) [1 2 3] int64 DataArray_t
   |_  FaceCenterFields : None FlowSolution_t
      |_  GridLocation : (10,) FaceCenter |S1 GridLocation_t
      |_  OriginalIds : (2,) [2 3] int64 DataArray_t
   |_  FamilyName : (2,) 2D |S1 FamilyName_t
|_  Time : (1,) [1] int32 BaseIterativeData_t
   |_  IterationValues : (1,) [1] int32 DataArray_t
   |_  TimeValues : (1,) [0.] float64 DataArray_t

Get Sample data

# Print sample general data
print(f"{sample_03 = }", end="\n\n")

# Print sample scalar data
print(f"{sample_03.get_scalar_names() = }")
print(f"{sample_03.get_scalar('speed') = }")
print(f"{sample_03.get_scalar('rotation') = }", end="\n\n")

# Print sample scalar data
print(f"{sample_03.get_field_names() = }")
print(f"{sample_03.get_field('temperature') = }")
sample_03 = Sample(2 globals, 1 timestamp, 2 fields)

sample_03.get_scalar_names() = ['speed', 'rotation']
sample_03.get_scalar('speed') = np.float64(0.09434129711096854)
sample_03.get_scalar('rotation') = np.float64(-1.5717808098126538)

sample_03.get_field_names() = ['OriginalIds', 'temperature']
sample_03.get_field('temperature') = array([0.27334053, 0.31500634, 0.5064709 , 0.76036868, 0.69596234])

Section 2: Performing Operations on the Dataset

This section demonstrates how to add Samples to the Dataset, add information, and access data.

Add Samples in the Dataset

# Add Samples by id in the Dataset
dataset.set_sample(id=0, sample=sample_01)
dataset.set_sample(1, sample_02)

# Add unique Sample and automatically create its id
added_sample_id = dataset.add_sample(sample_03)
print(f"{added_sample_id = }")
added_sample_id = 2

Iterate through Samples in the Dataset

for sample in dataset:
    print(sample)
Sample(1 global, 1 timestamp, 1 field)
Sample(1 global, 1 timestamp, 0 fields)
Sample(2 globals, 1 timestamp, 2 fields)

Add and display information to the Dataset

# Add node information to the Dataset
dataset.add_info("legal", "owner", "Safran")

# Retrive dataset information
import json

dataset_info = dataset.get_infos()
print("dataset info =", json.dumps(dataset_info, sort_keys=False, indent=4), end="\n\n")

# Overwrite information (logger will display warnings)
infos = {"legal": {"owner": "Safran", "license": "CC0"}}
dataset.set_infos(infos)

# Retrive dataset information
dataset_info = dataset.get_infos()
print("dataset info =", json.dumps(dataset_info, sort_keys=False, indent=4), end="\n\n")

# Add tree information to the Dataset (logger will display warnings)
dataset.add_infos("data_description", {"number_of_samples": 0, "number_of_splits": 0})

# Pretty print dataset information
dataset.print_infos()
[2026-03-25 17:13:57,088:WARNING:dataset.py:set_infos(974)]:infos not empty, replacing it anyway
dataset info = {
    "plaid": {
        "version": "0.1.dev50+gdcf1785f7.d20260325"
    },
    "legal": {
        "owner": "Safran"
    }
}

dataset info = {
    "legal": {
        "owner": "Safran",
        "license": "CC0"
    },
    "plaid": {
        "version": "0.1.dev50+gdcf1785f7.d20260325"
    }
}

*********************** dataset infos **********************
legal
  legal:Safran
  legal:CC0
plaid
  plaid:0.1.dev50+gdcf1785f7.d20260325
data_description
  data_description:0
  data_description:0
************************************************************

Get a list of specific Samples in a Dataset

get_samples_from_ids = dataset.get_samples(ids=[0, 1])
dprint("get samples from ids =", get_samples_from_ids)
get samples from ids = {
     0 : Sample(1 global, 1 timestamp, 1 field)
     1 : Sample(1 global, 1 timestamp, 0 fields)
}

Get the list of Sample ids in a Dataset

# Print sample IDs
print("get_sample_ids =", dataset.get_sample_ids())
get_sample_ids = [0, 1, 2]

Add a list of Sample to a Dataset

# Create a new Dataset and add multiple samples
dataset = Dataset()
samples = [sample_01, sample_02, sample_03]
added_ids = dataset.add_samples(samples)
print(f"{added_ids = }")
print(f"{dataset = }")
added_ids = [0, 1, 2]
dataset = Dataset(3 samples, 2 scalars, 1 field)

Access to Samples data through Dataset

# Access Sample data with indexes through the Dataset
print(f"{dataset(0) = }")  # call strategy
print(f"{dataset[1] = }")  # getitem strategy
print(f"{dataset[2] = }", end="\n\n")

print("scalar of the first sample = ", dataset[0].get_scalar_names())
print("scalar of the second sample = ", dataset[1].get_scalar_names())
print("scalar of the third sample = ", dataset[2].get_scalar_names())
dataset(0) = Sample(1 global, 1 timestamp, 1 field)
dataset[1] = Sample(1 global, 1 timestamp, 0 fields)
dataset[2] = Sample(2 globals, 1 timestamp, 2 fields)

scalar of the first sample =  ['rotation']
scalar of the second sample =  ['rotation']
scalar of the third sample =  ['speed', 'rotation']
# Access dataset information
print(f"{dataset[0].get_scalar('rotation') = }")
print(f"{dataset[1].get_scalar('rotation') = }")
print(f"{dataset[2].get_scalar('rotation') = }")
dataset[0].get_scalar('rotation') = np.float64(-1.5717808098126538)
dataset[1].get_scalar('rotation') = np.float64(0.41231653932210466)
dataset[2].get_scalar('rotation') = np.float64(-1.5717808098126538)

Get Dataset scalars to tabular

# Print scalars in tabular format
print(f"{dataset.get_scalar_names() = }", end="\n\n")

dprint("get rotation scalar = ", dataset.get_scalars_to_tabular(["rotation"]))
dprint("get speed scalar = ", dataset.get_scalars_to_tabular(["speed"]), end="\n\n")

# Get specific scalars in tabular format
dprint("get specific scalars =", dataset.get_scalars_to_tabular(["speed", "rotation"]))
dprint("get all scalars =", dataset.get_scalars_to_tabular())
dataset.get_scalar_names() = ['rotation', 'speed']

get rotation scalar =  {
     rotation : [-1.57178081  0.41231654 -1.57178081]
}
get speed scalar =  {
     speed : [      nan       nan 0.0943413]
}

get specific scalars = {
     speed : [      nan       nan 0.0943413]
     rotation : [-1.57178081  0.41231654 -1.57178081]
}
get all scalars = {
     rotation : [-1.57178081  0.41231654 -1.57178081]
     speed : [      nan       nan 0.0943413]
}
# Get specific scalars np.array
print("get all scalar arrays = ", dataset.get_scalars_to_tabular(as_nparray=True))
get all scalar arrays =  [[-1.57178081         nan]
 [ 0.41231654         nan]
 [-1.57178081  0.0943413 ]]

Get Dataset fields

# Print fields in the Dataset
print("fields in the dataset = ", dataset.get_field_names())
fields in the dataset =  ['OriginalIds']

Section 3: Various operations on the Dataset

This section demonstrates operations like merging datasets, adding tabular scalars, and setting information.

Initialize a Dataset with a list of Samples

# Create another Dataset
other_dataset = Dataset()
nb_samples = 3
samples = []
for _ in range(nb_samples):
    sample = Sample()
    sample.add_scalar("rotation", np.random.rand() + 1.0)
    sample.add_scalar("random_name", np.random.rand() - 1.0)
    samples.append(sample)

# Add a list of Samples
other_dataset.add_samples(samples)
print(f"{other_dataset = }")
other_dataset = Dataset(3 samples, 2 scalars, 0 fields)

Merge two Datasets

# Merge the other dataset with the main dataset
print(f"before merge: {dataset = }")
dataset.merge_dataset(other_dataset)
print(f"after merge: {dataset = }", end="\n\n")

dprint("dataset scalars = ", dataset.get_scalars_to_tabular())
before merge: dataset = Dataset(3 samples, 2 scalars, 1 field)
after merge: dataset = Dataset(6 samples, 3 scalars, 1 field)

dataset scalars =  {
     random_name : [        nan         nan         nan -0.14262502 -0.77784574 -0.8639243 ]
     rotation : [-1.57178081  0.41231654 -1.57178081  1.25120913  1.35902375  1.68057017]
     speed : [      nan       nan 0.0943413       nan       nan       nan]
}

Add tabular scalars to a Dataset

# Adding tabular scalars to the dataset
new_scalars = np.random.rand(3, 2)
dataset.add_tabular_scalars(new_scalars, names=["Tu", "random_name"])

print(f"{dataset = }")
dprint("dataset scalars =", dataset.get_scalars_to_tabular())
dataset = Dataset(9 samples, 4 scalars, 1 field)
dataset scalars = {
     Tu : [       nan        nan        nan        nan        nan        nan
 0.02082681 0.62775241 0.09914959]
     random_name : [        nan         nan         nan -0.14262502 -0.77784574 -0.8639243
  0.91749203  0.71411328  0.23167516]
     rotation : [-1.57178081  0.41231654 -1.57178081  1.25120913  1.35902375  1.68057017
         nan         nan         nan]
     speed : [      nan       nan 0.0943413       nan       nan       nan       nan
       nan       nan]
}

Set additional information to a dataset

infos = {
    "legal": {"owner": "Safran", "license": "CC0"},
    "data_production": {"type": "simulation", "simulator": "dummy"},
}
dataset.set_infos(infos)
dataset.print_infos()
*********************** dataset infos **********************
legal
  legal:Safran
  legal:CC0
data_production
  data_production:simulation
  data_production:dummy
plaid
  plaid:0.1.dev50+gdcf1785f7.d20260325
************************************************************

Section 4: Saving and Loading Dataset

This section demonstrates how to save and load a Dataset from a directory or file.

Save a Dataset as a file tree

tmpdir = f"/tmp/test_safe_to_delete_{np.random.randint(low=1, high=2_000_000_000)}"
print(f"Save dataset in: {tmpdir}")

dataset.save_to_dir(tmpdir)
Save dataset in: /tmp/test_safe_to_delete_1856155734

Get the number of Samples that can be loaded from a directory

nb_samples = plaid.get_number_of_samples(tmpdir)
print(f"{nb_samples = }")
nb_samples = 9

Load a Dataset from a directory via initialization

loaded_dataset_from_init = Dataset(tmpdir)
print(f"{loaded_dataset_from_init = }")

if platform.system() == "Linux":
    multi_process_loaded_dataset = Dataset(tmpdir, processes_number=3)
    print(f"{multi_process_loaded_dataset = }")
loaded_dataset_from_init = Dataset(9 samples, 4 scalars, 1 field)
multi_process_loaded_dataset = Dataset(9 samples, 4 scalars, 1 field)

Load a Dataset from a directory via the Dataset class

loaded_dataset_from_class = Dataset.load_from_dir(tmpdir)
print(f"{loaded_dataset_from_class = }")

if platform.system() == "Linux":
    multi_process_loaded_dataset = Dataset.load_from_dir(tmpdir, processes_number=3)
    print(f"{multi_process_loaded_dataset = }")
loaded_dataset_from_class = Dataset(9 samples, 4 scalars, 1 field)
multi_process_loaded_dataset = Dataset(9 samples, 4 scalars, 1 field)

Load the dataset from a directory via a Dataset instance

loaded_dataset_from_instance = Dataset()
loaded_dataset_from_instance.load(tmpdir)

print(f"{loaded_dataset_from_instance = }")

if platform.system() == "Linux":
    multi_process_loaded_dataset = Dataset()
    multi_process_loaded_dataset.load(tmpdir, processes_number=3)
    print(f"{multi_process_loaded_dataset = }")
loaded_dataset_from_instance = Dataset(9 samples, 4 scalars, 1 field)
multi_process_loaded_dataset = Dataset(9 samples, 4 scalars, 1 field)

Save the dataset to a TAR (Tape Archive) file

tmpdir = Path(f"/tmp/test_safe_to_delete_{np.random.randint(low=1, high=2_000_000_000)}")
tmpfile = tmpdir / "test_file.plaid"

print(f"Save dataset in: {tmpfile}")
dataset.save(tmpfile)
Save dataset in: /tmp/test_safe_to_delete_1525417865/test_file.plaid
/tmp/ipykernel_3232/4262429027.py:5: DeprecationWarning: `Dataset.save(...)` is deprecated, use instead `Dataset.save_to_file(...)` [since v0.1.10] (will be removed in v0.2.0)
  dataset.save(tmpfile)
[2026-03-25 17:13:57,557:WARNING:dataset.py:save_to_dir(1153)]:Version mismatch: Dataset was loaded from version 0.1.dev50+gdcf1785f7.d20260325, and will be saved with version: 0.1.dev50+gdcf1785f7.d20260325

Load the dataset from a TAR (Tape Archive) file via Dataset instance

new_dataset = Dataset()
new_dataset.load(tmpfile)

print(f"{dataset = }")
print(f"{new_dataset = }")
dataset = Dataset(9 samples, 4 scalars, 1 field)
new_dataset = Dataset(9 samples, 4 scalars, 1 field)

Load the dataset from a TAR (Tape Archive) file via initialization

new_dataset = Dataset(tmpfile)

print(f"{dataset = }")
print(f"{new_dataset = }")
dataset = Dataset(9 samples, 4 scalars, 1 field)
new_dataset = Dataset(9 samples, 4 scalars, 1 field)