Skip to content

Commit 8be2d2c

Browse files
committed
refactor: clean up async tests and remove redundant comments
1 parent 5278a0d commit 8be2d2c

File tree

3 files changed

+10
-51
lines changed

3 files changed

+10
-51
lines changed

src/functions_framework/aio/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858

5959

6060
async def _crash_handler(request, exc):
61-
# Log the exception
6261
logger = logging.getLogger()
6362
tb_lines = traceback.format_exception(type(exc), exc, exc.__traceback__)
6463
tb_text = "".join(tb_lines)
@@ -72,7 +71,6 @@ async def _crash_handler(request, exc):
7271
log_entry = {"message": error_msg, "levelname": "ERROR"}
7372
logger.error(json.dumps(log_entry))
7473
else:
75-
# Execution ID logging not enabled, log plain text
7674
logger.error(error_msg)
7775

7876
headers = {_FUNCTION_STATUS_HEADER_FIELD: _CRASH}
@@ -191,19 +189,16 @@ def _enable_execution_id_logging():
191189

192190

193191
def _configure_app_execution_id_logging():
194-
# Logging needs to be configured before app logger is accessed
195192
import logging
196193
import logging.config
197194

198-
# Configure root logger to use our custom handler
199195
root_logger = logging.getLogger()
200196
root_logger.setLevel(logging.INFO)
201197

202198
# Remove existing handlers
203199
for handler in root_logger.handlers[:]:
204200
root_logger.removeHandler(handler)
205201

206-
# Add our custom handler that adds execution ID
207202
handler = logging.StreamHandler(
208203
execution_id.LoggingHandlerAddExecutionId(sys.stderr)
209204
)

src/functions_framework/execution_id.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,9 @@ def __init__(self, execution_id=None, span_id=None):
5151

5252

