Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 112 additions & 1 deletion akd/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@
# Search and Data Models
# =============================================================================

class SearchItem(BaseModel):
pass

class SearchResultItem(BaseModel):

class SearchResultItem(SearchItem):
"""Represents a single search result item with metadata."""

# Required fields
Expand Down Expand Up @@ -100,6 +103,114 @@ class ResearchData(BaseModel):
)


class PaperDataItem(SearchItem):
"""Represents a single paper data object retrieved from Semantic Scholar."""
paper_id: Optional[str] = Field(
...,
description="Semantic Scholar’s primary unique identifier for a paper.",
)
corpus_id: Optional[int] = Field(
...,
description="Semantic Scholar’s secondary unique identifier for a paper.",
)
external_ids: Optional[object] = Field(
None,
description="Valid URL to download data referenced in research. Leave None if unavailable.",
)
url: Optional[str] = Field(
...,
description="URL of the paper on the Semantic Scholar website.",
)
title: Optional[str] = Field(
...,
description="Title of the paper.",
)
abstract: Optional[str] = Field(
...,
description="The paper's abstract. Note that due to legal reasons, this may be missing even if we display an abstract on the website.",
)
venue: Optional[str] = Field(
...,
description="The name of the paper’s publication venue.",
)
publication_venue: Optional[object] = Field(
...,
description="An object that contains the following information about the journal or conference in which this paper was published: id (the venue’s unique ID), name (the venue’s name), type (the type of venue), alternate_names (an array of alternate names for the venue), and url (the venue’s website).",
)
year: Optional[int] = Field(
...,
description="The year the paper was published.",
)
reference_count: Optional[int] = Field(
...,
description="The total number of papers this paper references.",
)
citation_count: Optional[int] = Field(
...,
description="The total number of papers that references this paper.",
)
influential_citation_count: Optional[int] = Field(
...,
description="A subset of the citation count, where the cited publication has a significant impact on the citing publication.",
)
is_open_access: Optional[bool] = Field(
...,
description="Whether the paper is open access.",
)
open_access_pdf: Optional[object] = Field(
...,
description="An object that contains the following parameters: url (a link to the paper’s PDF), status, the paper's license, and a legal disclaimer.",
)
fields_of_study: Optional[list[str]] = Field(
...,
description="A list of the paper’s high-level academic categories from external sources.",
)
s2_fields_of_study: Optional[list[object]] = Field(
...,
description="An array of objects. Each object contains the following parameters: category (a field of study. The possible fields are the same as in fieldsOfStudy), and source (specifies whether the category was classified by Semantic Scholar or by an external source.",
)
publication_types: Optional[list[str]] = Field(
...,
description="The type of this publication.",
)
publication_date: Optional[str] = Field(
...,
description="The date when this paper was published, in YYYY-MM-DD format.",
)
journal: Optional[object] = Field(
...,
description="An object that contains the following parameters, if available: name (the journal name), volume (the journal’s volume number), and pages (the page number range)",
)
citation_styles: Optional[object] = Field(
...,
description="The BibTex bibliographical citation of the paper.",
)
authors: Optional[list[object]] = Field(
...,
description="List of authors corresponding to the paper.",
)
citations: Optional[list[object]] = Field(
...,
description="List of citations the paper has.",
)
references: Optional[list[object]] = Field(
...,
description="List of references used in the paper.",
)
embedding: Optional[object] = Field(
...,
description="The paper's embedding.",
)
tldr: Optional[object] = Field(
...,
description="Tldr version of the paper.",
)
doi: Optional[str] = Field(
...,
description="The DOI of the paper from the query."
)


# =============================================================================
# Extraction Schemas
# =============================================================================
Expand Down
195 changes: 170 additions & 25 deletions akd/tools/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
QueryAgentInputSchema,
QueryAgentOutputSchema,
)
from akd.structures import SearchResultItem
from akd.structures import SearchItem, SearchResultItem, PaperDataItem
from akd.tools._base import BaseTool, BaseToolConfig
from akd.tools.relevancy import EnhancedRelevancyChecker

