Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/server/src/bin/generate_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ fn generate_types_content() -> String {
server::routes::auth::DevicePollStatus::decl(),
server::routes::auth::CheckTokenResponse::decl(),
services::services::git::GitBranch::decl(),
services::services::git::GitRemote::decl(),
utils::diff::Diff::decl(),
utils::diff::DiffChangeKind::decl(),
services::services::github_service::RepositoryInfo::decl(),
Expand Down
11 changes: 10 additions & 1 deletion crates/server/src/routes/projects.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use ignore::WalkBuilder;
use services::services::{
file_ranker::FileRanker,
file_search_cache::{CacheError, SearchMode, SearchQuery},
git::GitBranch,
git::{GitBranch, GitRemote},
};
use utils::{path::expand_tilde, response::ApiResponse};
use uuid::Uuid;
Expand Down Expand Up @@ -44,6 +44,14 @@ pub async fn get_project_branches(
Ok(ResponseJson(ApiResponse::success(branches)))
}

pub async fn get_project_remotes(
Extension(project): Extension<Project>,
State(deployment): State<DeploymentImpl>,
) -> Result<ResponseJson<ApiResponse<Vec<GitRemote>>>, ApiError> {
let remotes = deployment.git().get_all_remotes(&project.git_repo_path)?;
Ok(ResponseJson(ApiResponse::success(remotes)))
}

pub async fn create_project(
State(deployment): State<DeploymentImpl>,
Json(payload): Json<CreateProject>,
Expand Down Expand Up @@ -478,6 +486,7 @@ pub fn router(deployment: &DeploymentImpl) -> Router<DeploymentImpl> {
get(get_project).put(update_project).delete(delete_project),
)
.route("/branches", get(get_project_branches))
.route("/remotes", get(get_project_remotes))
.route("/search", get(search_project_files))
.route("/open-editor", post(open_project_in_editor))
.layer(from_fn_with_state(
Expand Down
81 changes: 69 additions & 12 deletions crates/server/src/routes/task_attempts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ pub struct CreateGitHubPrRequest {
pub title: String,
pub body: Option<String>,
pub target_branch: Option<String>,
pub remote_name: Option<String>,
pub head_remote_name: Option<String>,
}

#[derive(Debug, Serialize)]
Expand Down Expand Up @@ -674,10 +676,17 @@ pub async fn push_task_attempt_branch(
github_service.check_token().await?;

let ws_path = ensure_worktree_path(&deployment, &task_attempt).await?;

deployment
let branch_remote = deployment
.git()
.push_to_github(&ws_path, &task_attempt.branch, &github_token)?;
.get_remote_name_from_branch_name(&ws_path, &task_attempt.branch)
.ok();

deployment.git().push_to_github(
&ws_path,
&task_attempt.branch,
branch_remote.as_deref(),
&github_token,
)?;
Ok(ResponseJson(ApiResponse::success(())))
}

Expand Down Expand Up @@ -719,12 +728,22 @@ pub async fn create_github_pr(

let workspace_path = ensure_worktree_path(&deployment, &task_attempt).await?;

let inferred_branch_remote = deployment
.git()
.get_remote_name_from_branch_name(&workspace_path, &task_attempt.branch)
.ok();
let head_remote_name = request
.head_remote_name
.clone()
.or_else(|| inferred_branch_remote.clone());

// Push the branch to GitHub first
if let Err(e) =
deployment
.git()
.push_to_github(&workspace_path, &task_attempt.branch, &github_token)
{
if let Err(e) = deployment.git().push_to_github(
&workspace_path,
&task_attempt.branch,
head_remote_name.as_deref(),
&github_token,
) {
tracing::error!("Failed to push branch to GitHub: {}", e);
let gh_e = GitHubServiceError::from(e);
if gh_e.is_api_data() {
Expand All @@ -735,6 +754,13 @@ pub async fn create_github_pr(
)));
}
}
let head_remote = head_remote_name.clone().or_else(|| {
deployment
.git()
.get_remote_name_from_branch_name(&workspace_path, &task_attempt.branch)
.ok()
});
let mut base_remote: Option<String> = None;

let norm_target_branch_name = if matches!(
deployment
Expand All @@ -746,26 +772,46 @@ pub async fn create_github_pr(
// For PR APIs, we must provide just the branch name.
let remote = deployment
.git()
.get_remote_name_from_branch_name(&workspace_path, &target_branch)?;
.get_remote_name_from_branch_name(&project.git_repo_path, &target_branch)?;
base_remote = Some(remote.clone());
let remote_prefix = format!("{}/", remote);
target_branch
.strip_prefix(&remote_prefix)
.unwrap_or(&target_branch)
.to_string()
} else {
target_branch
if let Ok(remote) = deployment
.git()
.get_remote_name_from_branch_name(&project.git_repo_path, &target_branch)
{
base_remote = Some(remote);
}
target_branch.clone()
};
let preferred_remote = request
.remote_name
.clone()
.or(base_remote.clone())
.or_else(|| head_remote.clone());
let head_repo_info = head_remote.as_ref().and_then(|remote| {
deployment
.git()
.get_github_repo_info(&project.git_repo_path, Some(remote.as_str()))
.ok()
});

// Create the PR using GitHub service
let pr_request = CreatePrRequest {
title: request.title.clone(),
body: request.body.clone(),
head_branch: task_attempt.branch.clone(),
base_branch: norm_target_branch_name.clone(),
head_repo: head_repo_info.clone(),
};
// Use GitService to get the remote URL, then create GitHubRepoInfo
let repo_info = deployment
.git()
.get_github_repo_info(&project.git_repo_path)?;
.get_github_repo_info(&project.git_repo_path, preferred_remote.as_deref())?;

match github_service.create_pr(&repo_info, &pr_request).await {
Ok(pr_info) => {
Expand Down Expand Up @@ -1339,10 +1385,21 @@ pub async fn attach_existing_pr(
return Err(ApiError::Project(ProjectError::ProjectNotFound));
};

let workspace_path = ensure_worktree_path(&deployment, &task_attempt).await?;
let head_remote = deployment
.git()
.get_remote_name_from_branch_name(&workspace_path, &task_attempt.branch)
.ok();
let base_remote = deployment
.git()
.get_remote_name_from_branch_name(&project.git_repo_path, &task_attempt.target_branch)
.ok();
let preferred_remote = base_remote.clone().or(head_remote);

let github_service = GitHubService::new(&github_token)?;
let repo_info = deployment
.git()
.get_github_repo_info(&project.git_repo_path)?;
.get_github_repo_info(&project.git_repo_path, preferred_remote.as_deref())?;

// List all PRs for branch (open, closed, and merged)
let prs = github_service
Expand Down
69 changes: 61 additions & 8 deletions crates/services/src/services/git.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ pub struct GitBranch {
pub last_commit_date: DateTime<Utc>,
}

#[derive(Debug, Clone, Serialize, TS)]
pub struct GitRemote {
pub name: String,
pub url: Option<String>,
}

#[derive(Debug, Clone)]
pub struct HeadInfo {
pub branch: String,
Expand Down Expand Up @@ -1677,12 +1683,27 @@ impl GitService {
pub fn get_github_repo_info(
&self,
repo_path: &Path,
preferred_remote: Option<&str>,
) -> Result<GitHubRepoInfo, GitServiceError> {
let repo = self.open_repo(repo_path)?;
let remote_name = self.default_remote_name(&repo);
let remote = repo.find_remote(&remote_name).map_err(|_| {
GitServiceError::InvalidRepository(format!("No '{remote_name}' remote found"))
})?;
let default_remote_name = self.default_remote_name(&repo);

let remote = if let Some(preferred) = preferred_remote {
match repo.find_remote(preferred) {
Ok(remote) => remote,
Err(_) => repo.find_remote(&default_remote_name).map_err(|_| {
GitServiceError::InvalidRepository(format!(
"No '{preferred}' remote found and default remote '{default_remote_name}' missing"
))
})?,
}
} else {
repo.find_remote(&default_remote_name).map_err(|_| {
GitServiceError::InvalidRepository(format!(
"No '{default_remote_name}' remote found"
))
})?
};

let url = remote
.url()
Expand All @@ -1692,6 +1713,22 @@ impl GitService {
})
}

pub fn get_all_remotes(&self, repo_path: &Path) -> Result<Vec<GitRemote>, GitServiceError> {
let repo = self.open_repo(repo_path)?;
let remote_names = repo.remotes()?;
let mut remotes = Vec::new();

for remote_name in remote_names.iter().flatten() {
let remote = repo.find_remote(remote_name)?;
remotes.push(GitRemote {
name: remote_name.to_string(),
url: remote.url().map(|u| u.to_string()),
});
}

Ok(remotes)
}

pub fn get_remote_name_from_branch_name(
&self,
repo_path: &Path,
Expand Down Expand Up @@ -1733,14 +1770,31 @@ impl GitService {
&self,
worktree_path: &Path,
branch_name: &str,
remote_override: Option<&str>,
github_token: &str,
) -> Result<(), GitServiceError> {
let repo = Repository::open(worktree_path)?;
self.check_worktree_clean(&repo)?;

// Get the remote
let remote_name = self.default_remote_name(&repo);
let remote = repo.find_remote(&remote_name)?;
let default_remote_name = self.default_remote_name(&repo);
let mut branch = Self::find_branch(&repo, branch_name)?;
let remote = if let Some(target_remote) = remote_override {
repo.find_remote(target_remote).map_err(|_| {
GitServiceError::InvalidRepository(format!(
"Remote '{target_remote}' not found for branch '{branch_name}'"
))
})?
} else {
self.get_remote_from_branch_ref(&repo, branch.get())
.or_else(|_| {
repo.find_remote(&default_remote_name).map_err(|_| {
GitServiceError::InvalidRepository(format!(
"Remote '{default_remote_name}' not found for branch '{branch_name}'"
))
})
})?
};
let remote_name = remote.name().unwrap_or(&default_remote_name).to_string();

let remote_url = remote
.url()
Expand All @@ -1754,7 +1808,6 @@ impl GitService {
return Err(e.into());
}

let mut branch = Self::find_branch(&repo, branch_name)?;
if !branch.get().is_remote() {
if let Some(branch_target) = branch.get().target() {
let remote_ref = format!("refs/remotes/{remote_name}/{branch_name}");
Expand Down
15 changes: 12 additions & 3 deletions crates/services/src/services/github_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ pub struct CreatePrRequest {
pub body: Option<String>,
pub head_branch: String,
pub base_branch: String,
pub head_repo: Option<GitHubRepoInfo>,
}

#[derive(Debug, Clone, Serialize, Deserialize, TS)]
Expand Down Expand Up @@ -223,9 +224,10 @@ impl GitHubService {
other => other,
})?;

// Check if the head branch exists
let head_repo = request.head_repo.as_ref().unwrap_or(repo_info);

self.client
.repos(&repo_info.owner, &repo_info.repo_name)
.repos(&head_repo.owner, &head_repo.repo_name)
.get_ref(&octocrab::params::repos::Reference::Branch(
request.head_branch.to_string(),
))
Expand All @@ -238,11 +240,18 @@ impl GitHubService {
other => other,
})?;

let head_ref =
if head_repo.owner != repo_info.owner || head_repo.repo_name != repo_info.repo_name {
format!("{}:{}", head_repo.owner, request.head_branch)
} else {
request.head_branch.clone()
};

// Create the pull request
let pr_info = self
.client
.pulls(&repo_info.owner, &repo_info.repo_name)
.create(&request.title, &request.head_branch, &request.base_branch)
.create(&request.title, &head_ref, &request.base_branch)
.body(request.body.as_deref().unwrap_or(""))
.send()
.await
Expand Down
35 changes: 34 additions & 1 deletion crates/services/tests/git_workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,44 @@ fn get_github_repo_info_parses_origin() {
let s = GitService::new();
s.set_remote(&repo_path, "origin", "https://github.com/foo/bar.git")
.unwrap();
let info = s.get_github_repo_info(&repo_path).unwrap();
let info = s.get_github_repo_info(&repo_path, None).unwrap();
assert_eq!(info.owner, "foo");
assert_eq!(info.repo_name, "bar");
}

#[test]
fn get_github_repo_info_prefers_supplied_remote() {
let td = TempDir::new().unwrap();
let repo_path = init_repo_main(&td);
let s = GitService::new();
s.set_remote(&repo_path, "origin", "https://github.com/foo/bar.git")
.unwrap();
s.set_remote(&repo_path, "upstream", "https://github.com/baz/qux.git")
.unwrap();

let info = s
.get_github_repo_info(&repo_path, Some("upstream"))
.unwrap();
assert_eq!(info.owner, "baz");
assert_eq!(info.repo_name, "qux");
}

#[test]
fn get_all_remotes_returns_all_configured_remotes() {
let td = TempDir::new().unwrap();
let repo_path = init_repo_main(&td);
let s = GitService::new();
s.set_remote(&repo_path, "origin", "https://github.com/foo/bar.git")
.unwrap();
s.set_remote(&repo_path, "fork", "https://github.com/me/bar.git")
.unwrap();

let remotes = s.get_all_remotes(&repo_path).unwrap();
let names: Vec<_> = remotes.iter().map(|remote| remote.name.as_str()).collect();
assert!(names.contains(&"origin"));
assert!(names.contains(&"fork"));
}

#[test]
fn get_branch_diffs_between_branches() {
let td = TempDir::new().unwrap();
Expand Down
Loading
Loading