1717from  .body  import  Body , BodyResult , Keypoint 
1818from  .hand  import  Hand 
1919from  .face  import  Face 
20- from  .types  import  PoseResult , HandResult , FaceResult 
2120from  modules  import  devices 
2221from  annotator .annotator_path  import  models_path 
2322
24- from  typing  import  Tuple , List , Callable , Union , Optional 
23+ from  typing  import  NamedTuple ,  Tuple , List , Callable , Union , Optional 
2524
2625body_model_path  =  "https://huggingface.co/lllyasviel/Annotators/resolve/main/body_pose_model.pth" 
2726hand_model_path  =  "https://huggingface.co/lllyasviel/Annotators/resolve/main/hand_pose_model.pth" 
2827face_model_path  =  "https://huggingface.co/lllyasviel/Annotators/resolve/main/facenet.pth" 
29- remote_dw_model_path  =  "https://huggingface.co/camenduru/DWPose/resolve/main/dw-ll_ucoco_384.pth" 
3028
29+ HandResult  =  List [Keypoint ]
30+ FaceResult  =  List [Keypoint ]
31+ 
32+ class  PoseResult (NamedTuple ):
33+     body : BodyResult 
34+     left_hand : Union [HandResult , None ]
35+     right_hand : Union [HandResult , None ]
36+     face : Union [FaceResult , None ]
3137
3238def  draw_poses (poses : List [PoseResult ], H , W , draw_body = True , draw_hand = True , draw_face = True ):
3339    """ 
@@ -173,8 +179,6 @@ def __init__(self):
173179        self .hand_estimation  =  None 
174180        self .face_estimation  =  None 
175181
176-         self .dw_pose_estimation  =  None 
177- 
178182    def  load_model (self ):
179183        """ 
180184        Load the Openpose body, hand, and face models. 
@@ -194,19 +198,10 @@ def load_model(self):
194198        if  not  os .path .exists (face_modelpath ):
195199            from  basicsr .utils .download_util  import  load_file_from_url 
196200            load_file_from_url (face_model_path , model_dir = self .model_dir )
197-          
201+ 
198202        self .body_estimation  =  Body (body_modelpath )
199203        self .hand_estimation  =  Hand (hand_modelpath )
200204        self .face_estimation  =  Face (face_modelpath )
201-     
202-     def  load_dw_model (self ):
203-         from  .wholebody  import  Wholebody  # DW Pose 
204-         
205-         dw_modelpath  =  os .path .join (self .model_dir , "dw-ll_ucoco_384.pth" )
206-         if  not  os .path .exists (dw_modelpath ):
207-             from  basicsr .utils .download_util  import  load_file_from_url 
208-             load_file_from_url (remote_dw_model_path , model_dir = self .model_dir )
209-         self .dw_pose_estimation  =  Wholebody (dw_modelpath , device = self .device )
210205
211206    def  unload_model (self ):
212207        """ 
@@ -216,11 +211,6 @@ def unload_model(self):
216211            self .body_estimation .model .to ("cpu" )
217212            self .hand_estimation .model .to ("cpu" )
218213            self .face_estimation .model .to ("cpu" )
219-     
220-     def  unload_dw_model (self ):
221-         if  self .dw_pose_estimation  is  not None :
222-             self .dw_pose_estimation .detector .to ("cpu" )
223-             self .dw_pose_estimation .pose_estimator .to ("cpu" )
224214
225215    def  detect_hands (self , body : BodyResult , oriImg ) ->  Tuple [Union [HandResult , None ], Union [HandResult , None ]]:
226216        left_hand  =  None 
@@ -279,7 +269,7 @@ def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[P
279269
280270        self .body_estimation .model .to (self .device )
281271        self .hand_estimation .model .to (self .device )
282-         self .face_estimation .model .to (self .device )         
272+         self .face_estimation .model .to (self .device )
283273
284274        self .body_estimation .cn_device  =  self .device 
285275        self .hand_estimation .cn_device  =  self .device 
@@ -312,33 +302,10 @@ def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[P
312302                ), left_hand , right_hand , face ))
313303
314304            return  results 
315-     
316-     def  detect_poses_dw (self , oriImg ) ->  List [PoseResult ]:
317-         """ 
318-         Detect poses in the given image using DW Pose: 
319-         https://github.com/IDEA-Research/DWPose 
320- 
321-         Args: 
322-             oriImg (numpy.ndarray): The input image for pose detection. 
323- 
324-         Returns: 
325-             List[PoseResult]: A list of PoseResult objects containing the detected poses. 
326-         """ 
327-         from  .wholebody  import  Wholebody  # DW Pose 
328- 
329-         if  self .dw_pose_estimation  is  None :
330-             self .load_dw_model ()
331- 
332-         self .dw_pose_estimation .detector .to (self .device )
333-         self .dw_pose_estimation .pose_estimator .to (self .device )
334- 
335-         with  torch .no_grad ():
336-             keypoints_info  =  self .dw_pose_estimation (oriImg .copy ())
337-             return  Wholebody .format_result (keypoints_info )
338- 
305+         
339306    def  __call__ (
340-             self , oriImg , include_body = True , include_hand = False , include_face = False ,  
341-             use_dw_pose = False ,  json_pose_callback : Callable [[str ], None ] =  None ,
307+             self , oriImg , include_body = True , include_hand = False , include_face = False ,
308+             json_pose_callback : Callable [[str ], None ] =  None ,
342309        ):
343310        """ 
344311        Detect and draw poses in the given image. 
@@ -348,19 +315,14 @@ def __call__(
348315            include_body (bool, optional): Whether to include body keypoints. Defaults to True. 
349316            include_hand (bool, optional): Whether to include hand keypoints. Defaults to False. 
350317            include_face (bool, optional): Whether to include face keypoints. Defaults to False. 
351-             use_dw_pose (bool, optional): Whether to use DW pose detection algorithm. Defaults to False. 
352318            json_pose_callback (Callable, optional): A callback that accepts the pose JSON string. 
353319
354320        Returns: 
355321            numpy.ndarray: The image with detected and drawn poses. 
356322        """ 
357323        H , W , _  =  oriImg .shape 
358- 
359-         if  use_dw_pose :
360-             poses  =  self .detect_poses_dw (oriImg )
361-         else :
362-             poses  =  self .detect_poses (oriImg , include_hand , include_face )
363- 
324+         poses  =  self .detect_poses (oriImg , include_hand , include_face )
364325        if  json_pose_callback :
365326            json_pose_callback (encode_poses_as_json (poses , H , W ))
366-         return  draw_poses (poses , H , W , draw_body = include_body , draw_hand = include_hand , draw_face = include_face )
327+         return  draw_poses (poses , H , W , draw_body = include_body , draw_hand = include_hand , draw_face = include_face ) 
328+                      
0 commit comments