Skip to content

Commit 11dc4d1

Browse files
authored
refactor: optimize ctrl+c/ctrl+d abort handling (#27)
1 parent 1640456 commit 11dc4d1

File tree

5 files changed

+104
-48
lines changed

5 files changed

+104
-48
lines changed

src/client.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
use crate::config::SharedConfig;
2-
use crate::repl::ReplyStreamHandler;
2+
use crate::repl::{ReplyStreamHandler, SharedAbortSignal};
33

44
use anyhow::{anyhow, Context, Result};
55
use eventsource_stream::Eventsource;
66
use futures_util::StreamExt;
77
use reqwest::{Client, Proxy, RequestBuilder};
88
use serde_json::{json, Value};
9-
use std::sync::atomic::{AtomicBool, Ordering};
10-
use std::{sync::Arc, time::Duration};
9+
use std::time::Duration;
1110
use tokio::runtime::Runtime;
1211
use tokio::time::sleep;
1312

@@ -43,27 +42,27 @@ impl ChatGptClient {
4342
prompt: Option<String>,
4443
handler: &mut ReplyStreamHandler,
4544
) -> Result<()> {
46-
async fn watch_ctrlc(ctrlc: Arc<AtomicBool>) {
45+
async fn watch_abort(abort: SharedAbortSignal) {
4746
loop {
48-
if ctrlc.load(Ordering::SeqCst) {
47+
if abort.aborted() {
4948
break;
5049
}
5150
sleep(Duration::from_millis(100)).await;
5251
}
5352
}
54-
let ctrlc = handler.get_ctrlc();
53+
let abort = handler.get_abort();
5554
self.runtime.block_on(async {
5655
tokio::select! {
5756
ret = self.send_message_streaming_inner(input, prompt, handler) => {
5857
handler.done();
5958
ret.with_context(|| "Failed to send message streaming")
6059
}
61-
_ = watch_ctrlc(ctrlc.clone()) => {
60+
_ = watch_abort(abort.clone()) => {
6261
handler.done();
6362
Ok(())
6463
},
6564
_ = tokio::signal::ctrl_c() => {
66-
ctrlc.store(true, Ordering::SeqCst);
65+
abort.set_ctrlc();
6766
Ok(())
6867
}
6968
}

src/render/mod.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
mod markdown;
22

33
pub use self::markdown::MarkdownRender;
4-
use crate::repl::ReplyStreamEvent;
4+
use crate::repl::{ReplyStreamEvent, SharedAbortSignal};
55

66
use anyhow::Result;
77
use crossbeam::channel::Receiver;
@@ -13,20 +13,16 @@ use crossterm::{
1313
};
1414
use std::{
1515
io::{self, Stdout, Write},
16-
sync::{
17-
atomic::{AtomicBool, Ordering},
18-
Arc,
19-
},
2016
time::{Duration, Instant},
2117
};
2218
use unicode_width::UnicodeWidthStr;
2319

24-
pub fn render_stream(rx: Receiver<ReplyStreamEvent>, ctrlc: Arc<AtomicBool>) -> Result<()> {
20+
pub fn render_stream(rx: Receiver<ReplyStreamEvent>, abort: SharedAbortSignal) -> Result<()> {
2521
enable_raw_mode()?;
2622
let mut stdout = io::stdout();
2723
queue!(stdout, event::DisableMouseCapture)?;
2824

29-
let ret = render_stream_inner(rx, ctrlc, &mut stdout);
25+
let ret = render_stream_inner(rx, abort, &mut stdout);
3026

3127
queue!(stdout, event::DisableMouseCapture)?;
3228
disable_raw_mode()?;
@@ -36,7 +32,7 @@ pub fn render_stream(rx: Receiver<ReplyStreamEvent>, ctrlc: Arc<AtomicBool>) ->
3632

3733
pub fn render_stream_inner(
3834
rx: Receiver<ReplyStreamEvent>,
39-
ctrlc: Arc<AtomicBool>,
35+
abort: SharedAbortSignal,
4036
writer: &mut Stdout,
4137
) -> Result<()> {
4238
let mut last_tick = Instant::now();
@@ -45,7 +41,7 @@ pub fn render_stream_inner(
4541
let mut markdown_render = MarkdownRender::new();
4642
let terminal_columns = terminal::size()?.0;
4743
loop {
48-
if ctrlc.load(Ordering::SeqCst) {
44+
if abort.aborted() {
4945
return Ok(());
5046
}
5147

@@ -89,7 +85,11 @@ pub fn render_stream_inner(
8985
if let Event::Key(key) = event::read()? {
9086
match key.code {
9187
KeyCode::Char('c') if key.modifiers == KeyModifiers::CONTROL => {
92-
ctrlc.store(true, Ordering::SeqCst);
88+
abort.set_ctrlc();
89+
return Ok(());
90+
}
91+
KeyCode::Char('d') if key.modifiers == KeyModifiers::CONTROL => {
92+
abort.set_ctrld();
9393
return Ok(());
9494
}
9595
_ => {}

src/repl/abort.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
use std::sync::{
2+
atomic::{AtomicBool, Ordering},
3+
Arc,
4+
};
5+
6+
pub type SharedAbortSignal = Arc<AbortSignal>;
7+
8+
pub struct AbortSignal {
9+
ctrlc: AtomicBool,
10+
ctrld: AtomicBool,
11+
}
12+
13+
impl AbortSignal {
14+
pub fn new() -> SharedAbortSignal {
15+
Arc::new(Self {
16+
ctrlc: AtomicBool::new(false),
17+
ctrld: AtomicBool::new(false),
18+
})
19+
}
20+
21+
pub fn aborted(&self) -> bool {
22+
if self.aborted_ctrlc() {
23+
return true;
24+
}
25+
if self.aborted_ctrld() {
26+
return true;
27+
}
28+
false
29+
}
30+
31+
pub fn aborted_ctrlc(&self) -> bool {
32+
self.ctrlc.load(Ordering::SeqCst)
33+
}
34+
35+
pub fn aborted_ctrld(&self) -> bool {
36+
self.ctrld.load(Ordering::SeqCst)
37+
}
38+
39+
pub fn reset(&self) {
40+
self.ctrlc.store(false, Ordering::SeqCst);
41+
self.ctrld.store(false, Ordering::SeqCst);
42+
}
43+
44+
pub fn set_ctrlc(&self) {
45+
self.ctrlc.store(true, Ordering::SeqCst);
46+
}
47+
48+
pub fn set_ctrld(&self) {
49+
self.ctrld.store(true, Ordering::SeqCst);
50+
}
51+
}

src/repl/handler.rs

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ use crossbeam::channel::{unbounded, Sender};
88
use crossbeam::sync::WaitGroup;
99
use std::cell::RefCell;
1010
use std::fs::File;
11-
use std::sync::atomic::AtomicBool;
12-
use std::sync::Arc;
1311
use std::thread::spawn;
1412

13+
use super::abort::SharedAbortSignal;
14+
1515
pub enum ReplCmd {
1616
Submit(String),
1717
SetRole(String),
@@ -25,7 +25,7 @@ pub struct ReplCmdHandler {
2525
client: ChatGptClient,
2626
config: SharedConfig,
2727
state: RefCell<ReplCmdHandlerState>,
28-
ctrlc: Arc<AtomicBool>,
28+
abort: SharedAbortSignal,
2929
}
3030

3131
pub struct ReplCmdHandlerState {
@@ -34,9 +34,12 @@ pub struct ReplCmdHandlerState {
3434
}
3535

3636
impl ReplCmdHandler {
37-
pub fn init(client: ChatGptClient, config: SharedConfig) -> Result<Self> {
37+
pub fn init(
38+
client: ChatGptClient,
39+
config: SharedConfig,
40+
abort: SharedAbortSignal,
41+
) -> Result<Self> {
3842
let save_file = config.as_ref().borrow().open_message_file()?;
39-
let ctrlc = Arc::new(AtomicBool::new(false));
4043
let state = RefCell::new(ReplCmdHandlerState {
4144
save_file,
4245
reply: String::new(),
@@ -45,7 +48,7 @@ impl ReplCmdHandler {
4548
client,
4649
config,
4750
state,
48-
ctrlc,
51+
abort,
4952
})
5053
}
5154

@@ -61,15 +64,15 @@ impl ReplCmdHandler {
6164
let highlight = self.config.borrow().highlight;
6265
let mut stream_handler = if highlight {
6366
let (tx, rx) = unbounded();
64-
let ctrlc = self.ctrlc.clone();
67+
let abort = self.abort.clone();
6568
let wg = wg.clone();
6669
spawn(move || {
67-
let _ = render_stream(rx, ctrlc);
70+
let _ = render_stream(rx, abort);
6871
drop(wg);
6972
});
70-
ReplyStreamHandler::new(Some(tx), self.ctrlc.clone())
73+
ReplyStreamHandler::new(Some(tx), self.abort.clone())
7174
} else {
72-
ReplyStreamHandler::new(None, self.ctrlc.clone())
75+
ReplyStreamHandler::new(None, self.abort.clone())
7376
};
7477
self.client
7578
.send_message_streaming(&input, prompt, &mut stream_handler)?;
@@ -109,23 +112,19 @@ impl ReplCmdHandler {
109112
pub fn get_reply(&self) -> String {
110113
self.state.borrow().reply.to_string()
111114
}
112-
113-
pub fn get_ctrlc(&self) -> Arc<AtomicBool> {
114-
self.ctrlc.clone()
115-
}
116115
}
117116

118117
pub struct ReplyStreamHandler {
119118
sender: Option<Sender<ReplyStreamEvent>>,
120119
buffer: String,
121-
ctrlc: Arc<AtomicBool>,
120+
abort: SharedAbortSignal,
122121
}
123122

124123
impl ReplyStreamHandler {
125-
pub fn new(sender: Option<Sender<ReplyStreamEvent>>, ctrlc: Arc<AtomicBool>) -> Self {
124+
pub fn new(sender: Option<Sender<ReplyStreamEvent>>, abort: SharedAbortSignal) -> Self {
126125
Self {
127126
sender,
128-
ctrlc,
127+
abort,
129128
buffer: String::new(),
130129
}
131130
}
@@ -157,8 +156,8 @@ impl ReplyStreamHandler {
157156
&self.buffer
158157
}
159158

160-
pub fn get_ctrlc(&self) -> Arc<AtomicBool> {
161-
self.ctrlc.clone()
159+
pub fn get_abort(&self) -> SharedAbortSignal {
160+
self.abort.clone()
162161
}
163162
}
164163

src/repl/mod.rs

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
mod abort;
12
mod handler;
23
mod init;
34

@@ -8,9 +9,9 @@ use crate::utils::{copy, dump};
89

910
use anyhow::{Context, Result};
1011
use reedline::{DefaultPrompt, Reedline, Signal};
11-
use std::sync::atomic::Ordering;
1212
use std::sync::Arc;
1313

14+
pub use self::abort::*;
1415
pub use self::handler::*;
1516

1617
pub const REPL_COMMANDS: [(&str, &str, bool); 12] = [
@@ -35,23 +36,27 @@ pub struct Repl {
3536

3637
impl Repl {
3738
pub fn run(&mut self, client: ChatGptClient, config: SharedConfig) -> Result<()> {
38-
let handler = ReplCmdHandler::init(client, config)?;
39+
let abort = AbortSignal::new();
40+
let handler = ReplCmdHandler::init(client, config, abort.clone())?;
3941
dump(
4042
format!("Welcome to aichat {}", env!("CARGO_PKG_VERSION")),
4143
1,
4244
);
4345
dump("Type \".help\" for more information.", 1);
44-
let mut current_ctrlc = false;
46+
let mut already_ctrlc = false;
4547
let handler = Arc::new(handler);
4648
loop {
47-
let handler_ctrlc = handler.get_ctrlc();
48-
if handler_ctrlc.load(Ordering::SeqCst) {
49-
handler_ctrlc.store(false, Ordering::SeqCst);
50-
current_ctrlc = true
49+
if abort.aborted_ctrld() {
50+
break;
5151
}
52-
match self.editor.read_line(&self.prompt) {
52+
if abort.aborted_ctrlc() && !already_ctrlc {
53+
already_ctrlc = true;
54+
}
55+
let sig = self.editor.read_line(&self.prompt);
56+
match sig {
5357
Ok(Signal::Success(line)) => {
54-
current_ctrlc = false;
58+
already_ctrlc = false;
59+
abort.reset();
5560
match self.handle_line(handler.clone(), line) {
5661
Ok(quit) => {
5762
if quit {
@@ -65,14 +70,16 @@ impl Repl {
6570
}
6671
}
6772
Ok(Signal::CtrlC) => {
68-
if !current_ctrlc {
69-
current_ctrlc = true;
73+
abort.set_ctrlc();
74+
if !already_ctrlc {
75+
already_ctrlc = true;
7076
dump("(To exit, press Ctrl+C again or Ctrl+D or type .exit)", 2);
7177
} else {
7278
break;
7379
}
7480
}
7581
Ok(Signal::CtrlD) => {
82+
abort.set_ctrld();
7683
break;
7784
}
7885
_ => {}

0 commit comments

Comments
 (0)