Skip to content

Commit abd614d

Browse files
[SDK] Enable resource specification for trial containers
Co-authored-by: shipengcheng1230 <[email protected]>
1 parent c749d27 commit abd614d

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

sdk/python/v1beta1/kubeflow/katib/api/katib_client.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import multiprocessing
1717
import textwrap
1818
import time
19-
from typing import Any, Callable, Dict, List, Optional
19+
from typing import Any, Callable, Dict, List, Optional, Union
2020

2121
import grpc
2222
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
@@ -147,6 +147,7 @@ def tune(
147147
retain_trials: bool = False,
148148
packages_to_install: List[str] = None,
149149
pip_index_url: str = "https://pypi.org/simple",
150+
resources_per_trial: Union[dict, client.V1ResourceRequirements, None] = None,
150151
):
151152
"""Create HyperParameter Tuning Katib Experiment from the objective function.
152153
@@ -182,6 +183,20 @@ def tune(
182183
to the base image packages. These packages are installed before
183184
executing the objective function.
184185
pip_index_url: The PyPI url from which to install Python packages.
186+
resources_per trial: A parameter that lets you specify how much
187+
resources each trial container should have. You can either specify a
188+
kubernetes.client.V1ResourceRequirements object (documented here:
189+
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1ResourceRequirements.md)
190+
or a dictionary that includes one or more of the following keys:
191+
`cpu`, `memory`, or `gpu` (other keys will be ignored). Appropriate
192+
values for these keys are documented here:
193+
https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/.
194+
For example:
195+
{
196+
"cpu": "1",
197+
"memory": "2Gi",
198+
}
199+
This parameter is optional and defaults to None.
185200
186201
Raises:
187202
ValueError: Objective function has invalid arguments.
@@ -280,6 +295,23 @@ def tune(
280295
+ exec_script
281296
)
282297

298+
resources = client.V1ResourceRequirements()
299+
if isinstance(resources_per_trial, dict):
300+
requests = {
301+
"cpu": "200m",
302+
"memory": "256Mi",
303+
}
304+
if "gpu" in resources_per_trial:
305+
resources_per_trial["nvidia.com/gpu"] = resources_per_trial.pop("gpu")
306+
requests.update(resources_per_trial)
307+
308+
resources = client.V1ResourceRequirements(
309+
requests=requests,
310+
limits=requests,
311+
)
312+
else:
313+
resources = resources_per_trial
314+
283315
# Create Trial specification.
284316
trial_spec = client.V1Job(
285317
api_version="batch/v1",
@@ -297,6 +329,7 @@ def tune(
297329
image=base_image,
298330
command=["bash", "-c"],
299331
args=[exec_script],
332+
resources=resources,
300333
)
301334
],
302335
),

0 commit comments

Comments
 (0)