Skip to content

Commit ecee8df

Browse files
committed
Unify CodeFormer and GFPGAN restoration backends, use Spandrel for GFPGAN
1 parent c7fbff2 commit ecee8df

File tree

11 files changed

+290
-237
lines changed

11 files changed

+290
-237
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ notification.mp3
3737
/node_modules
3838
/package-lock.json
3939
/.coverage*
40+
/test/test_outputs

modules/codeformer_model.py

Lines changed: 40 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,140 +1,62 @@
1-
import os
1+
from __future__ import annotations
2+
3+
import logging
24

3-
import cv2
45
import torch
56

6-
import modules.face_restoration
7-
import modules.shared
8-
from modules import shared, devices, modelloader, errors
9-
from modules.paths import models_path
7+
from modules import (
8+
devices,
9+
errors,
10+
face_restoration,
11+
face_restoration_utils,
12+
modelloader,
13+
shared,
14+
)
15+
16+
logger = logging.getLogger(__name__)
1017

11-
model_dir = "Codeformer"
12-
model_path = os.path.join(models_path, model_dir)
1318
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
19+
model_download_name = 'codeformer-v0.1.0.pth'
1420

15-
codeformer = None
21+
# used by e.g. postprocessing_codeformer.py
22+
codeformer: face_restoration.FaceRestoration | None = None
1623

1724

18-
class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
25+
class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration):
1926
def name(self):
2027
return "CodeFormer"
2128

22-
def __init__(self, dirname):
23-
self.net = None
24-
self.face_helper = None
25-
self.cmd_dir = dirname
26-
27-
def create_models(self):
28-
from facexlib.detection import retinaface
29-
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
30-
31-
if self.net is not None and self.face_helper is not None:
32-
self.net.to(devices.device_codeformer)
33-
return self.net, self.face_helper
34-
model_paths = modelloader.load_models(
35-
model_path,
36-
model_url,
37-
self.cmd_dir,
38-
download_name='codeformer-v0.1.0.pth',
29+
def load_net(self) -> torch.Module:
30+
for model_path in modelloader.load_models(
31+
model_path=self.model_path,
32+
model_url=model_url,
33+
command_path=self.model_path,
34+
download_name=model_download_name,
3935
ext_filter=['.pth'],
40-
)
41-
42-
if len(model_paths) != 0:
43-
ckpt_path = model_paths[0]
44-
else:
45-
print("Unable to load codeformer model.")
46-
return None, None
47-
net = modelloader.load_spandrel_model(ckpt_path, device=devices.device_codeformer)
48-
49-
if hasattr(retinaface, 'device'):
50-
retinaface.device = devices.device_codeformer
51-
52-
face_helper = FaceRestoreHelper(
53-
upscale_factor=1,
54-
face_size=512,
55-
crop_ratio=(1, 1),
56-
det_model='retinaface_resnet50',
57-
save_ext='png',
58-
use_parse=True,
59-
device=devices.device_codeformer,
60-
)
61-
62-
self.net = net
63-
self.face_helper = face_helper
64-
65-
def send_model_to(self, device):
66-
self.net.to(device)
67-
self.face_helper.face_det.to(device)
68-
self.face_helper.face_parse.to(device)
69-
70-
def restore(self, np_image, w=None):
71-
from torchvision.transforms.functional import normalize
72-
from basicsr.utils import img2tensor, tensor2img
73-
np_image = np_image[:, :, ::-1]
74-
75-
original_resolution = np_image.shape[0:2]
76-
77-
self.create_models()
78-
if self.net is None or self.face_helper is None:
79-
return np_image
80-
81-
self.send_model_to(devices.device_codeformer)
82-
83-
self.face_helper.clean_all()
84-
self.face_helper.read_image(np_image)
85-
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
86-
self.face_helper.align_warp_face()
87-
88-
for cropped_face in self.face_helper.cropped_faces:
89-
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
90-
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
91-
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
92-
93-
try:
94-
with torch.no_grad():
95-
res = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)
96-
if isinstance(res, tuple):
97-
output = res[0]
98-
else:
99-
output = res
100-
if not isinstance(res, torch.Tensor):
101-
raise TypeError(f"Expected torch.Tensor, got {type(res)}")
102-
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
103-
del output
104-
devices.torch_gc()
105-
except Exception:
106-
errors.report('Failed inference for CodeFormer', exc_info=True)
107-
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
108-
109-
restored_face = restored_face.astype('uint8')
110-
self.face_helper.add_restored_face(restored_face)
111-
112-
self.face_helper.get_inverse_affine(None)
113-
114-
restored_img = self.face_helper.paste_faces_to_input_image()
115-
restored_img = restored_img[:, :, ::-1]
36+
):
37+
return modelloader.load_spandrel_model(
38+
model_path,
39+
device=devices.device_codeformer,
40+
).model
41+
raise ValueError("No codeformer model found")
11642

