Skip to content

Commit 29d038a

Browse files
committed
Merge remote-tracking branch 'remotes/bluelovers/patch-3' into dev-local-202406
* remotes/bluelovers/patch-3: chore(js): avoid lots of `Wake Lock is not supported.` Fix DAT models download (AUTOMATIC1111#16302)
2 parents c87bd49 + ee0ad5c commit 29d038a

File tree

5 files changed

+99
-29
lines changed

5 files changed

+99
-29
lines changed

javascript/progressbar.js

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,12 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
7979
var wakeLock = null;
8080

8181
var requestWakeLock = async function() {
82-
if (!opts.prevent_screen_sleep_during_generation || wakeLock) return;
82+
if (!opts.prevent_screen_sleep_during_generation || wakeLock !== null) return;
8383
try {
8484
wakeLock = await navigator.wakeLock.request('screen');
8585
} catch (err) {
8686
console.error('Wake Lock is not supported.');
87+
wakeLock = false;
8788
}
8889
};
8990

modules/dat_model.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,18 @@ def load_model(self, path):
4949
scaler.local_data_path = modelloader.load_file_from_url(
5050
scaler.data_path,
5151
model_dir=self.model_download_path,
52+
hash_prefix=scaler.sha256,
5253
)
54+
55+
if os.path.getsize(scaler.local_data_path) < 200:
56+
# Re-download if the file is too small, probably an LFS pointer
57+
scaler.local_data_path = modelloader.load_file_from_url(
58+
scaler.data_path,
59+
model_dir=self.model_download_path,
60+
hash_prefix=scaler.sha256,
61+
re_download=True,
62+
)
63+
5364
if not os.path.exists(scaler.local_data_path):
5465
raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}")
5566
return scaler
@@ -60,20 +71,23 @@ def get_dat_models(scaler):
6071
return [
6172
UpscalerData(
6273
name="DAT x2",
63-
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth",
74+
path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x2.pth",
6475
scale=2,
6576
upscaler=scaler,
77+
sha256='7760aa96e4ee77e29d4f89c3a4486200042e019461fdb8aa286f49aa00b89b51',
6678
),
6779
UpscalerData(
6880
name="DAT x3",
69-
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth",
81+
path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x3.pth",
7082
scale=3,
7183
upscaler=scaler,
84+
sha256='581973e02c06f90d4eb90acf743ec9604f56f3c2c6f9e1e2c2b38ded1f80d197',
7285
),
7386
UpscalerData(
7487
name="DAT x4",
75-
path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth",
88+
path="https://huggingface.co/w-e-w/DAT/resolve/main/experiments/pretrained_models/DAT/DAT_x4.pth",
7689
scale=4,
7790
upscaler=scaler,
91+
sha256='391a6ce69899dff5ea3214557e9d585608254579217169faf3d4c353caff049e',
7892
),
7993
]

modules/modelloader.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,37 +10,14 @@
1010

1111
from modules import shared
1212
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
13+
from modules.util import load_file_from_url # noqa, backwards compatibility
1314

1415
if TYPE_CHECKING:
1516
import spandrel
1617

1718
logger = logging.getLogger(__name__)
1819

1920

