Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 108 additions & 1 deletion akd/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
# Search and Data Models
# =============================================================================


class SearchResultItem(BaseModel):
"""Represents a single search result item with metadata."""

Expand Down Expand Up @@ -100,6 +99,114 @@ class ResearchData(BaseModel):
)


class PaperDataItem(BaseModel):
"""Represents a single paper data object retrieved from Semantic Scholar."""
paper_id: Optional[str] = Field(
...,
description="Semantic Scholar’s primary unique identifier for a paper.",
)
corpus_id: Optional[int] = Field(
...,
description="Semantic Scholar’s secondary unique identifier for a paper.",
)
external_ids: Optional[object] = Field(
None,
description="Valid URL to download data referenced in research. Leave None if unavailable.",
)
url: Optional[str] = Field(
...,
description="URL of the paper on the Semantic Scholar website.",
)
title: Optional[str] = Field(
...,
description="Title of the paper.",
)
abstract: Optional[str] = Field(
...,
description="The paper's abstract. Note that due to legal reasons, this may be missing even if we display an abstract on the website.",
)
venue: Optional[str] = Field(
...,
description="The name of the paper’s publication venue.",
)
publication_venue: Optional[object] = Field(
...,
description="An object that contains the following information about the journal or conference in which this paper was published: id (the venue’s unique ID), name (the venue’s name), type (the type of venue), alternate_names (an array of alternate names for the venue), and url (the venue’s website).",
)
year: Optional[int] = Field(
...,
description="The year the paper was published.",
)
reference_count: Optional[int] = Field(
...,
description="The total number of papers this paper references.",
)
citation_count: Optional[int] = Field(
...,
description="The total number of papers that references this paper.",
)
influential_citation_count: Optional[int] = Field(
...,
description="A subset of the citation count, where the cited publication has a significant impact on the citing publication.",
)
is_open_access: Optional[bool] = Field(
...,
description="Whether the paper is open access.",
)
open_access_pdf: Optional[object] = Field(
...,
description="An object that contains the following parameters: url (a link to the paper’s PDF), status, the paper's license, and a legal disclaimer.",
)
fields_of_study: Optional[list[str]] = Field(
...,
description="A list of the paper’s high-level academic categories from external sources.",
)
s2_fields_of_study: Optional[list[object]] = Field(
...,
description="An array of objects. Each object contains the following parameters: category (a field of study. The possible fields are the same as in fieldsOfStudy), and source (specifies whether the category was classified by Semantic Scholar or by an external source.",
)
publication_types: Optional[list[str]] = Field(
...,
description="The type of this publication.",
)
publication_date: Optional[str] = Field(
...,
description="The date when this paper was published, in YYYY-MM-DD format.",
)
journal: Optional[object] = Field(
...,
description="An object that contains the following parameters, if available: name (the journal name), volume (the journal’s volume number), and pages (the page number range)",
)
citation_styles: Optional[object] = Field(
...,
description="The BibTex bibliographical citation of the paper.",
)
authors: Optional[list[object]] = Field(
...,
description="List of authors corresponding to the paper.",
)
citations: Optional[list[object]] = Field(
...,
description="List of citations the paper has.",
)
references: Optional[list[object]] = Field(
...,
description="List of references used in the paper.",
)
embedding: Optional[object] = Field(
...,
description="The paper's embedding.",
)
tldr: Optional[object] = Field(
...,
description="Tldr version of the paper.",
)
doi: Optional[str] = Field(
...,
description="The DOI of the paper from the query."
)


# =============================================================================
# Extraction Schemas
# =============================================================================
Expand Down
157 changes: 153 additions & 4 deletions akd/tools/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
QueryAgentInputSchema,
QueryAgentOutputSchema,
)
from akd.structures import SearchResultItem
from akd.structures import SearchResultItem, PaperDataItem
from akd.tools._base import BaseTool, BaseToolConfig
from akd.tools.relevancy import EnhancedRelevancyChecker

