|
4 | 4 | import shutil
|
5 | 5 | import tempfile
|
6 | 6 | import time
|
| 7 | +from contextlib import contextmanager |
7 | 8 | from hashlib import sha256
|
8 | 9 | from multiprocessing import Pool
|
9 | 10 | from pathlib import Path
|
10 | 11 | from unittest import TestCase
|
11 | 12 | from unittest.mock import patch
|
12 | 13 |
|
| 14 | +import boto3 |
13 | 15 | import dill
|
14 | 16 | import pyarrow as pa
|
15 | 17 | import pytest
|
16 | 18 | import requests
|
| 19 | +from moto.server import ThreadedMotoServer |
17 | 20 |
|
18 | 21 | import datasets
|
19 | 22 | from datasets import config, load_dataset, load_from_disk
|
@@ -1648,6 +1651,39 @@ def test_load_from_disk_with_default_in_memory(
|
1648 | 1651 | _ = load_from_disk(dataset_path)
|
1649 | 1652 |
|
1650 | 1653 |
|
| 1654 | +@contextmanager |
| 1655 | +def moto_server(): |
| 1656 | + with patch.dict(os.environ, {"AWS_ENDPOINT_URL": "http://localhost:5000"}): |
| 1657 | + server = ThreadedMotoServer() |
| 1658 | + server.start() |
| 1659 | + try: |
| 1660 | + yield |
| 1661 | + finally: |
| 1662 | + server.stop() |
| 1663 | + |
| 1664 | + |
| 1665 | +def test_load_file_from_s3(): |
| 1666 | + with moto_server(): |
| 1667 | + # Create a mock S3 bucket |
| 1668 | + bucket_name = "test-bucket" |
| 1669 | + s3 = boto3.client("s3", region_name="us-east-1") |
| 1670 | + s3.create_bucket(Bucket=bucket_name) |
| 1671 | + |
| 1672 | + # Upload a file to the mock bucket |
| 1673 | + key = "test-file.csv" |
| 1674 | + csv_data = "Name\nPatrick\nMat" |
| 1675 | + |
| 1676 | + s3.put_object(Bucket=bucket_name, Key=key, Body=csv_data) |
| 1677 | + |
| 1678 | + # Load the file from the mock bucket |
| 1679 | + ds = datasets.load_dataset( |
| 1680 | + "csv", data_files={"train": "s3://test-bucket/test-file.csv"} |
| 1681 | + ) |
| 1682 | + |
| 1683 | + # Check if the loaded content matches the original content |
| 1684 | + assert list(ds["train"]) == [{"Name": "Patrick"}, {"Name": "Mat"}] |
| 1685 | + |
| 1686 | + |
1651 | 1687 | @pytest.mark.integration
|
1652 | 1688 | def test_remote_data_files():
|
1653 | 1689 | repo_id = "hf-internal-testing/raw_jsonl"
|
|
0 commit comments