2
2
import json
3
3
import os
4
4
import shutil
5
+ import subprocess
6
+ import socket
7
+ import sys
8
+ import time
9
+ import webbrowser
5
10
from datetime import datetime
6
11
from multiprocessing import cpu_count
7
12
10
15
11
16
from common .constants import LATEST_VERSION
12
17
from common .log import logger
18
+ from common .stdout_wrapper import SAFE_STDOUT
13
19
from common .subprocess_utils import run_script_with_log , second_elem_of
14
20
15
21
logger_handler = None
22
+ tensorboard_executed = False
16
23
17
24
# Get path settings
18
25
with open (os .path .join ("configs" , "paths.yml" ), "r" , encoding = "utf-8" ) as f :
@@ -316,6 +323,46 @@ def train(model_name, skip_style=False, use_jp_extra=True, speedup=False):
316
323
return True , "Success: 学習が完了しました"
317
324
318
325
326
+ def wait_for_tensorboard (port = 6006 , timeout = 10 ):
327
+ start_time = time .time ()
328
+ while True :
329
+ try :
330
+ with socket .create_connection (("localhost" , port ), timeout = 1 ):
331
+ return True # ポートが開いている場合
332
+ except OSError :
333
+ pass # ポートがまだ開いていない場合
334
+
335
+ if time .time () - start_time > timeout :
336
+ return False # タイムアウト
337
+
338
+ time .sleep (0.1 )
339
+
340
+
341
+ def run_tensorboard (model_name ):
342
+ global tensorboard_executed
343
+ if not tensorboard_executed :
344
+ python = sys .executable
345
+ tensorboard_cmd = [
346
+ python ,
347
+ "-m" ,
348
+ "tensorboard.main" ,
349
+ "--logdir" ,
350
+ f"Data/{ model_name } /models" ,
351
+ ]
352
+ subprocess .Popen (
353
+ tensorboard_cmd ,
354
+ stdout = SAFE_STDOUT , # type: ignore
355
+ stderr = SAFE_STDOUT , # type: ignore
356
+ )
357
+ yield gr .Button ("起動中…" )
358
+ if wait_for_tensorboard ():
359
+ tensorboard_executed = True
360
+ else :
361
+ logger .error ("Tensorboard did not start in the expected time." )
362
+ webbrowser .open ("http://localhost:6006" )
363
+ yield gr .Button ("Tensorboardを開く" )
364
+
365
+
319
366
initial_md = f"""
320
367
# Style-Bert-VITS2 ver { LATEST_VERSION } 学習用WebUI
321
368
@@ -369,7 +416,7 @@ def train(model_name, skip_style=False, use_jp_extra=True, speedup=False):
369
416
"""
370
417
371
418
if __name__ == "__main__" :
372
- with gr .Blocks (theme = "NoCrypt/miku" ) as app :
419
+ with gr .Blocks (theme = "NoCrypt/miku" ). queue () as app :
373
420
gr .Markdown (initial_md )
374
421
with gr .Accordion (label = "データの前準備" , open = False ):
375
422
gr .Markdown (prepare_md )
@@ -548,7 +595,7 @@ def train(model_name, skip_style=False, use_jp_extra=True, speedup=False):
548
595
style_gen_btn = gr .Button (value = "実行" , variant = "primary" )
549
596
info_style = gr .Textbox (label = "状況" )
550
597
gr .Markdown ("## 学習" )
551
- with gr .Row (variant = "panel" ):
598
+ with gr .Row ():
552
599
skip_style = gr .Checkbox (
553
600
label = "スタイルファイルの生成をスキップする" ,
554
601
info = "学習再開の場合の場合はチェックしてください" ,
@@ -564,7 +611,8 @@ def train(model_name, skip_style=False, use_jp_extra=True, speedup=False):
564
611
visible = False , # Experimental
565
612
)
566
613
train_btn = gr .Button (value = "学習を開始する" , variant = "primary" )
567
- info_train = gr .Textbox (label = "状況" )
614
+ tensorboard_btn = gr .Button (value = "Tensorboardを開く" )
615
+ info_train = gr .Textbox (label = "状況" )
568
616
569
617
preprocess_button .click (
570
618
second_elem_of (preprocess_all ),
@@ -635,6 +683,10 @@ def train(model_name, skip_style=False, use_jp_extra=True, speedup=False):
635
683
inputs = [model_name , skip_style , use_jp_extra_train , speedup ],
636
684
outputs = [info_train ],
637
685
)
686
+ tensorboard_btn .click (
687
+ run_tensorboard , inputs = [model_name ], outputs = [tensorboard_btn ]
688
+ )
689
+
638
690
use_jp_extra .change (
639
691
lambda x : gr .Checkbox (value = x ),
640
692
inputs = [use_jp_extra ],
0 commit comments