Expand Down Expand Up @@ -502,7 +502,7 @@ async def _fetch_search_page(
Returns:
The JSON response dictionary from the API or None if an error occurs.
"""
search_url = f"{self.config.base_url}/graph/v1/paper/search"
search_url = f"{self.config.base_url}graph/v1/paper/search"
params = {
"query": query,
"offset": offset,
Expand Down Expand Up @@ -560,6 +560,130 @@ async def _fetch_search_page(
)

return None

def _parse_paper(
self,
item: Dict[str, Any],
doi: Optional[str],
) -> Optional[PaperDataItem]:
"""Parses a single paper from the Semantic Scholar API response."""
if (
not item
or not item.get("paperId")
):
return None
try:
paper_item = PaperDataItem(
paper_id=item.get("paperId"),
corpus_id=item.get("corpusId"),
external_ids=item.get("externalIds"),
url=item.get("url"),
title=item.get("title"),
abstract=item.get("abstract"),
venue=item.get("venue"),
publication_venue=item.get("publicationVenue"),
year=item.get("year"),
reference_count=item.get("referenceCount"),
citation_count=item.get("citationCount"),
influential_citation_count=item.get("influentialCitationCount"),
is_open_access=item.get("isOpenAccess"),
open_access_pdf=item.get("openAccessPdf"),
fields_of_study=item.get("fieldsOfStudy"),
s2_fields_of_study=item.get("s2FieldsOfStudy"),
publication_types=item.get("publicationTypes"),
publication_date=item.get("publicationDate"),
journal=item.get("journal"),
citation_styles=item.get("citationStyles"),
authors=item.get("authors"),
citations=item.get("citations"),
references=item.get("references"),
embedding=item.get("embedding"),
tldr=item.get("tldr"),
doi=doi or None,
)
if self.debug:
logger.debug(
f"Processed paper with DOI {doi}"
)
return paper_item
except Exception as e:
logger.debug(
f"Could not parse response to paper object for doi {doi}: {str(e)}",
)

async def _fetch_paper_by_doi(
self,
session: aiohttp.ClientSession,
query: str,
) -> Optional[Dict[str, Any]]:
"""
Fetches a single page of search results from Semantic Scholar.

Args:
session: The aiohttp session.
query: The search query.
offset: The starting offset for results.
limit: The number of results to fetch for this page.

Returns:
The JSON response dictionary from the API or None if an error occurs.
"""
search_url = f"{self.config.base_url}graph/v1/paper/DOI:{query}"
params = {
"fields": ",".join(self.config.fields)
}
headers = {}
if self.config.api_key:
api_key_value = self.config.api_key
if api_key_value:
headers["x-api-key"] = api_key_value

if self.debug:
logger.debug(
f"Fetching paper details via Semantic Scholar: query='{query}'",
)
if headers:
logger.debug("Using API Key.")

try:
async with session.get(
search_url,
params=params,
headers=headers,
) as response:
response.raise_for_status() # Raise exception for 4xx or 5xx errors
data = await response.json()
if self.debug:
logger.debug(data)
logger.debug(f"API Response Status: {response.status}")
# Avoid logging full data if it's too large or sensitive
logger.debug(
f"Received {len(data.get('data', []))} items. Total: {data.get('total')}, Offset: {data.get('offset')}, Next: {data.get('next')}",
)
return [self._parse_paper(item=data, doi=query)]

except aiohttp.ClientResponseError as e:
logger.error(
f"HTTP Error fetching Semantic Scholar for query '{query}': {e.status} {e.message}",
)
# Log request details that caused the error
logger.error(f"Request URL: {response.url}")
logger.error(f"Request Params: {params}")
logger.error(f"Response Headers: {response.headers}")
try:
error_body = await response.text()
logger.error(
f"Response Body: {error_body[:500]}",
) # Log part of the body
except Exception as read_err:
logger.error(f"Could not read error response body: {read_err}")

except Exception as e:
logger.error(
f"Failed to fetch Semantic Scholar results for query '{query}': {e}",
)

return []

def _parse_result(
self,
Expand All @@ -577,7 +701,7 @@ def _parse_result(
return None # Skip incomplete results

external_ids = item.get("externalIds") or {}
doi = external_ids.pop("DOI")
doi = external_ids.get("DOI") # All papers do not have a DOI

# Extract author names if requested and available
authors = [
Expand Down Expand Up @@ -753,7 +877,32 @@ async def _process_final_results(
)

return final_results

async def doi_to_paper(self,
params: SemanticScholarSearchToolInputSchema,
**kwargs,
)-> list[PaperDataItem]:
"""
Fetches a paper based on it's DOI.

Args:
params: Input parameters including queries and category.

Returns:
List of PaperDataItem objects.
"""
async with aiohttp.ClientSession() as session:
tasks = [
self._fetch_paper_by_doi(
session,
query,
)
for query in params.queries
]
results_per_query = await asyncio.gather(*tasks)
results = [item for sublist in results_per_query for item in sublist]
return results

async def _arun(
self,
params: SemanticScholarSearchToolInputSchema,
Expand Down Expand Up @@ -791,7 +940,7 @@ async def _arun(

if self.debug:
logger.debug(
f"Running Semantic Scholar Search: "
f"Running Semantic Scholar Search: ",
f"final_max_results={final_max_results}, "
f"target_per_query={target_results_per_query}, "
f"num_queries={len(params.queries)}",
Expand Down