Skip to content

[Misc] Replace cuda hard code with current_platform #16983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,8 +1221,9 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
ray.shutdown()
gc.collect()
from vllm.platforms import current_platform
if not current_platform.is_cpu():
torch.cuda.empty_cache()
empty_cache = current_platform.empty_cache
if empty_cache is not None:
empty_cache()
try:
torch._C._host_emptyCache()
except AttributeError:
Expand Down
5 changes: 4 additions & 1 deletion vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ def set_forward_context(attn_metadata: Any,
# we use synchronous scheduling right now,
# adding a sync point here should not affect
# scheduling of the next batch
torch.cuda.synchronize()
from vllm.platforms import current_platform
synchronize = current_platform.synchronize
if synchronize is not None:
synchronize()
now = time.perf_counter()
# time measurement is in milliseconds
batchsize_forward_time[batchsize].append(
Expand Down
8 changes: 4 additions & 4 deletions vllm/spec_decode/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
"""Copy rejection/typical-acceptance sampling metrics
(number of accepted tokens, etc) to CPU asynchronously.

Returns a CUDA event recording when the copy is complete.
Returns a device event recording when the copy is complete.
"""
assert self._copy_stream is not None
self._copy_stream.wait_stream(torch.cuda.current_stream())
self._copy_stream.wait_stream(current_platform.current_stream())

with torch.cuda.stream(self._copy_stream):
with current_platform.stream(self._copy_stream):
self._aggregate_num_accepted_tokens.copy_(
self.spec_decode_sampler.num_accepted_tokens,
non_blocking=True)
Expand All @@ -142,7 +142,7 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
self._aggregate_num_draft_tokens = (
self.spec_decode_sampler.num_draft_tokens)

aggregate_metrics_ready = torch.cuda.Event()
aggregate_metrics_ready = current_platform.Event()
aggregate_metrics_ready.record(self._copy_stream)

return aggregate_metrics_ready
Expand Down