Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions app/modules/code_provider/code_provider_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,21 @@ async def get_project_structure_async(self, project_id, path: Optional[str] = No
return await self.service_instance.get_project_structure_async(project_id, path)

def get_file_content(
self, repo_name, file_path, start_line, end_line, branch_name, project_id
self,
repo_name,
file_path,
start_line,
end_line,
branch_name,
project_id,
commit_id,
):
return self.service_instance.get_file_content(
repo_name, file_path, start_line, end_line, branch_name, project_id
repo_name,
file_path,
start_line,
end_line,
branch_name,
project_id,
commit_id,
)
23 changes: 18 additions & 5 deletions app/modules/code_provider/github/github_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,16 @@ def get_file_content(
end_line: int,
branch_name: str,
project_id: str,
commit_id: str,
) -> str:
logger.info(f"Attempting to access file: {file_path} in repo: {repo_name}")

try:
# Try authenticated access first
github, repo = self.get_repo(repo_name)
file_contents = repo.get_contents(file_path, ref=branch_name)
file_contents = repo.get_contents(
file_path, ref=commit_id if commit_id else branch_name
)
except Exception as private_error:
logger.info(f"Failed to access private repo: {str(private_error)}")
# If authenticated access fails, try public access
Expand Down Expand Up @@ -166,8 +169,8 @@ def get_file_content(
if (start_line == end_line == 0) or (start_line == end_line == None):
return decoded_content
# added -2 to start and end line to include the function definition/ decorator line
start = start_line - 2 if start_line - 2 > 0 else 0
selected_lines = lines[start:end_line]
# start = start_line - 2 if start_line - 2 > 0 else 0
selected_lines = lines[max(0, start_line - 1) : min(len(lines), end_line)]
return "\n".join(selected_lines)
except Exception as e:
logger.error(
Expand Down Expand Up @@ -608,7 +611,15 @@ async def get_project_structure_async(

# Start structure fetch from the specified path with depth 0
structure = await self._fetch_repo_structure_async(
repo, path or "", current_depth=0, base_path=path
repo,
path or "",
current_depth=0,
base_path=path or "",
ref=(
project.get("branch_name")
if project.get("branch_name")
else project.get("commit_id")
),
)
formatted_structure = self._format_tree_structure(structure)

Expand All @@ -632,6 +643,7 @@ async def _fetch_repo_structure_async(
path: str = "",
current_depth: int = 0,
base_path: Optional[str] = None,
ref: Optional[str] = None,
) -> Dict[str, Any]:
exclude_extensions = [
"png",
Expand Down Expand Up @@ -677,7 +689,7 @@ async def _fetch_repo_structure_async(

try:
contents = await asyncio.get_event_loop().run_in_executor(
self.executor, repo.get_contents, path
self.executor, lambda: repo.get_contents(path, ref=ref)
)

if not isinstance(contents, list):
Expand All @@ -703,6 +715,7 @@ async def _fetch_repo_structure_async(
item.path,
current_depth=current_depth,
base_path=base_path,
ref=ref,
)
tasks.append(task)
else:
Expand Down
8 changes: 6 additions & 2 deletions app/modules/code_provider/local_repo/local_repo_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import re
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Union
import pathspec

import git
import pathspec
from fastapi import HTTPException
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -39,6 +39,7 @@ def get_file_content(
end_line: int,
branch_name: str,
project_id: str,
commit_id: str,
) -> str:
logger.info(
f"Attempting to access file: {file_path} for project ID: {project_id}"
Expand All @@ -54,7 +55,10 @@ def get_file_content(
)

repo = self.get_repo(repo_path)
repo.git.checkout(branch_name)
if commit_id:
repo.git.checkout(commit_id)
else:
repo.git.checkout(branch_name)
file_full_path = os.path.join(repo_path, file_path)
with open(file_full_path, "r", encoding="utf-8") as file:
lines = file.readlines()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ async def _find_changed_functions(self, changed_files, project_id):
0,
project["branch_name"],
project_id,
project["commit_id"],
)
tags = RepoMap.get_tags_from_code(relative_file_path, file_content)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def _process_result(
end_line,
project.branch_name,
project.id,
project.commit_id,
)

docstring = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _process_result(
end_line,
project.branch_name,
project.id,
project.commit_id,
)

docstring = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,11 @@ def _process_result(
code_content = CodeProviderService(self.sql_db).get_file_content(
project.repo_name,
relative_file_path,
start_line,
start_line - 3,
end_line,
project.branch_name,
project.id,
project.commit_id,
)

docstring = None
Expand Down
30 changes: 22 additions & 8 deletions app/modules/intelligence/tools/think_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,31 @@ async def arun(self, thought: str) -> Dict[str, Any]:
def run(self, thought: str) -> Dict[str, Any]:
"""Synchronous wrapper for arun."""
try:
loop = asyncio.get_event_loop()
# Check if we're already in an event loop
loop = asyncio.get_running_loop()
# If we're in a running loop, we need to use a different approach
import concurrent.futures

def run_in_thread():
# Create a new event loop in a separate thread
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(self.arun(thought))
finally:
new_loop.close()

with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result()

except RuntimeError:
# If there is no event loop in current thread, create a new one
# No event loop running, we can create one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

try:
return loop.run_until_complete(self.arun(thought))
finally:
# Clean up if we created a new loop
if not loop.is_running():
try:
return loop.run_until_complete(self.arun(thought))
finally:
loop.close()


Expand Down
2 changes: 1 addition & 1 deletion app/modules/key_management/secret_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ async def get_api_key(
logger.info(f"API key retrieved successfully for user: {user['user_id']}")
if api_key is None:
logger.info(f"No API key found for user: {user['user_id']}")
raise
raise ValueError(f"No API key found for user: {user['user_id']}")
return APIKeyResponse(api_key=api_key)
except Exception as e:
logger.error(f"Error getting API key for user {user['user_id']}: {str(e)}")
Expand Down
9 changes: 7 additions & 2 deletions app/modules/parsing/graph_construction/parsing_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,16 @@ async def parse_directory(

try:
project = await project_manager.get_project_from_db(
repo_name, repo_details.branch_name, user_id
repo_name,
repo_details.branch_name,
user_id,
commit_id=repo_details.commit_id,
)

# First check if this is a demo project that hasn't been accessed by this user yet
if not project and repo_details.repo_name in demo_repos:
existing_project = await project_manager.get_global_project_from_db(
repo_name, repo_details.branch_name
repo_name, repo_details.branch_name, repo_details.commit_id
)

new_project_id = str(uuid7())
Expand Down Expand Up @@ -226,6 +229,7 @@ async def handle_new_project(
repo_details.branch_name,
user_id,
new_project_id,
repo_details.commit_id,
repo_details.repo_path,
)
asyncio.create_task(
Expand All @@ -247,6 +251,7 @@ async def handle_new_project(
{
"repo_name": repo_details.repo_name,
"branch": repo_details.branch_name,
"commit_id": repo_details.commit_id,
"project_id": new_project_id,
},
)
Expand Down
65 changes: 53 additions & 12 deletions app/modules/parsing/graph_construction/parsing_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ async def setup_project_directory(
repo_details,
user_id,
project_id=None, # Change type to str
commit_id=None,
):
full_name = (
repo.working_tree_dir.split("/")[-1]
Expand All @@ -312,14 +313,15 @@ async def setup_project_directory(
if full_name is None:
full_name = repo_path.split("/")[-1]
project = await self.project_manager.get_project_from_db(
full_name, branch, user_id, repo_path
full_name, branch, user_id, repo_path, commit_id
)
if not project:
project_id = await self.project_manager.register_project(
full_name,
branch,
user_id,
project_id,
commit_id=commit_id,
)
if repo_path is not None:
if os.getenv("isDevelopmentMode", "false").lower() == "false":
Expand All @@ -333,22 +335,41 @@ async def setup_project_directory(
try:
current_dir = os.getcwd()
os.chdir(extracted_dir) # Change to the cloned repo directory
repo_details.git.checkout(branch)
if commit_id:
repo_details.git.checkout(commit_id)
latest_commit_sha = commit_id
else:
repo_details.git.checkout(branch)
branch_details = repo_details.head.commit
latest_commit_sha = branch_details.hexsha
except GitCommandError as e:
logger.error(f"Error checking out branch: {e}")
logger.error(
f"Error checking out {'commit' if commit_id else 'branch'}: {e}"
)
raise HTTPException(
status_code=400, detail=f"Failed to checkout branch {branch}"
status_code=400,
detail=f"Failed to checkout {'commit ' + commit_id if commit_id else 'branch ' + branch}",
)
finally:
os.chdir(current_dir) # Restore the original working directory
branch_details = repo_details.head.commit
latest_commit_sha = branch_details.hexsha
else:
extracted_dir = await self.download_and_extract_tarball(
repo, branch, os.getenv("PROJECT_PATH"), auth, repo_details, user_id
)
branch_details = repo_details.get_branch(branch)
latest_commit_sha = branch_details.commit.sha
if commit_id:
# For GitHub API, we need to download tarball for specific commit
extracted_dir = await self.download_and_extract_tarball(
repo,
commit_id,
os.getenv("PROJECT_PATH"),
auth,
repo_details,
user_id,
)
latest_commit_sha = commit_id
else:
extracted_dir = await self.download_and_extract_tarball(
repo, branch, os.getenv("PROJECT_PATH"), auth, repo_details, user_id
)
branch_details = repo_details.get_branch(branch)
latest_commit_sha = branch_details.commit.sha

repo_metadata = ParseHelper.extract_repository_metadata(repo_details)
repo_metadata["error_message"] = None
Expand Down Expand Up @@ -473,18 +494,38 @@ async def check_commit_status(self, project_id: str) -> bool:
repo_name = project.get("project_name")
branch_name = project.get("branch_name")

if not repo_name or not branch_name:
if not repo_name:
logger.error(
f"Repository name or branch name not found for project ID {project_id}"
)
return False

if not branch_name:
logger.error(
f"Branch is empty so sticking to commit and not updating it for: {project_id}"
)
return True

if len(repo_name.split("/")) < 2:
# Local repo, always parse local repos
return False

try:
github, repo = self.github_service.get_repo(repo_name)

# If current_commit_id is a specific commit (not a branch head),
# then we can assume it's not "latest" and should be reparsed
# This is because when using specific commits, we don't want to check branch head
if len(current_commit_id) == 40: # SHA1 commit hash is 40 chars
try:
# Try to verify if this is a specific commit instead of branch head
repo.get_commit(current_commit_id)
# If we successfully get a commit, assume that it was a pinned commit,
# thus it's still up to date (we're parsing a specific commit, not latest)
return True
except:
# If we can't find the commit, we should reparse
return False
branch = repo.get_branch(branch_name)
latest_commit_id = branch.commit.sha

Expand Down
3 changes: 2 additions & 1 deletion app/modules/parsing/graph_construction/parsing_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
class ParsingRequest(BaseModel):
repo_name: Optional[str] = Field(default=None)
repo_path: Optional[str] = Field(default=None)
branch_name: str
branch_name: Optional[str] = Field(default=None)
commit_id: Optional[str] = Field(default=None)

def __init__(self, **data):
super().__init__(**data)
Expand Down
9 changes: 8 additions & 1 deletion app/modules/parsing/graph_construction/parsing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,20 @@ async def parse_directory(
repo_details,
user_id,
project_id,
commit_id=repo_details.commit_id,
)
else:
(
extracted_dir,
project_id,
) = await self.parse_helper.setup_project_directory(
repo, repo_details.branch_name, auth, repo, user_id, project_id
repo,
repo_details.branch_name,
auth,
repo,
user_id,
project_id,
commit_id=repo_details.commit_id,
)

if isinstance(repo, Repo):
Expand Down
Loading