5
5
If you have any changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
6
6
"""
7
7
8
+ import base64
8
9
import dataclasses
9
10
from enum import auto , IntEnum
11
+ from io import BytesIO
10
12
from typing import List , Any , Dict , Union , Tuple
11
13
12
14
@@ -34,6 +36,9 @@ class SeparatorStyle(IntEnum):
34
36
YUAN2 = auto ()
35
37
36
38
39
+ IMAGE_PLACEHOLDER_STR = "$$<image>$$"
40
+
41
+
37
42
@dataclasses .dataclass
38
43
class Conversation :
39
44
"""A class that manages prompt templates and keeps all conversation history."""
@@ -47,6 +52,7 @@ class Conversation:
47
52
# The names of two roles
48
53
roles : Tuple [str ] = ("USER" , "ASSISTANT" )
49
54
# All messages. Each item is (role, message).
55
+ # Each message is either a string or a tuple of (string, List[image_url]).
50
56
messages : List [List [str ]] = ()
51
57
# The number of few shot examples
52
58
offset : int = 0
@@ -77,6 +83,7 @@ def get_prompt(self) -> str:
77
83
if message :
78
84
if type (message ) is tuple :
79
85
message , images = message
86
+ message = IMAGE_PLACEHOLDER_STR * len (images ) + message
80
87
ret += role + ": " + message + seps [i % 2 ]
81
88
else :
82
89
ret += role + ":"
@@ -289,11 +296,52 @@ def update_last_message(self, message: str):
289
296
"""
290
297
self .messages [- 1 ][1 ] = message
291
298
299
+ def convert_image_to_base64 (self , image ):
300
+ """Given an image, return the base64 encoded image string."""
301
+ from PIL import Image
302
+ import requests
303
+
304
+ # Load image if it has not been loaded in yet
305
+ if type (image ) == str :
306
+ if image .startswith ("http://" ) or image .startswith ("https://" ):
307
+ response = requests .get (image )
308
+ image = Image .open (BytesIO (response .content )).convert ("RGB" )
309
+ elif "base64" in image :
310
+ # OpenAI format is: data:image/jpeg;base64,{base64_encoded_image_str}
311
+ return image .split ("," )[1 ]
312
+ else :
313
+ image = Image .open (image ).convert ("RGB" )
314
+
315
+ max_hw , min_hw = max (image .size ), min (image .size )
316
+ aspect_ratio = max_hw / min_hw
317
+ max_len , min_len = 2048 , 2048
318
+ shortest_edge = int (min (max_len / aspect_ratio , min_len , min_hw ))
319
+ longest_edge = int (shortest_edge * aspect_ratio )
320
+ W , H = image .size
321
+ if longest_edge != max (image .size ):
322
+ if H > W :
323
+ H , W = longest_edge , shortest_edge
324
+ else :
325
+ H , W = shortest_edge , longest_edge
326
+ image = image .resize ((W , H ))
327
+
328
+ buffered = BytesIO ()
329
+ image .save (buffered , format = "PNG" )
330
+ img_b64_str = base64 .b64encode (buffered .getvalue ()).decode ()
331
+
332
+ return img_b64_str
333
+
292
334
def to_gradio_chatbot (self ):
293
335
"""Convert the conversation to gradio chatbot format."""
294
336
ret = []
295
337
for i , (role , msg ) in enumerate (self .messages [self .offset :]):
296
338
if i % 2 == 0 :
339
+ if type (msg ) is tuple :
340
+ msg , image = msg
341
+ img_b64_str = image [0 ] # Only one image on gradio at one time
342
+ img_str = f'<img src="data:image/jpeg;base64,{ img_b64_str } " alt="user upload image" />'
343
+ msg = img_str + msg .replace ("<image>\n " , "" ).strip ()
344
+
297
345
ret .append ([msg , None ])
298
346
else :
299
347
ret [- 1 ][- 1 ] = msg
@@ -314,6 +362,12 @@ def to_openai_api_messages(self):
314
362
ret .append ({"role" : "assistant" , "content" : msg })
315
363
return ret
316
364
365
+ def extract_text_from_messages (self ):
366
+ return [
367
+ (role , message [0 ]) if type (message ) is tuple else (role , message )
368
+ for role , message in self .messages
369
+ ]
370
+
317
371
def copy (self ):
318
372
return Conversation (
319
373
name = self .name ,
@@ -334,7 +388,7 @@ def dict(self):
334
388
"template_name" : self .name ,
335
389
"system_message" : self .system_message ,
336
390
"roles" : self .roles ,
337
- "messages" : self .messages ,
391
+ "messages" : self .extract_text_from_messages () ,
338
392
"offset" : self .offset ,
339
393
}
340
394
0 commit comments