Skip to content

Commit cba091b

Browse files
committed
adds vision to chat
Signed-off-by: Ian Eaves <[email protected]>
1 parent 149c9f1 commit cba091b

File tree

17 files changed

+959
-338
lines changed

17 files changed

+959
-338
lines changed

pyproject.toml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,26 @@ dictionary = ".codespelldict"
7777
ignore-words-list = ["cann", "clos", "creat", "ro", "hastable", "shouldnot", "mountns", "passt" ,"assertin"]
7878
check-hidden = true
7979

80+
[tool.ruff]
81+
line-length = 120
82+
target-version = "py312"
83+
include = ["\\.pyi?$"]
84+
exclude = [
85+
"/\\.git",
86+
"/\\.tox",
87+
"/\\.venv",
88+
"/\\.history",
89+
"/build",
90+
"/dist",
91+
"/docs",
92+
"/hack",
93+
"/venv"
94+
]
95+
96+
[tool.ruff.format]
97+
preview = true
98+
quote-style = "preserve"
99+
80100
[tool.pytest.ini_options]
81101
testpaths = ["."]
82102
log_cli = true

ramalama/chat.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ramalama.config import CONFIG
1515
from ramalama.console import EMOJI, should_colorize
1616
from ramalama.engine import dry_run, stop_container
17-
from ramalama.file_upload.file_loader import FileUpLoader
17+
from ramalama.file_loaders.file_manager import OpanAIChatAPIMessageBuilder
1818
from ramalama.logger import logger
1919

2020

@@ -90,10 +90,9 @@ def prep_rag_message(self):
9090
if (context := getattr(self.args, "rag", None)) is None:
9191
return
9292

93-
if not (message_content := FileUpLoader(context).load()):
94-
return
95-
96-
self.conversation_history.append({"role": "system", "content": message_content})
93+
builder = OpanAIChatAPIMessageBuilder()
94+
messages = builder.load(context)
95+
self.conversation_history.extend(messages)
9796

9897
def handle_args(self):
9998
prompt = " ".join(self.args.ARGS) if self.args.ARGS else None

ramalama/file_loaders/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ramalama.file_loaders import file_manager, file_types
2+
3+
__all__ = ["file_manager", "file_types"]

ramalama/file_loaders/file_manager.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import os
2+
from abc import ABC, abstractmethod
3+
from string import Template
4+
from typing import Dict, List, Type
5+
from warnings import warn
6+
7+
from ramalama.file_loaders.file_types import base, image, txt
8+
9+
10+
class BaseFileManager(ABC):
11+
"""
12+
Base class for file upload handlers.
13+
This class should be extended by specific file type handlers.
14+
"""
15+
16+
def __init__(self):
17+
self.loaders = {ext.lower(): loader() for loader in self.get_loaders() for ext in loader.file_extensions()}
18+
19+
def _get_loader(self, file: str) -> base.BaseFileLoader:
20+
return self.loaders[os.path.splitext(file)[1].lower()]
21+
22+
@abstractmethod
23+
def load(self):
24+
pass
25+
26+
@classmethod
27+
@abstractmethod
28+
def get_loaders(cls) -> List[Type[base.BaseFileLoader]]:
29+
pass
30+
31+
32+
class TextFileManager(BaseFileManager):
33+
def __init__(self, delim_string: str = "<!--start_document $name-->"):
34+
self.document_delimiter: Template = Template(delim_string)
35+
super().__init__()
36+
37+
@classmethod
38+
def get_loaders(cls) -> List[Type[base.BaseFileLoader]]:
39+
return [txt.TXTFileLoader]
40+
41+
def load(self, files: list[str]) -> str:
42+
"""
43+
Generate the output string by concatenating the processed files.
44+
"""
45+
contents = []
46+
for file in files:
47+
loader = self._get_loader(file)
48+
content = f"\n{self.document_delimiter.substitute(name=file)}\n{loader.load(file)}"
49+
contents.append(content)
50+
51+
return "".join(contents)
52+
53+
54+
class ImageFileManager(BaseFileManager):
55+
@classmethod
56+
def get_loaders(cls) -> List[Type[base.BaseFileLoader]]:
57+
return [image.BasicImageFileLoader]
58+
59+
def load(self, files: list[str]) -> list[str]:
60+
"""
61+
Generate the output string by concatenating the processed image files.
62+
"""
63+
return [self._get_loader(file).load(file) for file in files]
64+
65+
66+
def unsupported_files_warning(unsupported_files: list[str], supported_extensions: list[str]):
67+
supported_extensions = sorted(supported_extensions)
68+
formatted_supported = ", ".join(supported_extensions)
69+
formatted_unsupported = "- " + "\n- ".join(unsupported_files)
70+
warn(
71+
f"""
72+
⚠️ Unsupported file types detected!
73+
74+
Ramalama supports the following file types:
75+
{formatted_supported}
76+
77+
The following unsupported files were found and ignored:
78+
{formatted_unsupported}
79+
""".strip()
80+
)
81+
82+
83+
class OpanAIChatAPIMessageBuilder:
84+
def __init__(self):
85+
self.text_manager = TextFileManager()
86+
self.image_manager = ImageFileManager()
87+
88+
def partition_files(self, file_path: str) -> tuple[list[str], list[str], list[str]]:
89+
if not os.path.exists(file_path):
90+
raise ValueError(f"{file_path} does not exist.")
91+
92+
if not os.path.isdir(file_path):
93+
files = [file_path]
94+
else:
95+
files = [os.path.join(root, name) for root, _, files in os.walk(file_path) for name in files]
96+
97+
text_files = []
98+
image_files = []
99+
unsupported_files = []
100+
101+
for file in files:
102+
file_type = os.path.splitext(file)[1].lower() # Convert to lowercase for case-insensitive matching
103+
if file_type in self.text_manager.loaders:
104+
text_files.append(file)
105+
elif file_type in self.image_manager.loaders:
106+
image_files.append(file)
107+
else:
108+
unsupported_files.append(file)
109+
110+
return text_files, image_files, unsupported_files
111+
112+
def supported_extensions(self) -> set[str]:
113+
return self.text_manager.loaders.keys() | self.image_manager.loaders.keys()
114+
115+
def load(self, file_path: str) -> list[dict]:
116+
text_files, image_files, unsupported_files = self.partition_files(file_path)
117+
118+
if unsupported_files:
119+
unsupported_files_warning(unsupported_files, list(self.supported_extensions()))
120+
121+
messages = []
122+
if text_files:
123+
messages.append({"role": "system", "content": self.text_manager.load(text_files)})
124+
if image_files:
125+
message = {"role": "system", "content": []}
126+
for content in self.image_manager.load(image_files):
127+
message['content'].append({"type": "image_url", "image_url": {"url": content}})
128+
messages.append(message)
129+
return messages
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ramalama.file_loaders.file_types import base, image, txt
2+
3+
__all__ = ["base", "txt", "image"]
Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
11
from abc import ABC, abstractmethod
22