117-
if original_resolution != restored_img.shape[0:2]:
118-
restored_img = cv2.resize(
119-
restored_img,
120-
(0, 0),
121-
fx=original_resolution[1]/restored_img.shape[1],
122-
fy=original_resolution[0]/restored_img.shape[0],
123-
interpolation=cv2.INTER_LINEAR,
124-
)
43+
def get_device(self):
44+
return devices.device_codeformer
12545

126-
self.face_helper.clean_all()
46+
def restore(self, np_image, w: float | None = None):
47+
if w is None:
48+
w = getattr(shared.opts, "code_former_weight", 0.5)
12749

128-
if shared.opts.face_restoration_unload:
129-
self.send_model_to(devices.cpu)
50+
def restore_face(cropped_face_t):
51+
assert self.net is not None
52+
return self.net(cropped_face_t, w=w, adain=True)[0]
13053

131-
return restored_img
54+
return self.restore_with_helper(np_image, restore_face)
13255

13356

134-
def setup_model(dirname):
135-
os.makedirs(model_path, exist_ok=True)
57+
def setup_model(dirname: str) -> None:
58+
global codeformer
13659
try:
137-
global codeformer
13860
codeformer = FaceRestorerCodeFormer(dirname)
13961
shared.face_restorers.append(codeformer)
14062
except Exception:

