Skip to content

Commit 453508f

Browse files
authored
adding fine tune example with s3 as the dataset store (kubeflow/trainer#2006)
* s3 as dataset source code review changes Signed-off-by: deepanker13 <[email protected]> * fixing python black test Signed-off-by: deepanker13 <[email protected]> * removing conflicts in example file Signed-off-by: deepanker13 <[email protected]> * retriggering CI Signed-off-by: deepanker13 <[email protected]> * removing dummy keys Signed-off-by: deepanker13 <[email protected]> * code review change for adding s3 keys block Signed-off-by: deepanker13 <[email protected]> --------- Signed-off-by: deepanker13 <[email protected]>
1 parent a3efe86 commit 453508f

File tree

1 file changed

+23
-10
lines changed
  • python/kubeflow/storage_initializer

1 file changed

+23
-10
lines changed

python/kubeflow/storage_initializer/s3.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,31 @@ def download_dataset(self):
4242
import boto3
4343

4444
# Create an S3 client for Nutanix Object Store/S3
45-
s3_client = boto3.client(
46-
"s3",
45+
s3_client = boto3.Session(
4746
aws_access_key_id=self.config.access_key,
4847
aws_secret_access_key=self.config.secret_key,
49-
endpoint_url=self.config.endpoint_url,
5048
region_name=self.config.region_name,
5149
)
50+
s3_resource = s3_client.resource("s3", endpoint_url=self.config.endpoint_url)
51+
# Get the bucket object
52+
bucket = s3_resource.Bucket(self.config.bucket_name)
5253

53-
# Download the file
54-
s3_client.download_file(
55-
self.config.bucket_name,
56-
self.config.file_key,
57-
os.path.join(VOLUME_PATH_DATASET, self.config.file_key),
58-
)
59-
print(f"File downloaded to: {VOLUME_PATH_DATASET}")
54+
# Filter objects with the specified prefix
55+
objects = bucket.objects.filter(Prefix=self.config.file_key)
56+
# Iterate over filtered objects
57+
for obj in objects:
58+
# Extract the object key (filename)
59+
obj_key = obj.key
60+
path_components = obj_key.split(os.path.sep)
61+
path_excluded_first_last_parts = os.path.sep.join(path_components[1:-1])
62+
63+
# Create directories if they don't exist
64+
os.makedirs(
65+
os.path.join(VOLUME_PATH_DATASET, path_excluded_first_last_parts),
66+
exist_ok=True,
67+
)
68+
69+
# Download the file
70+
file_path = os.path.sep.join(path_components[1:])
71+
bucket.download_file(obj_key, os.path.join(VOLUME_PATH_DATASET, file_path))
72+
print(f"Files downloaded")

0 commit comments

Comments
 (0)