Skip to content

Commit bf3c8c2

Browse files
committed
Issue 6598: load_dataset broken for data_files on s3
1 parent 7ae4314 commit bf3c8c2

File tree

3 files changed

+41
-0
lines changed

3 files changed

+41
-0
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@
172172
"jax>=0.3.14; sys_platform != 'win32'",
173173
"jaxlib>=0.3.14; sys_platform != 'win32'",
174174
"lz4",
175+
"moto[server]",
175176
"pyspark>=3.4", # https://issues.apache.org/jira/browse/SPARK-40991 fixed in 3.4.0
176177
"py7zr",
177178
"rarfile>=4.0",

src/datasets/utils/file_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,10 @@ def get_from_cache(
568568
if scheme == "ftp":
569569
connected = ftp_head(url)
570570
elif scheme not in ("http", "https"):
571+
if scheme in ("s3", "s3a") and storage_options is not None and "hf" in storage_options:
572+
# Issue 6071: **storage_options is passed to botocore.session.Session()
573+
# and must not contain keys that become invalid kwargs.
574+
del storage_options["hf"]
571575
response = fsspec_head(url, storage_options=storage_options)
572576
# s3fs uses "ETag", gcsfs uses "etag"
573577
etag = (response.get("ETag", None) or response.get("etag", None)) if use_etag else None

tests/test_load.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44
import shutil
55
import tempfile
66
import time
7+
from contextlib import contextmanager
78
from hashlib import sha256
89
from multiprocessing import Pool
910
from pathlib import Path
1011
from unittest import TestCase
1112
from unittest.mock import patch
1213

14+
import boto3
1315
import dill
1416
import pyarrow as pa
1517
import pytest
1618
import requests
19+
from moto.server import ThreadedMotoServer
1720

1821
import datasets
1922
from datasets import config, load_dataset, load_from_disk
@@ -1648,6 +1651,39 @@ def test_load_from_disk_with_default_in_memory(
16481651
_ = load_from_disk(dataset_path)
16491652

16501653

1654+
@contextmanager
1655+
def moto_server():
1656+
with patch.dict(os.environ, {"AWS_ENDPOINT_URL": "http://localhost:5000"}):
1657+
server = ThreadedMotoServer()
1658+
server.start()
1659+
try:
1660+
yield
1661+
finally:
1662+
server.stop()
1663+
1664+
1665+
def test_load_file_from_s3():
1666+
with moto_server():
1667+
# Create a mock S3 bucket
1668+
bucket_name = "test-bucket"
1669+
s3 = boto3.client("s3", region_name="us-east-1")
1670+
s3.create_bucket(Bucket=bucket_name)
1671+
1672+
# Upload a file to the mock bucket
1673+
key = "test-file.csv"
1674+
csv_data = "Name\nPatrick\nMat"
1675+
1676+
s3.put_object(Bucket=bucket_name, Key=key, Body=csv_data)
1677+
1678+
# Load the file from the mock bucket
1679+
ds = datasets.load_dataset(
1680+
"csv", data_files={"train": "s3://test-bucket/test-file.csv"}
1681+
)
1682+
1683+
# Check if the loaded content matches the original content
1684+
assert list(ds["train"]) == [{"Name": "Patrick"}, {"Name": "Mat"}]
1685+
1686+
16511687
@pytest.mark.integration
16521688
def test_remote_data_files():
16531689
repo_id = "hf-internal-testing/raw_jsonl"

0 commit comments

Comments
 (0)