modules/face_restoration_utils.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
import os
5+
from functools import cached_property
6+
from typing import TYPE_CHECKING, Callable
7+
8+
import cv2
9+
import numpy as np
10+
import torch
11+
12+
from modules import devices, errors, face_restoration, shared
13+
14+
if TYPE_CHECKING:
15+
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
def create_face_helper(device) -> FaceRestoreHelper:
21+
from facexlib.detection import retinaface
22+
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
23+
if hasattr(retinaface, 'device'):
24+
retinaface.device = device
25+
return FaceRestoreHelper(
26+
upscale_factor=1,
27+
face_size=512,
28+
crop_ratio=(1, 1),
29+
det_model='retinaface_resnet50',
30+
save_ext='png',
31+
use_parse=True,
32+
device=device,
33+
)
34+
35+
36+
def restore_with_face_helper(
37+
np_image: np.ndarray,
38+
face_helper: FaceRestoreHelper,
39+
restore_face: Callable[[np.ndarray], np.ndarray],
40+
) -> np.ndarray:
41+
"""
42+
Find faces in the image using face_helper, restore them using restore_face, and paste them back into the image.
43+
44+
`restore_face` should take a cropped face image and return a restored face image.
45+
"""
46+
from basicsr.utils import img2tensor, tensor2img
47+
from torchvision.transforms.functional import normalize
48+
np_image = np_image[:, :, ::-1]
49+
original_resolution = np_image.shape[0:2]
50+
51+
try:
52+
logger.debug("Detecting faces...")
53+
face_helper.clean_all()
54+
face_helper.read_image(np_image)
55+
face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
56+
face_helper.align_warp_face()
57+
logger.debug("Found %d faces, restoring", len(face_helper.cropped_faces))
58+
for cropped_face in face_helper.cropped_faces:
59+
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
60+
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
61+
cropped_face_t = cropped_face_t.unsqueeze(0).to(devices.device_codeformer)
62+
63+
try:
64+
with torch.no_grad():
65+
restored_face = tensor2img(
66+
restore_face(cropped_face_t),
67+
rgb2bgr=True,
68+
min_max=(-1, 1),
69+
)
70+
devices.torch_gc()
71+
except Exception:
72+
errors.report('Failed face-restoration inference', exc_info=True)
73+
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
74+
75+
restored_face = restored_face.astype('uint8')
76+
face_helper.add_restored_face(restored_face)
77+
78+
logger.debug("Merging restored faces into image")
79+
face_helper.get_inverse_affine(None)
80+
img = face_helper.paste_faces_to_input_image()
81+
img = img[:, :, ::-1]
82+
if original_resolution != img.shape[0:2]:
83+
img = cv2.resize(
84+
img,
85+
(0, 0),
86+
fx=original_resolution[1] / img.shape[1],
87+
fy=original_resolution[0] / img.shape[0],
88+
interpolation=cv2.INTER_LINEAR,
89+
)
90+
logger.debug("Face restoration complete")
91+
finally:
92+
face_helper.clean_all()
93+
return img
94+
95+
96+
class CommonFaceRestoration(face_restoration.FaceRestoration):
97+
net: torch.Module | None
98+
model_url: str
99+
model_download_name: str
100+
101+
def __init__(self, model_path: str):
102+
super().__init__()
103+
self.net = None
104+
self.model_path = model_path
105+
os.makedirs(model_path, exist_ok=True)
106+
107+
@cached_property
108+
def face_helper(self) -> FaceRestoreHelper:
109+
return create_face_helper(self.get_device())
110+
111+
def send_model_to(self, device):
112+
if self.net:
113+
logger.debug("Sending %s to %s", self.net, device)
114+
self.net.to(device)
115+
if self.face_helper:
116+
logger.debug("Sending face helper to %s", device)
117+
self.face_helper.face_det.to(device)
118+
self.face_helper.face_parse.to(device)
119+
120+
def get_device(self):
121+
raise NotImplementedError("get_device must be implemented by subclasses")
122+
123+
def load_net(self) -> torch.Module:
124+
raise NotImplementedError("load_net must be implemented by subclasses")
125+
126+
def restore_with_helper(
127+
self,
128+
np_image: np.ndarray,
129+
restore_face: Callable[[np.ndarray], np.ndarray],
130+
) -> np.ndarray:
131+
try:
132+
if self.net is None:
133+
self.net = self.load_net()
134+
except Exception:
135+
logger.warning("Unable to load face-restoration model", exc_info=True)
136+
return np_image
137+
138+
try:
139+
self.send_model_to(self.get_device())
140+
return restore_with_face_helper(np_image, self.face_helper, restore_face)
141+
finally:
142+
if shared.opts.face_restoration_unload:
143+
self.send_model_to(devices.cpu)
144+
145+
146+
def patch_facexlib(dirname: str) -> None:
147+
import facexlib.detection
148+
import facexlib.parsing
149+
150+
det_facex_load_file_from_url = facexlib.detection.load_file_from_url
151+
par_facex_load_file_from_url = facexlib.parsing.load_file_from_url
152+
153+
def update_kwargs(kwargs):
154+
return dict(kwargs, save_dir=dirname, model_dir=None)
155+
156+
def facex_load_file_from_url(**kwargs):
157+
return det_facex_load_file_from_url(**update_kwargs(kwargs))
158+
159+
def facex_load_file_from_url2(**kwargs):
160+
return par_facex_load_file_from_url(**update_kwargs(kwargs))
161+
162+
facexlib.detection.load_file_from_url = facex_load_file_from_url
163+
facexlib.parsing.load_file_from_url = facex_load_file_from_url2

0 commit comments

Comments
 (0)