Skip to content

Commit 152a826

Browse files
author
Thomas Chaigneau
authored
Merge pull request #6 from chainyo/add-img-to-img
Add multiple diffusion pipelines to the API
2 parents bb269c7 + 08847e2 commit 152a826

File tree

8 files changed

+366
-123
lines changed

8 files changed

+366
-123
lines changed

config/api/.env.template

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ MAX_BATCH_SIZE=4
4141
# The max wait is the maximum time (in seconds) that the API will wait for the model to generate an image. It means that
4242
# even if the queue is not full, the API will wait MAX_WAIT seconds before processing the requests in the queue.
4343
MAX_WAIT=0.5
44+
# You can select the specific task that you want the model to perform from the following options: "text_to_image",
45+
# "image_to_image", "image_variation" or "super_resolution." Each task has a different purpose and will produce
46+
# different results. Check the Hugging Face Diffusers documentation for more information.
47+
# https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview
48+
TASK="text_to_image"
4449
# The model name is the name of the model that you want to use. You can find the list of available models on the Hugging
4550
# Face Hub, by filtering the models with the "text-to-image" tag.
4651
# Check here: https://huggingface.co/models?pipeline_tag=text-to-image

picaisso/api/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from os import getenv
44
from typing import Optional, Union
55

6+
from diffusion_service import TASK_MAPPING
67
from dotenv import load_dotenv
78
from loguru import logger
89
from pydantic import Field, validator
@@ -25,6 +26,7 @@ class Settings:
2526
# Model Configuration
2627
max_batch_size: int
2728
max_wait: float
29+
task: str
2830
model_name: str
2931
model_precision: str
3032
n_steps: int
@@ -70,6 +72,13 @@ def model_precision_must_be_valid(cls, value: str):
7072
raise ValueError("model_precision must be either `fp16`, `fp32` or `bf16`.")
7173
return value
7274

75+
@validator("task")
76+
def task_must_be_valid(cls, value: str):
77+
"""Check that the task is valid."""
78+
if value not in TASK_MAPPING.keys():
79+
raise ValueError(f"task must be one of {list(TASK_MAPPING.keys())}.")
80+
return value
81+
7382
def __post_init__(self):
7483
"""Post init hook."""
7584
self.using_s3 = all(
@@ -90,6 +99,7 @@ def __post_init__(self):
9099
load_dotenv()
91100

92101
settings = Settings(
102+
# General Configuration
93103
project_name=getenv("PROJECT_NAME", "PicAIsso"),
94104
version=getenv("VERSION", "1.0.0"),
95105
description=getenv(
@@ -98,15 +108,19 @@ def __post_init__(self):
98108
),
99109
api_prefix=getenv("API_PREFIX", "/api/v1"),
100110
debug=getenv("DEBUG", True),
111+
# Authentication
101112
username=getenv("USERNAME", None),
102113
password=getenv("PASSWORD", None),
103114
openssl_key=getenv("OPENSSL_KEY", None),
104115
algorithm=getenv("ALGORITHM", "HS256"),
116+
# Model Configuration
105117
max_batch_size=getenv("MAX_BATCH_SIZE", 1),
106118
max_wait=getenv("MAX_WAIT", 0.5),
119+
task=getenv("TASK", "text_to_image"),
107120
model_name=getenv("MODEL_NAME", "prompthero/openjourney"),
108121
model_precision=getenv("MODEL_PRECISION", "fp16"),
109122
n_steps=getenv("N_STEPS", 50),
123+
# S3 Configuration
110124
bucket_name=getenv("BUCKET_NAME", None),
111125
region_name=getenv("REGION_NAME", None),
112126
access_key_id=getenv("ACCESS_KEY_ID", None),

picaisso/api/diffusion_model.py

Lines changed: 0 additions & 106 deletions
This file was deleted.

0 commit comments

Comments
 (0)