1717from .body import Body , BodyResult , Keypoint
1818from .hand import Hand
1919from .face import Face
20- from .wholebody import Wholebody # DW Pose
21- from .types import PoseResult , HandResult , FaceResult
2220from modules import devices
2321from annotator .annotator_path import models_path
2422
25- from typing import Tuple , List , Callable , Union , Optional
23+ from typing import NamedTuple , Tuple , List , Callable , Union , Optional
2624
2725body_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/body_pose_model.pth"
2826hand_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/hand_pose_model.pth"
2927face_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/facenet.pth"
30- remote_dw_model_path = "https://huggingface.co/camenduru/DWPose/resolve/main/dw-ll_ucoco_384.pth"
3128
29+ HandResult = List [Keypoint ]
30+ FaceResult = List [Keypoint ]
31+
32+ class PoseResult (NamedTuple ):
33+ body : BodyResult
34+ left_hand : Union [HandResult , None ]
35+ right_hand : Union [HandResult , None ]
36+ face : Union [FaceResult , None ]
3237
3338def draw_poses (poses : List [PoseResult ], H , W , draw_body = True , draw_hand = True , draw_face = True ):
3439 """
@@ -174,8 +179,6 @@ def __init__(self):
174179 self .hand_estimation = None
175180 self .face_estimation = None
176181
177- self .dw_pose_estimation = None
178-
179182 def load_model (self ):
180183 """
181184 Load the Openpose body, hand, and face models.
@@ -195,17 +198,10 @@ def load_model(self):
195198 if not os .path .exists (face_modelpath ):
196199 from basicsr .utils .download_util import load_file_from_url
197200 load_file_from_url (face_model_path , model_dir = self .model_dir )
198-
201+
199202 self .body_estimation = Body (body_modelpath )
200203 self .hand_estimation = Hand (hand_modelpath )
201204 self .face_estimation = Face (face_modelpath )
202-
203- def load_dw_model (self ):
204- dw_modelpath = os .path .join (self .model_dir , "dw-ll_ucoco_384.pth" )
205- if not os .path .exists (dw_modelpath ):
206- from basicsr .utils .download_util import load_file_from_url
207- load_file_from_url (remote_dw_model_path , model_dir = self .model_dir )
208- self .dw_pose_estimation = Wholebody (dw_modelpath , device = self .device )
209205
210206 def unload_model (self ):
211207 """
@@ -215,11 +211,6 @@ def unload_model(self):
215211 self .body_estimation .model .to ("cpu" )
216212 self .hand_estimation .model .to ("cpu" )
217213 self .face_estimation .model .to ("cpu" )
218-
219- def unload_dw_model (self ):
220- if self .dw_pose_estimation is not None :
221- self .dw_pose_estimation .detector .to ("cpu" )
222- self .dw_pose_estimation .pose_estimator .to ("cpu" )
223214
224215 def detect_hands (self , body : BodyResult , oriImg ) -> Tuple [Union [HandResult , None ], Union [HandResult , None ]]:
225216 left_hand = None
@@ -278,7 +269,7 @@ def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[P
278269
279270 self .body_estimation .model .to (self .device )
280271 self .hand_estimation .model .to (self .device )
281- self .face_estimation .model .to (self .device )
272+ self .face_estimation .model .to (self .device )
282273
283274 self .body_estimation .cn_device = self .device
284275 self .hand_estimation .cn_device = self .device
@@ -311,31 +302,10 @@ def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[P
311302 ), left_hand , right_hand , face ))
312303
313304 return results
314-
315- def detect_poses_dw (self , oriImg ) -> List [PoseResult ]:
316- """
317- Detect poses in the given image using DW Pose:
318- https://github.com/IDEA-Research/DWPose
319-
320- Args:
321- oriImg (numpy.ndarray): The input image for pose detection.
322-
323- Returns:
324- List[PoseResult]: A list of PoseResult objects containing the detected poses.
325- """
326- if self .dw_pose_estimation is None :
327- self .load_dw_model ()
328-
329- self .dw_pose_estimation .detector .to (self .device )
330- self .dw_pose_estimation .pose_estimator .to (self .device )
331-
332- with torch .no_grad ():
333- keypoints_info = self .dw_pose_estimation (oriImg .copy ())
334- return Wholebody .format_result (keypoints_info )
335-
305+
336306 def __call__ (
337- self , oriImg , include_body = True , include_hand = False , include_face = False ,
338- use_dw_pose = False , json_pose_callback : Callable [[str ], None ] = None ,
307+ self , oriImg , include_body = True , include_hand = False , include_face = False ,
308+ json_pose_callback : Callable [[str ], None ] = None ,
339309 ):
340310 """
341311 Detect and draw poses in the given image.
@@ -345,19 +315,14 @@ def __call__(
345315 include_body (bool, optional): Whether to include body keypoints. Defaults to True.
346316 include_hand (bool, optional): Whether to include hand keypoints. Defaults to False.
347317 include_face (bool, optional): Whether to include face keypoints. Defaults to False.
348- use_dw_pose (bool, optional): Whether to use DW pose detection algorithm. Defaults to False.
349318 json_pose_callback (Callable, optional): A callback that accepts the pose JSON string.
350319
351320 Returns:
352321 numpy.ndarray: The image with detected and drawn poses.
353322 """
354323 H , W , _ = oriImg .shape
355-
356- if use_dw_pose :
357- poses = self .detect_poses_dw (oriImg )
358- else :
359- poses = self .detect_poses (oriImg , include_hand , include_face )
360-
324+ poses = self .detect_poses (oriImg , include_hand , include_face )
361325 if json_pose_callback :
362326 json_pose_callback (encode_poses_as_json (poses , H , W ))
363- return draw_poses (poses , H , W , draw_body = include_body , draw_hand = include_hand , draw_face = include_face )
327+ return draw_poses (poses , H , W , draw_body = include_body , draw_hand = include_hand , draw_face = include_face )
328+
0 commit comments