Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,5 @@ annotator/downloads/

# test results and expectations
web_tests/results/
web_tests/expectations/
web_tests/expectations/
*_diff.png
69 changes: 52 additions & 17 deletions annotator/openpose/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,18 @@
from .body import Body, BodyResult, Keypoint
from .hand import Hand
from .face import Face
from .wholebody import Wholebody # DW Pose
from .types import PoseResult, HandResult, FaceResult
from modules import devices
from annotator.annotator_path import models_path

from typing import NamedTuple, Tuple, List, Callable, Union, Optional
from typing import Tuple, List, Callable, Union, Optional

body_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/body_pose_model.pth"
hand_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/hand_pose_model.pth"
face_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/facenet.pth"
remote_dw_model_path = "https://huggingface.co/camenduru/DWPose/resolve/main/dw-ll_ucoco_384.pth"

HandResult = List[Keypoint]
FaceResult = List[Keypoint]

class PoseResult(NamedTuple):
body: BodyResult
left_hand: Union[HandResult, None]
right_hand: Union[HandResult, None]
face: Union[FaceResult, None]

def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True):
"""
Expand Down Expand Up @@ -179,6 +174,8 @@ def __init__(self):
self.hand_estimation = None
self.face_estimation = None

self.dw_pose_estimation = None

def load_model(self):
"""
Load the Openpose body, hand, and face models.
Expand All @@ -198,10 +195,17 @@ def load_model(self):
if not os.path.exists(face_modelpath):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(face_model_path, model_dir=self.model_dir)

self.body_estimation = Body(body_modelpath)
self.hand_estimation = Hand(hand_modelpath)
self.face_estimation = Face(face_modelpath)

def load_dw_model(self):
dw_modelpath = os.path.join(self.model_dir, "dw-ll_ucoco_384.pth")
if not os.path.exists(dw_modelpath):
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(remote_dw_model_path, model_dir=self.model_dir)
self.dw_pose_estimation = Wholebody(dw_modelpath, device=self.device)

def unload_model(self):
"""
Expand All @@ -211,6 +215,11 @@ def unload_model(self):
self.body_estimation.model.to("cpu")
self.hand_estimation.model.to("cpu")
self.face_estimation.model.to("cpu")

def unload_dw_model(self):
if self.dw_pose_estimation is not None:
self.dw_pose_estimation.detector.to("cpu")
self.dw_pose_estimation.pose_estimator.to("cpu")

def detect_hands(self, body: BodyResult, oriImg) -> Tuple[Union[HandResult, None], Union[HandResult, None]]:
left_hand = None
Expand Down Expand Up @@ -269,7 +278,7 @@ def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[P

self.body_estimation.model.to(self.device)
self.hand_estimation.model.to(self.device)
self.face_estimation.model.to(self.device)
self.face_estimation.model.to(self.device)

self.body_estimation.cn_device = self.device
self.hand_estimation.cn_device = self.device
Expand Down Expand Up @@ -302,10 +311,31 @@ def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[P
), left_hand, right_hand, face))

return results


def detect_poses_dw(self, oriImg) -> List[PoseResult]:
"""
Detect poses in the given image using DW Pose:
https://github.com/IDEA-Research/DWPose

Args:
oriImg (numpy.ndarray): The input image for pose detection.

Returns:
List[PoseResult]: A list of PoseResult objects containing the detected poses.
"""
if self.dw_pose_estimation is None:
self.load_dw_model()

self.dw_pose_estimation.detector.to(self.device)
self.dw_pose_estimation.pose_estimator.to(self.device)

with torch.no_grad():
keypoints_info = self.dw_pose_estimation(oriImg.copy())
return Wholebody.format_result(keypoints_info)

def __call__(
self, oriImg, include_body=True, include_hand=False, include_face=False,
json_pose_callback: Callable[[str], None] = None,
self, oriImg, include_body=True, include_hand=False, include_face=False,
use_dw_pose=False, json_pose_callback: Callable[[str], None] = None,
):
"""
Detect and draw poses in the given image.
Expand All @@ -315,14 +345,19 @@ def __call__(
include_body (bool, optional): Whether to include body keypoints. Defaults to True.
include_hand (bool, optional): Whether to include hand keypoints. Defaults to False.
include_face (bool, optional): Whether to include face keypoints. Defaults to False.
use_dw_pose (bool, optional): Whether to use DW pose detection algorithm. Defaults to False.
json_pose_callback (Callable, optional): A callback that accepts the pose JSON string.

Returns:
numpy.ndarray: The image with detected and drawn poses.
"""
H, W, _ = oriImg.shape
poses = self.detect_poses(oriImg, include_hand, include_face)

if use_dw_pose:
poses = self.detect_poses_dw(oriImg)
else:
poses = self.detect_poses(oriImg, include_hand, include_face)

if json_pose_callback:
json_pose_callback(encode_poses_as_json(poses, H, W))
return draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face)

return draw_poses(poses, H, W, draw_body=include_body, draw_hand=include_hand, draw_face=include_face)
19 changes: 1 addition & 18 deletions annotator/openpose/body.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,7 @@

from . import util
from .model import bodypose_model

class Keypoint(NamedTuple):
x: float
y: float
score: float = 1.0
id: int = -1


class BodyResult(NamedTuple):
# Note: Using `Union` instead of `|` operator as the ladder is a Python
# 3.10 feature.
# Annotator code should be Python 3.8 Compatible, as controlnet repo uses
# Python 3.8 environment.
# https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6
keypoints: List[Union[Keypoint, None]]
total_score: float = 0.0
total_parts: int = 0

from .types import Keypoint, BodyResult

class Body(object):
def __init__(self, model_path):
Expand Down
Loading