diff --git a/Cargo.lock b/Cargo.lock index 26b65073..18ddc867 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -295,6 +295,7 @@ dependencies = [ "octocrab", "parking_lot", "pulldown-cmark", + "rand 0.9.2", "regex", "reqwest", "secrecy", diff --git a/Cargo.toml b/Cargo.toml index c2a0d7b0..32276070 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,7 @@ chrono = "0.4" # Utilities itertools = "0.14" +rand = { version = "0.9", features = ["alloc"] } # Text processing pulldown-cmark = "0.13" diff --git a/README.md b/README.md index fa5da911..b2b8e51d 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,8 @@ required. | `--app-id` | `APP_ID` | | GitHub app ID of the bors bot. | | `--private-key` | `PRIVATE_KEY` | | Private key of the GitHub app. | | `--webhook-secret` | `WEBHOOK_SECRET` | | Key used to authenticate GitHub webhooks. | +| `--client-id` | `OAUTH_CLIENT_ID` | | GitHub OAuth client ID for rollup UI (optional). | +| `--client-secret` | `OAUTH_CLIENT_SECRET`| | GitHub OAuth client secret for rollup UI (optional). | | `--db` | `DATABASE_URL` | | Database connection string. Only PostgreSQL is supported. | | `--cmd-prefix` | `CMD_PREFIX` | @bors | Prefix used to invoke bors commands in PR comments. | @@ -45,6 +47,12 @@ atomically using the GitHub API. ### GitHub app If you want to attach `bors` to a GitHub app, you should point its webhooks at `/github`. +### OAuth app +If you want to create rollups, you will need to create a GitHub OAuth app configured like so: +1. In the [developer settings](https://github.com/settings/developers), go to "OAuth Apps" and create a new application. +2. Set the Authorization callback URL to `/oauth/callback`. +3. Note the generated Client ID and Client secret, and pass them through the CLI flags or via your environment configuration. + ### How to add a repository to bors Here is a guide on how to add a repository so that this bot can be used on it: 1) Add a file named `rust-bors.toml` to the root of the main branch of the repository. The configuration struct that diff --git a/src/bin/bors.rs b/src/bin/bors.rs index 581a9042..f7d5c4a9 100644 --- a/src/bin/bors.rs +++ b/src/bin/bors.rs @@ -6,7 +6,7 @@ use std::time::Duration; use anyhow::Context; use bors::{ - BorsContext, BorsGlobalEvent, BorsProcess, CommandParser, PgDbClient, ServerState, + BorsContext, BorsGlobalEvent, BorsProcess, CommandParser, OAuthConfig, PgDbClient, ServerState, TeamApiClient, TreeState, WebhookSecret, create_app, create_bors_process, create_github_client, load_repositories, }; @@ -49,6 +49,14 @@ struct Opts { #[arg(long, env = "PRIVATE_KEY")] private_key: String, + /// GitHub OAuth client ID for rollups. + #[arg(long, env = "CLIENT_ID")] + client_id: Option, + + /// GitHub OAuth client secret for rollups. + #[arg(long, env = "CLIENT_SECRET")] + client_secret: Option, + /// Secret used to authenticate webhooks. #[arg(long, env = "WEBHOOK_SECRET")] webhook_secret: String, @@ -214,10 +222,26 @@ fn try_main(opts: Opts) -> anyhow::Result<()> { } }; + let oauth_config = match (opts.client_id.clone(), opts.client_secret.clone()) { + (Some(client_id), Some(client_secret)) => Some(OAuthConfig::new(client_id, client_secret)), + (None, None) => None, + (Some(_), None) => { + return Err(anyhow::anyhow!( + "CLIENT_ID is set but CLIENT_SECRET is missing. Both must be set or neither." + )); + } + (None, Some(_)) => { + return Err(anyhow::anyhow!( + "CLIENT_SECRET is set but CLIENT_ID is missing. Both must be set or neither." + )); + } + }; + let state = ServerState::new( repository_tx, global_tx, WebhookSecret::new(opts.webhook_secret), + oauth_config, repos, db, opts.cmd_prefix.into(), diff --git a/src/database/mod.rs b/src/database/mod.rs index a67deaf5..7b99100b 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -418,9 +418,11 @@ impl PullRequestModel { } /// Determines if this PR can be included in a rollup. - /// A PR is rollupable if it has been approved and rollup is not `RollupMode::Never` + /// A PR is rollupable if it has been approved, does not have a pending build and rollup is not `RollupMode::Never`. pub fn is_rollupable(&self) -> bool { - self.is_approved() && !matches!(self.rollup, Some(RollupMode::Never)) + self.is_approved() + && !matches!(self.rollup, Some(RollupMode::Never)) + && !matches!(self.queue_status(), QueueStatus::Pending(..)) } } diff --git a/src/github/mod.rs b/src/github/mod.rs index 6009fabc..790f3e84 100644 --- a/src/github/mod.rs +++ b/src/github/mod.rs @@ -9,6 +9,7 @@ use url::Url; pub mod api; mod error; mod labels; +mod rollup; pub mod server; mod webhook; diff --git a/src/github/rollup.rs b/src/github/rollup.rs new file mode 100644 index 00000000..638e2010 --- /dev/null +++ b/src/github/rollup.rs @@ -0,0 +1,283 @@ +use super::GithubRepoName; +use super::error::AppError; +use super::server::ServerStateRef; +use anyhow::Context; +use axum::extract::{Query, State}; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Redirect}; +use octocrab::OctocrabBuilder; +use octocrab::params::repos::Reference; +use rand::{Rng, distr::Alphanumeric}; +use std::collections::HashMap; +use tracing::Instrument; + +/// Query parameters received from GitHub's OAuth callback. +/// +/// Documentation: https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps#2-users-are-redirected-back-to-your-site-by-github +#[derive(serde::Deserialize)] +pub struct OAuthCallbackQuery { + /// Temporary code from GitHub to exchange for an access token (expires in 10m). + pub code: String, + /// State passed in the initial OAuth request - contains rollup info created from the queue page. + pub state: String, +} + +#[derive(serde::Deserialize)] +pub struct OAuthState { + pub pr_nums: Vec, + pub repo_name: String, + pub repo_owner: String, +} + +pub async fn oauth_callback_handler( + Query(callback): Query, + State(state): State, +) -> Result { + let oauth_config = state.oauth.as_ref().ok_or_else(|| { + let error = + anyhow::anyhow!("OAuth not configured. Please set CLIENT_ID and CLIENT_SECRET."); + tracing::error!("{error}"); + error + })?; + + let oauth_state: OAuthState = serde_json::from_str(&callback.state) + .map_err(|_| anyhow::anyhow!("Invalid state parameter"))?; + + tracing::info!("Exchanging OAuth code for access token"); + let client = reqwest::Client::new(); + let token_response = client + .post("https://github.com/login/oauth/access_token") + .form(&[ + ("client_id", oauth_config.client_id()), + ("client_secret", oauth_config.client_secret()), + ("code", &callback.code), + ]) + .send() + .await + .context("Failed to send OAuth token exchange request to GitHub")? + .text() + .await + .context("Failed to read OAuth token response from GitHub")?; + + tracing::debug!("Extracting access token from OAuth response"); + let oauth_token_params: HashMap = + url::form_urlencoded::parse(token_response.as_bytes()) + .into_owned() + .collect(); + let access_token = oauth_token_params + .get("access_token") + .ok_or_else(|| anyhow::anyhow!("No access token in response"))?; + + tracing::info!("Retrieved OAuth access token, creating rollup"); + + let span = tracing::info_span!( + "create_rollup", + repo = %format!("{}/{}", oauth_state.repo_owner, oauth_state.repo_name), + pr_nums = ?oauth_state.pr_nums + ); + + match create_rollup(state, oauth_state, access_token) + .instrument(span) + .await + { + Ok(pr_url) => { + tracing::info!("Rollup created successfully, redirecting to: {pr_url}"); + Ok(Redirect::temporary(&pr_url).into_response()) + } + Err(error) => { + tracing::error!("Failed to create rollup: {error}"); + Ok(( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to create rollup: {error}"), + ) + .into_response()) + } + } +} + +/// Creates a rollup PR by merging multiple approved PRs into a single branch +/// in the user's fork, then opens a PR to the upstream repository. +async fn create_rollup( + state: ServerStateRef, + oauth_state: OAuthState, + access_token: &str, +) -> anyhow::Result { + let OAuthState { + repo_name, + repo_owner, + pr_nums, + } = oauth_state; + + let gh_client = OctocrabBuilder::new() + .user_access_token(access_token.to_string()) + .build()?; + let user = gh_client.current().user().await?; + let username = user.login; + + tracing::info!("User {username} is creating a rollup with PRs: {pr_nums:?}"); + + // Ensure user has a fork + match gh_client.repos(&username, &repo_name).get().await { + Ok(repo) => repo, + Err(_) => { + anyhow::bail!( + "You must have a fork of {username}/{repo_name} named {repo_name} under your account", + ); + } + }; + + // Validate PRs + let mut rollup_prs = Vec::new(); + for num in pr_nums { + match state + .db + .get_pull_request( + &GithubRepoName::new(&repo_owner, &repo_name), + (num as u64).into(), + ) + .await? + { + Some(pr) => { + if !pr.is_rollupable() { + let error = format!("PR #{num} cannot be included in rollup"); + tracing::error!("{error}"); + anyhow::bail!(error); + } + rollup_prs.push(pr); + } + None => anyhow::bail!("PR #{num} not found"), + } + } + + if rollup_prs.is_empty() { + anyhow::bail!("No pull requests are marked for rollup"); + } + + // Sort PRs by number + rollup_prs.sort_by_key(|pr| pr.number.0); + + // Fetch the first PR from GitHub to determine the target base branch + let first_pr_github = gh_client + .pulls(&repo_owner, &repo_name) + .get(rollup_prs[0].number.0) + .await?; + let base_ref = first_pr_github.base.ref_field.clone(); + + // Fetch the current SHA of the base branch - this is the commit our + // rollup branch starts from. + let base_branch_ref = gh_client + .repos(&repo_owner, &repo_name) + .get_ref(&Reference::Branch(base_ref.clone())) + .await?; + let base_sha = match base_branch_ref.object { + octocrab::models::repos::Object::Commit { sha, .. } => sha, + octocrab::models::repos::Object::Tag { sha, .. } => sha, + _ => unreachable!(), + }; + + let branch_suffix: String = rand::rng() + .sample_iter(Alphanumeric) + .take(7) + .map(char::from) + .collect(); + let branch_name = format!("rollup-{branch_suffix}"); + + // Create the branch on the user's fork + gh_client + .repos(&username, &repo_name) + .create_ref( + &octocrab::params::repos::Reference::Branch(branch_name.clone()), + base_sha, + ) + .await + .map_err(|error| { + anyhow::anyhow!("Could not create rollup branch {branch_name}: {error}",) + })?; + + let mut successes = Vec::new(); + let mut failures = Vec::new(); + + // Merge each PR's commits into the rollup branch + for pr in rollup_prs { + let pr_github = gh_client + .pulls(&repo_owner, &repo_name) + .get(pr.number.0) + .await?; + + // Skip PRs that don't target the same base branch + if pr_github.base.ref_field != base_ref { + failures.push(pr); + continue; + } + + let head_sha = pr_github.head.sha.clone(); + let merge_msg = format!( + "Rollup merge of #{} - {}, r={}\n\n{}\n\n{}", + pr.number.0, + pr_github.head.ref_field, + pr.approver().unwrap_or("unknown"), + pr.title, + &pr_github.body.unwrap_or_default() + ); + + // Merge the PR's head commit into the rollup branch + let merge_attempt = gh_client + .repos(&username, &repo_name) + .merge(&head_sha, &branch_name) + .commit_message(&merge_msg) + .send() + .await; + + match merge_attempt { + Ok(_) => { + successes.push(pr); + } + Err(error) => { + if let octocrab::Error::GitHub { source, .. } = &error { + if source.status_code == http::StatusCode::CONFLICT { + failures.push(pr); + continue; + } + + anyhow::bail!( + "Merge failed with GitHub error (status {}): {}", + source.status_code, + source.message + ); + } + + anyhow::bail!("Merge failed with unexpected error: {error}"); + } + } + } + + let mut body = "Successful merges:\n\n".to_string(); + for pr in &successes { + body.push_str(&format!(" - #{} ({})\n", pr.number.0, pr.title)); + } + + if !failures.is_empty() { + body.push_str("\nFailed merges:\n\n"); + for pr in &failures { + body.push_str(&format!(" - #{} ({})\n", pr.number.0, pr.title)); + } + } + body.push_str("\nr? @ghost\n@rustbot modify labels: rollup"); + + let title = format!("Rollup of {} pull requests", successes.len()); + + // Create the rollup PR from the user's fork branch to the base branch + let pr = gh_client + .pulls(&repo_owner, &repo_name) + .create(&title, format!("{username}:{branch_name}"), &base_ref) + .body(&body) + .send() + .await?; + let pr_url = pr + .html_url + .as_ref() + .ok_or_else(|| anyhow::anyhow!("GitHub returned PR without html_url"))? + .to_string(); + + Ok(pr_url) +} diff --git a/src/github/server.rs b/src/github/server.rs index b56b08de..28d1f938 100644 --- a/src/github/server.rs +++ b/src/github/server.rs @@ -18,6 +18,7 @@ use crate::{BorsGlobalEvent, BorsRepositoryEvent, PgDbClient, TeamApiClient}; use super::AppError; use super::GithubRepoName; +use super::rollup; use crate::utils::sort_queue::sort_queue_prs; use anyhow::Error; use axum::Router; @@ -26,6 +27,7 @@ use axum::http::StatusCode; use axum::response::{IntoResponse, Redirect, Response}; use axum::routing::{get, post}; use octocrab::Octocrab; +use secrecy::{ExposeSecret, SecretString}; use std::any::Any; use std::collections::HashMap; use std::future::Future; @@ -36,13 +38,37 @@ use tower::limit::ConcurrencyLimitLayer; use tower_http::catch_panic::CatchPanicLayer; use tracing::{Instrument, Span}; +#[derive(Clone)] +pub struct OAuthConfig { + client_id: String, + client_secret: SecretString, +} + +impl OAuthConfig { + pub fn new(client_id: String, client_secret: String) -> Self { + Self { + client_id, + client_secret: client_secret.into(), + } + } + + pub fn client_id(&self) -> &str { + &self.client_id + } + + pub fn client_secret(&self) -> &str { + self.client_secret.expose_secret() + } +} + /// Shared server state for all axum handlers. pub struct ServerState { repository_event_queue: mpsc::Sender, global_event_queue: mpsc::Sender, webhook_secret: WebhookSecret, + pub(super) oauth: Option, repositories: HashMap>, - db: Arc, + pub(super) db: Arc, cmd_prefix: CommandPrefix, } @@ -51,6 +77,7 @@ impl ServerState { repository_event_queue: mpsc::Sender, global_event_queue: mpsc::Sender, webhook_secret: WebhookSecret, + oauth: Option, repositories: HashMap>, db: Arc, cmd_prefix: CommandPrefix, @@ -59,6 +86,7 @@ impl ServerState { repository_event_queue, global_event_queue, webhook_secret, + oauth, repositories, db, cmd_prefix, @@ -83,6 +111,7 @@ pub fn create_app(state: ServerState) -> Router { .route("/queue/{repo_name}", get(queue_handler)) .route("/github", post(github_webhook_handler)) .route("/health", get(health_handler)) + .route("/oauth/callback", get(rollup::oauth_callback_handler)) .layer(ConcurrencyLimitLayer::new(100)) .layer(CatchPanicLayer::custom(handle_panic)) .with_state(Arc::new(state)) @@ -134,7 +163,7 @@ async fn help_handler(State(state): State) -> impl IntoResponse }) } -async fn queue_handler( +pub async fn queue_handler( Path(repo_name): Path, State(state): State, ) -> Result { @@ -171,7 +200,12 @@ async fn queue_handler( }); Ok(HtmlTemplate(QueueTemplate { + oauth_client_id: state + .oauth + .as_ref() + .map(|config| config.client_id().to_string()), repo_name: repo.name.name().to_string(), + repo_owner: repo.name.owner().to_string(), repo_url: format!("https://github.com/{}", repo.name), tree_state: repo.tree_state, stats: PullRequestStats { diff --git a/src/github/webhook.rs b/src/github/webhook.rs index a5d78e4f..660f4d40 100644 --- a/src/github/webhook.rs +++ b/src/github/webhook.rs @@ -1538,6 +1538,7 @@ mod tests { repository_tx, global_tx, WebhookSecret::new(TEST_WEBHOOK_SECRET.to_string()), + None, repos, db, default_cmd_prefix(), diff --git a/src/lib.rs b/src/lib.rs index 48181756..5b1e16a4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,7 +15,7 @@ pub use github::{ AppError, WebhookSecret, api::create_github_client, api::load_repositories, - server::{BorsProcess, ServerState, create_app, create_bors_process}, + server::{BorsProcess, OAuthConfig, ServerState, create_app, create_bors_process}, }; pub use permissions::TeamApiClient; diff --git a/src/templates.rs b/src/templates.rs index 478d5752..bf25186f 100644 --- a/src/templates.rs +++ b/src/templates.rs @@ -59,10 +59,12 @@ pub struct PullRequestStats { #[template(path = "queue.html")] pub struct QueueTemplate { pub repo_name: String, + pub repo_owner: String, pub repo_url: String, pub stats: PullRequestStats, pub prs: Vec, pub tree_state: TreeState, + pub oauth_client_id: Option, } #[derive(Template)] diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 07df1b09..0a57d5d1 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -203,6 +203,7 @@ impl BorsTester { repository_tx, global_tx.clone(), WebhookSecret::new(TEST_WEBHOOK_SECRET.to_string()), + None, repos.clone(), db.clone(), default_cmd_prefix(), diff --git a/templates/queue.html b/templates/queue.html index 35e81676..26bd5612 100644 --- a/templates/queue.html +++ b/templates/queue.html @@ -54,10 +54,6 @@ max-width: 500px; } - #rollupModalClose { - float: right; - cursor: pointer; - } {% endblock %} @@ -115,7 +111,7 @@

{% for pr in prs %} - + {{ pr.number.0 }} @@ -160,7 +156,6 @@

- ×

@@ -308,41 +303,66 @@

detachRowClick = bindRowClick(table); }); + const createRollupButton = document.getElementById("showRollupSelection") const modal = document.getElementById("rollupModal"); const modalMessage = document.getElementById("rollupModalMessage"); - const modalClose = document.getElementById("rollupModalClose"); - const modalContinue = document.getElementById("rollupModalContinue"); - - function closeModal() { - modal.style.display = "none"; - modalContinue.style.display = "none"; - } - - modalClose.addEventListener("click", closeModal); - modalContinue.addEventListener("click", closeModal); + const rollupContinueButton = document.getElementById("rollupModalContinue"); - window.addEventListener("click", function(event) { + // Handle modal outside click + window.addEventListener("click", (event) => { if (event.target === modal) { - closeModal(); + modal.style.display = "none"; + rollupContinueButton.style.display = "none"; } }); - document.getElementById("showRollupSelection").addEventListener("click", function() { - let selectedRows = table.rows({ selected: true }).nodes().toArray(); + createRollupButton.addEventListener("click", () => { + {% if oauth_client_id.is_none() %} + alert("Both CLIENT_ID and CLIENT_SECRET must be set to enabled OAuth."); + return; + {% endif %} + + const selectedRows = table.rows({ selected: true }).nodes().toArray(); let message; if (selectedRows.length === 0) { message = "No PRs selected for rollup."; - modalContinue.style.display = "none"; + rollupContinueButton.style.display = "none"; } else { message = `You've selected ${selectedRows.length} PR(s) to be included in this rollup.

A rollup is useful for shortening the queue, but jumping the queue is unfair to older PRs who have waited too long.

When creating a real rollup, see the instructions for reference.`; - modalContinue.style.display = "inline-block"; + rollupContinueButton.style.display = "inline-block"; } modalMessage.innerHTML = message; modal.style.display = "block"; }); + + rollupContinueButton.addEventListener("click", () => { + const scopes = ["public_repo", "workflow"]; + + // Gather PR numbers + let selectedRows = table.rows({ selected: true }).nodes().toArray(); + let nums = selectedRows + .map(row => { + let numberCell = row.cells[1]; + return numberCell?.dataset?.prNumber ? parseInt(numberCell.dataset.prNumber) : null; + }) + .filter(num => num !== null); + + let state = JSON.stringify({ + pr_nums: nums, + repo_name: "{{ repo_name }}", + repo_owner: "{{ repo_owner }}" + }); + + const oauthUrl = new URL("https://github.com/login/oauth/authorize"); + oauthUrl.searchParams.set("client_id", "{{ oauth_client_id.as_ref().unwrap() }}"); + oauthUrl.searchParams.set("scope", scopes.join(",")); + oauthUrl.searchParams.set("state", state); + + window.location.href = oauthUrl.toString(); + }); {% endblock %}