Skip to content

Commit 78514e0

Browse files
authored
Add files via upload
1 parent 82a973c commit 78514e0

File tree

2 files changed

+541
-0
lines changed

2 files changed

+541
-0
lines changed
Lines changed: 395 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,395 @@
1+
#simple_karras_exponential_scheduler.py
2+
import torch
3+
import logging
4+
from k_diffusion.sampling import get_sigmas_karras, get_sigmas_exponential
5+
import os
6+
import yaml
7+
import random
8+
from watchdog.observers import Observer
9+
from watchdog.events import FileSystemEventHandler
10+
from datetime import datetime
11+
12+
import os
13+
import logging
14+
from datetime import datetime
15+
16+
class CustomLogger:
17+
def __init__(self, log_name, print_to_console=False, debug_enabled=False):
18+
self.print_to_console = print_to_console #prints to console
19+
self.debug_enabled = debug_enabled #logs debug messages
20+
21+
# Create folders for generation info and error logs
22+
gen_log_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'simple_kes_generation')
23+
error_log_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'simple_kes_error')
24+
25+
os.makedirs(gen_log_dir, exist_ok=True)
26+
os.makedirs(error_log_dir, exist_ok=True)
27+
28+
# Get current time in HH-MM-SS format
29+
current_time = datetime.now().strftime('%H-%M-%S')
30+
31+
# Create file paths for the log files
32+
gen_log_file_path = os.path.join(gen_log_dir, f'{current_time}.log')
33+
error_log_file_path = os.path.join(error_log_dir, f'{current_time}.log')
34+
35+
# Set up generation logger
36+
#self.gen_logger = logging.getLogger(f'{log_name}_generation')
37+
self.gen_logger = logging.getLogger('simple_kes_generation')
38+
self.gen_logger.setLevel(logging.DEBUG)
39+
self._setup_file_handler(self.gen_logger, gen_log_file_path)
40+
41+
# Set up error logger
42+
self.error_logger = logging.getLogger(f'{log_name}_error')
43+
self.error_logger.setLevel(logging.ERROR)
44+
self._setup_file_handler(self.error_logger, error_log_file_path)
45+
46+
# Prevent log propagation to root logger (important to avoid accidental console logging)
47+
self.gen_logger.propagate = False
48+
self.error_logger.propagate = False
49+
50+
51+
# Optionally print to console
52+
if self.print_to_console:
53+
self._setup_console_handler(self.gen_logger)
54+
self._setup_console_handler(self.error_logger)
55+
56+
def _setup_file_handler(self, logger, file_path):
57+
"""Set up file handler for logging to a file."""
58+
file_handler = logging.FileHandler(file_path, mode='a')
59+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
60+
file_handler.setFormatter(formatter)
61+
logger.addHandler(file_handler)
62+
63+
def _setup_console_handler(self, logger):
64+
"""Optionally set up a console handler for logging to the console."""
65+
console_handler = logging.StreamHandler()
66+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
67+
console_handler.setFormatter(formatter)
68+
logger.addHandler(console_handler)
69+
70+
def log_debug(self, message):
71+
"""Log a debug message."""
72+
if self.debug_enabled:
73+
self.gen_logger.debug(message)
74+
75+
def log_info(self, message):
76+
"""Log an info message."""
77+
self.gen_logger.info(message)
78+
info=log_info #alias created
79+
80+
def log_error(self, message):
81+
"""Log an error message."""
82+
self.error_logger.error(message)
83+
84+
def enable_console_logging(self):
85+
"""Enable console logging dynamically."""
86+
if not any(isinstance(handler, logging.StreamHandler) for handler in self.gen_logger.handlers):
87+
self._setup_console_handler(self.gen_logger)
88+
89+
if not any(isinstance(handler, logging.StreamHandler) for handler in self.error_logger.handlers):
90+
self._setup_console_handler(self.error_logger)
91+
92+
# Usage example
93+
custom_logger = CustomLogger('simple_kes', print_to_console=False, debug_enabled=True)
94+
95+
# Logging examples
96+
#custom_logger.log_debug("Debug message: Using default sigma_min: 0.01")
97+
#custom_logger.info("Info message: Step completed successfully.")
98+
#custom_logger.log_error("Error message: Something went wrong!")
99+
100+
101+
class ConfigManagerYaml:
102+
def __init__(self, config_path):
103+
self.config_path = config_path
104+
self.config_data = self.load_config() # Initialize config_data here
105+
106+
def load_config(self):
107+
try:
108+
with open(self.config_path, 'r') as f:
109+
user_config = yaml.safe_load(f)
110+
return user_config
111+
except FileNotFoundError:
112+
print(f"Config file not found: {self.config_path}. Using empty config.")
113+
return {}
114+
except yaml.YAMLError as e:
115+
print(f"Error loading config file: {e}")
116+
return {}
117+
118+
119+
#ConfigWatcher monitors changes to the config file and reloads during program use (so you can continue work without resetting the program)
120+
class ConfigWatcher(FileSystemEventHandler):
121+
def __init__(self, config_manager, config_path):
122+
self.config_manager = config_manager
123+
self.config_path = config_path
124+
125+
def on_modified(self, event):
126+
if event.src_path == self.config_path:
127+
logging.info(f"Config file {self.config_path} modified. Reloading config.")
128+
self.config_manager.config_data = self.config_manager.load_config()
129+
130+
131+
132+
def start_config_watcher(config_manager, config_path):
133+
event_handler = ConfigWatcher(config_manager, config_path)
134+
observer = Observer()
135+
observer.schedule(event_handler, os.path.dirname(config_path), recursive=False)
136+
observer.start()
137+
return observer
138+
139+
140+
"""
141+
Scheduler function that blends sigma sequences using Karras and Exponential methods with adaptive parameters.
142+
143+
Parameters are dynamically updated if the config file changes during execution.
144+
"""
145+
# If user config is provided, update default config with user values
146+
config_path = "modules/simple_kes_scheduler.yaml"
147+
config_manager = ConfigManagerYaml(config_path)
148+
149+
150+
# Start watching for config changes
151+
observer = start_config_watcher(config_manager, config_path)
152+
'''
153+
def get_random_or_default(config, key_prefix, default_value):
154+
"""Helper function to either randomize a value or return the default."""
155+
randomize_flag = config['scheduler'].get(f'{key_prefix}_rand', False)
156+
if randomize_flag:
157+
rand_min = config['scheduler'].get(f'{key_prefix}_rand_min', default_value * 0.8)
158+
rand_max = config['scheduler'].get(f'{key_prefix}_rand_max', default_value * 1.2)
159+
value = random.uniform(rand_min, rand_max)
160+
custom_logger.info(f"Randomized {key_prefix}: {value}" )
161+
else:
162+
value = default_value
163+
custom_logger.info(f"Using default {key_prefix}: {value}")
164+
return value
165+
'''
166+
def get_random_or_default(config, key_prefix, default_value, global_randomize):
167+
"""Helper function to either randomize a value based on conditions or return the default."""
168+
# Check if global randomize is on or the individual flag is on
169+
randomize_flag = global_randomize or config['scheduler'].get(f'{key_prefix}_rand', False)
170+
171+
if randomize_flag:
172+
# Use specified min/max for randomization if the individual flag is set or global randomize is on
173+
rand_min = config['scheduler'].get(f'{key_prefix}_rand_min', default_value * 0.8)
174+
rand_max = config['scheduler'].get(f'{key_prefix}_rand_max', default_value * 1.2)
175+
value = random.uniform(rand_min, rand_max)
176+
custom_logger.info(f"Randomized {key_prefix}: {value}")
177+
else:
178+
value = default_value
179+
custom_logger.info(f"Using default {key_prefix}: {value}")
180+
181+
return value
182+
183+
184+
def simple_karras_exponential_scheduler(
185+
n, device, sigma_min=0.01, sigma_max=50, start_blend=0.1, end_blend=0.5,
186+
sharpness=0.95, early_stopping_threshold=0.01, update_interval=10, initial_step_size=0.9,
187+
final_step_size=0.2, initial_noise_scale=1.25, final_noise_scale=0.8, smooth_blend_factor=11, step_size_factor=0.8, noise_scale_factor=0.9, randomize=False, user_config=None
188+
):
189+
"""
190+
Scheduler function that blends sigma sequences using Karras and Exponential methods with adaptive parameters.
191+
192+
Parameters:
193+
n (int): Number of steps.
194+
sigma_min (float): Minimum sigma value.
195+
sigma_max (float): Maximum sigma value.
196+
device (torch.device): The device on which to perform computations (e.g., 'cuda' or 'cpu').
197+
start_blend (float): Initial blend factor for dynamic blending.
198+
end_bend (float): Final blend factor for dynamic blending.
199+
sharpen_factor (float): Sharpening factor to be applied adaptively.
200+
early_stopping_threshold (float): Threshold to trigger early stopping.
201+
update_interval (int): Interval to update blend factors.
202+
initial_step_size (float): Initial step size for adaptive step size calculation.
203+
final_step_size (float): Final step size for adaptive step size calculation.
204+
initial_noise_scale (float): Initial noise scale factor.
205+
final_noise_scale (float): Final noise scale factor.
206+
step_size_factor: Adjust to compensate for oversmoothing
207+
noise_scale_factor: Adjust to provide more variation
208+
209+
Returns:
210+
torch.Tensor: A tensor of blended sigma values.
211+
"""
212+
#debug_log("Entered simple_karras_exponential_scheduler function")
213+
default_config = {
214+
"debug": False,
215+
"device": "cuda" if torch.cuda.is_available() else "cpu",
216+
"sigma_min": 0.01,
217+
"sigma_max": 50, #if sigma_max is too low the resulting picture may be undesirable.
218+
"start_blend": 0.1,
219+
"end_blend": 0.5,
220+
"sharpness": 0.95,
221+
"early_stopping_threshold": 0.01,
222+
"update_interval": 10,
223+
"initial_step_size": 0.9,
224+
"final_step_size": 0.2,
225+
"initial_noise_scale": 1.25,
226+
"final_noise_scale": 0.8,
227+
"smooth_blend_factor": 11,
228+
"step_size_factor": 0.8, #suggested value to avoid oversmoothing
229+
"noise_scale_factor": 0.9, #suggested value to add more variation
230+
"randomize": False,
231+
"sigma_min_rand": False,
232+
"sigma_min_rand_min": 0.001,
233+
"sigma_min_rand_max": 0.05,
234+
"sigma_max_rand": False,
235+
"sigma_max_rand_min": 0.05,
236+
"sigma_max_rand_max": 0.20,
237+
"start_blend_rand": False,
238+
"start_blend_rand_min": 0.05,
239+
"start_blend_rand_max": 0.2,
240+
"end_blend_rand": False,
241+
"end_blend_rand_min": 0.4,
242+
"end_blend_rand_max": 0.6,
243+
"sharpness_rand": False,
244+
"sharpness_rand_min": 0.85,
245+
"sharpness_rand_max": 1.0,
246+
"early_stopping_rand": False,
247+
"early_stopping_rand_min": 0.001,
248+
"early_stopping_rand_max": 0.02,
249+
"update_interval_rand": False,
250+
"update_interval_rand_min": 5,
251+
"update_interval_rand_max": 10,
252+
"initial_step_rand": False,
253+
"initial_step_rand_min": 0.7,
254+
"initial_step_rand_max": 1.0,
255+
"final_step_rand": False,
256+
"final_step_rand_min": 0.1,
257+
"final_step_rand_max": 0.3,
258+
"initial_noise_rand": False,
259+
"initial_noise_rand_min": 1.0,
260+
"initial_noise_rand_max": 1.5,
261+
"final_noise_rand": False,
262+
"final_noise_rand_min": 0.6,
263+
"final_noise_rand_max": 1.0,
264+
"smooth_blend_factor_rand": False,
265+
"smooth_blend_factor_rand_min": 6,
266+
"smooth_blend_factor_rand_max": 11,
267+
"step_size_factor_rand": False,
268+
"step_size_factor_rand_min": 0.65,
269+
"step_size_factor_rand_max": 0.85,
270+
"noise_scale_factor_rand": False,
271+
"noise_scale_factor_rand_min": 0.75,
272+
"noise_scale_factor_rand_max": 0.95,
273+
}
274+
custom_logger.info(f"Default Config create {default_config}")
275+
for key, value in default_config.items():
276+
custom_logger.info(f"Default Config - {key}: {value}")
277+
278+
#config = config_manager.load_config()
279+
config = config_manager.load_config().get('scheduler', {})
280+
global_randomize = config.get('randomize', randomize)
281+
282+
custom_logger.info(f"Config loaded from yaml {config}")
283+
for key, value in config.items():
284+
custom_logger.info(f"Config - {key}: {value}")
285+
286+
# Check if the scheduler config is available in the YAML file
287+
scheduler_config = config.get('scheduler', {})
288+
if not scheduler_config:
289+
raise ValueError("Scheduler configuration is missing from the config file.")
290+
291+
for key, value in scheduler_config.items():
292+
custom_logger.info(f"Scheduler Config before update - {key}: {value}")
293+
for key, value in scheduler_config.items():
294+
if key in default_config:
295+
default_config[key] = value
296+
custom_logger.info(f"Overriding default config: {key} = {value}")
297+
else:
298+
debug.log(f"Ignoring unknown config option: {key}")
299+
# Now using default_config, updated with valid YAML values
300+
custom_logger.info(f"Final Config after overriding: {default_config}")
301+
302+
# Example: Reading the randomization flags from the config
303+
randomize = config.get('scheduler', {}).get('randomize', False)
304+
305+
# Use the get_random_or_default function for each parameter
306+
#if randomize = false, then it checks for each variable for randomize, if true, then that particular option is randomized, with the others using default or config defined values.
307+
sigma_min = get_random_or_default(config, 'sigma_min', sigma_min, global_randomize)
308+
sigma_max = get_random_or_default(config, 'sigma_max', sigma_max, global_randomize)
309+
start_blend = get_random_or_default(config, 'start_blend', start_blend, global_randomize)
310+
end_blend = get_random_or_default(config, 'end_blend', end_blend, global_randomize)
311+
sharpness = get_random_or_default(config, 'sharpness', sharpness, global_randomize)
312+
early_stopping_threshold = get_random_or_default(config, 'early_stopping', early_stopping_threshold, global_randomize)
313+
update_interval = get_random_or_default(config, 'update_interval', update_interval, global_randomize)
314+
initial_step_size = get_random_or_default(config, 'initial_step', initial_step_size, global_randomize)
315+
final_step_size = get_random_or_default(config, 'final_step', final_step_size, global_randomize)
316+
initial_noise_scale = get_random_or_default(config, 'initial_noise', initial_noise_scale, global_randomize)
317+
final_noise_scale = get_random_or_default(config, 'final_noise', final_noise_scale, global_randomize)
318+
smooth_blend_factor = get_random_or_default(config, 'smooth_blend_factor', smooth_blend_factor, global_randomize)
319+
step_size_factor = get_random_or_default(config, 'step_size_factor', step_size_factor, global_randomize)
320+
noise_scale_factor = get_random_or_default(config, 'noise_scale_factor', noise_scale_factor, global_randomize)
321+
322+
323+
# Expand sigma_max slightly to account for smoother transitions
324+
sigma_max = sigma_max * 1.1
325+
custom_logger.info(f"Using device: {device}")
326+
# Generate sigma sequences using Karras and Exponential methods
327+
sigmas_karras = get_sigmas_karras(n=n, sigma_min=sigma_min, sigma_max=sigma_max, device=device)
328+
sigmas_exponential = get_sigmas_exponential(n=n, sigma_min=sigma_min, sigma_max=sigma_max, device=device)
329+
config = config_manager.config_data.get('scheduler', {})
330+
# Match lengths of sigma sequences
331+
target_length = min(len(sigmas_karras), len(sigmas_exponential))
332+
sigmas_karras = sigmas_karras[:target_length]
333+
sigmas_exponential = sigmas_exponential[:target_length]
334+
335+
custom_logger.info(f"Generated sigma sequences. Karras: {sigmas_karras}, Exponential: {sigmas_exponential}")
336+
if sigmas_karras is None:
337+
raise ValueError("Sigmas Karras:{sigmas_karras} Failed to generate or assign sigmas correctly.")
338+
if sigmas_exponential is None:
339+
raise ValueError("Sigmas Exponential: {sigmas_exponential} Failed to generate or assign sigmas correctly.")
340+
#sigmas_karras = torch.zeros(n).to(device)
341+
#sigmas_exponential = torch.zeros(n).to(device)
342+
try:
343+
pass
344+
except Exception as e:
345+
error_log(f"Error generating sigmas: {e}")
346+
finally:
347+
# Stop the observer when done
348+
observer.stop()
349+
observer.join()
350+
351+
# Define progress and initialize blend factor
352+
progress = torch.linspace(0, 1, len(sigmas_karras)).to(device)
353+
custom_logger.info(f"Progress created {progress}")
354+
custom_logger.info(f"Progress Using device: {device}")
355+
356+
sigs = torch.zeros_like(sigmas_karras).to(device)
357+
custom_logger.info(f"Sigs created {sigs}")
358+
custom_logger.info(f"Sigs Using device: {device}")
359+
360+
# Iterate through each step, dynamically adjust blend factor, step size, and noise scaling
361+
for i in range(len(sigmas_karras)):
362+
# Adaptive step size and blend factor calculations
363+
step_size = initial_step_size * (1 - progress[i]) + final_step_size * progress[i] * step_size_factor # 0.8 default value Adjusted to avoid over-smoothing
364+
custom_logger.info(f"Step_size created {step_size}" )
365+
dynamic_blend_factor = start_blend * (1 - progress[i]) + end_blend * progress[i]
366+
custom_logger.info(f"Dynamic_blend_factor created {dynamic_blend_factor}" )
367+
noise_scale = initial_noise_scale * (1 - progress[i]) + final_noise_scale * progress[i] * noise_scale_factor # 0.9 default value Adjusted to keep more variation
368+
custom_logger.info(f"noise_scale created {noise_scale}" )
369+
370+
# Calculate smooth blending between the two sigma sequences
371+
smooth_blend = torch.sigmoid((dynamic_blend_factor - 0.5) * smooth_blend_factor) # Increase scaling factor to smooth transitions more
372+
custom_logger.info(f"smooth_blend created {smooth_blend}" )
373+
374+
# Compute blended sigma values
375+
blended_sigma = sigmas_karras[i] * (1 - smooth_blend) + sigmas_exponential[i] * smooth_blend
376+
custom_logger.info(f"blended_sigma created {blended_sigma}" )
377+
378+
# Apply step size and noise scaling
379+
sigs[i] = blended_sigma * step_size * noise_scale
380+
381+
# Optional: Adaptive sharpening based on sigma values
382+
sharpen_mask = torch.where(sigs < sigma_min * 1.5, sharpness, 1.0).to(device)
383+
custom_logger.info(f"sharpen_mask created {sharpen_mask} with device {device}" )
384+
sigs = sigs * sharpen_mask
385+
386+
# Implement early stop criteria based on sigma convergence
387+
change = torch.abs(sigs[1:] - sigs[:-1])
388+
if torch.all(change < early_stopping_threshold):
389+
custom_logger.info("Early stopping criteria met." )
390+
return sigs[:len(change) + 1].to(device)
391+
392+
if torch.isnan(sigs).any() or torch.isinf(sigs).any():
393+
raise ValueError("Invalid sigma values detected (NaN or Inf).")
394+
395+
return sigs.to(device)

0 commit comments

Comments
 (0)