1
+ import json
2
+ import os
3
+ import time
4
+
5
+ from model_handler .constant import GORILLA_TO_OPENAPI
1
6
from model_handler .gpt_handler import OpenAIHandler
2
7
from model_handler .model_style import ModelStyle
3
- import os , json
8
+ from model_handler . utils import convert_to_tool , language_specific_pre_processing
4
9
from openai import OpenAI
5
10
6
11
7
12
class FireworkAIHandler (OpenAIHandler ):
8
- def __init__ (self , model_name , temperature = 0.7 , top_p = 1 , max_tokens = 1000 ) -> None :
13
+ def __init__ (self , model_name , temperature = 0.0 , top_p = 1 , max_tokens = 1000 ) -> None :
9
14
super ().__init__ (model_name , temperature , top_p , max_tokens )
10
- self .model_name = "accounts/fireworks/models/firefunction-v1-FC"
11
15
self .model_style = ModelStyle .FIREWORK_AI
16
+ self .temperature = 0.0
12
17
13
18
self .client = OpenAI (
14
19
base_url = "https://api.fireworks.ai/inference/v1" ,
@@ -19,11 +24,54 @@ def write(self, result, file_to_open):
19
24
# This method is used to write the result to the file.
20
25
if not os .path .exists ("./result" ):
21
26
os .mkdir ("./result" )
22
- if not os .path .exists ("./result/fire-function-v1-FC " ):
23
- os .mkdir ("./result/fire-function-v1-FC " )
27
+ if not os .path .exists (f "./result/{ self . model_name } " ):
28
+ os .mkdir (f "./result/{ self . model_name } " )
24
29
with open (
25
- "./result/fire-function-v1-FC /"
30
+ f "./result/{ self . model_name } /"
26
31
+ file_to_open .replace (".json" , "_result.json" ),
27
32
"a+" ,
28
33
) as f :
29
34
f .write (json .dumps (result ) + "\n " )
35
+
36
+ def inference (self , prompt , functions , test_category ):
37
+ functions = language_specific_pre_processing (functions , test_category , True )
38
+ if type (functions ) is not list :
39
+ functions = [functions ]
40
+ message = [{"role" : "user" , "content" : prompt }]
41
+ oai_tool = convert_to_tool (
42
+ functions , GORILLA_TO_OPENAPI , self .model_style , test_category , True
43
+ )
44
+ start_time = time .time ()
45
+ model_name = self .model_name .replace ("-FC" , "" )
46
+ model_name = f"accounts/fireworks/models/{ model_name } "
47
+ if len (oai_tool ) > 0 :
48
+ response = self .client .chat .completions .create (
49
+ messages = message ,
50
+ model = model_name ,
51
+ temperature = self .temperature ,
52
+ max_tokens = self .max_tokens ,
53
+ top_p = self .top_p ,
54
+ tools = oai_tool ,
55
+ frequency_penalty = 0.4 ,
56
+ )
57
+ else :
58
+ response = self .client .chat .completions .create (
59
+ messages = message ,
60
+ model = model_name ,
61
+ temperature = self .temperature ,
62
+ max_tokens = self .max_tokens ,
63
+ top_p = self .top_p ,
64
+ )
65
+ latency = time .time () - start_time
66
+ try :
67
+ result = [
68
+ {func_call .function .name : func_call .function .arguments }
69
+ for func_call in response .choices [0 ].message .tool_calls
70
+ ]
71
+ except :
72
+ result = response .choices [0 ].message .content
73
+ metadata = {}
74
+ metadata ["input_tokens" ] = response .usage .prompt_tokens
75
+ metadata ["output_tokens" ] = response .usage .completion_tokens
76
+ metadata ["latency" ] = latency
77
+ return result , metadata
0 commit comments