Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,13 @@ void step_callback(int step, int frame_count, sd_image_t* image, bool is_noisy,
}
}

#if defined(__unix__) || defined(__APPLE__) || defined(_POSIX_VERSION)
#define SD_ENABLE_SIGNAL_HANDLER
static void set_sigint_handler(sd_ctx_t* sd_ctx);
#else
#define set_sigint_handler(SD_CTX) ((void)SD_CTX)
#endif

int main(int argc, const char* argv[]) {
if (argc > 1 && std::string(argv[1]) == "--version") {
std::cout << version_string() << "\n";
Expand Down Expand Up @@ -577,6 +584,8 @@ int main(int argc, const char* argv[]) {
return 1;
}

set_sigint_handler(sd_ctx);

if (gen_params.sample_params.sample_method == SAMPLE_METHOD_COUNT) {
gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
}
Expand Down Expand Up @@ -648,6 +657,8 @@ int main(int argc, const char* argv[]) {
results = generate_video(sd_ctx, &vid_gen_params, &num_results);
}

set_sigint_handler(nullptr);

if (results == nullptr) {
LOG_ERROR("generate failed");
free_sd_ctx(sd_ctx);
Expand Down Expand Up @@ -760,3 +771,59 @@ int main(int argc, const char* argv[]) {

return 0;
}

#ifdef SD_ENABLE_SIGNAL_HANDLER

#include <atomic>
#include <csignal>
#include <thread>
#include <unistd.h>

// this lock is needed to avoid a race condition between
// free_sd_ctx and a pending sd_cancel_generation call
std::atomic_flag signal_lock = ATOMIC_FLAG_INIT;
static int g_sigint_cnt;
static sd_ctx_t* g_sd_ctx;

static void sigint_handler(int /* signum */)
{
if (!signal_lock.test_and_set(std::memory_order_acquire)) {
if (g_sd_ctx != nullptr) {
if (g_sigint_cnt == 1) {
char msg[] = "\ngot SIGINT, cancelling new generations\n";
write(2, msg, sizeof(msg)-1);
/* first Ctrl‑C cancels only the remaining latents on a batch */
sd_cancel_generation(g_sd_ctx, SD_CANCEL_NEW_LATENTS);
++g_sigint_cnt;
} else {
char msg[] = "\ngot SIGINT, cancelling everything\n";
write(2, msg, sizeof(msg)-1);
/* cancels everything */
sd_cancel_generation(g_sd_ctx, SD_CANCEL_ALL);
}
}
signal_lock.clear(std::memory_order_release);
}
}

static void set_sigint_handler(sd_ctx_t* sd_ctx)
{
if (g_sigint_cnt == 0) {
g_sigint_cnt++;
struct sigaction sa{};
sa.sa_handler = sigint_handler;
sa.sa_flags = SA_RESTART;
sigaction(SIGINT, &sa, nullptr);
}

while (signal_lock.test_and_set(std::memory_order_acquire)) {
std::this_thread::yield();
}

g_sd_ctx = sd_ctx;

signal_lock.clear(std::memory_order_release);
}

#endif

49 changes: 49 additions & 0 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include "latent-preview.h"
#include "name_conversion.h"

#include <atomic>

const char* model_version_to_str[] = {
"SD 1.x",
"SD 1.x Inpaint",
Expand Down Expand Up @@ -91,6 +93,9 @@ void suppress_pp(int step, int steps, float time, void* data) {

/*=============================================== StableDiffusionGGML ================================================*/

static_assert(std::atomic<sd_cancel_mode_t>::is_always_lock_free,
"sd_cancel_mode_t must be lock-free");

class StableDiffusionGGML {
public:
ggml_backend_t backend = nullptr; // general backend
Expand Down Expand Up @@ -139,6 +144,8 @@ class StableDiffusionGGML {

std::shared_ptr<Denoiser> denoiser = std::make_shared<CompVisDenoiser>();

std::atomic<sd_cancel_mode_t> cancellation_flag;

StableDiffusionGGML() = default;

~StableDiffusionGGML() {
Expand All @@ -154,6 +161,18 @@ class StableDiffusionGGML {
ggml_backend_free(backend);
}

void set_cancel_flag(enum sd_cancel_mode_t flag) {
cancellation_flag.store(flag, std::memory_order_release);
}

void reset_cancel_flag() {
set_cancel_flag(SD_CANCEL_RESET);
}

enum sd_cancel_mode_t get_cancel_flag() {
return cancellation_flag.load(std::memory_order_acquire);
}

void init_backend() {
#ifdef SD_USE_CUDA
LOG_DEBUG("Using CUDA backend");
Expand Down Expand Up @@ -1657,6 +1676,12 @@ class StableDiffusionGGML {
}

auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* {
enum sd_cancel_mode_t cancel_flag = get_cancel_flag();
if (cancel_flag != SD_CANCEL_RESET) {
LOG_DEBUG("cancelling latent decodings");
return nullptr;
}

auto sd_preview_cb = sd_get_preview_callback();
auto sd_preview_cb_data = sd_get_preview_callback_data();
auto sd_preview_mode = sd_get_preview_mode();
Expand Down Expand Up @@ -3120,6 +3145,12 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
img_cond = SDCondition(uncond.c_crossattn, uncond.c_vector, cond.c_concat);
}
for (int b = 0; b < batch_count; b++) {

if (sd_ctx->sd->get_cancel_flag() != SD_CANCEL_RESET) {
LOG_ERROR("cancelling generation");
break;
}

int64_t sampling_start = ggml_time_ms();
int64_t cur_seed = seed + b;
LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, batch_count, cur_seed);
Expand Down Expand Up @@ -3181,6 +3212,12 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
LOG_INFO("decoding %zu latents", final_latents.size());
std::vector<struct ggml_tensor*> decoded_images; // collect decoded images
for (size_t i = 0; i < final_latents.size(); i++) {

if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling latent decodings");
break;
}

t1 = ggml_time_ms();
struct ggml_tensor* img = sd_ctx->sd->decode_first_stage(work_ctx, final_latents[i] /* x_0 */);
// print_ggml_tensor(img);
Expand Down Expand Up @@ -3216,6 +3253,16 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
return result_images;
}

void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode)
{
if (sd_ctx && sd_ctx->sd) {
if (mode < SD_CANCEL_ALL || mode > SD_CANCEL_RESET) {
mode = SD_CANCEL_ALL;
}
sd_ctx->sd->set_cancel_flag(mode);
}
}

sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) {
sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params;
int width = sd_img_gen_params->width;
Expand All @@ -3238,6 +3285,8 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
return nullptr;
}

sd_ctx->sd->reset_cancel_flag();

struct ggml_init_params params;
params.mem_size = static_cast<size_t>(1024 * 1024) * 1024; // 1G
params.mem_buffer = nullptr;
Expand Down
9 changes: 9 additions & 0 deletions stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,15 @@ SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);

enum sd_cancel_mode_t
{
SD_CANCEL_ALL,
SD_CANCEL_NEW_LATENTS,
SD_CANCEL_RESET
};

SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode);

SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out);

Expand Down
Loading