Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion crates/rust-mcp-transport/src/client_sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ where
let max_retries = self.max_retries;
let retry_delay = self.retry_delay;

let custom_headers = self.custom_headers.clone();

let read_stream = SseStream {
sse_client,
sse_url,
Expand All @@ -218,7 +220,7 @@ where
let cancellation_token_sse = cancellation_token.clone();
let sse_task_handle = tokio::spawn(async move {
read_stream
.run(endpoint_event_tx, cancellation_token_sse)
.run(endpoint_event_tx, cancellation_token_sse, &custom_headers)
.await;
});
let mut sse_task_lock = self.sse_task.write().await;
Expand Down
99 changes: 96 additions & 3 deletions crates/rust-mcp-transport/src/utils/sse_stream.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use bytes::{Bytes, BytesMut};
use reqwest::header::{HeaderMap, HeaderValue, ACCEPT};
use reqwest::Client;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
Expand Down Expand Up @@ -39,11 +40,15 @@ impl SseStream {
&self,
mut endpoint_event_tx: Option<oneshot::Sender<Option<String>>>,
cancellation_token: CancellationToken,
custom_headers: &Option<HeaderMap>,
) {
let mut retry_count = 0;
let mut buffer = BytesMut::with_capacity(BUFFER_CAPACITY);
let mut endpoint_event_received = false;

let mut request_headers: HeaderMap = custom_headers.to_owned().unwrap_or_default();
request_headers.insert(ACCEPT, HeaderValue::from_static("text/event-stream"));

// Main loop for reconnection attempts
loop {
// Check for cancellation before attempting connection
Expand All @@ -56,7 +61,7 @@ impl SseStream {
let response = match self
.sse_client
.get(&self.sse_url)
.header("Accept", "text/event-stream")
.headers(request_headers.clone())
.send()
.await
{
Expand Down Expand Up @@ -86,7 +91,18 @@ impl SseStream {
chunk = stream.next() => {
match chunk {
Some(chunk) => chunk,
None => break, // Stream ended, break from inner loop to reconnect
None => {
if retry_count >= self.max_retries {
tracing::error!("Max retries ({}) reached, giving up",self.max_retries);
if let Some(tx) = endpoint_event_tx.take() {
let _ = tx.send(None);
}
return;
}
retry_count += 1;
time::sleep(self.retry_delay).await;
break; // Stream ended, break from inner loop to reconnect
}
}
}
// Wait for cancellation
Expand Down Expand Up @@ -177,4 +193,81 @@ impl SseStream {
}

#[cfg(test)]
mod tests {}
mod tests {
use super::*;
use crate::utils::CancellationTokenSource;
use reqwest::header::{HeaderMap, HeaderValue};
use tokio::time::Duration;
use wiremock::matchers::{header, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};

#[tokio::test]
async fn test_sse_client_sends_custom_headers_on_connection() {
// Start WireMock server
let mock_server = MockServer::builder().start().await;

// Create WireMock stub with connection close
Mock::given(method("GET"))
.and(path("/sse"))
.and(header("Accept", "text/event-stream"))
.and(header("X-Custom-Header", "CustomValue"))
.respond_with(
ResponseTemplate::new(200)
.set_body_string("event: endpoint\ndata: mock-endpoint\n\n")
.append_header("Content-Type", "text/event-stream")
.append_header("Connection", "close"), // Ensure connection closes
)
.expect(1) // Expect exactly one request
.mount(&mock_server)
.await;

// Create custom headers
let mut custom_headers = HeaderMap::new();
custom_headers.insert("X-Custom-Header", HeaderValue::from_static("CustomValue"));

// Create channel and SseStream
let (read_tx, _read_rx) = mpsc::channel::<Bytes>(64);
let sse = SseStream {
sse_client: reqwest::Client::new(),
sse_url: format!("{}/sse", mock_server.uri()),
max_retries: 0, // to receive one request only
retry_delay: Duration::from_millis(100),
read_tx,
};

// Create cancellation token and endpoint channel
let (cancellation_source, cancellation_token) = CancellationTokenSource::new();
let (endpoint_event_tx, endpoint_event_rx) = oneshot::channel::<Option<String>>();

// Spawn the run method
let sse_task = tokio::spawn({
async move {
sse.run(
Some(endpoint_event_tx),
cancellation_token,
&Some(custom_headers),
)
.await;
}
});

// Wait for the endpoint event or timeout
let event_result =
tokio::time::timeout(Duration::from_millis(500), endpoint_event_rx).await;

// Cancel the task to ensure loop exits
let _ = cancellation_source.cancel();

// Wait for the task to complete with a timeout
match tokio::time::timeout(Duration::from_secs(1), sse_task).await {
Ok(result) => result.unwrap(),
Err(_) => panic!("Test timed out after 1 second"),
}

// Verify the endpoint event was received
match event_result {
Ok(Ok(Some(event))) => assert_eq!(event, "mock-endpoint", "Expected endpoint event"),
_ => panic!("Did not receive expected endpoint event"),
}
}
}