Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
"jax>=0.3.14; sys_platform != 'win32'",
"jaxlib>=0.3.14; sys_platform != 'win32'",
"lz4",
"moto[server]",
"pyspark>=3.4", # https://issues.apache.org/jira/browse/SPARK-40991 fixed in 3.4.0
"py7zr",
"rarfile>=4.0",
Expand Down
2 changes: 0 additions & 2 deletions src/datasets/download/download_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,6 @@ def __post_init__(self, use_auth_token):
FutureWarning,
)
self.token = use_auth_token
if "hf" not in self.storage_options:
self.storage_options["hf"] = {"token": self.token, "endpoint": config.HF_ENDPOINT}

def copy(self) -> "DownloadConfig":
return self.__class__(**{k: copy.deepcopy(v) for k, v in self.__dict__.items()})
Expand Down
36 changes: 36 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
import shutil
import tempfile
import time
from contextlib import contextmanager
from hashlib import sha256
from multiprocessing import Pool
from pathlib import Path
from unittest import TestCase
from unittest.mock import patch

import boto3
import dill
import pyarrow as pa
import pytest
import requests
from moto.server import ThreadedMotoServer

import datasets
from datasets import config, load_dataset, load_from_disk
Expand Down Expand Up @@ -1648,6 +1651,39 @@ def test_load_from_disk_with_default_in_memory(
_ = load_from_disk(dataset_path)


@contextmanager
def moto_server():
with patch.dict(os.environ, {"AWS_ENDPOINT_URL": "http://localhost:5000"}):
server = ThreadedMotoServer()
server.start()
try:
yield
finally:
server.stop()


def test_load_file_from_s3():
# we need server mode here because of an aiobotocore incompatibility with moto.mock_aws
# (https://github.com/getmoto/moto/issues/6836)
with moto_server():
# Create a mock S3 bucket
bucket_name = "test-bucket"
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket=bucket_name)

# Upload a file to the mock bucket
key = "test-file.csv"
csv_data = "Island\nIsabela\nBaltra"

s3.put_object(Bucket=bucket_name, Key=key, Body=csv_data)

# Load the file from the mock bucket
ds = datasets.load_dataset("csv", data_files={"train": "s3://test-bucket/test-file.csv"})

# Check if the loaded content matches the original content
assert list(ds["train"]) == [{"Island": "Isabela"}, {"Island": "Baltra"}]


@pytest.mark.integration
def test_remote_data_files():
repo_id = "hf-internal-testing/raw_jsonl"
Expand Down