5353
def _get_current_context():
54-
# First try to get from async context
5554
context = execution_context_var.get()
5655
if context is not None:
5756
return context
58-
# Fall back to Flask context for sync
5957
return ( # pragma: no cover
6058
flask.g.execution_id_context
6159
if flask.has_request_context() and "execution_id_context" in flask.g
@@ -64,8 +62,6 @@ def _get_current_context():
6462

6563

6664
def _set_current_context(context):
67-
# Set in both contexts to support both sync and async
68-
# Set in contextvars for async
6965
execution_context_var.set(context)
7066
# Also set in Flask context if available for sync
7167
if flask.has_request_context():
@@ -82,8 +78,7 @@ def _generate_execution_id():
8278
def _extract_context_from_headers(headers):
8379
"""Extract execution context from request headers."""
8480
execution_id = headers.get(EXECUTION_ID_REQUEST_HEADER)
85-
86-
# Try to get span ID from trace context header
81+
8782
trace_context = re.match(
8883
_TRACE_CONTEXT_REGEX_PATTERN,
8984
headers.get(TRACE_CONTEXT_REQUEST_HEADER, ""),
@@ -113,7 +108,6 @@ def __init__(self, app):
113108

114109
async def __call__(self, scope, receive, send):
115110
if scope["type"] == "http":
116-
# Extract existing execution ID or generate a new one
117111
execution_id_header = b"function-execution-id"
118112
trace_context_header = b"x-cloud-trace-context"
119113
execution_id = None
@@ -127,22 +121,18 @@ async def __call__(self, scope, receive, send):
127121

128122
if not execution_id:
129123
execution_id = _generate_execution_id()
130-
# Add the execution ID to headers
131124
new_headers = list(scope.get("headers", []))
132125
new_headers.append(
133126
(execution_id_header, execution_id.encode("latin-1"))
134127
)
135128
scope["headers"] = new_headers
136129

137-
# Store execution context in ASGI scope for recovery in case of context loss
138-
# Parse trace context to extract span ID
139130
span_id = None
140131
if trace_context:
141132
trace_match = re.match(_TRACE_CONTEXT_REGEX_PATTERN, trace_context)
142133
if trace_match:
143134
span_id = trace_match.group("span_id")
144135

145-
# Store in scope for potential recovery
146136
scope["execution_context"] = ExecutionContext(execution_id, span_id)
147137

148138
await self.app(scope, receive, send) # pragma: no cover
@@ -169,9 +159,7 @@ def wrapper(*args, **kwargs):
169159

170160
with stderr_redirect, stdout_redirect:
171161
result = view_function(*args, **kwargs)
172-
173162
# Context cleanup happens automatically via Flask's request context
174-
# No need to manually clean up flask.g
175163
return result
176164

177165
return wrapper
@@ -195,14 +183,10 @@ def set_execution_context_async(enable_id_logging=False):
195183
def decorator(view_function):
196184
@functools.wraps(view_function)
197185
async def async_wrapper(request, *args, **kwargs):
198-
# Extract execution context from headers
199186
context = _extract_context_from_headers(request.headers)
200-
201-
# Set context using contextvars
202187
token = execution_context_var.set(context)
203188

204189
with stderr_redirect, stdout_redirect:
205-
# Handle both sync and async functions
206190
if inspect.iscoroutinefunction(view_function):
207191
result = await view_function(request, *args, **kwargs)
208192
else:
@@ -215,10 +199,7 @@ async def async_wrapper(request, *args, **kwargs):
215199

216200
@functools.wraps(view_function)
217201
def sync_wrapper(request, *args, **kwargs): # pragma: no cover
218-
# For sync functions, we still need to set up the context
219202
context = _extract_context_from_headers(request.headers)
220-
221-
# Set context using contextvars
222203
token = execution_context_var.set(context)
223204

224205
with stderr_redirect, stdout_redirect:
@@ -229,7 +210,6 @@ def sync_wrapper(request, *args, **kwargs): # pragma: no cover
229210
execution_context_var.reset(token)
230211
return result
231212

232-
# Return appropriate wrapper based on whether the function is async
233213
if inspect.iscoroutinefunction(view_function):
234214
return async_wrapper
235215
else:

tests/test_execution_id_async.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import asyncio
1514
import json
1615
import pathlib
17-
import re
18-
import sys
19-
20-
from functools import partial
2116
from unittest.mock import Mock
2217

2318
import pytest
@@ -32,8 +27,7 @@
3227
TEST_SPAN_ID = "123456"
3328

3429

35-
@pytest.mark.asyncio
36-
async def test_async_user_function_can_retrieve_execution_id_from_header():
30+
def test_async_user_function_can_retrieve_execution_id_from_header():
3731
source = TEST_FUNCTIONS_DIR / "execution_id" / "async_main.py"
3832
target = "async_function"
3933
app = create_asgi_app(target, source)
@@ -49,8 +43,7 @@ async def test_async_user_function_can_retrieve_execution_id_from_header():
4943
assert resp.json()["execution_id"] == TEST_EXECUTION_ID
5044

5145

52-
@pytest.mark.asyncio
53-
async def test_async_uncaught_exception_in_user_function_sets_execution_id(
46+
def test_async_uncaught_exception_in_user_function_sets_execution_id(
5447
capsys, monkeypatch
5548
):
5649
monkeypatch.setenv("LOG_EXECUTION_ID", "true")
@@ -71,8 +64,7 @@ async def test_async_uncaught_exception_in_user_function_sets_execution_id(
7164
assert f'"execution_id": "{TEST_EXECUTION_ID}"' in record.err
7265

7366

74-
@pytest.mark.asyncio
75-
async def test_async_print_from_user_function_sets_execution_id(capsys, monkeypatch):
67+
def test_async_print_from_user_function_sets_execution_id(capsys, monkeypatch):
7668
monkeypatch.setenv("LOG_EXECUTION_ID", "true")
7769
source = TEST_FUNCTIONS_DIR / "execution_id" / "async_main.py"
7870
target = "async_print_message"
@@ -91,8 +83,7 @@ async def test_async_print_from_user_function_sets_execution_id(capsys, monkeypa
9183
assert '"message": "some-message"' in record.out
9284

9385

94-
@pytest.mark.asyncio
95-
async def test_async_log_from_user_function_sets_execution_id(capsys, monkeypatch):
86+
def test_async_log_from_user_function_sets_execution_id(capsys, monkeypatch):
9687
monkeypatch.setenv("LOG_EXECUTION_ID", "true")
9788
source = TEST_FUNCTIONS_DIR / "execution_id" / "async_main.py"
9889
target = "async_log_message"
@@ -111,8 +102,7 @@ async def test_async_log_from_user_function_sets_execution_id(capsys, monkeypatc
111102
assert '"custom-field": "some-message"' in record.err
112103

113104

114-
@pytest.mark.asyncio
115-
async def test_async_user_function_can_retrieve_generated_execution_id(monkeypatch):
105+
def test_async_user_function_can_retrieve_generated_execution_id(monkeypatch):
116106
monkeypatch.setattr(
117107
execution_id, "_generate_execution_id", lambda: TEST_EXECUTION_ID
118108
)
@@ -130,8 +120,7 @@ async def test_async_user_function_can_retrieve_generated_execution_id(monkeypat
130120
assert resp.json()["execution_id"] == TEST_EXECUTION_ID
131121

132122

133-
@pytest.mark.asyncio
134-
async def test_async_does_not_set_execution_id_when_not_enabled(capsys):
123+
def test_async_does_not_set_execution_id_when_not_enabled(capsys):
135124
source = TEST_FUNCTIONS_DIR / "execution_id" / "async_main.py"
136125
target = "async_print_message"
137126
app = create_asgi_app(target, source)
@@ -149,8 +138,7 @@ async def test_async_does_not_set_execution_id_when_not_enabled(capsys):
149138
assert "some-message" in record.out
150139

151140

152-
@pytest.mark.asyncio
153-
async def test_async_concurrent_requests_maintain_separate_execution_ids(
141+
def test_async_concurrent_requests_maintain_separate_execution_ids(
154142
capsys, monkeypatch
155143
):
156144
monkeypatch.setenv("LOG_EXECUTION_ID", "true")
@@ -236,8 +224,7 @@ def make_request(client, message, exec_id):
236224
),
237225
],
238226
)
239-
@pytest.mark.asyncio
240-
async def test_async_set_execution_context_headers(
227+
def test_async_set_execution_context_headers(
241228
headers, expected_execution_id, expected_span_id, should_generate
242229
):
243230
source = TEST_FUNCTIONS_DIR / "execution_id" / "async_main.py"
@@ -260,8 +247,6 @@ async def test_async_set_execution_context_headers(
260247
@pytest.mark.asyncio
261248
async def test_crash_handler_without_context_sets_execution_id():
262249
"""Test that crash handler returns proper error response with crash header."""
263-
from unittest.mock import Mock
264-
265250
from functions_framework.aio import _crash_handler
266251

267252
# Create a mock request
@@ -281,8 +266,7 @@ async def test_crash_handler_without_context_sets_execution_id():
281266
assert response.headers["X-Google-Status"] == "crash"
282267

283268

284-
@pytest.mark.asyncio
285-
async def test_async_decorator_with_sync_function():
269+
def test_async_decorator_with_sync_function():
286270
"""Test that the async decorator handles sync functions properly."""
287271
from functions_framework.execution_id import set_execution_context_async
288272

0 commit comments

Comments
 (0)