33from os import getenv
44from typing import Optional , Union
55
6+ from diffusion_service import TASK_MAPPING
67from dotenv import load_dotenv
78from loguru import logger
89from pydantic import Field , validator
@@ -25,6 +26,7 @@ class Settings:
2526 # Model Configuration
2627 max_batch_size : int
2728 max_wait : float
29+ task : str
2830 model_name : str
2931 model_precision : str
3032 n_steps : int
@@ -70,6 +72,13 @@ def model_precision_must_be_valid(cls, value: str):
7072 raise ValueError ("model_precision must be either `fp16`, `fp32` or `bf16`." )
7173 return value
7274
75+ @validator ("task" )
76+ def task_must_be_valid (cls , value : str ):
77+ """Check that the task is valid."""
78+ if value not in TASK_MAPPING .keys ():
79+ raise ValueError (f"task must be one of { list (TASK_MAPPING .keys ())} ." )
80+ return value
81+
7382 def __post_init__ (self ):
7483 """Post init hook."""
7584 self .using_s3 = all (
@@ -90,6 +99,7 @@ def __post_init__(self):
9099load_dotenv ()
91100
92101settings = Settings (
102+ # General Configuration
93103 project_name = getenv ("PROJECT_NAME" , "PicAIsso" ),
94104 version = getenv ("VERSION" , "1.0.0" ),
95105 description = getenv (
@@ -98,15 +108,19 @@ def __post_init__(self):
98108 ),
99109 api_prefix = getenv ("API_PREFIX" , "/api/v1" ),
100110 debug = getenv ("DEBUG" , True ),
111+ # Authentication
101112 username = getenv ("USERNAME" , None ),
102113 password = getenv ("PASSWORD" , None ),
103114 openssl_key = getenv ("OPENSSL_KEY" , None ),
104115 algorithm = getenv ("ALGORITHM" , "HS256" ),
116+ # Model Configuration
105117 max_batch_size = getenv ("MAX_BATCH_SIZE" , 1 ),
106118 max_wait = getenv ("MAX_WAIT" , 0.5 ),
119+ task = getenv ("TASK" , "text_to_image" ),
107120 model_name = getenv ("MODEL_NAME" , "prompthero/openjourney" ),
108121 model_precision = getenv ("MODEL_PRECISION" , "fp16" ),
109122 n_steps = getenv ("N_STEPS" , 50 ),
123+ # S3 Configuration
110124 bucket_name = getenv ("BUCKET_NAME" , None ),
111125 region_name = getenv ("REGION_NAME" , None ),
112126 access_key_id = getenv ("ACCESS_KEY_ID" , None ),
0 commit comments