Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use clap::Parser;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Cli {
/// Choose a model
#[clap(short, long)]
pub model: Option<String>,
/// Add a GPT prompt
#[clap(short, long)]
pub prompt: Option<String>,
Expand All @@ -15,6 +18,9 @@ pub struct Cli {
/// List all roles
#[clap(long)]
pub list_roles: bool,
/// List all models
#[clap(long)]
pub list_models: bool,
/// Select a role
#[clap(short, long)]
pub role: Option<String>,
Expand Down
4 changes: 2 additions & 2 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use tokio::time::sleep;

const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const API_URL: &str = "https://api.openai.com/v1/chat/completions";
const MODEL: &str = "gpt-3.5-turbo";

#[derive(Debug)]
pub struct ChatGptClient {
Expand Down Expand Up @@ -137,9 +136,10 @@ impl ChatGptClient {
}

fn request_builder(&self, content: &str, stream: bool) -> Result<RequestBuilder> {
let (model, _) = self.config.read().get_model();
let messages = self.config.read().build_messages(content)?;
let mut body = json!({
"model": MODEL,
"model": model,
"messages": messages,
});

Expand Down
6 changes: 1 addition & 5 deletions src/config/conversation.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::message::{num_tokens_from_messages, Message, MessageRole, MAX_TOKENS};
use super::message::{num_tokens_from_messages, Message, MessageRole};
use super::role::Role;

use anyhow::{bail, Result};
Expand Down Expand Up @@ -87,8 +87,4 @@ impl Conversation {
}
messages
}

pub fn reamind_tokens(&self) -> usize {
MAX_TOKENS.saturating_sub(self.tokens)
}
}
6 changes: 2 additions & 4 deletions src/config/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ use crate::utils::count_tokens;
use anyhow::{bail, Result};
use serde::{Deserialize, Serialize};

pub const MAX_TOKENS: usize = 4096;

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Message {
pub role: MessageRole,
Expand All @@ -28,9 +26,9 @@ pub enum MessageRole {
User,
}

pub fn within_max_tokens_limit(messages: &[Message]) -> Result<()> {
pub fn within_max_tokens_limit(messages: &[Message], max_tokens: usize) -> Result<()> {
let tokens = num_tokens_from_messages(messages);
if tokens >= MAX_TOKENS {
if tokens >= max_tokens {
bail!("Exceed max tokens limit")
}
Ok(())
Expand Down
55 changes: 46 additions & 9 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ use std::{
sync::Arc,
};

pub const MODELS: [(&str, usize); 3] = [
("gpt-4", 8192),
("gpt-4-32k", 32768),
("gpt-3.5-turbo", 4096),
];

const CONFIG_FILE_NAME: &str = "config.yaml";
const ROLES_FILE_NAME: &str = "roles.yaml";
const HISTORY_FILE_NAME: &str = "history.txt";
Expand All @@ -42,6 +48,9 @@ const SET_COMPLETIONS: [&str; 9] = [
pub struct Config {
/// Openai api key
pub api_key: Option<String>,
/// Openai model
#[serde(rename(serialize = "model", deserialize = "model"))]
pub model_name: Option<String>,
/// What sampling temperature to use, between 0 and 2
pub temperature: Option<f64>,
/// Whether to persistently save chat messages
Expand All @@ -65,12 +74,15 @@ pub struct Config {
/// Current conversation
#[serde(skip)]
pub conversation: Option<Conversation>,
#[serde(skip)]
pub model: (String, usize),
}

impl Default for Config {
fn default() -> Self {
Self {
api_key: None,
model_name: None,
temperature: None,
save: false,
highlight: true,
Expand All @@ -81,6 +93,7 @@ impl Default for Config {
roles: vec![],
role: None,
conversation: None,
model: ("gpt-3.5-turbo".into(), 4096),
}
}
}
Expand All @@ -105,6 +118,9 @@ impl Config {
if config.api_key.is_none() {
bail!("api_key not set");
}
if let Some(name) = config.model_name.clone() {
config.set_model(&name)?;
}
config.merge_env_vars();
config.maybe_proxy();
config.load_roles()?;
Expand Down Expand Up @@ -251,6 +267,10 @@ impl Config {
}
}

pub fn get_model(&self) -> (String, usize) {
self.model.clone()
}

pub fn build_messages(&self, content: &str) -> Result<Vec<Message>> {
let messages = if let Some(conversation) = self.conversation.as_ref() {
conversation.build_emssages(content)
Expand All @@ -260,11 +280,28 @@ impl Config {
let message = Message::new(content);
vec![message]
};
within_max_tokens_limit(&messages)?;
within_max_tokens_limit(&messages, self.model.1)?;

Ok(messages)
}

pub fn set_model(&mut self, name: &str) -> Result<()> {
if let Some(token) = MODELS.iter().find(|(v, _)| *v == name).map(|(_, v)| *v) {
self.model = (name.to_string(), token);
} else {
bail!("Invalid model")
}
Ok(())
}

pub fn get_reamind_tokens(&self) -> usize {
let mut tokens = self.model.1;
if let Some(conversation) = self.conversation.as_ref() {
tokens = tokens.saturating_sub(conversation.tokens);
}
tokens
}

pub fn info(&self) -> Result<String> {
let file_info = |path: &Path| {
let state = if path.exists() { "" } else { " ⚠️" };
Expand All @@ -284,6 +321,7 @@ impl Config {
("roles_file", file_info(&Config::roles_file()?)),
("messages_file", file_info(&Config::messages_file()?)),
("api_key", self.get_api_key().to_string()),
("model", self.model.0.to_string()),
("temperature", temperature),
("save", self.save.to_string()),
("highlight", self.highlight.to_string()),
Expand All @@ -307,6 +345,7 @@ impl Config {
.collect();

completion.extend(SET_COMPLETIONS.map(|v| v.to_string()));
completion.extend(MODELS.map(|(v, _)| format!(".model {}", v)));
completion
}

Expand Down Expand Up @@ -359,14 +398,12 @@ impl Config {
}

pub fn start_conversation(&mut self) -> Result<()> {
if let Some(conversation) = self.conversation.as_ref() {
if conversation.reamind_tokens() > 0 {
let ans = Confirm::new("Already in a conversation, start a new one?")
.with_default(true)
.prompt()?;
if !ans {
return Ok(());
}
if self.conversation.is_some() && self.get_reamind_tokens() > 0 {
let ans = Confirm::new("Already in a conversation, start a new one?")
.with_default(true)
.prompt()?;
if !ans {
return Ok(());
}
}
self.conversation = Some(Conversation::new(self.role.clone()));
Expand Down
9 changes: 9 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ fn main() -> Result<()> {
.for_each(|v| println!("{}", v.name));
exit(0);
}
if cli.list_models {
config::MODELS
.iter()
.for_each(|(name, _)| println!("{}", name));
exit(0);
}
let role = match &cli.role {
Some(name) => Some(
config
Expand All @@ -44,6 +50,9 @@ fn main() -> Result<()> {
),
None => None,
};
if let Some(model) = &cli.model {
config.write().set_model(model)?;
}
config.write().role = role;
if cli.no_highlight {
config.write().highlight = false;
Expand Down
5 changes: 5 additions & 0 deletions src/repl/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::cell::RefCell;

pub enum ReplCmd {
Submit(String),
SetModel(String),
SetRole(String),
UpdateConfig(String),
Prompt(String),
Expand Down Expand Up @@ -65,6 +66,10 @@ impl ReplCmdHandler {
self.config.write().save_conversation(&input, &buffer)?;
*self.reply.borrow_mut() = buffer;
}
ReplCmd::SetModel(name) => {
self.config.write().set_model(&name)?;
print_now!("\n");
}
ReplCmd::SetRole(name) => {
let output = self.config.write().change_role(&name)?;
print_now!("{}\n\n", output.trim_end());
Expand Down
7 changes: 6 additions & 1 deletion src/repl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ use reedline::Signal;
use std::borrow::Cow;
use std::sync::Arc;

pub const REPL_COMMANDS: [(&str, &str); 11] = [
pub const REPL_COMMANDS: [(&str, &str); 12] = [
(".info", "Print the information"),
(".set", "Modify the configuration temporarily"),
(".model", "Choose a model"),
(".prompt", "Add a GPT prompt"),
(".role", "Select a role"),
(".clear role", "Clear the currently selected role"),
Expand Down Expand Up @@ -109,6 +110,10 @@ impl Repl {
self.editor.print_history()?;
print_now!("\n");
}
".model" => match args {
Some(name) => handler.handle(ReplCmd::SetModel(name.to_string()))?,
None => print_now!("Usage: .model <name>\n\n"),
},
".role" => match args {
Some(name) => handler.handle(ReplCmd::SetRole(name.to_string()))?,
None => print_now!("Usage: .role <name>\n\n"),
Expand Down
6 changes: 3 additions & 3 deletions src/repl/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ impl Prompt for ReplPrompt {
}

fn render_prompt_right(&self) -> Cow<str> {
if let Some(conversation) = self.config.read().conversation.as_ref() {
conversation.reamind_tokens().to_string().into()
} else {
if self.config.read().conversation.is_none() {
Cow::Borrowed("")
} else {
self.config.read().get_reamind_tokens().to_string().into()
}
}

Expand Down