# -*- coding: utf-8 -*-## This file is subject to the terms and conditions defined in# file 'LICENSE.txt', which is part of this source code package.##"""Reader for hf dataset storage.- If the environment variable `HF_ENDPOINT` is set, uses a private Hugging Face mirror. - Streaming is disabled. - The dataset is downloaded locally via `snapshot_download` and loaded from disk.- If `HF_ENDPOINT` is not set, attempts to load from the public Hugging Face hub. - If the dataset is already cached locally, loads from disk. - Otherwise, loads from the hub, optionally using streaming mode."""importloggingimportosimportshutilfrompathlibimportPathfromtypingimportOptional,Unionimportdatasetsfromdatasetsimportload_dataset,load_from_diskfromhuggingface_hubimportsnapshot_downloadfromplaid.storage.common.readerimportload_infos_from_disklogger=logging.getLogger(__name__)# ------------------------------------------------------# Load from disk# ------------------------------------------------------
[docs]definit_datasetdict_from_disk(path:Union[str,Path])->datasets.DatasetDict:"""Initializes a DatasetDict from local disk files. Args: path (Union[str, Path]): Path to the directory containing the dataset files. Returns: datasets.DatasetDict: The loaded dataset dictionary. """file_=Path(path)/"data"/"dataset_dict.json"iffile_.is_file():# This is a dataset generated and save locallyreturnload_from_disk(dataset_path=str(Path(path)/"data"))else:# pragma: no cover# This is a dataset downloaded from the hubinfos=load_infos_from_disk(path)split_names=list(infos["num_samples"].keys())base=Path(path)/"data"data_files={sn:str(base/f"{sn}*.parquet")forsninsplit_names}returnload_dataset("parquet",data_files=data_files)
# ------------------------------------------------------# Load from from hub# ------------------------------------------------------
[docs]defdownload_datasetdict_from_hub(repo_id:str,local_dir:Union[str,Path],split_ids:Optional[dict[str,int]]=None,# noqa: ARG001features:Optional[list[str]]=None,# noqa: ARG001overwrite:bool=False,)->str:# pragma: no cover (not tested in unit tests)"""Downloads a dataset from Hugging Face Hub to local directory. Args: repo_id (str): The repository ID on Hugging Face Hub. local_dir (Union[str, Path]): Local directory to download to. split_ids (Optional[dict[str, int]]): Unused parameter for split selection. features (Optional[list[str]]): Unused parameter for feature selection. overwrite (bool): Whether to overwrite existing directory. Returns: str: Path to the downloaded dataset. """output_folder=Path(local_dir)ifoutput_folder.is_dir():ifoverwrite:shutil.rmtree(output_folder)logger.warning(f"Existing {output_folder} directory has been reset.")elifany(output_folder.iterdir()):raiseValueError(f"directory {output_folder} already exists and is not empty. Set `overwrite` to True if needed.")returnsnapshot_download(repo_id=repo_id,repo_type="dataset",allow_patterns=["data/*"],local_dir=local_dir,)
[docs]definit_datasetdict_streaming_from_hub(repo_id:str,split_ids:Optional[dict[str,int]]=None,# noqa: ARG001features:Optional[list[str]]=None,)->datasets.IterableDatasetDict:# pragma: no cover"""Initializes a streaming DatasetDict from Hugging Face Hub. Args: repo_id (str): The repository ID on Hugging Face Hub. split_ids (Optional[dict[str, int]]): Unused parameter for split selection. features (Optional[list[str]]): Optional list of features to load. Returns: datasets.IterableDatasetDict: The streaming dataset dictionary. """hf_endpoint=os.getenv("HF_ENDPOINT","").strip()ifhf_endpoint:raiseRuntimeError("Streaming mode not compatible with private mirror.")returnload_dataset(repo_id,streaming=True,columns=features)