Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from src.health import health_router
from src.infini_gram_exception_handler import infini_gram_engine_exception_handler
from src.infinigram import infinigram_router
from src.documents import documents_router

# If LOG_FORMAT is "google:json" emit log message as JSON in a format Google Cloud can parse.
fmt = os.getenv("LOG_FORMAT")
Expand Down Expand Up @@ -51,6 +52,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, Any]:

app.include_router(health_router)
app.include_router(router=infinigram_router)
app.include_router(router=documents_router)
app.include_router(router=attribution_router)

tracer_provider = TracerProvider()
Expand Down
2 changes: 1 addition & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ version = "0.1.0"
requires-python = ">=3.12"
dependencies = [
"fastapi==0.111.0",
"infini-gram",
"infini-gram==2.5.2",
"numpy<2.0.0",
"opentelemetry-api==1.30.0",
"opentelemetry-exporter-gcp-trace==1.9.0",
Expand Down
58 changes: 42 additions & 16 deletions api/src/documents/documents_router.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Annotated, TypeAlias

from fastapi import APIRouter, Depends, Query
from infini_gram_processor.models import GetDocumentByIndexRequest
from infini_gram_processor.models import (
GetDocumentByIndexRequest,
GetDocumentByPointerRequest,
GetDocumentByRankRequest,
)

from src.documents.documents_service import (
DocumentsService,
Expand Down Expand Up @@ -52,34 +56,56 @@ def search_documents(
return result


@documents_router.get("/{index}/documents/{document_index}", tags=["documents"])
@documents_router.post("/{index}/get_document_by_rank", tags=["documents"])
def get_document_by_rank(
documents_service: DocumentsServiceDependency,
body: GetDocumentByRankRequest,
) -> InfiniGramDocumentResponse:
result = documents_service.get_document_by_rank(
shard=body.shard,
rank=body.rank,
needle_length=body.needle_length,
maximum_context_length=body.maximum_context_length,
)

return result


@documents_router.post("/{index}/get_document_by_pointer", tags=["documents"])
def get_document_by_pointer(
documents_service: DocumentsServiceDependency,
body: GetDocumentByPointerRequest,
) -> InfiniGramDocumentResponse:
result = documents_service.get_document_by_pointer(
shard=body.shard,
pointer=body.pointer,
needle_length=body.needle_length,
maximum_context_length=body.maximum_context_length,
)

return result


@documents_router.post("/{index}/get_document_by_index", tags=["documents"])
def get_document_by_index(
documents_service: DocumentsServiceDependency,
document_index: int,
maximum_document_display_length: MaximumDocumentDisplayLengthType = 10,
body: GetDocumentByIndexRequest,
) -> InfiniGramDocumentResponse:
result = documents_service.get_document_by_index(
document_index=int(document_index),
maximum_context_length=maximum_document_display_length,
document_index=body.document_index,
maximum_context_length=body.maximum_context_length,
)

return result


@documents_router.get("/{index}/documents", tags=["documents"])
@documents_router.post("/{index}/get_documents_by_index", tags=["documents"])
def get_documents_by_index(
documents_service: DocumentsServiceDependency,
document_indexes: Annotated[list[int], Query()],
maximum_document_display_length: MaximumDocumentDisplayLengthType = 10,
body: list[GetDocumentByIndexRequest],
) -> InfiniGramDocumentsResponse:
result = documents_service.get_multiple_documents_by_index(
document_requests=[
GetDocumentByIndexRequest(
document_index=document_index,
maximum_context_length=maximum_document_display_length,
)
for document_index in document_indexes
],
document_requests=body,
)

return result
93 changes: 93 additions & 0 deletions api/src/documents/documents_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
BaseInfiniGramResponse,
Document,
GetDocumentByIndexRequest,
GetDocumentByPointerRequest,
GetDocumentByRankRequest,
)
from opentelemetry import trace

Expand Down Expand Up @@ -73,6 +75,97 @@ def search_documents(
page_count=ceil(search_documents_result.total_documents / page_size),
)

@tracer.start_as_current_span("documents_service/get_document_by_rank")
def get_document_by_rank(
self, shard: int, rank: int, needle_length: int, maximum_context_length: int
) -> InfiniGramDocumentResponse:
document = self.infini_gram_processor.get_document_by_rank(
shard=shard,
rank=rank,
needle_length=needle_length,
maximum_context_length=maximum_context_length,
)

return InfiniGramDocumentResponse(
index=self.infini_gram_processor.index,
document_index=document.document_index,
document_length=document.document_length,
display_length=document.display_length,
needle_offset=document.needle_offset,
metadata=document.metadata,
token_ids=document.token_ids,
text=document.text,
)

@tracer.start_as_current_span("documents_service/get_multiple_documents_by_rank")
def get_multiple_documents_by_rank(
self,
document_requests: Iterable[GetDocumentByRankRequest],
) -> InfiniGramDocumentsResponse:
documents = self.infini_gram_processor.get_documents_by_ranks(
document_requests=document_requests,
)
mapped_documents = [
Document(
document_index=document.document_index,
document_length=document.document_length,
display_length=document.display_length,
needle_offset=document.needle_offset,
metadata=document.metadata,
token_ids=document.token_ids,
text=document.text,
)
for document in documents
]
return InfiniGramDocumentsResponse(
index=self.infini_gram_processor.index, documents=mapped_documents
)

@tracer.start_as_current_span("documents_service/get_document_by_pointer")
def get_document_by_pointer(
self, shard: int, pointer: int, needle_length: int, maximum_context_length: int
) -> InfiniGramDocumentResponse:
document = self.infini_gram_processor.get_document_by_pointer(
shard=shard,
pointer=pointer,
needle_length=needle_length,
maximum_context_length=maximum_context_length,
)

return InfiniGramDocumentResponse(
index=self.infini_gram_processor.index,
document_index=document.document_index,
document_length=document.document_length,
display_length=document.display_length,
needle_offset=document.needle_offset,
metadata=document.metadata,
token_ids=document.token_ids,
text=document.text,
)

@tracer.start_as_current_span("documents_service/get_multiple_documents_by_pointer")
def get_multiple_documents_by_pointer(
self,
document_requests: Iterable[GetDocumentByPointerRequest],
) -> InfiniGramDocumentsResponse:
documents = self.infini_gram_processor.get_documents_by_pointers(
document_requests=document_requests,
)
mapped_documents = [
Document(
document_index=document.document_index,
document_length=document.document_length,
display_length=document.display_length,
needle_offset=document.needle_offset,
metadata=document.metadata,
token_ids=document.token_ids,
text=document.text,
)
for document in documents
]
return InfiniGramDocumentsResponse(
index=self.infini_gram_processor.index, documents=mapped_documents
)
@tracer.start_as_current_span("documents_service/get_document_by_index")
def get_document_by_index(
self, document_index: int, maximum_context_length: int
Expand Down
103 changes: 103 additions & 0 deletions api/src/infinigram/infinigram_router.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,112 @@
from fastapi import APIRouter
from infini_gram_processor.index_mappings import AvailableInfiniGramIndexId
from infini_gram_processor.models import (
InfiniGramFindRequest,
FindResponse,
InfiniGramFindCnfRequest,
FindCnfResponse,
InfiniGramCountRequest,
CountResponse,
InfiniGramCountCnfRequest,
CountCnfResponse,
InfiniGramProbRequest,
ProbResponse,
InfiniGramNtdRequest,
NtdResponse,
InfiniGramInfgramProbRequest,
InfgramProbResponse,
InfiniGramInfgramNtdRequest,
InfgramNtdResponse,
)
from src.infinigram.infini_gram_dependency import InfiniGramProcessorDependency

infinigram_router = APIRouter()


@infinigram_router.get(path="/indexes")
def get_available_indexes() -> list[AvailableInfiniGramIndexId]:
return [index for index in AvailableInfiniGramIndexId]


@infinigram_router.post(path="/{index}/find", description="Find the locations of an n-gram in the index")
def find(
infini_gram_processor: InfiniGramProcessorDependency,
body: InfiniGramFindRequest,
) -> FindResponse:
return infini_gram_processor.find(
query=body.query,
)


@infinigram_router.post(path="/{index}/find_cnf", description="Find the locations of a CNF query in the index")
def find_cnf(
infini_gram_processor: InfiniGramProcessorDependency,
body: InfiniGramFindCnfRequest,
) -> FindCnfResponse:
return infini_gram_processor.find_cnf(
query=body.query,
max_clause_freq=body.max_clause_freq,
max_diff_tokens=body.max_diff_tokens,
)


@infinigram_router.post(path="/{index}/count", description="Count the number of times an n-gram appears in the index")
def count(
infini_gram_processor: InfiniGramProcessorDependency,
body: InfiniGramCountRequest,
) -> CountResponse:
return infini_gram_processor.count(query=body.query)


@infinigram_router.post(path="/{index}/count_cnf", description="Count the number of times a CNF query appears in the index")
def count_cnf(
infini_gram_processor: InfiniGramProcessorDependency,
body: InfiniGramCountCnfRequest,
) -> CountCnfResponse:
return infini_gram_processor.count_cnf(
query=body.query,
max_clause_freq=body.max_clause_freq,
max_diff_tokens=body.max_diff_tokens,
)


@infinigram_router.post(path="/{index}/prob", description="Compute the n-gram probability of the last token conditioned on all previous tokens")
def prob(
infini_gram_processor: InfiniGramProcessorDependency,
body: InfiniGramProbRequest,
) -> ProbResponse:
return infini_gram_processor.prob(
query=body.query,
)


@infinigram_router.post(path="/{index}/ntd", description="Compute the distribution of next token conditioned on all tokens in the query according to the n-gram model")
def ntd(
infini_gram_processor: InfiniGramProcessorDependency,
body: InfiniGramNtdRequest,
) -> NtdResponse:
return infini_gram_processor.ntd(
query=body.query,
max_support=body.max_support,
)


@infinigram_router.post(path="/{index}/infgram_prob", description="Compute the infinity-gram probability of the last token conditioned on all previous tokens")
def infgram_prob(
infini_gram_processor: InfiniGramProcessorDependency,
body: InfiniGramInfgramProbRequest,
) -> InfgramProbResponse:
return infini_gram_processor.infgram_prob(
query=body.query,
)


@infinigram_router.post(path="/{index}/infgram_ntd", description="Compute the distribution of next token conditioned on all tokens in the query according to the infinity-gram model")
def infgram_ntd(
infini_gram_processor: InfiniGramProcessorDependency,
body: InfiniGramInfgramNtdRequest,
) -> InfgramNtdResponse:
return infini_gram_processor.infgram_ntd(
query=body.query,
max_support=body.max_support,
)
8 changes: 4 additions & 4 deletions attribution_worker/get_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
AttributionDocument,
AttributionSpan,
Document,
GetDocumentByPointerRequest,
GetDocumentByPointerGroupedRequest,
SpanRankingMethod,
)
from infini_gram_processor.processor import InfiniGramProcessor
Expand Down Expand Up @@ -100,15 +100,15 @@ def get_document_requests(
input_token_ids: list[int],
maximum_documents_per_span: int,
maximum_context_length: int,
) -> list[GetDocumentByPointerRequest]:
document_request_by_span: list[GetDocumentByPointerRequest] = []
) -> list[GetDocumentByPointerGroupedRequest]:
document_request_by_span: list[GetDocumentByPointerGroupedRequest] = []
for span in spans:
docs = span["docs"]
if len(docs) > maximum_documents_per_span:
random.seed(42) # For reproducibility
docs = random.sample(docs, maximum_documents_per_span)
document_request_by_span.append(
GetDocumentByPointerRequest(
GetDocumentByPointerGroupedRequest(
docs=docs,
span_ids=input_token_ids[span["l"] : span["r"]],
needle_length=span["length"],
Expand Down
2 changes: 1 addition & 1 deletion attribution_worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async def attribution_job(
)

documents_by_span = await asyncio.to_thread(
infini_gram_index.get_documents_by_pointers,
infini_gram_index.get_documents_by_pointers_grouped,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed this function to give way to a regular version of get_documents_by_pointers

document_request_by_span=document_request_by_span,
)

Expand Down
2 changes: 1 addition & 1 deletion packages/infini-gram-processor/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ requires-python = ">=3.12"
dependencies = [
"opentelemetry-api==1.30.0",
"opentelemetry-sdk==1.30.0",
"infini-gram",
"infini-gram==2.5.2",
"transformers==4.49.0",
]

Expand Down
Loading