Source code for plaid.storage.hf_datasets.reader
"""Reader for hf dataset storage.
- If the environment variable `HF_ENDPOINT` is set, uses a private Hugging Face mirror.
- Streaming is disabled.
- The dataset is downloaded locally via `snapshot_download` and loaded from disk.
- If `HF_ENDPOINT` is not set, attempts to load from the public Hugging Face hub.
- If the dataset is already cached locally, loads from disk.
- Otherwise, loads from the hub, optionally using streaming mode.
"""
# -*- coding: utf-8 -*-
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
#
#
import logging
import os
import shutil
import tempfile
from pathlib import Path
from typing import Optional, Union
import datasets
from datasets import load_dataset, load_from_disk
from huggingface_hub import snapshot_download
from plaid.storage.common.reader import load_infos_from_hub
logger = logging.getLogger(__name__)
# ------------------------------------------------------
# Load from disk
# ------------------------------------------------------
[docs]
HFDatasetDict = dict[str, datasets.DatasetDict]
[docs]
def init_datasetdict_from_disk(path: Union[str, Path]) -> HFDatasetDict:
"""Initializes a DatasetDict from local disk files.
Args:
path (Union[str, Path]): Path to the directory containing the dataset files.
Returns:
HFDatasetDict: The loaded dataset dictionary.
"""
return load_from_disk(dataset_path=str(Path(path) / "data"))
# ------------------------------------------------------
# Load from from hub
# ------------------------------------------------------
[docs]
def download_datasetdict_from_hub(
repo_id: str,
local_dir: Union[str, Path],
split_ids: Optional[dict[str, int]] = None, # noqa: ARG001
features: Optional[list[str]] = None, # noqa: ARG001
overwrite: bool = False,
) -> str: # pragma: no cover (not tested in unit tests)
"""Downloads a dataset from Hugging Face Hub to local directory.
Args:
repo_id (str): The repository ID on Hugging Face Hub.
local_dir (Union[str, Path]): Local directory to download to.
split_ids (Optional[dict[str, int]]): Unused parameter for split selection.
features (Optional[list[str]]): Unused parameter for feature selection.
overwrite (bool): Whether to overwrite existing directory.
Returns:
str: Path to the downloaded dataset.
"""
output_folder = Path(local_dir)
if output_folder.is_dir():
if overwrite:
shutil.rmtree(output_folder)
logger.warning(f"Existing {output_folder} directory has been reset.")
elif any(output_folder.iterdir()):
raise ValueError(
f"directory {output_folder} already exists and is not empty. Set `overwrite` to True if needed."
)
with tempfile.TemporaryDirectory() as tmp_dir:
snapshot_download(
repo_id=repo_id,
repo_type="dataset",
allow_patterns=["data/*"],
local_dir=tmp_dir,
)
infos = load_infos_from_hub(repo_id=repo_id)
split_names = list(infos["num_samples"].keys())
base = Path(tmp_dir) / "data"
data_files = {sn: str(base / f"{sn}*.parquet") for sn in split_names}
datasetdict = load_dataset("parquet", data_files=data_files, cache_dir=tmp_dir)
datasetdict.save_to_disk(str(Path(output_folder) / "data"))
return output_folder
[docs]
def init_datasetdict_streaming_from_hub(
repo_id: str,
split_ids: Optional[dict[str, int]] = None, # noqa: ARG001
features: Optional[list[str]] = None,
) -> datasets.IterableDatasetDict: # pragma: no cover
"""Initializes a streaming DatasetDict from Hugging Face Hub.
Args:
repo_id (str): The repository ID on Hugging Face Hub.
split_ids (Optional[dict[str, int]]): Unused parameter for split selection.
features (Optional[list[str]]): Optional list of features to load.
Returns:
datasets.IterableDatasetDict: The streaming dataset dictionary.
"""
hf_endpoint = os.getenv("HF_ENDPOINT", "").strip()
if hf_endpoint:
raise RuntimeError("Streaming mode not compatible with private mirror.")
return load_dataset(repo_id, streaming=True, columns=features)