Skip to content

Commit 11eeebb

Browse files
feat(sdk): Support multiple pip index URLs. (#2566)
* feat: add extra index urls to the katib tune Signed-off-by: Andrey Velichkevich <[email protected]> * Remove test with None index Signed-off-by: Andrey Velichkevich <[email protected]> --------- Signed-off-by: Andrey Velichkevich <[email protected]> Co-authored-by: Andrey Velichkevich <[email protected]>
1 parent 7ab0939 commit 11eeebb

File tree

4 files changed

+98
-29
lines changed

4 files changed

+98
-29
lines changed

sdk/python/v1beta1/kubeflow/katib/api/katib_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def tune(
215215
] = None,
216216
retain_trials: bool = False,
217217
packages_to_install: List[str] = None,
218-
pip_index_url: str = "https://pypi.org/simple",
218+
pip_index_urls: Optional[List[str]] = ["https://pypi.org/simple"],
219219
metrics_collector_config: Dict[str, Any] = {"kind": "StdOut"},
220220
trial_active_deadline_seconds: Optional[int] = None,
221221
):
@@ -353,7 +353,7 @@ class name in this argument.
353353
packages_to_install: List of Python packages to install in addition
354354
to the base image packages. These packages are installed before
355355
executing the objective function.
356-
pip_index_url: The PyPI url from which to install Python packages.
356+
pip_index_urls: List of PyPI urls from which to install Python packages.
357357
metrics_collector_config: Specify the config of metrics collector,
358358
for example, `metrics_collector_config = {"kind": "Push"}`.
359359
Currently, we only support `StdOut` and `Push` metrics collector.
@@ -469,7 +469,7 @@ class name in this argument.
469469
entrypoint,
470470
input_params,
471471
packages_to_install,
472-
pip_index_url,
472+
pip_index_urls,
473473
)
474474

475475
# Generate container spec for PyTorchJob or Job.

sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,20 @@ def create_experiment(
569569
},
570570
TEST_RESULT_SUCCESS,
571571
),
572+
(
573+
"valid flow with pip_index_urls",
574+
{
575+
"name": "tune_test",
576+
"objective": lambda x: print(f"a={x}"),
577+
"parameters": {"a": katib.search.int(min=10, max=100)},
578+
"packages_to_install": ["pandas", "numpy"],
579+
"pip_index_urls": [
580+
"https://pypi.org/simple",
581+
"https://private-repo.com/simple",
582+
],
583+
},
584+
TEST_RESULT_SUCCESS,
585+
),
572586
]
573587

574588

@@ -682,32 +696,39 @@ def read_namespaced_pod_log_response(*args, **kwargs):
682696

