16
16
import multiprocessing
17
17
import textwrap
18
18
import time
19
- from typing import Any , Callable , Dict , List , Optional
19
+ from typing import Any , Callable , Dict , List , Optional , Union
20
20
21
21
import grpc
22
22
import kubeflow .katib .katib_api_pb2 as katib_api_pb2
@@ -147,6 +147,7 @@ def tune(
147
147
retain_trials : bool = False ,
148
148
packages_to_install : List [str ] = None ,
149
149
pip_index_url : str = "https://pypi.org/simple" ,
150
+ resources_per_trial : Union [dict , client .V1ResourceRequirements , None ] = None ,
150
151
):
151
152
"""Create HyperParameter Tuning Katib Experiment from the objective function.
152
153
@@ -182,6 +183,20 @@ def tune(
182
183
to the base image packages. These packages are installed before
183
184
executing the objective function.
184
185
pip_index_url: The PyPI url from which to install Python packages.
186
+ resources_per trial: A parameter that lets you specify how much
187
+ resources each trial container should have. You can either specify a
188
+ kubernetes.client.V1ResourceRequirements object (documented here:
189
+ https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1ResourceRequirements.md)
190
+ or a dictionary that includes one or more of the following keys:
191
+ `cpu`, `memory`, or `gpu` (other keys will be ignored). Appropriate
192
+ values for these keys are documented here:
193
+ https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/.
194
+ For example:
195
+ {
196
+ "cpu": "1",
197
+ "memory": "2Gi",
198
+ }
199
+ This parameter is optional and defaults to None.
185
200
186
201
Raises:
187
202
ValueError: Objective function has invalid arguments.
@@ -280,6 +295,23 @@ def tune(
280
295
+ exec_script
281
296
)
282
297
298
+ resources = client .V1ResourceRequirements ()
299
+ if isinstance (resources_per_trial , dict ):
300
+ requests = {
301
+ "cpu" : "200m" ,
302
+ "memory" : "256Mi" ,
303
+ }
304
+ if "gpu" in resources_per_trial :
305
+ resources_per_trial ["nvidia.com/gpu" ] = resources_per_trial .pop ("gpu" )
306
+ requests .update (resources_per_trial )
307
+
308
+ resources = client .V1ResourceRequirements (
309
+ requests = requests ,
310
+ limits = requests ,
311
+ )
312
+ else :
313
+ resources = resources_per_trial
314
+
283
315
# Create Trial specification.
284
316
trial_spec = client .V1Job (
285
317
api_version = "batch/v1" ,
@@ -297,6 +329,7 @@ def tune(
297
329
image = base_image ,
298
330
command = ["bash" , "-c" ],
299
331
args = [exec_script ],
332
+ resources = resources ,
300
333
)
301
334
],
302
335
),
0 commit comments