Skip to content

Commit 1136c27

Browse files
fix: add missing uv index env var for auth (#751)
1 parent 24cfa65 commit 1136c27

File tree

5 files changed

+288
-23
lines changed

5 files changed

+288
-23
lines changed

safety/tool/auth.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import base64
2+
import json
3+
4+
import typer
5+
6+
7+
def index_credentials(ctx: typer.Context) -> str:
8+
"""
9+
Returns the index credentials for the current context.
10+
This should be used together with user:index_credential for index
11+
basic auth.
12+
13+
Args:
14+
ctx (typer.Context): The context.
15+
16+
Returns:
17+
str: The index credentials.
18+
"""
19+
api_key = None
20+
token = None
21+
22+
if auth := getattr(ctx.obj, "auth", None):
23+
client = auth.client
24+
token = client.token.get("access_token") if client.token else None
25+
api_key = client.api_key
26+
27+
auth_envelop = json.dumps(
28+
{
29+
"version": "1.0",
30+
"access_token": token,
31+
"api_key": api_key,
32+
"project_id": ctx.obj.project.id if ctx.obj.project else None,
33+
}
34+
)
35+
return base64.urlsafe_b64encode(auth_envelop.encode("utf-8")).decode("utf-8")

safety/tool/pip/main.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import base64
2-
import json
31
import logging
42
import re
53
import shutil
@@ -19,6 +17,7 @@
1917
from safety.tool.resolver import get_unwrapped_command
2018

2119
from safety.console import main_console
20+
from safety.tool.auth import index_credentials
2221

2322

2423
logger = logging.getLogger(__name__)
@@ -137,26 +136,6 @@ def reset_system(cls, console: Console = main_console):
137136
except Exception:
138137
console.print("Failed to reset PIP global settings.")
139138

140-
@classmethod
141-
def index_credentials(cls, ctx: typer.Context):
142-
api_key = None
143-
token = None
144-
145-
if auth := getattr(ctx.obj, "auth", None):
146-
client = auth.client
147-
token = client.token.get("access_token") if client.token else None
148-
api_key = client.api_key
149-
150-
auth_envelop = json.dumps(
151-
{
152-
"version": "1.0",
153-
"access_token": token,
154-
"api_key": api_key,
155-
"project_id": ctx.obj.project.id if ctx.obj.project else None,
156-
}
157-
)
158-
return base64.urlsafe_b64encode(auth_envelop.encode("utf-8")).decode("utf-8")
159-
160139
@classmethod
161140
def default_index_url(cls) -> str:
162141
return "https://pypi.org/simple/"
@@ -168,7 +147,7 @@ def build_index_url(cls, ctx: typer.Context, index_url: Optional[str]) -> str:
168147

169148
url = urlsplit(index_url)
170149

171-
encoded_auth = cls.index_credentials(ctx)
150+
encoded_auth = index_credentials(ctx)
172151
netloc = f"user:{encoded_auth}@{url.netloc}"
173152

174153
if type(url.netloc) is bytes:

safety/tool/uv/command.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from typing import List
2+
3+
import typer
24
from safety.tool.intents import ToolIntentionType
35
from safety.tool.pip.parser import PipParser
6+
from safety.tool.auth import index_credentials
47
from ..pip.command import PipCommand, PipInstallCommand, PipGenericCommand
58
from safety_schemas.models.events.types import ToolType
69

@@ -35,6 +38,18 @@ def should_track_state(self) -> bool:
3538

3639
return any(cmd in command_str for cmd in package_modifying_commands)
3740

41+
def env(self, ctx: typer.Context) -> dict:
42+
env = super().env(ctx)
43+
44+
env.update(
45+
{
46+
"UV_INDEX_SAFETY_USERNAME": "user",
47+
"UV_INDEX_SAFETY_PASSWORD": index_credentials(ctx),
48+
}
49+
)
50+
51+
return env
52+
3853
@classmethod
3954
def from_args(cls, args: List[str], **kwargs):
4055
pip_parser = PipParser()

tests/tool/test_auth.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# type: ignore
2+
import unittest
3+
import json
4+
import base64
5+
from unittest.mock import MagicMock
6+
7+
import typer
8+
from safety.tool.auth import index_credentials
9+
10+
11+
class TestIndexCredentials(unittest.TestCase):
12+
"""
13+
Test cases for index_credentials function.
14+
"""
15+
16+
def test_index_credentials_with_full_auth_object(self):
17+
"""
18+
Test index_credentials when ctx.obj.auth is fully populated with token and api_key.
19+
"""
20+
ctx = MagicMock(spec=typer.Context)
21+
ctx.obj = MagicMock()
22+
ctx.obj.auth.client.token = {"access_token": "test_token"}
23+
ctx.obj.auth.client.api_key = "test_api_key"
24+
ctx.obj.project.id = "test_project_id"
25+
26+
result = index_credentials(ctx)
27+
28+
decoded = json.loads(
29+
base64.urlsafe_b64decode(result.encode("utf-8")).decode("utf-8")
30+
)
31+
32+
self.assertEqual(decoded["version"], "1.0")
33+
self.assertEqual(decoded["access_token"], "test_token")
34+
self.assertEqual(decoded["api_key"], "test_api_key")
35+
self.assertEqual(decoded["project_id"], "test_project_id")
36+
37+
def test_index_credentials_with_missing_token(self):
38+
"""
39+
Test index_credentials when token is None but api_key is present.
40+
"""
41+
ctx = MagicMock(spec=typer.Context)
42+
ctx.obj = MagicMock()
43+
ctx.obj.auth.client.token = None
44+
ctx.obj.auth.client.api_key = "test_api_key"
45+
ctx.obj.project.id = "test_project_id"
46+
47+
result = index_credentials(ctx)
48+
49+
decoded = json.loads(
50+
base64.urlsafe_b64decode(result.encode("utf-8")).decode("utf-8")
51+
)
52+
self.assertEqual(decoded["version"], "1.0")
53+
self.assertIsNone(decoded["access_token"])
54+
self.assertEqual(decoded["api_key"], "test_api_key")
55+
self.assertEqual(decoded["project_id"], "test_project_id")
56+
57+
def test_index_credentials_with_missing_api_key(self):
58+
"""
59+
Test index_credentials when api_key is None but token is present.
60+
"""
61+
ctx = MagicMock(spec=typer.Context)
62+
ctx.obj = MagicMock()
63+
ctx.obj.auth.client.token = {"access_token": "test_token"}
64+
ctx.obj.auth.client.api_key = None
65+
ctx.obj.project.id = "test_project_id"
66+
67+
result = index_credentials(ctx)
68+
69+
decoded = json.loads(
70+
base64.urlsafe_b64decode(result.encode("utf-8")).decode("utf-8")
71+
)
72+
73+
self.assertEqual(decoded["version"], "1.0")
74+
self.assertEqual(decoded["access_token"], "test_token")
75+
self.assertIsNone(decoded["api_key"])
76+
self.assertEqual(decoded["project_id"], "test_project_id")
77+
78+
def test_index_credentials_with_no_auth(self):
79+
"""
80+
Test index_credentials when ctx.obj.auth is None.
81+
"""
82+
83+
ctx = MagicMock(spec=typer.Context)
84+
ctx.obj = MagicMock()
85+
ctx.obj.auth = None
86+
ctx.obj.project.id = "test_project_id"
87+
88+
result = index_credentials(ctx)
89+
90+
decoded = json.loads(
91+
base64.urlsafe_b64decode(result.encode("utf-8")).decode("utf-8")
92+
)
93+
94+
self.assertEqual(decoded["version"], "1.0")
95+
self.assertIsNone(decoded["access_token"])
96+
self.assertIsNone(decoded["api_key"])
97+
self.assertEqual(decoded["project_id"], "test_project_id")
98+
99+
def test_index_credentials_with_no_project(self):
100+
"""
101+
Test index_credentials when ctx.obj.project is None.
102+
"""
103+
104+
ctx = MagicMock(spec=typer.Context)
105+
ctx.obj = MagicMock()
106+
ctx.obj.auth.client.token = {"access_token": "test_token"}
107+
ctx.obj.auth.client.api_key = "test_api_key"
108+
ctx.obj.project = None
109+
110+
result = index_credentials(ctx)
111+
112+
decoded = json.loads(
113+
base64.urlsafe_b64decode(result.encode("utf-8")).decode("utf-8")
114+
)
115+
116+
self.assertEqual(decoded["version"], "1.0")
117+
self.assertEqual(decoded["access_token"], "test_token")
118+
self.assertEqual(decoded["api_key"], "test_api_key")
119+
self.assertIsNone(decoded["project_id"])
120+
121+
def test_index_credentials_correct_encoding(self):
122+
"""
123+
Test that index_credentials correctly encodes the credentials in base64url format.
124+
"""
125+
126+
ctx = MagicMock(spec=typer.Context)
127+
ctx.obj = MagicMock()
128+
ctx.obj.auth.client.token = {"access_token": "test_token"}
129+
ctx.obj.auth.client.api_key = "test_api_key"
130+
ctx.obj.project.id = "test_project_id"
131+
132+
result = index_credentials(ctx)
133+
134+
expected_json = json.dumps(
135+
{
136+
"version": "1.0",
137+
"access_token": "test_token",
138+
"api_key": "test_api_key",
139+
"project_id": "test_project_id",
140+
}
141+
)
142+
143+
expected_encoded = base64.urlsafe_b64encode(
144+
expected_json.encode("utf-8")
145+
).decode("utf-8")
146+
147+
self.assertEqual(result, expected_encoded)

tests/tool/uv/test_uv_command.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# type: ignore
2+
import unittest
3+
from unittest.mock import patch, MagicMock
4+
5+
import typer
6+
from safety.tool.uv.command import UvCommand
7+
8+
9+
class TestUvCommand(unittest.TestCase):
10+
"""
11+
Test cases for UvCommand functionality.
12+
"""
13+
14+
def setUp(self):
15+
"""
16+
Set up test environment before each test method.
17+
"""
18+
self.command = UvCommand(["uv", "pip", "install", "package"])
19+
20+
def test_env_preserves_existing_variables(self):
21+
"""
22+
Test that env() method does not replace existing environment variables.
23+
"""
24+
ctx = MagicMock(spec=typer.Context)
25+
26+
existing_env = {
27+
"EXISTING_VAR": "existing_value",
28+
"ANOTHER_VAR": "another_value",
29+
}
30+
31+
with patch(
32+
"safety.tool.pip.command.PipCommand.env", return_value=existing_env
33+
) as mock_super_env:
34+
with patch(
35+
"safety.tool.uv.command.index_credentials",
36+
return_value="mock_credentials",
37+
):
38+
result_env = self.command.env(ctx)
39+
40+
self.assertEqual(result_env["EXISTING_VAR"], "existing_value")
41+
self.assertEqual(result_env["ANOTHER_VAR"], "another_value")
42+
mock_super_env.assert_called_once_with(ctx)
43+
44+
def test_env_adds_uv_credentials_variables(self):
45+
"""
46+
Test that env() method always adds UV_INDEX_SAFETY_USERNAME and
47+
UV_INDEX_SAFETY_PASSWORD environment variables.
48+
"""
49+
ctx = MagicMock(spec=typer.Context)
50+
51+
mock_credentials = "mock_credentials_value"
52+
53+
with patch(
54+
"safety.tool.uv.command.index_credentials", return_value=mock_credentials
55+
):
56+
with patch("safety.tool.pip.command.PipCommand.env", return_value={}):
57+
result_env = self.command.env(ctx)
58+
59+
self.assertIn("UV_INDEX_SAFETY_USERNAME", result_env)
60+
self.assertIn("UV_INDEX_SAFETY_PASSWORD", result_env)
61+
self.assertEqual(result_env["UV_INDEX_SAFETY_USERNAME"], "user")
62+
self.assertEqual(
63+
result_env["UV_INDEX_SAFETY_PASSWORD"], mock_credentials
64+
)
65+
66+
def test_env_combines_parent_env_with_uv_credentials(self):
67+
"""
68+
Test that env() method properly combines parent environment with UV credentials.
69+
"""
70+
ctx = MagicMock(spec=typer.Context)
71+
72+
existing_env = {"EXISTING_VAR": "existing_value", "PATH": "/usr/bin:/bin"}
73+
74+
mock_credentials = "mock_credentials_value"
75+
76+
with patch(
77+
"safety.tool.uv.command.index_credentials", return_value=mock_credentials
78+
):
79+
with patch(
80+
"safety.tool.pip.command.PipCommand.env", return_value=existing_env
81+
):
82+
result_env = self.command.env(ctx)
83+
84+
self.assertEqual(result_env["EXISTING_VAR"], "existing_value")
85+
self.assertEqual(result_env["PATH"], "/usr/bin:/bin")
86+
self.assertEqual(result_env["UV_INDEX_SAFETY_USERNAME"], "user")
87+
self.assertEqual(
88+
result_env["UV_INDEX_SAFETY_PASSWORD"], mock_credentials
89+
)

0 commit comments

Comments
 (0)