|  | 
|  | 1 | +# Copyright 2025 Google LLC | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | +# | 
|  | 15 | + | 
|  | 16 | +"""Tests for closing the clients and context managers.""" | 
|  | 17 | +import asyncio | 
|  | 18 | +from unittest import mock | 
|  | 19 | + | 
|  | 20 | +from google.oauth2 import credentials | 
|  | 21 | +import pytest | 
|  | 22 | +try: | 
|  | 23 | +  import aiohttp | 
|  | 24 | +  AIOHTTP_NOT_INSTALLED = False | 
|  | 25 | +except ImportError: | 
|  | 26 | +  AIOHTTP_NOT_INSTALLED = True | 
|  | 27 | +  aiohttp = mock.MagicMock() | 
|  | 28 | + | 
|  | 29 | + | 
|  | 30 | +from ... import _api_client as api_client | 
|  | 31 | +from ... import Client | 
|  | 32 | + | 
|  | 33 | + | 
|  | 34 | +requires_aiohttp = pytest.mark.skipif( | 
|  | 35 | +    AIOHTTP_NOT_INSTALLED, reason='aiohttp is not installed, skipping test.' | 
|  | 36 | +) | 
|  | 37 | + | 
|  | 38 | + | 
|  | 39 | +def test_close_httpx_client(): | 
|  | 40 | +  """Tests that the httpx client is closed when the client is closed.""" | 
|  | 41 | +  api_client.has_aiohttp = False | 
|  | 42 | +  client = Client( | 
|  | 43 | +      vertexai=True, | 
|  | 44 | +      project='test_project', | 
|  | 45 | +      location='global', | 
|  | 46 | +  ) | 
|  | 47 | +  client.close() | 
|  | 48 | +  assert client._api_client._httpx_client.is_closed | 
|  | 49 | + | 
|  | 50 | + | 
|  | 51 | +def test_httpx_client_context_manager(): | 
|  | 52 | +  """Tests that the httpx client is closed when the client is closed.""" | 
|  | 53 | +  api_client.has_aiohttp = False | 
|  | 54 | +  with Client( | 
|  | 55 | +      vertexai=True, | 
|  | 56 | +      project='test_project', | 
|  | 57 | +      location='global', | 
|  | 58 | +  ) as client: | 
|  | 59 | +    pass | 
|  | 60 | +    assert not client._api_client._httpx_client.is_closed | 
|  | 61 | + | 
|  | 62 | +  assert client._api_client._httpx_client.is_closed | 
|  | 63 | + | 
|  | 64 | + | 
|  | 65 | +@pytest.mark.asyncio | 
|  | 66 | +async def test_aclose_httpx_client(): | 
|  | 67 | +  """Tests that the httpx async client is closed when the client is closed.""" | 
|  | 68 | +  api_client.has_aiohttp = False | 
|  | 69 | +  async_client = Client( | 
|  | 70 | +      vertexai=True, | 
|  | 71 | +      project='test_project', | 
|  | 72 | +      location='global', | 
|  | 73 | +  ).aio | 
|  | 74 | +  await async_client.aclose() | 
|  | 75 | +  assert async_client._api_client._async_httpx_client.is_closed | 
|  | 76 | + | 
|  | 77 | + | 
|  | 78 | +@pytest.mark.asyncio | 
|  | 79 | +async def test_async_httpx_client_context_manager(): | 
|  | 80 | +  """Tests that the httpx async client is closed when the client is closed.""" | 
|  | 81 | +  api_client.has_aiohttp = False | 
|  | 82 | +  async with Client( | 
|  | 83 | +      vertexai=True, | 
|  | 84 | +      project='test_project', | 
|  | 85 | +      location='global', | 
|  | 86 | +  ).aio as async_client: | 
|  | 87 | +    pass | 
|  | 88 | +    assert not async_client._api_client._async_httpx_client.is_closed | 
|  | 89 | + | 
|  | 90 | +  assert async_client._api_client._async_httpx_client.is_closed | 
|  | 91 | + | 
|  | 92 | + | 
|  | 93 | +@requires_aiohttp | 
|  | 94 | +@pytest.mark.asyncio | 
|  | 95 | +async def test_aclose_aiohttp_session(): | 
|  | 96 | +  """Tests that the aiohttp session is closed when the client is closed.""" | 
|  | 97 | +  api_client.has_aiohttp = True | 
|  | 98 | +  async_client = Client( | 
|  | 99 | +      vertexai=True, | 
|  | 100 | +      project='test_project', | 
|  | 101 | +      location='global', | 
|  | 102 | +  ).aio | 
|  | 103 | +  await async_client.aclose() | 
|  | 104 | +  assert async_client._api_client._aiohttp_session is None | 
|  | 105 | + | 
|  | 106 | + | 
|  | 107 | +@requires_aiohttp | 
|  | 108 | +@pytest.fixture | 
|  | 109 | +def mock_request(): | 
|  | 110 | +  mock_aiohttp_response = mock.Mock(spec=aiohttp.ClientSession.request) | 
|  | 111 | +  mock_aiohttp_response.return_value = mock_aiohttp_response | 
|  | 112 | +  yield mock_aiohttp_response | 
|  | 113 | + | 
|  | 114 | + | 
|  | 115 | +def _patch_auth_default(): | 
|  | 116 | +  return mock.patch( | 
|  | 117 | +      'google.auth.default', | 
|  | 118 | +      return_value=(credentials.Credentials('magic_token'), 'test_project'), | 
|  | 119 | +      autospec=True, | 
|  | 120 | +  ) | 
|  | 121 | + | 
|  | 122 | + | 
|  | 123 | +async def _aiohttp_async_response(status: int): | 
|  | 124 | +  """Has to return a coroutine hence async.""" | 
|  | 125 | +  response = mock.Mock(spec=aiohttp.ClientResponse) | 
|  | 126 | +  response.status = status | 
|  | 127 | +  response.headers = {'status-code': str(status)} | 
|  | 128 | +  response.json.return_value = {} | 
|  | 129 | +  response.text.return_value = 'test' | 
|  | 130 | +  return response | 
|  | 131 | + | 
|  | 132 | + | 
|  | 133 | +@requires_aiohttp | 
|  | 134 | +@mock.patch.object(aiohttp.ClientSession, 'request', autospec=True) | 
|  | 135 | +def test_aiohttp_session_context_manager(mock_request): | 
|  | 136 | +  """Tests that the aiohttp session is closed when the client is closed.""" | 
|  | 137 | +  api_client.has_aiohttp = True | 
|  | 138 | +  async def run(): | 
|  | 139 | +    mock_request.side_effect = ( | 
|  | 140 | +        aiohttp.ClientConnectorError( | 
|  | 141 | +            connection_key=aiohttp.client_reqrep.ConnectionKey( | 
|  | 142 | +                'localhost', 80, False, True, None, None, None | 
|  | 143 | +            ), | 
|  | 144 | +            os_error=OSError, | 
|  | 145 | +        ), | 
|  | 146 | +        _aiohttp_async_response(200), | 
|  | 147 | +    ) | 
|  | 148 | +    with _patch_auth_default(): | 
|  | 149 | +      async with Client( | 
|  | 150 | +          vertexai=True, | 
|  | 151 | +          project='test_project', | 
|  | 152 | +          location='global', | 
|  | 153 | +      ).aio as async_client: | 
|  | 154 | +        # aiohttp session is created in the first request instead of client | 
|  | 155 | +        # initialization. | 
|  | 156 | +        _ = await async_client._api_client._async_request_once( | 
|  | 157 | +            api_client.HttpRequest( | 
|  | 158 | +                method='GET', | 
|  | 159 | +                url='https://example.com', | 
|  | 160 | +                headers={}, | 
|  | 161 | +                data=None, | 
|  | 162 | +                timeout=None, | 
|  | 163 | +            ) | 
|  | 164 | +        ) | 
|  | 165 | +        assert async_client._api_client._aiohttp_session is not None | 
|  | 166 | +        assert not async_client._api_client._aiohttp_session.closed | 
|  | 167 | + | 
|  | 168 | +      assert async_client._api_client._aiohttp_session.closed | 
|  | 169 | + | 
|  | 170 | +  asyncio.run(run()) | 
0 commit comments