|
1 |
| -import os |
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import logging |
2 | 4 |
|
3 |
| -import cv2 |
4 | 5 | import torch
|
5 | 6 |
|
6 |
| -import modules.face_restoration |
7 |
| -import modules.shared |
8 |
| -from modules import shared, devices, modelloader, errors |
9 |
| -from modules.paths import models_path |
| 7 | +from modules import ( |
| 8 | + devices, |
| 9 | + errors, |
| 10 | + face_restoration, |
| 11 | + face_restoration_utils, |
| 12 | + modelloader, |
| 13 | + shared, |
| 14 | +) |
| 15 | + |
| 16 | +logger = logging.getLogger(__name__) |
10 | 17 |
|
11 |
| -model_dir = "Codeformer" |
12 |
| -model_path = os.path.join(models_path, model_dir) |
13 | 18 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
| 19 | +model_download_name = 'codeformer-v0.1.0.pth' |
14 | 20 |
|
15 |
| -codeformer = None |
| 21 | +# used by e.g. postprocessing_codeformer.py |
| 22 | +codeformer: face_restoration.FaceRestoration | None = None |
16 | 23 |
|
17 | 24 |
|
18 |
| -class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration): |
| 25 | +class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration): |
19 | 26 | def name(self):
|
20 | 27 | return "CodeFormer"
|
21 | 28 |
|
22 |
| - def __init__(self, dirname): |
23 |
| - self.net = None |
24 |
| - self.face_helper = None |
25 |
| - self.cmd_dir = dirname |
26 |
| - |
27 |
| - def create_models(self): |
28 |
| - from facexlib.detection import retinaface |
29 |
| - from facexlib.utils.face_restoration_helper import FaceRestoreHelper |
30 |
| - |
31 |
| - if self.net is not None and self.face_helper is not None: |
32 |
| - self.net.to(devices.device_codeformer) |
33 |
| - return self.net, self.face_helper |
34 |
| - model_paths = modelloader.load_models( |
35 |
| - model_path, |
36 |
| - model_url, |
37 |
| - self.cmd_dir, |
38 |
| - download_name='codeformer-v0.1.0.pth', |
| 29 | + def load_net(self) -> torch.Module: |
| 30 | + for model_path in modelloader.load_models( |
| 31 | + model_path=self.model_path, |
| 32 | + model_url=model_url, |
| 33 | + command_path=self.model_path, |
| 34 | + download_name=model_download_name, |
39 | 35 | ext_filter=['.pth'],
|
40 |
| - ) |
41 |
| - |
42 |
| - if len(model_paths) != 0: |
43 |
| - ckpt_path = model_paths[0] |
44 |
| - else: |
45 |
| - print("Unable to load codeformer model.") |
46 |
| - return None, None |
47 |
| - net = modelloader.load_spandrel_model(ckpt_path, device=devices.device_codeformer) |
48 |
| - |
49 |
| - if hasattr(retinaface, 'device'): |
50 |
| - retinaface.device = devices.device_codeformer |
51 |
| - |
52 |
| - face_helper = FaceRestoreHelper( |
53 |
| - upscale_factor=1, |
54 |
| - face_size=512, |
55 |
| - crop_ratio=(1, 1), |
56 |
| - det_model='retinaface_resnet50', |
57 |
| - save_ext='png', |
58 |
| - use_parse=True, |
59 |
| - device=devices.device_codeformer, |
60 |
| - ) |
61 |
| - |
62 |
| - self.net = net |
63 |
| - self.face_helper = face_helper |
64 |
| - |
65 |
| - def send_model_to(self, device): |
66 |
| - self.net.to(device) |
67 |
| - self.face_helper.face_det.to(device) |
68 |
| - self.face_helper.face_parse.to(device) |
69 |
| - |
70 |
| - def restore(self, np_image, w=None): |
71 |
| - from torchvision.transforms.functional import normalize |
72 |
| - from basicsr.utils import img2tensor, tensor2img |
73 |
| - np_image = np_image[:, :, ::-1] |
74 |
| - |
75 |
| - original_resolution = np_image.shape[0:2] |
76 |
| - |
77 |
| - self.create_models() |
78 |
| - if self.net is None or self.face_helper is None: |
79 |
| - return np_image |
80 |
| - |
81 |
| - self.send_model_to(devices.device_codeformer) |
82 |
| - |
83 |
| - self.face_helper.clean_all() |
84 |
| - self.face_helper.read_image(np_image) |
85 |
| - self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) |
86 |
| - self.face_helper.align_warp_face() |
87 |
| - |
88 |
| - for cropped_face in self.face_helper.cropped_faces: |
89 |
| - cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) |
90 |
| - normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) |
91 |
| - cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer) |
92 |
| - |
93 |
| - try: |
94 |
| - with torch.no_grad(): |
95 |
| - res = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True) |
96 |
| - if isinstance(res, tuple): |
97 |
| - output = res[0] |
98 |
| - else: |
99 |
| - output = res |
100 |
| - if not isinstance(res, torch.Tensor): |
101 |
| - raise TypeError(f"Expected torch.Tensor, got {type(res)}") |
102 |
| - restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) |
103 |
| - del output |
104 |
| - devices.torch_gc() |
105 |
| - except Exception: |
106 |
| - errors.report('Failed inference for CodeFormer', exc_info=True) |
107 |
| - restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) |
108 |
| - |
109 |
| - restored_face = restored_face.astype('uint8') |
110 |
| - self.face_helper.add_restored_face(restored_face) |
111 |
| - |
112 |
| - self.face_helper.get_inverse_affine(None) |
113 |
| - |
114 |
| - restored_img = self.face_helper.paste_faces_to_input_image() |
115 |
| - restored_img = restored_img[:, :, ::-1] |
| 36 | + ): |
| 37 | + return modelloader.load_spandrel_model( |
| 38 | + model_path, |
| 39 | + device=devices.device_codeformer, |
| 40 | + ).model |
| 41 | + raise ValueError("No codeformer model found") |
116 | 42 |
|
117 |
| - if original_resolution != restored_img.shape[0:2]: |
118 |
| - restored_img = cv2.resize( |
119 |
| - restored_img, |
120 |
| - (0, 0), |
121 |
| - fx=original_resolution[1]/restored_img.shape[1], |
122 |
| - fy=original_resolution[0]/restored_img.shape[0], |
123 |
| - interpolation=cv2.INTER_LINEAR, |
124 |
| - ) |
| 43 | + def get_device(self): |
| 44 | + return devices.device_codeformer |
125 | 45 |
|
126 |
| - self.face_helper.clean_all() |
| 46 | + def restore(self, np_image, w: float | None = None): |
| 47 | + if w is None: |
| 48 | + w = getattr(shared.opts, "code_former_weight", 0.5) |
127 | 49 |
|
128 |
| - if shared.opts.face_restoration_unload: |
129 |
| - self.send_model_to(devices.cpu) |
| 50 | + def restore_face(cropped_face_t): |
| 51 | + assert self.net is not None |
| 52 | + return self.net(cropped_face_t, w=w, adain=True)[0] |
130 | 53 |
|
131 |
| - return restored_img |
| 54 | + return self.restore_with_helper(np_image, restore_face) |
132 | 55 |
|
133 | 56 |
|
134 |
| -def setup_model(dirname): |
135 |
| - os.makedirs(model_path, exist_ok=True) |
| 57 | +def setup_model(dirname: str) -> None: |
| 58 | + global codeformer |
136 | 59 | try:
|
137 |
| - global codeformer |
138 | 60 | codeformer = FaceRestorerCodeFormer(dirname)
|
139 | 61 | shared.face_restorers.append(codeformer)
|
140 | 62 | except Exception:
|
|
0 commit comments