Skip to content

Commit 81a8b86

Browse files
committed
Adds file upload feature
Signed-off-by: Ian Eaves <[email protected]>
1 parent 370f1cc commit 81a8b86

21 files changed

+1021
-1
lines changed

docs/ramalama-chat.1.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,13 @@ Show this help message and exit
2929
#### **--prefix**
3030
Prefix for the user prompt (default: 🦭 > )
3131

32+
#### **--rag**=path
33+
A file or directory of files to be loaded and provided as local context in the chat history.
34+
3235
#### **--url**=URL
3336
The host to send requests to (default: http://127.0.0.1:8080)
3437

38+
3539
## EXAMPLES
3640

3741
Communicate with the default local OpenAI REST API. (http://127.0.0.1:8080)

ramalama/chat.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ramalama.config import CONFIG
1515
from ramalama.console import EMOJI, should_colorize
1616
from ramalama.engine import dry_run, stop_container
17+
from ramalama.file_upload.file_loader import FileUpLoader
1718
from ramalama.logger import logger
1819

1920

@@ -72,6 +73,16 @@ def __init__(self, args):
7273
self.prompt = args.prefix
7374

7475
self.url = f"{args.url}/chat/completions"
76+
self.prep_rag_message()
77+
78+
def prep_rag_message(self):
79+
if (context := getattr(self.args, "rag", None)) is None:
80+
return
81+
82+
if not (message_content := FileUpLoader(context).load()):
83+
return
84+
85+
self.conversation_history.append({"role": "system", "content": message_content})
7586

7687
def handle_args(self):
7788
prompt = " ".join(self.args.ARGS) if self.args.ARGS else None

ramalama/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ def chat_parser(subparsers):
905905
help='possible values are "never", "always" and "auto".',
906906
)
907907
parser.add_argument("--prefix", type=str, help="prefix for the user prompt", default=default_prefix())
908+
parser.add_argument("--rag", type=str, help="a file or directory to use as context for the chat")
908909
parser.add_argument("--url", type=str, default="http://127.0.0.1:8080/v1", help="the url to send requests to")
909910
parser.add_argument("MODEL", completer=local_models) # positional argument
910911
parser.add_argument(
@@ -925,6 +926,7 @@ def run_parser(subparsers):
925926
)
926927
parser.add_argument("--prefix", type=str, help="prefix for the user prompt", default=default_prefix())
927928
parser.add_argument("MODEL", completer=local_models) # positional argument
929+
928930
parser.add_argument(
929931
"ARGS",
930932
nargs="*",

ramalama/file_upload/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ramalama.file_upload import file_loader, file_types
2+
3+
__all__ = ["file_loader", "file_types"]

ramalama/file_upload/file_loader.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
from string import Template
3+
from typing import Type
4+
from warnings import warn
5+
6+
from ramalama.file_upload.file_types import base, txt
7+
8+
SUPPORTED_EXTENSIONS = {
9+
'.txt': txt.TXTFileUpload,
10+
'.sh': txt.TXTFileUpload,
11+
'.md': txt.TXTFileUpload,
12+
'.yaml': txt.TXTFileUpload,
13+
'.yml': txt.TXTFileUpload,
14+
'.json': txt.TXTFileUpload,
15+
'.csv': txt.TXTFileUpload,
16+
'.toml': txt.TXTFileUpload,
17+
}
18+
19+
20+
class BaseFileUploader:
21+
"""
22+
Base class for file upload handlers.
23+
This class should be extended by specific file type handlers.
24+
"""
25+
26+
def __init__(self, files: list[Type[base.BaseFileUpload]], delim_string: str = "<!--start_document $name-->"):
27+
self.files = files
28+
self.document_delimiter: Template = Template(delim_string)
29+
30+
def load(self) -> str:
31+
"""
32+
Generate the output string by concatenating the processed files.
33+
"""
34+
output = (f"\n{self.document_delimiter.substitute(name=f.file)}\n{f.load()}" for f in self.files)
35+
return "".join(output)
36+
37+
38+
class FileUpLoader(BaseFileUploader):
39+
def __init__(self, file_path: str):
40+
if not os.path.exists(file_path):
41+
raise ValueError(f"{file_path} does not exist.")
42+
43+
if not os.path.isdir(file_path):
44+
files = [file_path]
45+
else:
46+
files = [
47+
os.path.join(file_path, f) for f in os.listdir(file_path) if os.path.isfile(os.path.join(file_path, f))
48+
]
49+
50+
extensions = [os.path.splitext(f)[1].lower() for f in files]
51+
52+
if set(extensions) - set(SUPPORTED_EXTENSIONS):
53+
warning_message = (
54+
f"Unsupported file types found: {set(extensions) - set(SUPPORTED_EXTENSIONS)}\n"
55+
f"Supported types are: {set(SUPPORTED_EXTENSIONS.keys())}"
56+
)
57+
warn(warning_message)
58+
59+
files = [SUPPORTED_EXTENSIONS[ext](file=f) for ext, f in zip(extensions, files) if ext in SUPPORTED_EXTENSIONS]
60+
super().__init__(files=files)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ramalama.file_upload.file_types import base, pdf, txt
2+
3+
__all__ = ["base", "pdf", "txt"]
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
class BaseFileUpload(ABC):
5+
"""
6+
Base class for file upload handlers.
7+
This class should be extended by specific file type handlers.
8+
"""
9+
10+
def __init__(self, file):
11+
self.file = file
12+
13+
@abstractmethod
14+
def load(self) -> str:
15+
"""
16+
Load the content of the file.
17+
This method should be implemented by subclasses to handle specific file types.
18+
"""
19+
pass
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from ramalama.file_upload.file_types.base import BaseFileUpload
2+
3+
4+
class PDFFileUpload(BaseFileUpload):
5+
def load(self) -> str:
6+
"""
7+
Load the content of the PDF file.
8+
This method should be implemented to handle PDF file reading.
9+
"""
10+
return ""
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from ramalama.file_upload.file_types.base import BaseFileUpload
2+
3+
4+
class TXTFileUpload(BaseFileUpload):
5+
def load(self) -> str:
6+
"""
7+
Load the content of the text file.
8+
"""
9+
10+
# TODO: Support for non-default encodings?
11+
with open(self.file, 'r') as f:
12+
return f.read()

test/unit/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# # tests/conftest.py
2+
# import pytest
3+
4+
# @pytest.fixture(autouse=True)
5+
# def set_container_engine_env(monkeypatch):
6+
# monkeypatch.setenv("RAMALAMA_CONTAINER_ENGINE", "docker")

0 commit comments

Comments
 (0)