33

4-
class BaseFileUpload(ABC):
4+
class BaseFileLoader(ABC):
55
"""
66
Base class for file upload handlers.
77
This class should be extended by specific file type handlers.
88
"""
99

10-
def __init__(self, file):
11-
self.file = file
12-
10+
@staticmethod
1311
@abstractmethod
14-
def load(self) -> str:
12+
def load(file: str) -> str:
1513
"""
1614
Load the content of the file.
1715
This method should be implemented by subclasses to handle specific file types.
1816
"""
1917
pass
18+
19+
@staticmethod
20+
@abstractmethod
21+
def file_extensions() -> set[str]:
22+
"""
23+
Get the file extension supported by this file type handler.
24+
"""
25+
pass
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import base64
2+
import mimetypes
3+
4+
from ramalama.file_loaders.file_types.base import BaseFileLoader
5+
6+
7+
class BasicImageFileLoader(BaseFileLoader):
8+
@staticmethod
9+
def file_extensions() -> set[str]:
10+
return {
11+
".jpg",
12+
".jpeg",
13+
".png",
14+
".gif",
15+
".bmp",
16+
".tiff",
17+
".tif",
18+
".webp",
19+
".ico",
20+
}
21+
22+
@staticmethod
23+
def load(file: str) -> str:
24+
"""
25+
Load the content of the text file.
26+
"""
27+
28+
mime_type, _ = mimetypes.guess_type(file)
29+
with open(file, "rb") as f:
30+
data = base64.b64encode(f.read()).decode("utf-8")
31+
32+
return f"data:{mime_type};base64,{data}"
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from ramalama.file_loaders.file_types.base import BaseFileLoader
2+
3+
4+
class PDFFileLoader(BaseFileLoader):
5+
@staticmethod
6+
def file_extensions() -> set[str]:
7+
return {".pdf"}
8+
9+
@staticmethod
10+
def load(file: str) -> str:
11+
"""
12+
Load the content of the PDF file.
13+
This method should be implemented to handle PDF file reading.
14+
"""
15+
return ""
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from ramalama.file_loaders.file_types.base import BaseFileLoader
2+
3+
4+
class TXTFileLoader(BaseFileLoader):
5+
@staticmethod
6+
def file_extensions() -> set[str]:
7+
return {
8+
".txt",
9+
".sh",
10+
".md",
11+
".yaml",
12+
".yml",
13+
".json",
14+
".csv",
15+
".toml",
16+
}
17+
18+
@staticmethod
19+
def load(file: str) -> str:
20+
"""
21+
Load the content of the text file.
22+
"""
23+
24+
# TODO: Support for non-default encodings?
25+
with open(file, "r") as f:
26+
return f.read()

ramalama/file_upload/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

0 commit comments

Comments
 (0)