@@ -21,6 +21,12 @@ use std::{
21
21
sync:: Arc ,
22
22
} ;
23
23
24
+ pub const MODELS : [ ( & str , usize ) ; 3 ] = [
25
+ ( "gpt-4" , 8192 ) ,
26
+ ( "gpt-4-32k" , 32768 ) ,
27
+ ( "gpt-3.5-turbo" , 4096 ) ,
28
+ ] ;
29
+
24
30
const CONFIG_FILE_NAME : & str = "config.yaml" ;
25
31
const ROLES_FILE_NAME : & str = "roles.yaml" ;
26
32
const HISTORY_FILE_NAME : & str = "history.txt" ;
@@ -42,6 +48,9 @@ const SET_COMPLETIONS: [&str; 9] = [
42
48
pub struct Config {
43
49
/// Openai api key
44
50
pub api_key : Option < String > ,
51
+ /// Openai model
52
+ #[ serde( rename( serialize = "model" , deserialize = "model" ) ) ]
53
+ pub model_name : Option < String > ,
45
54
/// What sampling temperature to use, between 0 and 2
46
55
pub temperature : Option < f64 > ,
47
56
/// Whether to persistently save chat messages
@@ -65,12 +74,15 @@ pub struct Config {
65
74
/// Current conversation
66
75
#[ serde( skip) ]
67
76
pub conversation : Option < Conversation > ,
77
+ #[ serde( skip) ]
78
+ pub model : ( String , usize ) ,
68
79
}
69
80
70
81
impl Default for Config {
71
82
fn default ( ) -> Self {
72
83
Self {
73
84
api_key : None ,
85
+ model_name : None ,
74
86
temperature : None ,
75
87
save : false ,
76
88
highlight : true ,
@@ -81,6 +93,7 @@ impl Default for Config {
81
93
roles : vec ! [ ] ,
82
94
role : None ,
83
95
conversation : None ,
96
+ model : ( "gpt-3.5-turbo" . into ( ) , 4096 ) ,
84
97
}
85
98
}
86
99
}
@@ -105,6 +118,9 @@ impl Config {
105
118
if config. api_key . is_none ( ) {
106
119
bail ! ( "api_key not set" ) ;
107
120
}
121
+ if let Some ( name) = config. model_name . clone ( ) {
122
+ config. set_model ( & name) ?;
123
+ }
108
124
config. merge_env_vars ( ) ;
109
125
config. maybe_proxy ( ) ;
110
126
config. load_roles ( ) ?;
@@ -251,6 +267,10 @@ impl Config {
251
267
}
252
268
}
253
269
270
+ pub fn get_model ( & self ) -> ( String , usize ) {
271
+ self . model . clone ( )
272
+ }
273
+
254
274
pub fn build_messages ( & self , content : & str ) -> Result < Vec < Message > > {
255
275
let messages = if let Some ( conversation) = self . conversation . as_ref ( ) {
256
276
conversation. build_emssages ( content)
@@ -260,11 +280,28 @@ impl Config {
260
280
let message = Message :: new ( content) ;
261
281
vec ! [ message]
262
282
} ;
263
- within_max_tokens_limit ( & messages) ?;
283
+ within_max_tokens_limit ( & messages, self . model . 1 ) ?;
264
284
265
285
Ok ( messages)
266
286
}
267
287
288
+ pub fn set_model ( & mut self , name : & str ) -> Result < ( ) > {
289
+ if let Some ( token) = MODELS . iter ( ) . find ( |( v, _) | * v == name) . map ( |( _, v) | * v) {
290
+ self . model = ( name. to_string ( ) , token) ;
291
+ } else {
292
+ bail ! ( "Invalid model" )
293
+ }
294
+ Ok ( ( ) )
295
+ }
296
+
297
+ pub fn get_reamind_tokens ( & self ) -> usize {
298
+ let mut tokens = self . model . 1 ;
299
+ if let Some ( conversation) = self . conversation . as_ref ( ) {
300
+ tokens = tokens. saturating_sub ( conversation. tokens ) ;
301
+ }
302
+ tokens
303
+ }
304
+
268
305
pub fn info ( & self ) -> Result < String > {
269
306
let file_info = |path : & Path | {
270
307
let state = if path. exists ( ) { "" } else { " ⚠️" } ;
@@ -284,6 +321,7 @@ impl Config {
284
321
( "roles_file" , file_info( & Config :: roles_file( ) ?) ) ,
285
322
( "messages_file" , file_info( & Config :: messages_file( ) ?) ) ,
286
323
( "api_key" , self . get_api_key( ) . to_string( ) ) ,
324
+ ( "model" , self . model. 0 . to_string( ) ) ,
287
325
( "temperature" , temperature) ,
288
326
( "save" , self . save. to_string( ) ) ,
289
327
( "highlight" , self . highlight. to_string( ) ) ,
@@ -307,6 +345,7 @@ impl Config {
307
345
. collect ( ) ;
308
346
309
347
completion. extend ( SET_COMPLETIONS . map ( |v| v. to_string ( ) ) ) ;
348
+ completion. extend ( MODELS . map ( |( v, _) | format ! ( ".model {}" , v) ) ) ;
310
349
completion
311
350
}
312
351
@@ -359,14 +398,12 @@ impl Config {
359
398
}
360
399
361
400
pub fn start_conversation ( & mut self ) -> Result < ( ) > {
362
- if let Some ( conversation) = self . conversation . as_ref ( ) {
363
- if conversation. reamind_tokens ( ) > 0 {
364
- let ans = Confirm :: new ( "Already in a conversation, start a new one?" )
365
- . with_default ( true )
366
- . prompt ( ) ?;
367
- if !ans {
368
- return Ok ( ( ) ) ;
369
- }
401
+ if self . conversation . is_some ( ) && self . get_reamind_tokens ( ) > 0 {
402
+ let ans = Confirm :: new ( "Already in a conversation, start a new one?" )
403
+ . with_default ( true )
404
+ . prompt ( ) ?;
405
+ if !ans {
406
+ return Ok ( ( ) ) ;
370
407
}
371
408
}
372
409
self . conversation = Some ( Conversation :: new ( self . role . clone ( ) ) ) ;
0 commit comments