683697
@pytest.fixture
684698
def katib_client():
685-
with patch(
686-
"kubernetes.client.CustomObjectsApi",
687-
return_value=Mock(
688-
create_namespaced_custom_object=Mock(
689-
side_effect=create_namespaced_custom_object_response
690-
),
691-
get_namespaced_custom_object=Mock(
692-
side_effect=get_namespaced_custom_object_response
699+
with (
700+
patch(
701+
"kubernetes.client.CustomObjectsApi",
702+
return_value=Mock(
703+
create_namespaced_custom_object=Mock(
704+
side_effect=create_namespaced_custom_object_response
705+
),
706+
get_namespaced_custom_object=Mock(
707+
side_effect=get_namespaced_custom_object_response
708+
),
693709
),
694710
),
695-
), patch("kubernetes.config.load_kube_config", return_value=Mock()), patch(
696-
"kubeflow.katib.katib_api_pb2_grpc.DBManagerStub",
697-
return_value=Mock(
698-
GetObservationLog=Mock(side_effect=get_observation_log_response)
699-
),
700-
), patch(
701-
"kubernetes.client.CoreV1Api",
702-
return_value=Mock(
703-
create_namespaced_persistent_volume_claim=Mock(
704-
side_effect=create_namespaced_persistent_volume_claim_response
711+
patch("kubernetes.config.load_kube_config", return_value=Mock()),
712+
patch(
713+
"kubeflow.katib.katib_api_pb2_grpc.DBManagerStub",
714+
return_value=Mock(
715+
GetObservationLog=Mock(side_effect=get_observation_log_response)
705716
),
706-
list_namespaced_persistent_volume_claim=Mock(
707-
side_effect=list_namespaced_persistent_volume_claim_response
717+
),
718+
patch(
719+
"kubernetes.client.CoreV1Api",
720+
return_value=Mock(
721+
create_namespaced_persistent_volume_claim=Mock(
722+
side_effect=create_namespaced_persistent_volume_claim_response
723+
),
724+
list_namespaced_persistent_volume_claim=Mock(
725+
side_effect=list_namespaced_persistent_volume_claim_response
726+
),
727+
list_namespaced_pod=Mock(side_effect=list_namespaced_pod_response),
728+
read_namespaced_pod_log=Mock(
729+
side_effect=read_namespaced_pod_log_response
730+
),
708731
),
709-
list_namespaced_pod=Mock(side_effect=list_namespaced_pod_response),
710-
read_namespaced_pod_log=Mock(side_effect=read_namespaced_pod_log_response),
711732
),
712733
):
713734
client = KatibClient()
@@ -860,6 +881,18 @@ def test_tune(katib_client, test_name, kwargs, expected_output):
860881
additional_metric_names=[],
861882
)
862883

884+
elif test_name == "valid flow with pip_index_urls":
885+
# Verify pip install command in container args.
886+
args_content = "".join(
887+
experiment.spec.trial_template.trial_spec.spec.template.spec.containers[
888+
0
889+
].args
890+
)
891+
assert (
892+
"--index-url https://pypi.org/simple --extra-index-url "
893+
"https://private-repo.com/simple pandas numpy" in args_content
894+
)
895+
863896
except Exception as e:
864897
assert type(e) is expected_output
865898
print("test execution complete")

sdk/python/v1beta1/kubeflow/katib/utils/utils.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,18 @@ def validate_objective_function(objective: Callable):
115115
)
116116

117117

118-
def get_script_for_python_packages(packages_to_install, pip_index_url):
118+
def format_pip_index_urls(
119+
pip_index_urls: List[str] = ["https://pypi.org/simple"],
120+
) -> str:
121+
index_url = f"--index-url {pip_index_urls[0]}"
122+
for url in pip_index_urls[1:]:
123+
index_url += f" --extra-index-url {url}"
124+
return index_url
125+
126+
127+
def get_script_for_python_packages(
128+
packages_to_install, pip_index_urls=["https://pypi.org/simple"]
129+
):
119130
packages_str = " ".join([str(package) for package in packages_to_install])
120131

121132
script_for_python_packages = textwrap.dedent(
@@ -125,7 +136,7 @@ def get_script_for_python_packages(packages_to_install, pip_index_url):
125136
fi
126137
127138
PIP_DISABLE_PIP_VERSION_CHECK=1 python3 -m pip install --prefer-binary --quiet \
128-
--no-warn-script-location --index-url {pip_index_url} {packages_str}
139+
--no-warn-script-location {format_pip_index_urls(pip_index_urls)} {packages_str}
129140
"""
130141
)
131142

@@ -233,7 +244,7 @@ def get_exec_script_from_objective(
233244
entrypoint: str,
234245
input_params: Dict[str, Any],
235246
packages_to_install: Optional[List[str]] = None,
236-
pip_index_url: str = "https://pypi.org/simple",
247+
pip_index_urls: Optional[List[str]] = ["https://pypi.org/simple"],
237248
) -> str:
238249
"""
239250
Get executable script for container args from the given objective function and parameters.
@@ -277,7 +288,7 @@ def get_exec_script_from_objective(
277288
# Install Python packages if that is required.
278289
if packages_to_install is not None:
279290
exec_script = (
280-
get_script_for_python_packages(packages_to_install, pip_index_url)
291+
get_script_for_python_packages(packages_to_install, pip_index_urls)
281292
+ exec_script
282293
)
283294

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import pytest
2+
from kubeflow.katib.utils import utils
3+
4+
5+
@pytest.mark.parametrize(
6+
"pip_index_urls, expected",
7+
[
8+
(["https://pypi.org/simple"], "--index-url https://pypi.org/simple"),
9+
(
10+
["https://pypi.org/simple", "https://private-repo.com/simple"],
11+
"--index-url https://pypi.org/simple --extra-index-url https://private-repo.com/simple",
12+
),
13+
(
14+
[
15+
"https://pypi.org/simple",
16+
"https://private-repo.com/simple",
17+
"https://another-repo.com/simple",
18+
],
19+
"--index-url https://pypi.org/simple --extra-index-url https://private-repo.com/simple "
20+
"--extra-index-url https://another-repo.com/simple",
21+
),
22+
],
23+
)
24+
def test_format_pip_index_urls(pip_index_urls, expected):
25+
assert utils.format_pip_index_urls(pip_index_urls) == expected

0 commit comments

Comments
 (0)