66from typing import Union
77from pydantic import BaseModel , field_validator
88import pathlib
9+ import av
910
1011from .interface import Pipeline
1112from comfystream .pipeline import Pipeline as ComfyStreamPipeline
12- from trickle import VideoFrame , VideoOutput
13+ from trickle import VideoFrame , VideoOutput , AudioFrame , AudioOutput
1314
1415import logging
1516
@@ -64,22 +65,22 @@ async def initialize(self, **params):
6465 await self .pipeline .warm_video ()
6566 logging .info ("Pipeline initialization and warmup complete" )
6667
68+
6769 async def put_video_frame (self , frame : VideoFrame , request_id : str ):
68- # Convert VideoFrame to format expected by comfystream
69- image_np = np .array (frame .image .convert ("RGB" )).astype (np .float32 ) / 255.0
70- frame .side_data .input = torch .tensor (image_np ).unsqueeze (0 )
71- frame .side_data .skipped = True
72- frame .side_data .request_id = request_id
73- await self .pipeline .put_video_frame (frame )
74-
75- async def get_processed_video_frame (self ) -> VideoOutput :
76- processed_frame = await self .pipeline .get_processed_video_frame ()
77- # Convert back to VideoOutput format
78- result_tensor = processed_frame .side_data .input
79- result_tensor = result_tensor .squeeze (0 )
80- result_image_np = (result_tensor * 255 ).byte ()
81- result_image = Image .fromarray (result_image_np .cpu ().numpy ())
82- return VideoOutput (processed_frame , processed_frame .side_data .request_id ).replace_image (result_image )
70+ await self .pipeline .put_video_frame (self ._convert_to_av_frame (frame ))
71+
72+ async def put_audio_frame (self , frame : AudioFrame , request_id : str ):
73+ await self .pipeline .put_audio_frame (self ._convert_to_av_frame (frame ))
74+
75+ async def get_processed_video_frame (self , request_id : str ) -> VideoOutput :
76+ av_frame = await self .pipeline .get_processed_video_frame ()
77+ video_frame = VideoFrame .from_av_video (av_frame )
78+ video_frame .side_data .request_id = request_id
79+ return VideoOutput (video_frame ).replace_image (av_frame .to_image ())
80+
81+ async def get_processed_audio_frame (self , request_id : str ) -> AudioOutput :
82+ av_frame = await self .pipeline .get_processed_audio_frame ()
83+ return AudioOutput (av_frame , request_id )
8384
8485 async def update_params (self , ** params ):
8586 new_params = ComfyUIParams (** params )
@@ -91,3 +92,22 @@ async def stop(self):
9192 logging .info ("Stopping ComfyUI pipeline" )
9293 await self .pipeline .cleanup ()
9394 logging .info ("ComfyUI pipeline stopped" )
95+
96+ def _convert_to_av_frame (self , frame : Union [VideoFrame , AudioFrame ]) -> Union [av .VideoFrame , av .AudioFrame ]:
97+ """Convert trickle frame to av frame"""
98+ if isinstance (frame , VideoFrame ):
99+ av_frame = av .VideoFrame .from_ndarray (
100+ np .array (frame .image .convert ("RGB" )),
101+ format = 'rgb24'
102+ )
103+ elif isinstance (frame , AudioFrame ):
104+ av_frame = av .AudioFrame .from_ndarray (
105+ frame .samples .reshape (- 1 , 1 ),
106+ layout = 'mono' ,
107+ rate = frame .rate
108+ )
109+
110+ # Common frame properties
111+ av_frame .pts = frame .timestamp
112+ av_frame .time_base = frame .time_base
113+ return av_frame
0 commit comments