Skip to content

Commit b1a56ba

Browse files
authored
fix: resolve issue where rest transport is not used in certain tests (#1231)
1 parent 9801fde commit b1a56ba

File tree

8 files changed

+348
-141
lines changed

8 files changed

+348
-141
lines changed

packages/gapic-generator/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,20 +85,32 @@ def test__get_default_mtls_endpoint():
8585
assert {{ service.client_name }}._get_default_mtls_endpoint(non_googleapi) == non_googleapi
8686

8787

88-
@pytest.mark.parametrize("client_class", [
89-
{{ service.client_name }},
88+
@pytest.mark.parametrize("client_class,transport_name", [
89+
{% if 'grpc' in opts.transport %}
90+
({{ service.client_name }}, "grpc"),
91+
{% endif %}
92+
{% if 'rest' in opts.transport %}
93+
({{ service.client_name }}, "rest"),
94+
{% endif %}
9095
])
91-
def test_{{ service.client_name|snake_case }}_from_service_account_info(client_class):
96+
def test_{{ service.client_name|snake_case }}_from_service_account_info(client_class, transport_name):
9297
creds = ga_credentials.AnonymousCredentials()
9398
with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory:
9499
factory.return_value = creds
95100
info = {"valid": True}
96-
client = client_class.from_service_account_info(info)
101+
client = client_class.from_service_account_info(info, transport=transport_name)
97102
assert client.transport._credentials == creds
98103
assert isinstance(client, client_class)
99104

100105
{% if service.host %}
101-
assert client.transport._host == '{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'
106+
assert client.transport._host == (
107+
'{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'
108+
{% if 'rest' in opts.transport %}
109+
if transport_name in ['grpc', 'grpc_asyncio']
110+
else
111+
'https://{{ service.host }}'
112+
{% endif %}
113+
)
102114
{% endif %}
103115

104116

@@ -122,23 +134,35 @@ def test_{{ service.client_name|snake_case }}_service_account_always_use_jwt(tra
122134
use_jwt.assert_not_called()
123135

124136

125-
@pytest.mark.parametrize("client_class", [
126-
{{ service.client_name }},
137+
@pytest.mark.parametrize("client_class,transport_name", [
138+
{% if 'grpc' in opts.transport %}
139+
({{ service.client_name }}, "grpc"),
140+
{% endif %}
141+
{% if 'rest' in opts.transport %}
142+
({{ service.client_name }}, "rest"),
143+
{% endif %}
127144
])
128-
def test_{{ service.client_name|snake_case }}_from_service_account_file(client_class):
145+
def test_{{ service.client_name|snake_case }}_from_service_account_file(client_class, transport_name):
129146
creds = ga_credentials.AnonymousCredentials()
130147
with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory:
131148
factory.return_value = creds
132-
client = client_class.from_service_account_file("dummy/file/path.json")
149+
client = client_class.from_service_account_file("dummy/file/path.json", transport=transport_name)
133150
assert client.transport._credentials == creds
134151
assert isinstance(client, client_class)
135152

136-
client = client_class.from_service_account_json("dummy/file/path.json")
153+
client = client_class.from_service_account_json("dummy/file/path.json", transport=transport_name)
137154
assert client.transport._credentials == creds
138155
assert isinstance(client, client_class)
139156

140157
{% if service.host %}
141-
assert client.transport._host == '{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'
158+
assert client.transport._host == (
159+
'{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'
160+
{% if 'rest' in opts.transport %}
161+
if transport_name in ['grpc', 'grpc_asyncio']
162+
else
163+
'https://{{ service.host }}'
164+
{% endif %}
165+
)
142166
{% endif %}
143167

144168

@@ -1853,23 +1877,53 @@ def test_{{ service.name|snake_case }}_rest_lro_client():
18531877
{%- endif %}
18541878
{% endif %} {# rest #}
18551879

1856-
def test_{{ service.name|snake_case }}_host_no_port():
1880+
@pytest.mark.parametrize("transport_name", [
1881+
{% if 'grpc' in opts.transport %}
1882+
"grpc",
1883+
{% endif %}
1884+
{% if 'rest' in opts.transport %}
1885+
"rest",
1886+
{% endif %}
1887+
])
1888+
def test_{{ service.name|snake_case }}_host_no_port(transport_name):
18571889
{% with host = (service.host|default('localhost', true)).split(':')[0] %}
18581890
client = {{ service.client_name }}(
18591891
credentials=ga_credentials.AnonymousCredentials(),
18601892
client_options=client_options.ClientOptions(api_endpoint='{{ host }}'),
1893+
transport=transport_name,
1894+
)
1895+
assert client.transport._host == (
1896+
'{{ host }}:443'
1897+
{% if 'rest' in opts.transport %}
1898+
if transport_name in ['grpc', 'grpc_asyncio']
1899+
else 'https://{{ host }}'
1900+
{% endif %}
18611901
)
1862-
assert client.transport._host == '{{ host }}:443'
18631902
{% endwith %}
18641903

18651904

1866-
def test_{{ service.name|snake_case }}_host_with_port():
1905+
@pytest.mark.parametrize("transport_name", [
1906+
{% if 'grpc' in opts.transport %}
1907+
"grpc",
1908+
{% endif %}
1909+
{% if 'rest' in opts.transport %}
1910+
"rest",
1911+
{% endif %}
1912+
])
1913+
def test_{{ service.name|snake_case }}_host_with_port(transport_name):
18671914
{% with host = (service.host|default('localhost', true)).split(':')[0] %}
18681915
client = {{ service.client_name }}(
18691916
credentials=ga_credentials.AnonymousCredentials(),
18701917
client_options=client_options.ClientOptions(api_endpoint='{{ host }}:8000'),
1918+
transport=transport_name,
1919+
)
1920+
assert client.transport._host == (
1921+
'{{ host }}:8000'
1922+
{% if 'rest' in opts.transport %}
1923+
if transport_name in ['grpc', 'grpc_asyncio']
1924+
else 'https://{{ host }}:8000'
1925+
{% endif %}
18711926
)
1872-
assert client.transport._host == '{{ host }}:8000'
18731927
{% endwith %}
18741928

18751929
{% if 'grpc' in opts.transport %}

packages/gapic-generator/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -89,23 +89,33 @@ def test__get_default_mtls_endpoint():
8989
assert {{ service.client_name }}._get_default_mtls_endpoint(non_googleapi) == non_googleapi
9090

9191

92-
@pytest.mark.parametrize("client_class", [
93-
{{ service.client_name }},
92+
@pytest.mark.parametrize("client_class,transport_name", [
9493
{% if 'grpc' in opts.transport %}
95-
{{ service.async_client_name }},
94+
({{ service.client_name }}, "grpc"),
95+
({{ service.async_client_name }}, "grpc_asyncio"),
96+
{% endif %}
97+
{% if 'rest' in opts.transport %}
98+
({{ service.client_name }}, "rest"),
9699
{% endif %}
97100
])
98-
def test_{{ service.client_name|snake_case }}_from_service_account_info(client_class):
101+
def test_{{ service.client_name|snake_case }}_from_service_account_info(client_class, transport_name):
99102
creds = ga_credentials.AnonymousCredentials()
100103
with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory:
101104
factory.return_value = creds
102105
info = {"valid": True}
103-
client = client_class.from_service_account_info(info)
106+
client = client_class.from_service_account_info(info, transport=transport_name)
104107
assert client.transport._credentials == creds
105108
assert isinstance(client, client_class)
106109

107110
{% if service.host %}
108-
assert client.transport._host == '{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'
111+
assert client.transport._host == (
112+
'{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'
113+
{% if 'rest' in opts.transport %}
114+
if transport_name in ['grpc', 'grpc_asyncio']
115+
else
116+
'https://{{ service.host }}'
117+
{% endif %}
118+
)
109119
{% endif %}
110120

111121

@@ -130,26 +140,36 @@ def test_{{ service.client_name|snake_case }}_service_account_always_use_jwt(tra
130140
use_jwt.assert_not_called()
131141

132142

133-
@pytest.mark.parametrize("client_class", [
134-
{{ service.client_name }},
143+
@pytest.mark.parametrize("client_class,transport_name", [
135144
{% if 'grpc' in opts.transport %}
136-
{{ service.async_client_name }},
145+
({{ service.client_name }}, "grpc"),
146+
({{ service.async_client_name }}, "grpc_asyncio"),
147+
{% endif %}
148+
{% if 'rest' in opts.transport %}
149+
({{ service.client_name }}, "rest"),
137150
{% endif %}
138151
])
139-
def test_{{ service.client_name|snake_case }}_from_service_account_file(client_class):
152+
def test_{{ service.client_name|snake_case }}_from_service_account_file(client_class, transport_name):
140153
creds = ga_credentials.AnonymousCredentials()
141154
with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory:
142155
factory.return_value = creds
143-
client = client_class.from_service_account_file("dummy/file/path.json")
156+
client = client_class.from_service_account_file("dummy/file/path.json", transport=transport_name)
144157
assert client.transport._credentials == creds
145158
assert isinstance(client, client_class)
146159

147-
client = client_class.from_service_account_json("dummy/file/path.json")
160+
client = client_class.from_service_account_json("dummy/file/path.json", transport=transport_name)
148161
assert client.transport._credentials == creds
149162
assert isinstance(client, client_class)
150163

151164
{% if service.host %}
152-
assert client.transport._host == '{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'
165+
assert client.transport._host == (
166+
'{{ service.host }}{% if ":" not in service.host %}:443{% endif %}'
167+
{% if 'rest' in opts.transport %}
168+
if transport_name in ['grpc', 'grpc_asyncio']
169+
else
170+
'https://{{ service.host }}'
171+
{% endif %}
172+
)
153173
{% endif %}
154174

155175

@@ -2323,23 +2343,54 @@ def test_{{ service.name|snake_case }}_rest_lro_client():
23232343

23242344
{% endif %} {# rest #}
23252345

2326-
def test_{{ service.name|snake_case }}_host_no_port():
2346+
@pytest.mark.parametrize("transport_name", [
2347+
{% if 'grpc' in opts.transport %}
2348+
"grpc",
2349+
"grpc_asyncio",
2350+
{% endif %}
2351+
{% if 'rest' in opts.transport %}
2352+
"rest",
2353+
{% endif %}
2354+
])
2355+
def test_{{ service.name|snake_case }}_host_no_port(transport_name):
23272356
{% with host = (service.host|default('localhost', true)).split(':')[0] %}
23282357
client = {{ service.client_name }}(
23292358
credentials=ga_credentials.AnonymousCredentials(),
23302359
client_options=client_options.ClientOptions(api_endpoint='{{ host }}'),
2360+
transport=transport_name,
2361+
)
2362+
assert client.transport._host == (
2363+
'{{ host }}:443'
2364+
{% if 'rest' in opts.transport %}
2365+
if transport_name in ['grpc', 'grpc_asyncio']
2366+
else 'https://{{ host }}'
2367+
{% endif %}
23312368
)
2332-
assert client.transport._host == '{{ host }}:443'
23332369
{% endwith %}
23342370

2335-
2336-
def test_{{ service.name|snake_case }}_host_with_port():
2371+
@pytest.mark.parametrize("transport_name", [
2372+
{% if 'grpc' in opts.transport %}
2373+
"grpc",
2374+
"grpc_asyncio",
2375+
{% endif %}
2376+
{% if 'rest' in opts.transport %}
2377+
"rest",
2378+
{% endif %}
2379+
])
2380+
def test_{{ service.name|snake_case }}_host_with_port(transport_name):
23372381
{% with host = (service.host|default('localhost', true)).split(':')[0] %}
23382382
client = {{ service.client_name }}(
23392383
credentials=ga_credentials.AnonymousCredentials(),
23402384
client_options=client_options.ClientOptions(api_endpoint='{{ host }}:8000'),
2385+
transport=transport_name,
2386+
)
2387+
assert client.transport._host == (
2388+
'{{ host }}:8000'
2389+
{% if 'rest' in opts.transport %}
2390+
if transport_name in ['grpc', 'grpc_asyncio']
2391+
else 'https://{{ host }}:8000'
2392+
{% endif %}
23412393
)
2342-
assert client.transport._host == '{{ host }}:8000'
23432394
{% endwith %}
23442395

23452396
{% if 'grpc' in opts.transport %}

packages/gapic-generator/tests/integration/goldens/asset/tests/unit/gapic/asset_v1/test_asset_service.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,22 @@ def test__get_default_mtls_endpoint():
7676
assert AssetServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi
7777

7878

79-
@pytest.mark.parametrize("client_class", [
80-
AssetServiceClient,
81-
AssetServiceAsyncClient,
79+
@pytest.mark.parametrize("client_class,transport_name", [
80+
(AssetServiceClient, "grpc"),
81+
(AssetServiceAsyncClient, "grpc_asyncio"),
8282
])
83-
def test_asset_service_client_from_service_account_info(client_class):
83+
def test_asset_service_client_from_service_account_info(client_class, transport_name):
8484
creds = ga_credentials.AnonymousCredentials()
8585
with mock.patch.object(service_account.Credentials, 'from_service_account_info') as factory:
8686
factory.return_value = creds
8787
info = {"valid": True}
88-
client = client_class.from_service_account_info(info)
88+
client = client_class.from_service_account_info(info, transport=transport_name)
8989
assert client.transport._credentials == creds
9090
assert isinstance(client, client_class)
9191

92-
assert client.transport._host == 'cloudasset.googleapis.com:443'
92+
assert client.transport._host == (
93+
'cloudasset.googleapis.com:443'
94+
)
9395

9496

9597
@pytest.mark.parametrize("transport_class,transport_name", [
@@ -108,23 +110,25 @@ def test_asset_service_client_service_account_always_use_jwt(transport_class, tr
108110
use_jwt.assert_not_called()
109111

110112

111-
@pytest.mark.parametrize("client_class", [
112-
AssetServiceClient,
113-
AssetServiceAsyncClient,
113+
@pytest.mark.parametrize("client_class,transport_name", [
114+
(AssetServiceClient, "grpc"),
115+
(AssetServiceAsyncClient, "grpc_asyncio"),
114116
])
115-
def test_asset_service_client_from_service_account_file(client_class):
117+
def test_asset_service_client_from_service_account_file(client_class, transport_name):
116118
creds = ga_credentials.AnonymousCredentials()
117119
with mock.patch.object(service_account.Credentials, 'from_service_account_file') as factory:
118120
factory.return_value = creds
119-
client = client_class.from_service_account_file("dummy/file/path.json")
121+
client = client_class.from_service_account_file("dummy/file/path.json", transport=transport_name)
120122
assert client.transport._credentials == creds
121123
assert isinstance(client, client_class)
122124

123-
client = client_class.from_service_account_json("dummy/file/path.json")
125+
client = client_class.from_service_account_json("dummy/file/path.json", transport=transport_name)
124126
assert client.transport._credentials == creds
125127
assert isinstance(client, client_class)
126128

127-
assert client.transport._host == 'cloudasset.googleapis.com:443'
129+
assert client.transport._host == (
130+
'cloudasset.googleapis.com:443'
131+
)
128132

129133

130134
def test_asset_service_client_get_transport_class():
@@ -3866,20 +3870,33 @@ def test_asset_service_grpc_transport_client_cert_source_for_mtls(
38663870
)
38673871

38683872

3869-
def test_asset_service_host_no_port():
3873+
@pytest.mark.parametrize("transport_name", [
3874+
"grpc",
3875+
"grpc_asyncio",
3876+
])
3877+
def test_asset_service_host_no_port(transport_name):
38703878
client = AssetServiceClient(
38713879
credentials=ga_credentials.AnonymousCredentials(),
38723880
client_options=client_options.ClientOptions(api_endpoint='cloudasset.googleapis.com'),
3881+
transport=transport_name,
3882+
)
3883+
assert client.transport._host == (
3884+
'cloudasset.googleapis.com:443'
38733885
)
3874-
assert client.transport._host == 'cloudasset.googleapis.com:443'
3875-
38763886

3877-
def test_asset_service_host_with_port():
3887+
@pytest.mark.parametrize("transport_name", [
3888+
"grpc",
3889+
"grpc_asyncio",
3890+
])
3891+
def test_asset_service_host_with_port(transport_name):
38783892
client = AssetServiceClient(
38793893
credentials=ga_credentials.AnonymousCredentials(),
38803894
client_options=client_options.ClientOptions(api_endpoint='cloudasset.googleapis.com:8000'),
3895+
transport=transport_name,
3896+
)
3897+
assert client.transport._host == (
3898+
'cloudasset.googleapis.com:8000'
38813899
)
3882-
assert client.transport._host == 'cloudasset.googleapis.com:8000'
38833900

38843901
def test_asset_service_grpc_transport_channel():
38853902
channel = grpc.secure_channel('http://localhost/', grpc.local_channel_credentials())

0 commit comments

Comments
 (0)