Expand All @@ -45,7 +45,7 @@ class SearchToolOutputSchema(OutputSchema):
"""Schema for output of a tool for searching for information,
news, references, and other content."""

results: List[SearchResultItem] = Field(
results: List[SearchItem] = Field(
...,
description="List of search result items",
)
Expand Down Expand Up @@ -401,6 +401,8 @@ class SemanticScholarSearchToolConfig(BaseToolConfig):
default=(os.getenv("SEMANTIC_SCHOLAR_API_KEY") or os.getenv("S2_API_KEY")),
)
base_url: HttpUrl = Field(default="https://api.semanticscholar.org")
endpoint: str = Field(default="search",
description="API endpoint type: 'search', 'doi', or externalID such as 'ARXIV'")
max_results: int = Field(default=int(os.getenv("SEMANTIC_SCHOLAR_MAX_RESULTS", 10)))
fields: List[str] = Field(
default_factory=lambda: [
Expand Down Expand Up @@ -462,6 +464,7 @@ def from_params(
cls,
api_key: Optional[str] = None,
base_url: Optional[HttpUrl] = "https://api.semanticscholar.org",
endpoint: str = "search",
max_results: int = 10,
fields: Optional[List[str]] = None,
results_per_page: int = 100,
Expand All @@ -472,6 +475,7 @@ def from_params(
config_data = {
"api_key": api_key or os.getenv("SEMANTIC_SCHOLAR_API_KEY"),
"base_url": base_url or "https://api.semanticscholar.org",
"endpoint": endpoint,
"max_results": max_results,
"results_per_page": results_per_page,
"max_pages_per_query": max_pages_per_query,
Expand Down Expand Up @@ -502,7 +506,7 @@ async def _fetch_search_page(
Returns:
The JSON response dictionary from the API or None if an error occurs.
"""
search_url = f"{self.config.base_url}/graph/v1/paper/search"
search_url = f"{self.config.base_url}graph/v1/paper/search"
params = {
"query": query,
"offset": offset,
Expand Down Expand Up @@ -560,6 +564,130 @@ async def _fetch_search_page(
)

return None

def _parse_paper(
self,
item: Dict[str, Any],
doi: Optional[str],
) -> Optional[PaperDataItem]:
"""Parses a single paper from the Semantic Scholar API response."""
if (
not item
or not item.get("paperId")
):
return None
try:
paper_item = PaperDataItem(
paper_id=item.get("paperId") or None,
corpus_id=item.get("corpusId") or None,
external_ids=item.get("externalIds") or None,
url=item.get("url") or None,
title=item.get("title") or None,
abstract=item.get("abstract") or None,
venue=item.get("venue") or None,
publication_venue=item.get("publicationVenue") or None,
year=item.get("year") or None,
reference_count=item.get("referenceCount") or None,
citation_count=item.get("citationCount") or None,
influential_citation_count=item.get("influentialCitationCount") or None,
is_open_access=item.get("isOpenAccess") or None,
open_access_pdf=item.get("openAccessPdf") or None,
fields_of_study=item.get("fieldsOfStudy") or None,
s2_fields_of_study=item.get("s2FieldsOfStudy") or None,
publication_types=item.get("publicationTypes") or None,
publication_date=item.get("publicationDate") or None,
journal=item.get("journal") or None,
citation_styles=item.get("citationStyles") or None,
authors=item.get("authors") or None,
citations=item.get("citations") or None,
references=item.get("references") or None,
embedding=item.get("embedding") or None,
tldr=item.get("tldr") or None,
doi=doi or None,
)
if self.debug:
logger.debug(
f"Processed paper with DOI {doi}"
)
return paper_item
except Exception as e:
logger.debug(
f"Could not parse response to paper object for doi {doi}: {str(e)}",
)

async def _fetch_paper_by_doi(
self,
session: aiohttp.ClientSession,
query: str,
) -> Optional[Dict[str, Any]]:
"""
Fetches a single page of search results from Semantic Scholar.

Args:
session: The aiohttp session.
query: The search query.
offset: The starting offset for results.
limit: The number of results to fetch for this page.

Returns:
The JSON response dictionary from the API or None if an error occurs.
"""
search_url = f"{self.config.base_url}graph/v1/paper/DOI:{query}"
params = {
"fields": ",".join(self.config.fields)
}
headers = {}
if self.config.api_key:
api_key_value = self.config.api_key
if api_key_value:
headers["x-api-key"] = api_key_value

if self.debug:
logger.debug(
f"Fetching paper details via Semantic Scholar: query='{query}'",
)
if headers:
logger.debug("Using API Key.")

try:
async with session.get(
search_url,
params=params,
headers=headers,
) as response:
response.raise_for_status() # Raise exception for 4xx or 5xx errors
data = await response.json()
if self.debug:
logger.debug(data)
logger.debug(f"API Response Status: {response.status}")
# Avoid logging full data if it's too large or sensitive
logger.debug(
f"Received {len(data.get('data', []))} items. Total: {data.get('total')}, Offset: {data.get('offset')}, Next: {data.get('next')}",
)
return [self._parse_paper(item=data, doi=query)]

except aiohttp.ClientResponseError as e:
logger.error(
f"HTTP Error fetching Semantic Scholar for query '{query}': {e.status} {e.message}",
)
# Log request details that caused the error
logger.error(f"Request URL: {response.url}")
logger.error(f"Request Params: {params}")
logger.error(f"Response Headers: {response.headers}")
try:
error_body = await response.text()
logger.error(
f"Response Body: {error_body[:500]}",
) # Log part of the body
except Exception as read_err:
logger.error(f"Could not read error response body: {read_err}")

except Exception as e:
logger.error(
f"Failed to fetch Semantic Scholar results for query '{query}': {e}",
)

return []

def _parse_result(
self,
Expand Down Expand Up @@ -791,36 +919,53 @@ async def _arun(

if self.debug:
logger.debug(
f"Running Semantic Scholar Search: "
f"Running Semantic Scholar Search: ",
f"endpoint={self.config.endpoint}, "
f"final_max_results={final_max_results}, "
f"target_per_query={target_results_per_query}, "
f"num_queries={len(params.queries)}",
)

async with aiohttp.ClientSession() as session:
tasks = [
self._fetch_search_results_paginated(
session,
query,
params.category,
target_results_per_query,
)
for query in params.queries
]
results_per_query = await asyncio.gather(*tasks)
if self.config.endpoint == "doi":
async with aiohttp.ClientSession() as session:
tasks = [
self._fetch_paper_by_doi(
session,
query,
)
for query in params.queries
]
results_per_query = await asyncio.gather(*tasks)
all_raw_results = [item for sublist in results_per_query for item in sublist]
return SemanticScholarSearchToolOutputSchema(
results=all_raw_results,
)

else:
async with aiohttp.ClientSession() as session:
tasks = [
self._fetch_search_results_paginated(
session,
query,
params.category,
target_results_per_query,
)
for query in params.queries
]
results_per_query = await asyncio.gather(*tasks)

all_raw_results = [item for sublist in results_per_query for item in sublist]
all_raw_results = [item for sublist in results_per_query for item in sublist]

# Final processing: deduplicate across queries and trim to max_results
final_results = await self._process_final_results(
all_raw_results,
final_max_results,
)
# Final processing: deduplicate across queries and trim to max_results
final_results = await self._process_final_results(
all_raw_results,
final_max_results,
)

return SemanticScholarSearchToolOutputSchema(
results=final_results,
category=params.category, # Pass through the requested category
)
return SemanticScholarSearchToolOutputSchema(
results=final_results,
category=params.category, # Pass through the requested category
)


class SimpleAgenticLitSearchToolConfig(BaseToolConfig):
Expand Down