Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/package_reference/loading_methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Datasets

.. autofunction:: nlp.load_dataset

.. autofunction:: nlp.load_from_disk

Metrics
~~~~~~~~~~~~~~~~~~~~~

Expand Down
4 changes: 3 additions & 1 deletion docs/source/package_reference/main_classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The base class :class:`nlp.Dataset` implements a Dataset backed by an Apache Arr
__len__, __iter__, formatted_as, set_format, reset_format,
__getitem__, cleanup_cache_files,
map, filter, select, sort, shuffle, train_test_split, shard, export,
save_to_disk, load_from_disk,
add_faiss_index, add_faiss_index_from_external_arrays, save_faiss_index, load_faiss_index,
add_elasticsearch_index,
list_indexes, get_index, drop_index, search, search_batch, get_nearest_examples, get_nearest_examples_batch,
Expand All @@ -41,7 +42,8 @@ It also has dataset transform methods like map or filter, to process all the spl
unique, flatten_,
cleanup_cache_files,
map, filter, sort, shuffle, set_format, reset_format, formatted_as,
cast_, remove_columns_, rename_column_
cast_, remove_columns_, rename_column_,
save_to_disk, load_from_disk,


``Features``
Expand Down
19 changes: 19 additions & 0 deletions docs/source/processing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,25 @@ When you have several :obj:`nlp.Dataset` objects that share the same column type
>>> bert_dataset = concatenate_datasets([bookcorpus, wiki])


Saving a processed dataset on disk and reload it
------------------------------------------------

Once you have your final dataset you can save it on your disk and reuse it later using :obj:`nlp.load_from_disk`.
Saving a dataset creates a directory with various files:

- arrow files: they contain your dataset's data
- dataset_info.json: contains the description, citations, etc. of the dataset
- state.json: contains the list of the arrow files and other informations like the dataset format type, if any (torch or tensorflow for example)

.. code-block::

>>> encoded_dataset.save_to_disk("path/of/my/dataset/directory")
>>> ...
>>> from nlp import load_from_disk
>>> reloaded_encoded_dataset = load_from_disk("path/of/my/dataset/directory")

Both :obj:`nlp.Dataset` and :obj:`nlp.DatasetDict` objects can be saved on disk, by using respectively :func:`nlp.Dataset.save_to_disk` and :func:`nlp.DatasetDict.save_to_disk`.

Controling the cache behavior
-----------------------------------

