From f9861b9748448337ae542b2f7119c40a6dcb3f3e Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sun, 25 May 2025 10:55:13 -0300 Subject: [PATCH 1/3] feat: ensure unnecessary dependencies are excluded in build unless related feature is enabled --- Cargo.lock | 2 - Cargo.toml | 2 +- crates/rust-mcp-sdk/Cargo.toml | 14 ++-- crates/rust-mcp-transport/Cargo.toml | 17 ++--- crates/rust-mcp-transport/src/error.rs | 4 +- crates/rust-mcp-transport/src/lib.rs | 5 +- crates/rust-mcp-transport/src/utils.rs | 69 ++++--------------- .../src/utils/cancellation_token.rs | 1 + .../src/utils/http_utils.rs | 50 ++++++++++++++ 9 files changed, 89 insertions(+), 75 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1802dcb..bc142a4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1650,7 +1650,6 @@ name = "rust-mcp-transport" version = "0.3.1" dependencies = [ "async-trait", - "axum", "bytes", "futures", "reqwest", @@ -1661,7 +1660,6 @@ dependencies = [ "tokio", "tokio-stream", "tracing", - "uuid", "wiremock", ] diff --git a/Cargo.toml b/Cargo.toml index 3ba93c8..dd51f2e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ members = [ [workspace.dependencies] # Workspace member crates -rust-mcp-transport = { version = "0.3.1", path = "crates/rust-mcp-transport" } +rust-mcp-transport = { version = "0.3.1", path = "crates/rust-mcp-transport", default-features = false } rust-mcp-sdk = { path = "crates/rust-mcp-sdk", default-features = false } rust-mcp-macros = { version = "0.3.0", path = "crates/rust-mcp-macros" } diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml index 46be6b2..0b901bd 100644 --- a/crates/rust-mcp-sdk/Cargo.toml +++ b/crates/rust-mcp-sdk/Cargo.toml @@ -12,7 +12,7 @@ edition = "2021" [dependencies] rust-mcp-schema = { workspace = true } -rust-mcp-transport = { workspace = true } +rust-mcp-transport = { workspace = true, default-features = false, optional = true } rust-mcp-macros = { workspace = true, optional = true } tokio.workspace = true @@ -47,9 +47,15 @@ default = [ "hyper-server", "ssl", ] # All features enabled by default -server = [] # Server feature -client = [] # Client feature -hyper-server = ["axum", "axum-server", "uuid", "tokio-stream"] +server = ["rust-mcp-transport/stdio"] # Server feature +client = ["rust-mcp-transport/stdio", "rust-mcp-transport/sse"] # Client feature +hyper-server = [ + "axum", + "axum-server", + "uuid", + "tokio-stream", + "rust-mcp-transport/sse", +] ssl = ["axum-server/tls-rustls"] macros = ["rust-mcp-macros"] diff --git a/crates/rust-mcp-transport/Cargo.toml b/crates/rust-mcp-transport/Cargo.toml index 0630568..071fe65 100644 --- a/crates/rust-mcp-transport/Cargo.toml +++ b/crates/rust-mcp-transport/Cargo.toml @@ -14,17 +14,15 @@ edition = "2021" rust-mcp-schema = { workspace = true } async-trait = { workspace = true } -tokio = { workspace = true } futures = { workspace = true } thiserror = { workspace = true } serde_json = { workspace = true } serde = { workspace = true } -axum = { workspace = true } -uuid = { workspace = true, features = ["v4"] } tokio-stream = { workspace = true } -reqwest = { workspace = true, features = ["stream"] } bytes = { workspace = true } tracing = { workspace = true } +tokio = { workspace = true } +reqwest = { workspace = true, features = ["stream"], optional = true } [dev-dependencies] wiremock = "0.5" @@ -34,10 +32,9 @@ futures = { workspace = true } workspace = true -# ### FEATURES ################################################################# -# [features] - -# default = ["stdio", "sse"] # Default features +### FEATURES ################################################################# +[features] +default = ["stdio", "sse"] # Default features -# stdio = [] -# sse = [] +stdio = [] +sse = ["reqwest"] diff --git a/crates/rust-mcp-transport/src/error.rs b/crates/rust-mcp-transport/src/error.rs index 55ae29e..e9dfa13 100644 --- a/crates/rust-mcp-transport/src/error.rs +++ b/crates/rust-mcp-transport/src/error.rs @@ -1,12 +1,11 @@ use rust_mcp_schema::{schema_utils::SdkError, RpcError}; use thiserror::Error; +use crate::utils::CancellationError; use core::fmt; use std::any::Any; use tokio::sync::broadcast; -use crate::utils::CancellationError; - /// A wrapper around a broadcast send error. This structure allows for generic error handling /// by boxing the underlying error into a type-erased form. #[derive(Debug)] @@ -99,6 +98,7 @@ pub enum TransportError { FromString(String), #[error("{0}")] OneshotRecvError(#[from] tokio::sync::oneshot::error::RecvError), + #[cfg(feature = "sse")] #[error("{0}")] SendMessageError(#[from] reqwest::Error), #[error("Http Error: {0}")] diff --git a/crates/rust-mcp-transport/src/lib.rs b/crates/rust-mcp-transport/src/lib.rs index 31d810d..c8779c2 100644 --- a/crates/rust-mcp-transport/src/lib.rs +++ b/crates/rust-mcp-transport/src/lib.rs @@ -1,18 +1,21 @@ // Copyright (c) 2025 mcp-rust-stack // Licensed under the MIT License. See LICENSE file for details. // Modifications to this file must be documented with a description of the changes made. - +#[cfg(feature = "sse")] mod client_sse; pub mod error; mod mcp_stream; mod message_dispatcher; +#[cfg(feature = "sse")] mod sse; mod stdio; mod transport; mod utils; +#[cfg(feature = "sse")] pub use client_sse::*; pub use message_dispatcher::*; +#[cfg(feature = "sse")] pub use sse::*; pub use stdio::*; pub use transport::*; diff --git a/crates/rust-mcp-transport/src/utils.rs b/crates/rust-mcp-transport/src/utils.rs index 4e8a2d7..06ee174 100644 --- a/crates/rust-mcp-transport/src/utils.rs +++ b/crates/rust-mcp-transport/src/utils.rs @@ -1,22 +1,30 @@ mod cancellation_token; +#[cfg(feature = "sse")] mod http_utils; +#[cfg(feature = "sse")] mod readable_channel; +#[cfg(feature = "sse")] mod sse_stream; +#[cfg(feature = "sse")] mod writable_channel; pub(crate) use cancellation_token::*; +#[cfg(feature = "sse")] pub(crate) use http_utils::*; +#[cfg(feature = "sse")] pub(crate) use readable_channel::*; +#[cfg(feature = "sse")] pub(crate) use sse_stream::*; +#[cfg(feature = "sse")] pub(crate) use writable_channel::*; use rust_mcp_schema::schema_utils::SdkError; use tokio::time::{timeout, Duration}; -use crate::{ - error::{TransportError, TransportResult}, - SessionId, -}; +use crate::error::{TransportError, TransportResult}; + +#[cfg(feature = "sse")] +use crate::SessionId; pub async fn await_timeout(operation: F, timeout_duration: Duration) -> TransportResult where @@ -29,21 +37,6 @@ where } } -pub fn extract_origin(url: &str) -> Option { - // Remove the fragment first (everything after '#') - let url = url.split('#').next()?; // Keep only part before `#` - - // Split scheme and the rest - let (scheme, rest) = url.split_once("://")?; - - // Get host and optionally the port (before first '/') - let end = rest.find('/').unwrap_or(rest.len()); - let host_port = &rest[..end]; - - // Reconstruct origin - Some(format!("{}://{}", scheme, host_port)) -} - /// Adds a session ID as a query parameter to a given endpoint URL. /// /// # Arguments @@ -53,6 +46,7 @@ pub fn extract_origin(url: &str) -> Option { /// # Returns /// A String containing the endpoint with the session ID added as a query parameter /// +#[cfg(feature = "sse")] pub(crate) fn endpoint_with_session_id(endpoint: &str, session_id: &SessionId) -> String { // Handle empty endpoint let base = if endpoint.is_empty() { "/" } else { endpoint }; @@ -84,45 +78,10 @@ pub(crate) fn endpoint_with_session_id(endpoint: &str, session_id: &SessionId) - } } +#[cfg(feature = "sse")] #[cfg(test)] mod tests { use super::*; - - #[test] - fn test_extract_origin_with_path() { - let url = "https://example.com:8080/some/path"; - assert_eq!( - extract_origin(url), - Some("https://example.com:8080".to_string()) - ); - } - - #[test] - fn test_extract_origin_without_path() { - let url = "https://example.com"; - assert_eq!(extract_origin(url), Some("https://example.com".to_string())); - } - - #[test] - fn test_extract_origin_with_fragment() { - let url = "https://example.com:8080/path#section"; - assert_eq!( - extract_origin(url), - Some("https://example.com:8080".to_string()) - ); - } - - #[test] - fn test_extract_origin_invalid_url() { - let url = "example.com/path"; - assert_eq!(extract_origin(url), None); - } - - #[test] - fn test_extract_origin_empty_string() { - assert_eq!(extract_origin(""), None); - } - #[test] fn test_endpoint_with_session_id() { let session_id: SessionId = "AAA".to_string(); diff --git a/crates/rust-mcp-transport/src/utils/cancellation_token.rs b/crates/rust-mcp-transport/src/utils/cancellation_token.rs index 84f5b78..b039f13 100644 --- a/crates/rust-mcp-transport/src/utils/cancellation_token.rs +++ b/crates/rust-mcp-transport/src/utils/cancellation_token.rs @@ -74,6 +74,7 @@ impl CancellationToken { /// /// # Returns /// * `bool` - True if cancellation is requested, false otherwise + #[allow(unused)] pub fn is_cancelled(&self) -> bool { *self.receiver.borrow() } diff --git a/crates/rust-mcp-transport/src/utils/http_utils.rs b/crates/rust-mcp-transport/src/utils/http_utils.rs index dc0237a..f8403e7 100644 --- a/crates/rust-mcp-transport/src/utils/http_utils.rs +++ b/crates/rust-mcp-transport/src/utils/http_utils.rs @@ -34,6 +34,21 @@ pub async fn http_post( Ok(()) } +pub fn extract_origin(url: &str) -> Option { + // Remove the fragment first (everything after '#') + let url = url.split('#').next()?; // Keep only part before `#` + + // Split scheme and the rest + let (scheme, rest) = url.split_once("://")?; + + // Get host and optionally the port (before first '/') + let end = rest.find('/').unwrap_or(rest.len()); + let host_port = &rest[..end]; + + // Reconstruct origin + Some(format!("{}://{}", scheme, host_port)) +} + #[cfg(test)] mod tests { use super::*; @@ -147,4 +162,39 @@ mod tests { // Assert the result is an error (likely a connection error) assert!(result.is_err()); } + + #[test] + fn test_extract_origin_with_path() { + let url = "https://example.com:8080/some/path"; + assert_eq!( + extract_origin(url), + Some("https://example.com:8080".to_string()) + ); + } + + #[test] + fn test_extract_origin_without_path() { + let url = "https://example.com"; + assert_eq!(extract_origin(url), Some("https://example.com".to_string())); + } + + #[test] + fn test_extract_origin_with_fragment() { + let url = "https://example.com:8080/path#section"; + assert_eq!( + extract_origin(url), + Some("https://example.com:8080".to_string()) + ); + } + + #[test] + fn test_extract_origin_invalid_url() { + let url = "example.com/path"; + assert_eq!(extract_origin(url), None); + } + + #[test] + fn test_extract_origin_empty_string() { + assert_eq!(extract_origin(""), None); + } } From aa97ce86f48c7c57c50a9fc433c46d3520d6841f Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sun, 25 May 2025 11:01:17 -0300 Subject: [PATCH 2/3] chore: test-modules --- Cargo.lock | 29 +- .../rust-mcp-sdk/tests/common/test_server.rs | 188 ++++---- crates/rust-mcp-sdk/tests/test_server_sse.rs | 412 +++++++++--------- 3 files changed, 318 insertions(+), 311 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bc142a4..916be96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -257,9 +257,9 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "cc" -version = "1.2.23" +version = "1.2.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f4ac86a9e5bc1e2b3449ab9d7d3a6a405e3d1bb28d7b9be8614f55846ae3766" +checksum = "16595d3be041c03b09d08d0858631facccee9221e579704070e6e9e4915d3bc7" dependencies = [ "jobserver", "libc", @@ -884,11 +884,10 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.5" +version = "0.27.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" +checksum = "03a01595e11bdcec50946522c32dde3fc6914743000a68b93000965f2f02406d" dependencies = [ - "futures-util", "http 1.3.1", "hyper 1.6.0", "hyper-util", @@ -1214,13 +1213,13 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" +checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" dependencies = [ "libc", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -1614,9 +1613,9 @@ dependencies = [ [[package]] name = "rust-mcp-schema" -version = "0.5.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49212f1da431236217031807377e6296db06a270224698c426afa94e5dacd8e7" +checksum = "40fc3768cfcc6756ebc6c91f5d52abcfacac20cb953010483d52524ed7f08eaf" dependencies = [ "serde", "serde_json", @@ -2112,9 +2111,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.45.0" +version = "1.45.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2513ca694ef9ede0fb23fe71a4ee4107cb102b9dc1930f6d0fd77aae068ae165" +checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" dependencies = [ "backtrace", "bytes", @@ -2311,11 +2310,13 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" +checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" dependencies = [ "getrandom 0.3.3", + "js-sys", + "wasm-bindgen", ] [[package]] diff --git a/crates/rust-mcp-sdk/tests/common/test_server.rs b/crates/rust-mcp-sdk/tests/common/test_server.rs index dcb4e1b..8bb2e48 100644 --- a/crates/rust-mcp-sdk/tests/common/test_server.rs +++ b/crates/rust-mcp-sdk/tests/common/test_server.rs @@ -1,118 +1,120 @@ -use async_trait::async_trait; -use tokio_stream::StreamExt; +pub mod test_server_common { + use async_trait::async_trait; + use tokio_stream::StreamExt; -use rust_mcp_schema::{ - Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, - LATEST_PROTOCOL_VERSION, -}; -use rust_mcp_sdk::{ - mcp_server::{hyper_server, HyperServer, HyperServerOptions, IdGenerator, ServerHandler}, - McpServer, SessionId, -}; -use std::sync::RwLock; -use std::time::Duration; -use tokio::time::timeout; + use rust_mcp_schema::{ + Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools, + LATEST_PROTOCOL_VERSION, + }; + use rust_mcp_sdk::{ + mcp_server::{hyper_server, HyperServer, HyperServerOptions, IdGenerator, ServerHandler}, + McpServer, SessionId, + }; + use std::sync::RwLock; + use std::time::Duration; + use tokio::time::timeout; -pub const INITIALIZE_REQUEST: &str = r#"{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2.0","capabilities":{"sampling":{},"roots":{"listChanged":true}},"clientInfo":{"name":"reqwest-test","version":"0.1.0"}}}"#; -pub const PING_REQUEST: &str = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#; + pub const INITIALIZE_REQUEST: &str = r#"{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2.0","capabilities":{"sampling":{},"roots":{"listChanged":true}},"clientInfo":{"name":"reqwest-test","version":"0.1.0"}}}"#; + pub const PING_REQUEST: &str = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#; -pub fn test_server_details() -> InitializeResult { - InitializeResult { - // server name and version - server_info: Implementation { - name: "Test MCP Server".to_string(), - version: "0.1.0".to_string(), - }, - capabilities: ServerCapabilities { - // indicates that server support mcp tools - tools: Some(ServerCapabilitiesTools { list_changed: None }), - ..Default::default() // Using default values for other fields - }, - meta: None, - instructions: Some("server instructions...".to_string()), - protocol_version: LATEST_PROTOCOL_VERSION.to_string(), + pub fn test_server_details() -> InitializeResult { + InitializeResult { + // server name and version + server_info: Implementation { + name: "Test MCP Server".to_string(), + version: "0.1.0".to_string(), + }, + capabilities: ServerCapabilities { + // indicates that server support mcp tools + tools: Some(ServerCapabilitiesTools { list_changed: None }), + ..Default::default() // Using default values for other fields + }, + meta: None, + instructions: Some("server instructions...".to_string()), + protocol_version: LATEST_PROTOCOL_VERSION.to_string(), + } } -} -pub struct TestServerHandler; + pub struct TestServerHandler; -#[async_trait] -impl ServerHandler for TestServerHandler { - async fn on_server_started(&self, runtime: &dyn McpServer) { - let _ = runtime - .stderr_message("Server started successfully".into()) - .await; + #[async_trait] + impl ServerHandler for TestServerHandler { + async fn on_server_started(&self, runtime: &dyn McpServer) { + let _ = runtime + .stderr_message("Server started successfully".into()) + .await; + } } -} -pub fn create_test_server(options: HyperServerOptions) -> HyperServer { - hyper_server::create_server(test_server_details(), TestServerHandler {}, options) -} + pub fn create_test_server(options: HyperServerOptions) -> HyperServer { + hyper_server::create_server(test_server_details(), TestServerHandler {}, options) + } -// Tests the session ID generator, ensuring it returns a sequence of predefined session IDs. -pub struct TestIdGenerator { - constant_ids: Vec, - generated: RwLock, -} + // Tests the session ID generator, ensuring it returns a sequence of predefined session IDs. + pub struct TestIdGenerator { + constant_ids: Vec, + generated: RwLock, + } -impl TestIdGenerator { - pub fn new(constant_ids: Vec) -> Self { - TestIdGenerator { - constant_ids, - generated: RwLock::new(0), + impl TestIdGenerator { + pub fn new(constant_ids: Vec) -> Self { + TestIdGenerator { + constant_ids, + generated: RwLock::new(0), + } } } -} -impl IdGenerator for TestIdGenerator { - fn generate(&self) -> SessionId { - let mut lock = self.generated.write().unwrap(); - *lock += 1; - if *lock > self.constant_ids.len() { - *lock = 1; + impl IdGenerator for TestIdGenerator { + fn generate(&self) -> SessionId { + let mut lock = self.generated.write().unwrap(); + *lock += 1; + if *lock > self.constant_ids.len() { + *lock = 1; + } + self.constant_ids[*lock - 1].to_owned() } - self.constant_ids[*lock - 1].to_owned() } -} -pub async fn collect_sse_lines( - response: reqwest::Response, - line_count: usize, - read_timeout: Duration, -) -> Result, Box> { - let mut collected_lines = Vec::new(); - let mut stream = response.bytes_stream(); + pub async fn collect_sse_lines( + response: reqwest::Response, + line_count: usize, + read_timeout: Duration, + ) -> Result, Box> { + let mut collected_lines = Vec::new(); + let mut stream = response.bytes_stream(); - let result = timeout(read_timeout, async { - while let Some(chunk) = stream.next().await { - let chunk = chunk.map_err(|e| Box::new(e) as Box)?; - let chunk_str = String::from_utf8_lossy(&chunk); + let result = timeout(read_timeout, async { + while let Some(chunk) = stream.next().await { + let chunk = chunk.map_err(|e| Box::new(e) as Box)?; + let chunk_str = String::from_utf8_lossy(&chunk); - // Split the chunk into lines - let lines: Vec<&str> = chunk_str.lines().collect(); + // Split the chunk into lines + let lines: Vec<&str> = chunk_str.lines().collect(); - // Add each line to the collected_lines vector - for line in lines { - collected_lines.push(line.to_string()); + // Add each line to the collected_lines vector + for line in lines { + collected_lines.push(line.to_string()); - // Check if we have collected 5 lines - if collected_lines.len() >= line_count { - return Ok(collected_lines); + // Check if we have collected 5 lines + if collected_lines.len() >= line_count { + return Ok(collected_lines); + } } } - } - // If the stream ends before collecting 5 lines, return what we have - Ok(collected_lines) - }) - .await; + // If the stream ends before collecting 5 lines, return what we have + Ok(collected_lines) + }) + .await; - // Handle timeout or stream result - match result { - Ok(Ok(lines)) => Ok(lines), - Ok(Err(e)) => Err(e), - Err(_) => Err(Box::new(std::io::Error::new( - std::io::ErrorKind::TimedOut, - "Timed out waiting for 5 lines", - ))), + // Handle timeout or stream result + match result { + Ok(Ok(lines)) => Ok(lines), + Ok(Err(e)) => Err(e), + Err(_) => Err(Box::new(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "Timed out waiting for 5 lines", + ))), + } } } diff --git a/crates/rust-mcp-sdk/tests/test_server_sse.rs b/crates/rust-mcp-sdk/tests/test_server_sse.rs index ba7df51..6705d29 100644 --- a/crates/rust-mcp-sdk/tests/test_server_sse.rs +++ b/crates/rust-mcp-sdk/tests/test_server_sse.rs @@ -1,211 +1,215 @@ -use std::{sync::Arc, time::Duration}; - -use common::{ - collect_sse_lines, create_test_server, sse_data, sse_event, TestIdGenerator, INITIALIZE_REQUEST, -}; -use reqwest::Client; -use rust_mcp_schema::{ - schema_utils::{ResultFromServer, ServerMessage}, - ServerResult, -}; -use rust_mcp_sdk::mcp_server::HyperServerOptions; -use tokio::time::sleep; - #[path = "common/common.rs"] pub mod common; - -#[tokio::test] -async fn tets_sse_endpoint_event_default() { - let server_options = HyperServerOptions { - port: 8081, - session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ - "AAA-BBB-CCC".to_string() - ]))), - ..Default::default() - }; - - let base_url = format!("http://{}:{}", server_options.host, server_options.port); - - let server_endpoint = format!("{}{}", base_url, server_options.sse_endpoint()); - - let server = create_test_server(server_options); - let handle = server.server_handle(); - let server_task = tokio::spawn(async move { - server.start().await.unwrap(); - eprintln!("Server 1 is down"); - }); - - sleep(Duration::from_millis(750)).await; - - let client = Client::new(); - println!("connecting to : {}", server_endpoint); - // Act: Connect to the SSE endpoint and read the event stream - let response = client - .get(server_endpoint) - .header("Accept", "text/event-stream") - .send() - .await - .expect("Failed to connect to SSE endpoint"); - - assert_eq!( - response.headers().get("content-type").map(|v| v.as_bytes()), - Some(b"text/event-stream" as &[u8]), - "Response content-type should be text/event-stream" - ); - - let lines = collect_sse_lines(response, 2, Duration::from_secs(5)) - .await - .unwrap(); - - assert_eq!(sse_event(&lines[0]), "endpoint"); - assert_eq!(sse_data(&lines[1]), "/messages?sessionId=AAA-BBB-CCC"); - - let message_endpoint = format!("{}{}", base_url, sse_data(&lines[1])); - let res = client - .post(message_endpoint) - .header("Content-Type", "application/json") - .body(INITIALIZE_REQUEST.to_string()) - .send() - .await - .unwrap(); - assert!(res.status().is_success()); - handle.graceful_shutdown(Some(Duration::from_millis(1))); - server_task.await.unwrap(); -} - -#[tokio::test] -async fn tets_sse_message_endpoint_query_hash() { - let server_options = HyperServerOptions { - port: 8082, - custom_messages_endpoint: Some( - "/custom-msg-endpoint?something=true&otherthing=false#section-59".to_string(), - ), - session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ - "AAA-BBB-CCC".to_string() - ]))), - ..Default::default() +mod tets_mest { + use std::{sync::Arc, time::Duration}; + + use crate::common::{ + sse_data, sse_event, + test_server_common::{ + collect_sse_lines, create_test_server, TestIdGenerator, INITIALIZE_REQUEST, + }, }; - - let base_url = format!("http://{}:{}", server_options.host, server_options.port); - - let server_endpoint = format!("{}{}", base_url, server_options.sse_endpoint()); - - let server = create_test_server(server_options); - let handle = server.server_handle(); - - let server_task = tokio::spawn(async move { - server.start().await.unwrap(); - eprintln!("Server 2 is down"); - }); - - sleep(Duration::from_millis(750)).await; - - let client = Client::new(); - println!("connecting to : {}", server_endpoint); - // Act: Connect to the SSE endpoint and read the event stream - let response = client - .get(server_endpoint) - .header("Accept", "text/event-stream") - .send() - .await - .expect("Failed to connect to SSE endpoint"); - - assert_eq!( - response.headers().get("content-type").map(|v| v.as_bytes()), - Some(b"text/event-stream" as &[u8]), - "Response content-type should be text/event-stream" - ); - - let lines = collect_sse_lines(response, 2, Duration::from_secs(5)) - .await - .unwrap(); - - assert_eq!(sse_event(&lines[0]), "endpoint"); - assert_eq!( - sse_data(&lines[1]), - "/custom-msg-endpoint?something=true&otherthing=false&sessionId=AAA-BBB-CCC#section-59" - ); - - let message_endpoint = format!("{}{}", base_url, sse_data(&lines[1])); - let res = client - .post(message_endpoint) - .header("Content-Type", "application/json") - .body(INITIALIZE_REQUEST.to_string()) - .send() - .await - .unwrap(); - assert!(res.status().is_success()); - handle.graceful_shutdown(Some(Duration::from_millis(1))); - server_task.await.unwrap(); -} - -#[tokio::test] -async fn tets_sse_custom_message_endpoint() { - let server_options = HyperServerOptions { - port: 8083, - custom_messages_endpoint: Some( - "/custom-msg-endpoint?something=true&otherthing=false#section-59".to_string(), - ), - session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ - "AAA-BBB-CCC".to_string() - ]))), - ..Default::default() + use reqwest::Client; + use rust_mcp_schema::{ + schema_utils::{ResultFromServer, ServerMessage}, + ServerResult, }; - - let base_url = format!("http://{}:{}", server_options.host, server_options.port); - - let server_endpoint = format!("{}{}", base_url, server_options.sse_endpoint()); - - let server = create_test_server(server_options); - let handle = server.server_handle(); - - let server_task = tokio::spawn(async move { - server.start().await.unwrap(); - eprintln!("Server 3 is down"); - }); - - sleep(Duration::from_millis(750)).await; - - let client = Client::new(); - println!("connecting to : {}", server_endpoint); - // Act: Connect to the SSE endpoint and read the event stream - let response = client - .get(server_endpoint) - .header("Accept", "text/event-stream") - .send() - .await - .expect("Failed to connect to SSE endpoint"); - - assert_eq!( - response.headers().get("content-type").map(|v| v.as_bytes()), - Some(b"text/event-stream" as &[u8]), - "Response content-type should be text/event-stream" - ); - - let message_endpoint = format!( - "{}{}", - base_url, - "/custom-msg-endpoint?something=true&otherthing=false&sessionId=AAA-BBB-CCC#section-59" - ); - let res = client - .post(message_endpoint) - .header("Content-Type", "application/json") - .body(INITIALIZE_REQUEST.to_string()) - .send() - .await - .unwrap(); - assert!(res.status().is_success()); - - let lines = collect_sse_lines(response, 5, Duration::from_secs(5)) - .await - .unwrap(); - - let init_response = sse_data(&lines[3]); - let result = serde_json::from_str::(&init_response).unwrap(); - - assert!(matches!(result, ServerMessage::Response(response) + use rust_mcp_sdk::mcp_server::HyperServerOptions; + use tokio::time::sleep; + + #[tokio::test] + async fn tets_sse_endpoint_event_default() { + let server_options = HyperServerOptions { + port: 8081, + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + ..Default::default() + }; + + let base_url = format!("http://{}:{}", server_options.host, server_options.port); + + let server_endpoint = format!("{}{}", base_url, server_options.sse_endpoint()); + + let server = create_test_server(server_options); + let handle = server.server_handle(); + let server_task = tokio::spawn(async move { + server.start().await.unwrap(); + eprintln!("Server 1 is down"); + }); + + sleep(Duration::from_millis(750)).await; + + let client = Client::new(); + println!("connecting to : {}", server_endpoint); + // Act: Connect to the SSE endpoint and read the event stream + let response = client + .get(server_endpoint) + .header("Accept", "text/event-stream") + .send() + .await + .expect("Failed to connect to SSE endpoint"); + + assert_eq!( + response.headers().get("content-type").map(|v| v.as_bytes()), + Some(b"text/event-stream" as &[u8]), + "Response content-type should be text/event-stream" + ); + + let lines = collect_sse_lines(response, 2, Duration::from_secs(5)) + .await + .unwrap(); + + assert_eq!(sse_event(&lines[0]), "endpoint"); + assert_eq!(sse_data(&lines[1]), "/messages?sessionId=AAA-BBB-CCC"); + + let message_endpoint = format!("{}{}", base_url, sse_data(&lines[1])); + let res = client + .post(message_endpoint) + .header("Content-Type", "application/json") + .body(INITIALIZE_REQUEST.to_string()) + .send() + .await + .unwrap(); + assert!(res.status().is_success()); + handle.graceful_shutdown(Some(Duration::from_millis(1))); + server_task.await.unwrap(); + } + + #[tokio::test] + async fn tets_sse_message_endpoint_query_hash() { + let server_options = HyperServerOptions { + port: 8082, + custom_messages_endpoint: Some( + "/custom-msg-endpoint?something=true&otherthing=false#section-59".to_string(), + ), + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + ..Default::default() + }; + + let base_url = format!("http://{}:{}", server_options.host, server_options.port); + + let server_endpoint = format!("{}{}", base_url, server_options.sse_endpoint()); + + let server = create_test_server(server_options); + let handle = server.server_handle(); + + let server_task = tokio::spawn(async move { + server.start().await.unwrap(); + eprintln!("Server 2 is down"); + }); + + sleep(Duration::from_millis(750)).await; + + let client = Client::new(); + println!("connecting to : {}", server_endpoint); + // Act: Connect to the SSE endpoint and read the event stream + let response = client + .get(server_endpoint) + .header("Accept", "text/event-stream") + .send() + .await + .expect("Failed to connect to SSE endpoint"); + + assert_eq!( + response.headers().get("content-type").map(|v| v.as_bytes()), + Some(b"text/event-stream" as &[u8]), + "Response content-type should be text/event-stream" + ); + + let lines = collect_sse_lines(response, 2, Duration::from_secs(5)) + .await + .unwrap(); + + assert_eq!(sse_event(&lines[0]), "endpoint"); + assert_eq!( + sse_data(&lines[1]), + "/custom-msg-endpoint?something=true&otherthing=false&sessionId=AAA-BBB-CCC#section-59" + ); + + let message_endpoint = format!("{}{}", base_url, sse_data(&lines[1])); + let res = client + .post(message_endpoint) + .header("Content-Type", "application/json") + .body(INITIALIZE_REQUEST.to_string()) + .send() + .await + .unwrap(); + assert!(res.status().is_success()); + handle.graceful_shutdown(Some(Duration::from_millis(1))); + server_task.await.unwrap(); + } + + #[tokio::test] + async fn tets_sse_custom_message_endpoint() { + let server_options = HyperServerOptions { + port: 8083, + custom_messages_endpoint: Some( + "/custom-msg-endpoint?something=true&otherthing=false#section-59".to_string(), + ), + session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![ + "AAA-BBB-CCC".to_string() + ]))), + ..Default::default() + }; + + let base_url = format!("http://{}:{}", server_options.host, server_options.port); + + let server_endpoint = format!("{}{}", base_url, server_options.sse_endpoint()); + + let server = create_test_server(server_options); + let handle = server.server_handle(); + + let server_task = tokio::spawn(async move { + server.start().await.unwrap(); + eprintln!("Server 3 is down"); + }); + + sleep(Duration::from_millis(750)).await; + + let client = Client::new(); + println!("connecting to : {}", server_endpoint); + // Act: Connect to the SSE endpoint and read the event stream + let response = client + .get(server_endpoint) + .header("Accept", "text/event-stream") + .send() + .await + .expect("Failed to connect to SSE endpoint"); + + assert_eq!( + response.headers().get("content-type").map(|v| v.as_bytes()), + Some(b"text/event-stream" as &[u8]), + "Response content-type should be text/event-stream" + ); + + let message_endpoint = format!( + "{}{}", + base_url, + "/custom-msg-endpoint?something=true&otherthing=false&sessionId=AAA-BBB-CCC#section-59" + ); + let res = client + .post(message_endpoint) + .header("Content-Type", "application/json") + .body(INITIALIZE_REQUEST.to_string()) + .send() + .await + .unwrap(); + assert!(res.status().is_success()); + + let lines = collect_sse_lines(response, 5, Duration::from_secs(5)) + .await + .unwrap(); + + let init_response = sse_data(&lines[3]); + let result = serde_json::from_str::(&init_response).unwrap(); + + assert!(matches!(result, ServerMessage::Response(response) if matches!(&response.result, ResultFromServer::ServerResult(server_result) if matches!(server_result, ServerResult::InitializeResult(init_result) if init_result.server_info.name == "Test MCP Server")))); - handle.graceful_shutdown(Some(Duration::from_millis(1))); - server_task.await.unwrap(); + handle.graceful_shutdown(Some(Duration::from_millis(1))); + server_task.await.unwrap(); + } } From cd01ebfdac989fdb7d7be107a8d36aad50abb3d2 Mon Sep 17 00:00:00 2001 From: Ali Hashemi Date: Sun, 25 May 2025 11:09:37 -0300 Subject: [PATCH 3/3] chore: cfg sse tests --- crates/rust-mcp-sdk/tests/common/test_server.rs | 1 + crates/rust-mcp-sdk/tests/test_server_sse.rs | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/crates/rust-mcp-sdk/tests/common/test_server.rs b/crates/rust-mcp-sdk/tests/common/test_server.rs index 8bb2e48..c62bb83 100644 --- a/crates/rust-mcp-sdk/tests/common/test_server.rs +++ b/crates/rust-mcp-sdk/tests/common/test_server.rs @@ -1,3 +1,4 @@ +#[cfg(feature = "hyper-server")] pub mod test_server_common { use async_trait::async_trait; use tokio_stream::StreamExt; diff --git a/crates/rust-mcp-sdk/tests/test_server_sse.rs b/crates/rust-mcp-sdk/tests/test_server_sse.rs index 6705d29..5d053c7 100644 --- a/crates/rust-mcp-sdk/tests/test_server_sse.rs +++ b/crates/rust-mcp-sdk/tests/test_server_sse.rs @@ -1,6 +1,7 @@ #[path = "common/common.rs"] pub mod common; -mod tets_mest { +#[cfg(feature = "hyper-server")] +mod tets_server_sse { use std::{sync::Arc, time::Duration}; use crate::common::{