20-
def load_file_from_url(
21-
url: str,
22-
*,
23-
model_dir: str,
24-
progress: bool = True,
25-
file_name: str | None = None,
26-
hash_prefix: str | None = None,
27-
) -> str:
28-
"""Download a file from `url` into `model_dir`, using the file present if possible.
29-
30-
Returns the path to the downloaded file.
31-
"""
32-
os.makedirs(model_dir, exist_ok=True)
33-
if not file_name:
34-
parts = urlparse(url)
35-
file_name = os.path.basename(parts.path)
36-
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
37-
if not os.path.exists(cached_file):
38-
print(f'Downloading: "{url}" to {cached_file}\n')
39-
from torch.hub import download_url_to_file
40-
download_url_to_file(url, cached_file, progress=progress, hash_prefix=hash_prefix)
41-
return cached_file
42-
43-
4421
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None, hash_prefix=None) -> list:
4522
"""
4623
A one-and done loader to try finding the desired models in specified directories.

modules/upscaler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,14 @@ class UpscalerData:
9393
scaler: Upscaler = None
9494
model: None
9595

96-
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
96+
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None, sha256: str = None):
9797
self.name = name
9898
self.data_path = path
9999
self.local_data_path = path
100100
self.scaler = upscaler
101101
self.scale = scale
102102
self.model = model
103+
self.sha256 = sha256
103104

104105
def __repr__(self):
105106
return f"<UpscalerData name={self.name} path={self.data_path} scale={self.scale}>"

modules/util.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,80 @@ def open_folder(path):
211211
subprocess.Popen(["explorer.exe", subprocess.check_output(["wslpath", "-w", path])])
212212
else:
213213
subprocess.Popen(["xdg-open", path])
214+
215+
216+
def load_file_from_url(
217+
url: str,
218+
*,
219+
model_dir: str,
220+
progress: bool = True,
221+
file_name: str | None = None,
222+
hash_prefix: str | None = None,
223+
re_download: bool = False,
224+
) -> str:
225+
"""Download a file from `url` into `model_dir`, using the file present if possible.
226+
Returns the path to the downloaded file.
227+
228+
file_name: if specified, it will be used as the filename, otherwise the filename will be extracted from the url.
229+
file is downloaded to {file_name}.tmp then moved to the final location after download is complete.
230+
hash_prefix: sha256 hex string, if provided, the hash of the downloaded file will be checked against this prefix.
231+
if the hash does not match, the temporary file is deleted and a ValueError is raised.
232+
re_download: forcibly re-download the file even if it already exists.
233+
"""
234+
from urllib.parse import urlparse
235+
import requests
236+
try:
237+
from tqdm import tqdm
238+
except ImportError:
239+
class tqdm:
240+
def __init__(self, *args, **kwargs):
241+
pass
242+
243+
def update(self, n=1, *args, **kwargs):
244+
pass
245+
246+
def __enter__(self):
247+
return self
248+
249+
def __exit__(self, exc_type, exc_val, exc_tb):
250+
pass
251+
252+
if not file_name:
253+
parts = urlparse(url)
254+
file_name = os.path.basename(parts.path)
255+
256+
cached_file = os.path.abspath(os.path.join(model_dir, file_name))
257+
258+
if re_download or not os.path.exists(cached_file):
259+
os.makedirs(model_dir, exist_ok=True)
260+
temp_file = os.path.join(model_dir, f"{file_name}.tmp")
261+
print(f'\nDownloading: "{url}" to {cached_file}')
262+
response = requests.get(url, stream=True)
263+
response.raise_for_status()
264+
total_size = int(response.headers.get('content-length', 0))
265+
with tqdm(total=total_size, unit='B', unit_scale=True, desc=file_name, disable=not progress) as progress_bar:
266+
with open(temp_file, 'wb') as file:
267+
for chunk in response.iter_content(chunk_size=1024):
268+
if chunk:
269+
file.write(chunk)
270+
progress_bar.update(len(chunk))
271+
272+
if hash_prefix and not compare_sha256(temp_file, hash_prefix):
273+
print(f"Hash mismatch for {temp_file}. Deleting the temporary file.")
274+
os.remove(temp_file)
275+
raise ValueError(f"File hash does not match the expected hash prefix {hash_prefix}!")
276+
277+
os.rename(temp_file, cached_file)
278+
return cached_file
279+
280+
281+
def compare_sha256(file_path: str, hash_prefix: str) -> bool:
282+
"""Check if the SHA256 hash of the file matches the given prefix."""
283+
import hashlib
284+
hash_sha256 = hashlib.sha256()
285+
blksize = 1024 * 1024
286+
287+
with open(file_path, "rb") as f:
288+
for chunk in iter(lambda: f.read(blksize), b""):
289+
hash_sha256.update(chunk)
290+
return hash_sha256.hexdigest().startswith(hash_prefix.strip().lower())

0 commit comments

Comments
 (0)