Expand Down
2 changes: 1 addition & 1 deletion src/nlp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from .info import DatasetInfo, MetricInfo
from .inspect import inspect_dataset, inspect_metric, list_datasets, list_metrics
from .load import import_main_class, load_dataset, load_metric, prepare_module
from .load import import_main_class, load_dataset, load_from_disk, load_metric, prepare_module
from .metric import Metric
from .splits import NamedSplit, Split, SplitBase, SplitDict, SplitGenerator, SplitInfo, SubSplitInfo, percent
from .utils import *
Expand Down
82 changes: 82 additions & 0 deletions src/nlp/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
import contextlib
import json
import os
import pickle
import shutil
import tempfile
from collections import defaultdict
from collections.abc import Iterable, Mapping
from dataclasses import asdict
from functools import partial
from math import ceil, floor
from multiprocessing import Pool, RLock
Expand Down Expand Up @@ -337,6 +339,8 @@ def from_dict(

def __getstate__(self):
state = dict(self.__dict__)
state["_info"] = json.dumps(asdict(state["_info"]))
state["_split"] = str(state["_split"]) if state["_split"] is not None else None
if self._data_files:
state["_data"] = None
if self._indices_data_files:
Expand All @@ -347,8 +351,12 @@ def __setstate__(self, state):
assert (
state.get("_data") is not None or state.get("_data_files") is not None
), "tried to unpickle a dataset without arrow_table or data_files"
state = dict(state)
state["_info"] = DatasetInfo.from_dict(json.loads(state["_info"]))
state["_split"] = NamedSplit(state["_split"]) if state["_split"] is not None else None
self.__dict__ = state
reader = ArrowReader("", self.info)
# Read arrow tables
if self._data is None and self._data_files:
tables = []
for data_file, inplace_hist_per_file in zip(self._data_files, self._inplace_history):
Expand All @@ -366,6 +374,79 @@ def __setstate__(self, state):
if self._indices is None and self._indices_data_files:
self._indices = reader._read_files(self._indices_data_files)

def save_to_disk(self, dataset_path: str):
"""
Save the dataset in a dataset directory

Args:
dataset_path (``str``): path of the dataset directory where the dataset will be saved to
"""
assert (
not self.list_indexes()
), "please remove all the indexes using `dataset.drop_index` before saving a dataset"
self = pickle.loads(pickle.dumps(self))
os.makedirs(dataset_path, exist_ok=True)
# Write indices if needed
if self._indices is not None:
if not self._indices_data_files:
cache_file_name = os.path.join(dataset_path, "indices.arrow")
writer = ArrowWriter(path=cache_file_name)
writer.write_table(self._indices)
writer.finalize()
self._indices_data_files = [{"filename": cache_file_name}]
# Write dataset if needed
if not self._data_files or any(len(h["transforms"]) > 0 for h in self._inplace_history):
cache_file_name = os.path.join(dataset_path, "dataset.arrow")
writer = ArrowWriter(path=cache_file_name)
writer.write_table(self._data)
writer.finalize()
self._data_files = [{"filename": cache_file_name}]
self._inplace_history = [{"transforms": []}]
# Copy all files into the dataset directory
for data_file in self._data_files + self._indices_data_files:
# Copy file to destination directory
src = data_file["filename"]
filename = src.split("/")[-1]
dest = os.path.join(dataset_path, filename)
if src != dest:
shutil.copy(src, dest)
# Change path to relative path from inside the destination directory
data_file["filename"] = filename
# Get state
state = self.__getstate__()
dataset_info = json.loads(state.pop("_info"))
assert state.get("_data") is None, "arrow table needs to be memory mapped"
assert state.get("_indices") is None, "arrow table needs to be memory mapped"
assert all(
len(h["transforms"]) == 0 for h in state.get("_inplace_history", [])
), "in-place history needs to be empty"
# Serialize state
with open(os.path.join(dataset_path, "state.json"), "w") as state_file:
json.dump(state, state_file, indent=2, sort_keys=True)
with open(os.path.join(dataset_path, "dataset_info.json"), "w") as dataset_info_file:
json.dump(dataset_info, dataset_info_file, indent=2, sort_keys=True)
logger.info("Dataset saved in {}".format(dataset_path))

@staticmethod
def load_from_disk(dataset_path: str) -> "Dataset":
"""Load the dataset from a dataset directory

Args:
dataset_path (``str``): path of the dataset directory where the dataset will be loaded from
"""
with open(os.path.join(dataset_path, "state.json"), "r") as state_file:
state = json.load(state_file)
with open(os.path.join(dataset_path, "dataset_info.json"), "r") as dataset_info_file:
dataset_info = json.load(dataset_info_file)
state["_info"] = json.dumps(dataset_info)
dataset = Dataset.from_dict({})
state = {k: state[k] for k in dataset.__dict__.keys()} # in case we add new fields
# Change path to absolute path
for data_file in state.get("_data_files", []) + state.get("_indices_data_files", []):
data_file["filename"] = os.path.join(dataset_path, data_file["filename"])
dataset.__setstate__(state)
return dataset

@property
def data(self) -> pa.Table:
"""The Apache Arrow table backing the dataset."""
Expand Down Expand Up @@ -1473,6 +1554,7 @@ def flatten_indices(
"""

return self.map(
batched=True, # for speed
keep_in_memory=keep_in_memory,
cache_file_name=cache_file_name,
writer_batch_size=writer_batch_size,
Expand Down
5 changes: 4 additions & 1 deletion src/nlp/arrow_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def _read_files(self, files) -> pa.Table:
"""
assert len(files) > 0 and all(isinstance(f, dict) for f in files), "please provide valid file informations"
pa_tables = []
files = copy.deepcopy(files)
for f in files:
f.update(filename=os.path.join(self._path, f["filename"]))
for f_dict in files:
pa_table: pa.Table = self._get_dataset_from_filename(f_dict)
pa_tables.append(pa_table)
Expand Down Expand Up @@ -218,10 +221,10 @@ def read_files(
kwargs to build a Dataset instance.
"""
# Prepend path to filename
pa_table = self._read_files(files)
files = copy.deepcopy(files)
for f in files:
f.update(filename=os.path.join(self._path, f["filename"]))
pa_table = self._read_files(files)
dataset_kwargs = dict(arrow_table=pa_table, data_files=files, info=self._info, split=original_instructions)
return dataset_kwargs

Expand Down
27 changes: 27 additions & 0 deletions src/nlp/dataset_dict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import contextlib
import json
import os
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -457,3 +459,28 @@ def shuffle(
for k, dataset in self.items()
}
)

def save_to_disk(self, dataset_dict_path: str):
"""
Save the dataset dict in a dataset dict directory.

Args:
dataset_dict_path (``str``): path of the dataset dict directory where the dataset dict will be saved to
"""
os.makedirs(dataset_dict_path, exist_ok=True)
json.dump({"splits": list(self)}, open(os.path.join(dataset_dict_path, "dataset_dict.json"), "w"))
for k, dataset in self.items():
dataset.save_to_disk(os.path.join(dataset_dict_path, k))

@staticmethod
def load_from_disk(dataset_dict_path: str) -> "DatasetDict":
"""
Load the dataset dict from a dataset dict directory

Args:
dataset_dict_path (``str``): path of the dataset dict directory where the dataset dict will be loaded from
"""
dataset_dict = DatasetDict()
for k in json.load(open(os.path.join(dataset_dict_path, "dataset_dict.json"), "r"))["splits"]:
dataset_dict[k] = Dataset.load_from_disk(os.path.join(dataset_dict_path, k))
return dataset_dict
24 changes: 24 additions & 0 deletions src/nlp/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,27 @@ def load_dataset(
builder_instance._save_infos()

return ds


def load_from_disk(dataset_path: str) -> Union[Dataset, DatasetDict]:
"""
Load a dataset that was previously saved using ``dataset.save_to_disk(dataset_path)``.

Args:
dataset_path (``str``): path of a Dataset directory or a DatasetDict directory

Returns:
``nlp.Dataset`` or ``nlp.DatasetDict``
if `dataset_path` is a path of a dataset directory: the dataset requested,
if `dataset_path` is a path of a dataset dict directory: a ``nlp.DatasetDict`` with each split.
"""
if not os.path.isdir(dataset_path):
raise FileNotFoundError("Directory {} not found".format(dataset_path))
if os.path.exists(os.path.join(dataset_path, "dataset_info.json")):
return Dataset.load_from_disk(dataset_path)
elif os.path.exists(os.path.join(dataset_path, "dataset_dict.json")):
return DatasetDict.load_from_disk(dataset_path)
else:
raise FileNotFoundError(
"Directory {} is neither a dataset directory nor a dataset dict directory.".format(dataset_path)
)
64 changes: 63 additions & 1 deletion tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pyarrow as pa

import nlp.arrow_dataset
from nlp import concatenate_datasets
from nlp import concatenate_datasets, load_from_disk
from nlp.arrow_dataset import Dataset
from nlp.features import ClassLabel, Features, Sequence, Value
from nlp.info import DatasetInfo
Expand Down Expand Up @@ -116,6 +116,68 @@ def test_dummy_dataset_pickle_memory_mapped(self):
self.assertEqual(dset[0]["filename"], "my_name-train_0")
self.assertEqual(dset["filename"][0], "my_name-train_0")

def test_dummy_dataset_serialize(self):
with tempfile.TemporaryDirectory() as tmp_dir:

dset = self._create_dummy_dataset().select(range(10))
dataset_path = os.path.join(tmp_dir, "my_dataset")
dset.save_to_disk(dataset_path)
dset = dset.load_from_disk(dataset_path)

self.assertEqual(len(dset), 10)
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
self.assertEqual(dset[0]["filename"], "my_name-train_0")
self.assertEqual(dset["filename"][0], "my_name-train_0")

def test_dummy_dataset_serialize_memory_mapped(self):
with tempfile.TemporaryDirectory() as tmp_dir:

dset = (
self._create_dummy_dataset().map(cache_file_name=os.path.join(tmp_dir, "test.arrow")).select(range(10))
)
dset._data = Unpicklable() # check that we don't pickle the entire table

dataset_path = os.path.join(tmp_dir, "my_dataset")
dset.save_to_disk(dataset_path)
dset = dset.load_from_disk(dataset_path)

self.assertEqual(len(dset), 10)
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
self.assertEqual(dset[0]["filename"], "my_name-train_0")
self.assertEqual(dset["filename"][0], "my_name-train_0")

with tempfile.TemporaryDirectory() as tmp_dir:

dset = (
self._create_dummy_dataset()
.map(cache_file_name=os.path.join(tmp_dir, "test.arrow"))
.select(range(10), indices_cache_file_name=os.path.join(tmp_dir, "ind.arrow"))
)
dset._data = Unpicklable()
dset._indices = Unpicklable()

dataset_path = os.path.join(tmp_dir, "my_dataset")
dset.save_to_disk(dataset_path)
dset = dset.load_from_disk(dataset_path)

self.assertEqual(len(dset), 10)
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
self.assertEqual(dset[0]["filename"], "my_name-train_0")
self.assertEqual(dset["filename"][0], "my_name-train_0")

def test_dummy_dataset_load_from_dick(self):
with tempfile.TemporaryDirectory() as tmp_dir:

dset = self._create_dummy_dataset().select(range(10))
dataset_path = os.path.join(tmp_dir, "my_dataset")
dset.save_to_disk(dataset_path)
dset = load_from_disk(dataset_path)

self.assertEqual(len(dset), 10)
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
self.assertEqual(dset[0]["filename"], "my_name-train_0")
self.assertEqual(dset["filename"][0], "my_name-train_0")

def test_from_pandas(self):
data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"]}
df = pd.DataFrame.from_dict(data)
Expand Down
31 changes: 30 additions & 1 deletion tests/test_dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pandas as pd

from nlp import Features, Sequence, Value
from nlp import Features, Sequence, Value, load_from_disk
from nlp.arrow_dataset import Dataset
from nlp.dataset_dict import DatasetDict

Expand Down Expand Up @@ -276,3 +276,32 @@ def test_check_values_type(self):
self.assertRaises(TypeError, dsets.filter, lambda x: True)
self.assertRaises(TypeError, dsets.shuffle)
self.assertRaises(TypeError, dsets.sort, "filename")

def test_serialization(self):
with tempfile.TemporaryDirectory() as tmp_dir:
dsets = self._create_dummy_dataset_dict()
dsets.save_to_disk(tmp_dir)
dsets = DatasetDict.load_from_disk(tmp_dir)
self.assertListEqual(sorted(dsets), ["test", "train"])
self.assertEqual(len(dsets["train"]), 30)
self.assertListEqual(dsets["train"].column_names, ["filename"])
self.assertEqual(len(dsets["test"]), 30)
self.assertListEqual(dsets["test"].column_names, ["filename"])

del dsets["test"]
dsets.save_to_disk(tmp_dir)
dsets = DatasetDict.load_from_disk(tmp_dir)
self.assertListEqual(sorted(dsets), ["train"])
self.assertEqual(len(dsets["train"]), 30)
self.assertListEqual(dsets["train"].column_names, ["filename"])

def test_load_from_disk(self):
with tempfile.TemporaryDirectory() as tmp_dir:
dsets = self._create_dummy_dataset_dict()
dsets.save_to_disk(tmp_dir)
dsets = load_from_disk(tmp_dir)
self.assertListEqual(sorted(dsets), ["test", "train"])
self.assertEqual(len(dsets["train"]), 30)
self.assertListEqual(dsets["train"].column_names, ["filename"])
self.assertEqual(len(dsets["test"]), 30)
self.assertListEqual(dsets["test"].column_names, ["filename"])