diff --git a/crates/rust-mcp-transport/src/client_sse.rs b/crates/rust-mcp-transport/src/client_sse.rs index e7acb41..9a7183e 100644 --- a/crates/rust-mcp-transport/src/client_sse.rs +++ b/crates/rust-mcp-transport/src/client_sse.rs @@ -206,6 +206,8 @@ where let max_retries = self.max_retries; let retry_delay = self.retry_delay; + let custom_headers = self.custom_headers.clone(); + let read_stream = SseStream { sse_client, sse_url, @@ -218,7 +220,7 @@ where let cancellation_token_sse = cancellation_token.clone(); let sse_task_handle = tokio::spawn(async move { read_stream - .run(endpoint_event_tx, cancellation_token_sse) + .run(endpoint_event_tx, cancellation_token_sse, &custom_headers) .await; }); let mut sse_task_lock = self.sse_task.write().await; diff --git a/crates/rust-mcp-transport/src/utils/sse_stream.rs b/crates/rust-mcp-transport/src/utils/sse_stream.rs index aac29c8..e71ff18 100644 --- a/crates/rust-mcp-transport/src/utils/sse_stream.rs +++ b/crates/rust-mcp-transport/src/utils/sse_stream.rs @@ -1,4 +1,5 @@ use bytes::{Bytes, BytesMut}; +use reqwest::header::{HeaderMap, HeaderValue, ACCEPT}; use reqwest::Client; use std::time::Duration; use tokio::sync::{mpsc, oneshot}; @@ -39,11 +40,15 @@ impl SseStream { &self, mut endpoint_event_tx: Option>>, cancellation_token: CancellationToken, + custom_headers: &Option, ) { let mut retry_count = 0; let mut buffer = BytesMut::with_capacity(BUFFER_CAPACITY); let mut endpoint_event_received = false; + let mut request_headers: HeaderMap = custom_headers.to_owned().unwrap_or_default(); + request_headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream")); + // Main loop for reconnection attempts loop { // Check for cancellation before attempting connection @@ -56,7 +61,7 @@ impl SseStream { let response = match self .sse_client .get(&self.sse_url) - .header("Accept", "text/event-stream") + .headers(request_headers.clone()) .send() .await { @@ -86,7 +91,18 @@ impl SseStream { chunk = stream.next() => { match chunk { Some(chunk) => chunk, - None => break, // Stream ended, break from inner loop to reconnect + None => { + if retry_count >= self.max_retries { + tracing::error!("Max retries ({}) reached, giving up",self.max_retries); + if let Some(tx) = endpoint_event_tx.take() { + let _ = tx.send(None); + } + return; + } + retry_count += 1; + time::sleep(self.retry_delay).await; + break; // Stream ended, break from inner loop to reconnect + } } } // Wait for cancellation @@ -177,4 +193,81 @@ impl SseStream { } #[cfg(test)] -mod tests {} +mod tests { + use super::*; + use crate::utils::CancellationTokenSource; + use reqwest::header::{HeaderMap, HeaderValue}; + use tokio::time::Duration; + use wiremock::matchers::{header, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + #[tokio::test] + async fn test_sse_client_sends_custom_headers_on_connection() { + // Start WireMock server + let mock_server = MockServer::builder().start().await; + + // Create WireMock stub with connection close + Mock::given(method("GET")) + .and(path("/sse")) + .and(header("Accept", "text/event-stream")) + .and(header("X-Custom-Header", "CustomValue")) + .respond_with( + ResponseTemplate::new(200) + .set_body_string("event: endpoint\ndata: mock-endpoint\n\n") + .append_header("Content-Type", "text/event-stream") + .append_header("Connection", "close"), // Ensure connection closes + ) + .expect(1) // Expect exactly one request + .mount(&mock_server) + .await; + + // Create custom headers + let mut custom_headers = HeaderMap::new(); + custom_headers.insert("X-Custom-Header", HeaderValue::from_static("CustomValue")); + + // Create channel and SseStream + let (read_tx, _read_rx) = mpsc::channel::(64); + let sse = SseStream { + sse_client: reqwest::Client::new(), + sse_url: format!("{}/sse", mock_server.uri()), + max_retries: 0, // to receive one request only + retry_delay: Duration::from_millis(100), + read_tx, + }; + + // Create cancellation token and endpoint channel + let (cancellation_source, cancellation_token) = CancellationTokenSource::new(); + let (endpoint_event_tx, endpoint_event_rx) = oneshot::channel::>(); + + // Spawn the run method + let sse_task = tokio::spawn({ + async move { + sse.run( + Some(endpoint_event_tx), + cancellation_token, + &Some(custom_headers), + ) + .await; + } + }); + + // Wait for the endpoint event or timeout + let event_result = + tokio::time::timeout(Duration::from_millis(500), endpoint_event_rx).await; + + // Cancel the task to ensure loop exits + let _ = cancellation_source.cancel(); + + // Wait for the task to complete with a timeout + match tokio::time::timeout(Duration::from_secs(1), sse_task).await { + Ok(result) => result.unwrap(), + Err(_) => panic!("Test timed out after 1 second"), + } + + // Verify the endpoint event was received + match event_result { + Ok(Ok(Some(event))) => assert_eq!(event, "mock-endpoint", "Expected endpoint event"), + _ => panic!("Did not receive expected endpoint event"), + } + } +}