Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions sdk/python/v1beta1/kubeflow/katib/api/katib_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def tune(
] = None,
retain_trials: bool = False,
packages_to_install: List[str] = None,
pip_index_url: str = "https://pypi.org/simple",
pip_index_urls : Optional[List[str]] = ["https://pypi.org/simple"],
metrics_collector_config: Dict[str, Any] = {"kind": "StdOut"},
):
"""
Expand Down Expand Up @@ -351,7 +351,7 @@ class name in this argument.
packages_to_install: List of Python packages to install in addition
to the base image packages. These packages are installed before
executing the objective function.
pip_index_url: The PyPI url from which to install Python packages.
pip_index_urls: List of PyPI urls from which to install Python packages.
metrics_collector_config: Specify the config of metrics collector,
for example, `metrics_collector_config = {"kind": "Push"}`.
Currently, we only support `StdOut` and `Push` metrics collector.
Expand Down Expand Up @@ -462,7 +462,7 @@ class name in this argument.
entrypoint,
input_params,
packages_to_install,
pip_index_url,
pip_index_urls,
)

# Generate container spec for PyTorchJob or Job.
Expand Down
26 changes: 26 additions & 0 deletions sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,20 @@ def create_experiment(
},
TEST_RESULT_SUCCESS,
),
(
"valid flow with pip_index_urls",
{
"name": "tune_test",
"objective": lambda x: print(f"a={x}"),
"parameters": {"a": katib.search.int(min=10, max=100)},
"packages_to_install": ["pandas", "numpy"],
"pip_index_urls": [
"https://pypi.org/simple",
"https://private-repo.com/simple",
],
},
TEST_RESULT_SUCCESS,
),
]


Expand Down Expand Up @@ -703,6 +717,18 @@ def test_tune(katib_client, test_name, kwargs, expected_output):
additional_metric_names=[],
)

elif test_name == "valid flow with pip_index_urls":
# Verify pip install command in container args.
args_content = "".join(
experiment.spec.trial_template.trial_spec.spec.template.spec.containers[
0
].args
)
assert (
"--index-url https://pypi.org/simple --extra-index-url https://private-repo.com/simple pandas numpy"
in args_content
)

except Exception as e:
assert type(e) is expected_output
print("test execution complete")
16 changes: 11 additions & 5 deletions sdk/python/v1beta1/kubeflow/katib/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,14 @@ def validate_objective_function(objective: Callable):
f"Current Objective arguments: {objective_signature}"
)

def format_pip_index_urls(pip_index_urls: List[str] = ["https://pypi.org/simple"]) -> str:
index_url = f'--index-url {pip_index_urls[0]}'
for url in pip_index_urls[1:]:
index_url += f' --extra-index-url {url}'
return index_url

def get_script_for_python_packages(packages_to_install, pip_index_url):

def get_script_for_python_packages(packages_to_install, pip_index_urls=["https://pypi.org/simple"]):
packages_str = " ".join([str(package) for package in packages_to_install])

script_for_python_packages = textwrap.dedent(
Expand All @@ -125,7 +131,7 @@ def get_script_for_python_packages(packages_to_install, pip_index_url):
fi

PIP_DISABLE_PIP_VERSION_CHECK=1 python3 -m pip install --prefer-binary --quiet \
--no-warn-script-location --index-url {pip_index_url} {packages_str}
--no-warn-script-location {format_pip_index_urls(pip_index_urls)} {packages_str}
"""
)

Expand Down Expand Up @@ -228,7 +234,7 @@ def get_exec_script_from_objective(
entrypoint: str,
input_params: Dict[str, Any],
packages_to_install: Optional[List[str]] = None,
pip_index_url: str = "https://pypi.org/simple",
pip_index_urls: Optional[List[str]] = ["https://pypi.org/simple"],
) -> str:
"""
Get executable script for container args from the given objective function and parameters.
Expand Down Expand Up @@ -272,7 +278,7 @@ def get_exec_script_from_objective(
# Install Python packages if that is required.
if packages_to_install is not None:
exec_script = (
get_script_for_python_packages(packages_to_install, pip_index_url)
get_script_for_python_packages(packages_to_install, pip_index_urls)
+ exec_script
)

Expand Down Expand Up @@ -350,4 +356,4 @@ def get_trial_template_with_pytorchjob(
trial_parameters=trial_parameters,
trial_spec=pytorchjob,
)
return trial_template
return trial_template
18 changes: 18 additions & 0 deletions sdk/python/v1beta1/kubeflow/katib/utils/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
from kubeflow.katib.utils import utils

@pytest.mark.parametrize(
"pip_index_urls, expected",
[
(["https://pypi.org/simple"],
"--index-url https://pypi.org/simple"),
(["https://pypi.org/simple", "https://private-repo.com/simple"],
"--index-url https://pypi.org/simple --extra-index-url https://private-repo.com/simple"),
(["https://pypi.org/simple", "https://private-repo.com/simple", "https://another-repo.com/simple"],
"--index-url https://pypi.org/simple --extra-index-url https://private-repo.com/simple --extra-index-url https://another-repo.com/simple"),
(None,
"--index-url https://pypi.org/simple"),
]
)
def test_format_pip_index_urls(pip_index_urls, expected):
assert utils.format_pip_index_urls(pip_index_urls) == expected
Loading