Skip to content

Commit 1ef97b2

Browse files
authored
feat: support multiple models (#71)
1 parent 4a74f5c commit 1ef97b2

File tree

9 files changed

+80
-24
lines changed

9 files changed

+80
-24
lines changed

src/cli.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ use clap::Parser;
33
#[derive(Parser, Debug)]
44
#[command(author, version, about, long_about = None)]
55
pub struct Cli {
6+
/// Choose a model
7+
#[clap(short, long)]
8+
pub model: Option<String>,
69
/// Add a GPT prompt
710
#[clap(short, long)]
811
pub prompt: Option<String>,
@@ -15,6 +18,9 @@ pub struct Cli {
1518
/// List all roles
1619
#[clap(long)]
1720
pub list_roles: bool,
21+
/// List all models
22+
#[clap(long)]
23+
pub list_models: bool,
1824
/// Select a role
1925
#[clap(short, long)]
2026
pub role: Option<String>,

src/client.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ use tokio::time::sleep;
1212

1313
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
1414
const API_URL: &str = "https://api.openai.com/v1/chat/completions";
15-
const MODEL: &str = "gpt-3.5-turbo";
1615

1716
#[derive(Debug)]
1817
pub struct ChatGptClient {
@@ -137,9 +136,10 @@ impl ChatGptClient {
137136
}
138137

139138
fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
139+
let (model, _) = self.config.read().get_model();
140140
let messages = self.config.read().build_messages(content)?;
141141
let mut body = json!({
142-
"model": MODEL,
142+
"model": model,
143143
"messages": messages,
144144
});
145145

src/config/conversation.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use super::message::{num_tokens_from_messages, Message, MessageRole, MAX_TOKENS};
1+
use super::message::{num_tokens_from_messages, Message, MessageRole};
22
use super::role::Role;
33

44
use anyhow::{bail, Result};
@@ -87,8 +87,4 @@ impl Conversation {
8787
}
8888
messages
8989
}
90-
91-
pub fn reamind_tokens(&self) -> usize {
92-
MAX_TOKENS.saturating_sub(self.tokens)
93-
}
9490
}

src/config/message.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ use crate::utils::count_tokens;
33
use anyhow::{bail, Result};
44
use serde::{Deserialize, Serialize};
55

6-
pub const MAX_TOKENS: usize = 4096;
7-
86
#[derive(Debug, Clone, Deserialize, Serialize)]
97
pub struct Message {
108
pub role: MessageRole,
@@ -28,9 +26,9 @@ pub enum MessageRole {
2826
User,
2927
}
3028

31-
pub fn within_max_tokens_limit(messages: &[Message]) -> Result<()> {
29+
pub fn within_max_tokens_limit(messages: &[Message], max_tokens: usize) -> Result<()> {
3230
let tokens = num_tokens_from_messages(messages);
33-
if tokens >= MAX_TOKENS {
31+
if tokens >= max_tokens {
3432
bail!("Exceed max tokens limit")
3533
}
3634
Ok(())

src/config/mod.rs

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ use std::{
2121
sync::Arc,
2222
};
2323

24+
pub const MODELS: [(&str, usize); 3] = [
25+
("gpt-4", 8192),
26+
("gpt-4-32k", 32768),
27+
("gpt-3.5-turbo", 4096),
28+
];
29+
2430
const CONFIG_FILE_NAME: &str = "config.yaml";
2531
const ROLES_FILE_NAME: &str = "roles.yaml";
2632
const HISTORY_FILE_NAME: &str = "history.txt";
@@ -42,6 +48,9 @@ const SET_COMPLETIONS: [&str; 9] = [
4248
pub struct Config {
4349
/// Openai api key
4450
pub api_key: Option<String>,
51+
/// Openai model
52+
#[serde(rename(serialize = "model", deserialize = "model"))]
53+
pub model_name: Option<String>,
4554
/// What sampling temperature to use, between 0 and 2
4655
pub temperature: Option<f64>,
4756
/// Whether to persistently save chat messages
@@ -65,12 +74,15 @@ pub struct Config {
6574
/// Current conversation
6675
#[serde(skip)]
6776
pub conversation: Option<Conversation>,
77+
#[serde(skip)]
78+
pub model: (String, usize),
6879
}
6980

7081
impl Default for Config {
7182
fn default() -> Self {
7283
Self {
7384
api_key: None,
85+
model_name: None,
7486
temperature: None,
7587
save: false,
7688
highlight: true,
@@ -81,6 +93,7 @@ impl Default for Config {
8193
roles: vec![],
8294
role: None,
8395
conversation: None,
96+
model: ("gpt-3.5-turbo".into(), 4096),
8497
}
8598
}
8699
}
@@ -105,6 +118,9 @@ impl Config {
105118
if config.api_key.is_none() {
106119
bail!("api_key not set");
107120
}
121+
if let Some(name) = config.model_name.clone() {
122+
config.set_model(&name)?;
123+
}
108124
config.merge_env_vars();
109125
config.maybe_proxy();
110126
config.load_roles()?;
@@ -251,6 +267,10 @@ impl Config {
251267
}
252268
}
253269

270+
pub fn get_model(&self) -> (String, usize) {
271+
self.model.clone()
272+
}
273+
254274
pub fn build_messages(&self, content: &str) -> Result<Vec<Message>> {
255275
let messages = if let Some(conversation) = self.conversation.as_ref() {
256276
conversation.build_emssages(content)
@@ -260,11 +280,28 @@ impl Config {
260280
let message = Message::new(content);
261281
vec![message]
262282
};
263-
within_max_tokens_limit(&messages)?;
283+
within_max_tokens_limit(&messages, self.model.1)?;
264284

265285
Ok(messages)
266286
}
267287

288+
pub fn set_model(&mut self, name: &str) -> Result<()> {
289+
if let Some(token) = MODELS.iter().find(|(v, _)| *v == name).map(|(_, v)| *v) {
290+
self.model = (name.to_string(), token);
291+
} else {
292+
bail!("Invalid model")
293+
}
294+
Ok(())
295+
}
296+
297+
pub fn get_reamind_tokens(&self) -> usize {
298+
let mut tokens = self.model.1;
299+
if let Some(conversation) = self.conversation.as_ref() {
300+
tokens = tokens.saturating_sub(conversation.tokens);
301+
}
302+
tokens
303+
}
304+
268305
pub fn info(&self) -> Result<String> {
269306
let file_info = |path: &Path| {
270307
let state = if path.exists() { "" } else { " ⚠️" };
@@ -284,6 +321,7 @@ impl Config {
284321
("roles_file", file_info(&Config::roles_file()?)),
285322
("messages_file", file_info(&Config::messages_file()?)),
286323
("api_key", self.get_api_key().to_string()),
324+
("model", self.model.0.to_string()),
287325
("temperature", temperature),
288326
("save", self.save.to_string()),
289327
("highlight", self.highlight.to_string()),
@@ -307,6 +345,7 @@ impl Config {
307345
.collect();
308346

309347
completion.extend(SET_COMPLETIONS.map(|v| v.to_string()));
348+
completion.extend(MODELS.map(|(v, _)| format!(".model {}", v)));
310349
completion
311350
}
312351

@@ -359,14 +398,12 @@ impl Config {
359398
}
360399

361400
pub fn start_conversation(&mut self) -> Result<()> {
362-
if let Some(conversation) = self.conversation.as_ref() {
363-
if conversation.reamind_tokens() > 0 {
364-
let ans = Confirm::new("Already in a conversation, start a new one?")
365-
.with_default(true)
366-
.prompt()?;
367-
if !ans {
368-
return Ok(());
369-
}
401+
if self.conversation.is_some() && self.get_reamind_tokens() > 0 {
402+
let ans = Confirm::new("Already in a conversation, start a new one?")
403+
.with_default(true)
404+
.prompt()?;
405+
if !ans {
406+
return Ok(());
370407
}
371408
}
372409
self.conversation = Some(Conversation::new(self.role.clone()));

src/main.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ fn main() -> Result<()> {
3535
.for_each(|v| println!("{}", v.name));
3636
exit(0);
3737
}
38+
if cli.list_models {
39+
config::MODELS
40+
.iter()
41+
.for_each(|(name, _)| println!("{}", name));
42+
exit(0);
43+
}
3844
let role = match &cli.role {
3945
Some(name) => Some(
4046
config
@@ -44,6 +50,9 @@ fn main() -> Result<()> {
4450
),
4551
None => None,
4652
};
53+
if let Some(model) = &cli.model {
54+
config.write().set_model(model)?;
55+
}
4756
config.write().role = role;
4857
if cli.no_highlight {
4958
config.write().highlight = false;

src/repl/handler.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use std::cell::RefCell;
1212

1313
pub enum ReplCmd {
1414
Submit(String),
15+
SetModel(String),
1516
SetRole(String),
1617
UpdateConfig(String),
1718
Prompt(String),
@@ -65,6 +66,10 @@ impl ReplCmdHandler {
6566
self.config.write().save_conversation(&input, &buffer)?;
6667
*self.reply.borrow_mut() = buffer;
6768
}
69+
ReplCmd::SetModel(name) => {
70+
self.config.write().set_model(&name)?;
71+
print_now!("\n");
72+
}
6873
ReplCmd::SetRole(name) => {
6974
let output = self.config.write().change_role(&name)?;
7075
print_now!("{}\n\n", output.trim_end());

src/repl/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ use reedline::Signal;
1919
use std::borrow::Cow;
2020
use std::sync::Arc;
2121

22-
pub const REPL_COMMANDS: [(&str, &str); 11] = [
22+
pub const REPL_COMMANDS: [(&str, &str); 12] = [
2323
(".info", "Print the information"),
2424
(".set", "Modify the configuration temporarily"),
25+
(".model", "Choose a model"),
2526
(".prompt", "Add a GPT prompt"),
2627
(".role", "Select a role"),
2728
(".clear role", "Clear the currently selected role"),
@@ -109,6 +110,10 @@ impl Repl {
109110
self.editor.print_history()?;
110111
print_now!("\n");
111112
}
113+
".model" => match args {
114+
Some(name) => handler.handle(ReplCmd::SetModel(name.to_string()))?,
115+
None => print_now!("Usage: .model <name>\n\n"),
116+
},
112117
".role" => match args {
113118
Some(name) => handler.handle(ReplCmd::SetRole(name.to_string()))?,
114119
None => print_now!("Usage: .role <name>\n\n"),

src/repl/prompt.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ impl Prompt for ReplPrompt {
7676
}
7777

7878
fn render_prompt_right(&self) -> Cow<str> {
79-
if let Some(conversation) = self.config.read().conversation.as_ref() {
80-
conversation.reamind_tokens().to_string().into()
81-
} else {
79+
if self.config.read().conversation.is_none() {
8280
Cow::Borrowed("")
81+
} else {
82+
self.config.read().get_reamind_tokens().to_string().into()
8383
}
8484
}
8585

0 commit comments

Comments
 (0)