Skip to content

Commit cd12c0e

Browse files
Merge pull request #14425 from akx/spandrel
Use Spandrel for upscaling and face restoration architectures
2 parents 05230c0 + 4ad0c0c commit cd12c0e

29 files changed

+609
-3927
lines changed

.github/workflows/run_tests.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ jobs:
2020
cache-dependency-path: |
2121
**/requirements*txt
2222
launch.py
23+
- name: Cache models
24+
id: cache-models
25+
uses: actions/cache@v3
26+
with:
27+
path: models
28+
key: "2023-12-30"
2329
- name: Install test dependencies
2430
run: pip install wait-for-it -r requirements-test.txt
2531
env:
@@ -33,6 +39,8 @@ jobs:
3339
TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu
3440
WEBUI_LAUNCH_LIVE_OUTPUT: "1"
3541
PYTHONUNBUFFERED: "1"
42+
- name: Print installed packages
43+
run: pip freeze
3644
- name: Start test server
3745
run: >
3846
python -m coverage run

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ notification.mp3
3737
/node_modules
3838
/package-lock.json
3939
/.coverage*
40+
/test/test_outputs

extensions-builtin/ScuNET/scripts/scunet_model.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77

88
import modules.upscaler
99
from modules import devices, modelloader, script_callbacks, errors
10-
from scunet_model_arch import SCUNet
1110

12-
from modules.modelloader import load_file_from_url
1311
from modules.shared import opts
1412

1513

@@ -120,17 +118,10 @@ def load_model(self, path: str):
120118
device = devices.get_device_for('scunet')
121119
if path.startswith("http"):
122120
# TODO: this doesn't use `path` at all?
123-
filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
121+
filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
124122
else:
125123
filename = path
126-
model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
127-
model.load_state_dict(torch.load(filename), strict=True)
128-
model.eval()
129-
for _, v in model.named_parameters():
130-
v.requires_grad = False
131-
model = model.to(device)
132-
133-
return model
124+
return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet')
134125

135126

136127
def on_ui_settings():

extensions-builtin/ScuNET/scunet_model_arch.py

Lines changed: 0 additions & 268 deletions
This file was deleted.

0 commit comments

Comments
 (0)