16
16
import multiprocessing
17
17
import textwrap
18
18
import time
19
- from typing import Any , Callable , Dict , List , Optional
19
+ from typing import Any , Callable , Dict , List , Optional , Union
20
20
21
21
import grpc
22
22
import kubeflow .katib .katib_api_pb2 as katib_api_pb2
@@ -144,6 +144,7 @@ def tune(
144
144
max_trial_count : int = None ,
145
145
parallel_trial_count : int = None ,
146
146
max_failed_trial_count : int = None ,
147
+ resources_per_trial : Union [dict , client .V1ResourceRequirements , None ] = None ,
147
148
retain_trials : bool = False ,
148
149
packages_to_install : List [str ] = None ,
149
150
pip_index_url : str = "https://pypi.org/simple" ,
@@ -177,6 +178,21 @@ def tune(
177
178
values check this doc: https://www.kubeflow.org/docs/components/katib/experiment/#configuration-spec.
178
179
parallel_trial_count: Number of Trials that Experiment runs in parallel.
179
180
max_failed_trial_count: Maximum number of Trials allowed to fail.
181
+ resources_per_trial: A parameter that lets you specify how much
182
+ resources each trial container should have. You can either specify a
183
+ kubernetes.client.V1ResourceRequirements object (documented here:
184
+ https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1ResourceRequirements.md)
185
+ or a dictionary that includes one or more of the following keys:
186
+ `cpu`, `memory`, or `gpu` (other keys will be ignored). Appropriate
187
+ values for these keys are documented here:
188
+ https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/.
189
+ For example:
190
+ {
191
+ "cpu": "1",
192
+ "gpu": "1",
193
+ "memory": "2Gi",
194
+ }
195
+ This parameter is optional and defaults to None.
180
196
retain_trials: Whether Trials' resources (e.g. pods) are deleted after Succeeded state.
181
197
packages_to_install: List of Python packages to install in addition
182
198
to the base image packages. These packages are installed before
@@ -280,6 +296,20 @@ def tune(
280
296
+ exec_script
281
297
)
282
298
299
+ if isinstance (resources_per_trial , dict ):
300
+ requests = {
301
+ "cpu" : "200m" ,
302
+ "memory" : "256Mi" ,
303
+ }
304
+ if "gpu" in resources_per_trial :
305
+ resources_per_trial ["nvidia.com/gpu" ] = resources_per_trial .pop ("gpu" )
306
+ requests .update (resources_per_trial )
307
+
308
+ resources_per_trial = client .V1ResourceRequirements (
309
+ requests = requests ,
310
+ limits = requests ,
311
+ )
312
+
283
313
# Create Trial specification.
284
314
trial_spec = client .V1Job (
285
315
api_version = "batch/v1" ,
@@ -297,6 +327,7 @@ def tune(
297
327
image = base_image ,
298
328
command = ["bash" , "-c" ],
299
329
args = [exec_script ],
330
+ resources = resources_per_trial ,
300
331
)
301
332
],
302
333
),
0 commit comments