diff --git a/Cargo.lock b/Cargo.lock
index a80f199..060ec84 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -239,9 +239,9 @@ dependencies = [
[[package]]
name = "bitflags"
-version = "2.9.0"
+version = "2.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd"
+checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967"
[[package]]
name = "bumpalo"
@@ -257,9 +257,9 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a"
[[package]]
name = "cc"
-version = "1.2.22"
+version = "1.2.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "32db95edf998450acc7881c932f94cd9b05c87b4b2599e8bab064753da4acfd1"
+checksum = "5f4ac86a9e5bc1e2b3449ab9d7d3a6a405e3d1bb28d7b9be8614f55846ae3766"
dependencies = [
"jobserver",
"libc",
@@ -400,9 +400,9 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "errno"
-version = "0.3.11"
+version = "0.3.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e"
+checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18"
dependencies = [
"libc",
"windows-sys 0.59.0",
@@ -917,9 +917,9 @@ dependencies = [
[[package]]
name = "hyper-util"
-version = "0.1.11"
+version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "497bbc33a26fdd4af9ed9c70d63f61cf56a938375fbb32df34db9b1cd6d643f2"
+checksum = "cf9f1e950e0d9d1d3c47184416723cf29c0d1f93bd8cccf37e4beb6b44f31710"
dependencies = [
"bytes",
"futures-channel",
@@ -984,9 +984,9 @@ checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3"
[[package]]
name = "icu_properties"
-version = "2.0.0"
+version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2549ca8c7241c82f59c80ba2a6f415d931c5b58d24fb8412caa1a1f02c49139a"
+checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b"
dependencies = [
"displaydoc",
"icu_collections",
@@ -1000,9 +1000,9 @@ dependencies = [
[[package]]
name = "icu_properties_data"
-version = "2.0.0"
+version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8197e866e47b68f8f7d95249e172903bec06004b18b2937f1095d40a0c57de04"
+checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632"
[[package]]
name = "icu_provider"
@@ -1614,9 +1614,9 @@ dependencies = [
[[package]]
name = "rust-mcp-schema"
-version = "0.4.0"
+version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "868d31d0ae0376ba45786eac9058771da06839e83bb961ac7e5997ab3910f086"
+checksum = "49212f1da431236217031807377e6296db06a270224698c426afa94e5dacd8e7"
dependencies = [
"serde",
"serde_json",
@@ -1631,6 +1631,7 @@ dependencies = [
"axum-server",
"futures",
"hyper 1.6.0",
+ "reqwest",
"rust-mcp-macros",
"rust-mcp-schema",
"rust-mcp-transport",
@@ -1748,9 +1749,9 @@ dependencies = [
[[package]]
name = "rustversion"
-version = "1.0.20"
+version = "1.0.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2"
+checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d"
[[package]]
name = "ryu"
@@ -2514,9 +2515,9 @@ dependencies = [
[[package]]
name = "windows-result"
-version = "0.3.3"
+version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4b895b5356fc36103d0f64dd1e94dfa7ac5633f1c9dd6e80fe9ec4adef69e09d"
+checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6"
dependencies = [
"windows-link",
]
diff --git a/Cargo.toml b/Cargo.toml
index 986e877..f9e897b 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -21,7 +21,7 @@ rust-mcp-sdk = { path = "crates/rust-mcp-sdk", default-features = false }
rust-mcp-macros = { version = "0.2.1", path = "crates/rust-mcp-macros" }
# External crates
-rust-mcp-schema = { version = "0.4" }
+rust-mcp-schema = { version = "0.5" }
futures = { version = "0.3" }
tokio = { version = "1.4", features = ["full"] }
serde = { version = "1.0", features = ["derive", "serde_derive"] }
diff --git a/crates/rust-mcp-macros/Cargo.toml b/crates/rust-mcp-macros/Cargo.toml
index e63f2ae..2b272c8 100644
--- a/crates/rust-mcp-macros/Cargo.toml
+++ b/crates/rust-mcp-macros/Cargo.toml
@@ -28,3 +28,16 @@ workspace = true
[lib]
proc-macro = true
+
+
+[features]
+# defalt features
+default = ["2025_03_26"] # Default features
+
+# activates the latest MCP schema version, this will be updated once a new version of schema is published
+latest = ["2025_03_26"]
+
+# enabled mcp schema version 2025_03_26
+2025_03_26 = ["rust-mcp-schema/2025_03_26"]
+# enabled mcp schema version 2024_11_05
+2024_11_05 = ["rust-mcp-schema/2024_11_05"]
diff --git a/crates/rust-mcp-macros/README.md b/crates/rust-mcp-macros/README.md
index 5246a5b..6f6c956 100644
--- a/crates/rust-mcp-macros/README.md
+++ b/crates/rust-mcp-macros/README.md
@@ -19,6 +19,10 @@ The `mcp_tool` macro generates an implementation for the annotated struct that i
#[mcp_tool(
name = "write_file",
description = "Create a new file or completely overwrite an existing file with new content."
+ destructive_hint = false
+ idempotent_hint = false
+ open_world_hint = false
+ read_only_hint = false
)]
#[derive(rust_mcp_macros::JsonSchema)]
pub struct WriteFileTool {
@@ -60,3 +64,11 @@ fn main() {
Check out [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) , a high-performance, asynchronous toolkit for building MCP servers and clients. Focus on your app's logic while [rust-mcp-sdk](https://github.com/rust-mcp-stack/rust-mcp-sdk) takes care of the rest!
---
+
+
+**Note**: The following attributes are available only in version `2025_03_26` and later of the MCP Schema, and their values will be used in the [annotations](https://github.com/rust-mcp-stack/rust-mcp-schema/blob/main/src/generated_schema/2025_03_26/mcp_schema.rs#L5557) attribute of the *[Tool struct](https://github.com/rust-mcp-stack/rust-mcp-schema/blob/main/src/generated_schema/2025_03_26/mcp_schema.rs#L5554-L5566).
+
+- `destructive_hint`
+- `idempotent_hint`
+- `open_world_hint`
+- `read_only_hint`
diff --git a/crates/rust-mcp-macros/src/lib.rs b/crates/rust-mcp-macros/src/lib.rs
index d123b1e..3f35b03 100644
--- a/crates/rust-mcp-macros/src/lib.rs
+++ b/crates/rust-mcp-macros/src/lib.rs
@@ -19,11 +19,37 @@ use utils::{is_option, renamed_field, type_to_json_schema};
/// * `name` - An optional string representing the tool's name.
/// * `description` - An optional string describing the tool.
///
+#[cfg(feature = "2024_11_05")]
struct McpToolMacroAttributes {
name: Option,
description: Option,
}
+/// Represents the attributes for the `mcp_tool` procedural macro.
+///
+/// This struct parses and validates the `name` and `description` attributes provided
+/// to the `mcp_tool` macro. Both attributes are required and must not be empty strings.
+///
+/// # Fields
+/// * `name` - An optional string representing the tool's name.
+/// * `description` - An optional string describing the tool.
+/// * `destructive_hint` - Optional boolean for `ToolAnnotations::destructive_hint`.
+/// * `idempotent_hint` - Optional boolean for `ToolAnnotations::idempotent_hint`.
+/// * `open_world_hint` - Optional boolean for `ToolAnnotations::open_world_hint`.
+/// * `read_only_hint` - Optional boolean for `ToolAnnotations::read_only_hint`.
+/// * `title` - Optional string for `ToolAnnotations::title`.
+///
+#[cfg(feature = "2025_03_26")]
+struct McpToolMacroAttributes {
+ name: Option,
+ description: Option,
+ destructive_hint: Option,
+ idempotent_hint: Option,
+ open_world_hint: Option,
+ read_only_hint: Option,
+ title: Option,
+}
+
use syn::parse::ParseStream;
struct ExprList {
@@ -51,59 +77,102 @@ impl Parse for McpToolMacroAttributes {
fn parse(attributes: syn::parse::ParseStream) -> syn::Result {
let mut name = None;
let mut description = None;
+ let mut destructive_hint = None;
+ let mut idempotent_hint = None;
+ let mut open_world_hint = None;
+ let mut read_only_hint = None;
+ let mut title = None;
+
let meta_list: Punctuated = Punctuated::parse_terminated(attributes)?;
for meta in meta_list {
if let Meta::NameValue(meta_name_value) = meta {
let ident = meta_name_value.path.get_ident().unwrap();
let ident_str = ident.to_string();
- let value = match &meta_name_value.value {
- Expr::Lit(ExprLit {
- lit: Lit::Str(lit_str),
- ..
- }) => lit_str.value(),
-
- Expr::Macro(expr_macro) => {
- let mac = &expr_macro.mac;
- if mac.path.is_ident("concat") {
- let args: ExprList = syn::parse2(mac.tokens.clone())?;
- let mut result = String::new();
-
- for expr in args.exprs {
- if let Expr::Lit(ExprLit {
- lit: Lit::Str(lit_str),
- ..
- }) = expr
- {
- result.push_str(&lit_str.value());
+ match ident_str.as_str() {
+ "name" | "description" => {
+ let value = match &meta_name_value.value {
+ Expr::Lit(ExprLit {
+ lit: Lit::Str(lit_str),
+ ..
+ }) => lit_str.value(),
+ Expr::Macro(expr_macro) => {
+ let mac = &expr_macro.mac;
+ if mac.path.is_ident("concat") {
+ let args: ExprList = syn::parse2(mac.tokens.clone())?;
+ let mut result = String::new();
+ for expr in args.exprs {
+ if let Expr::Lit(ExprLit {
+ lit: Lit::Str(lit_str),
+ ..
+ }) = expr
+ {
+ result.push_str(&lit_str.value());
+ } else {
+ return Err(Error::new_spanned(
+ expr,
+ "Only string literals are allowed inside concat!()",
+ ));
+ }
+ }
+ result
} else {
return Err(Error::new_spanned(
- expr,
- "Only string literals are allowed inside concat!()",
+ expr_macro,
+ "Only concat!(...) is supported here",
));
}
}
-
- result
- } else {
- return Err(Error::new_spanned(
- expr_macro,
- "Only concat!(...) is supported here",
- ));
+ _ => {
+ return Err(Error::new_spanned(
+ &meta_name_value.value,
+ "Expected a string literal or concat!(...)",
+ ));
+ }
+ };
+ match ident_str.as_str() {
+ "name" => name = Some(value),
+ "description" => description = Some(value),
+ _ => {}
}
}
-
- _ => {
- return Err(Error::new_spanned(
- &meta_name_value.value,
- "Expected a string literal or concat!(...)",
- ));
+ "destructive_hint" | "idempotent_hint" | "open_world_hint"
+ | "read_only_hint" => {
+ let value = match &meta_name_value.value {
+ Expr::Lit(ExprLit {
+ lit: Lit::Bool(lit_bool),
+ ..
+ }) => lit_bool.value,
+ _ => {
+ return Err(Error::new_spanned(
+ &meta_name_value.value,
+ "Expected a boolean literal",
+ ));
+ }
+ };
+ match ident_str.as_str() {
+ "destructive_hint" => destructive_hint = Some(value),
+ "idempotent_hint" => idempotent_hint = Some(value),
+ "open_world_hint" => open_world_hint = Some(value),
+ "read_only_hint" => read_only_hint = Some(value),
+ _ => {}
+ }
+ }
+ "title" => {
+ let value = match &meta_name_value.value {
+ Expr::Lit(ExprLit {
+ lit: Lit::Str(lit_str),
+ ..
+ }) => lit_str.value(),
+ _ => {
+ return Err(Error::new_spanned(
+ &meta_name_value.value,
+ "Expected a string literal",
+ ));
+ }
+ };
+ title = Some(value);
}
- };
-
- match ident_str.as_str() {
- "name" => name = Some(value),
- "description" => description = Some(value),
_ => {}
}
}
@@ -116,7 +185,6 @@ impl Parse for McpToolMacroAttributes {
"The 'name' attribute is required and must not be empty.",
));
}
-
if description
.as_ref()
.map(|s| s.trim().is_empty())
@@ -128,7 +196,21 @@ impl Parse for McpToolMacroAttributes {
));
}
- Ok(Self { name, description })
+ #[cfg(feature = "2024_11_05")]
+ let instance = Self { name, description };
+
+ #[cfg(feature = "2025_03_26")]
+ let instance = Self {
+ name,
+ description,
+ destructive_hint,
+ idempotent_hint,
+ open_world_hint,
+ read_only_hint,
+ title,
+ };
+
+ Ok(instance)
}
}
@@ -148,7 +230,7 @@ impl Parse for McpToolMacroAttributes {
///
/// # Example
/// ```rust
-/// #[rust_mcp_macros::mcp_tool(name = "example_tool", description = "An example tool")]
+/// #[rust_mcp_macros::mcp_tool(name = "example_tool", description = "An example tool", idempotent_hint=true )]
/// #[derive(rust_mcp_macros::JsonSchema)]
/// struct ExampleTool {
/// field1: String,
@@ -159,6 +241,7 @@ impl Parse for McpToolMacroAttributes {
/// let tool : rust_mcp_schema::Tool = ExampleTool::tool();
/// assert_eq!(tool.name , "example_tool");
/// assert_eq!(tool.description.unwrap() , "An example tool");
+/// assert_eq!(tool.annotations.unwrap().idempotent_hint.unwrap() , true);
///
/// let schema_properties = tool.input_schema.properties.unwrap();
/// assert_eq!(schema_properties.len() , 2);
@@ -176,6 +259,62 @@ pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream {
let tool_name = macro_attributes.name.unwrap_or_default();
let tool_description = macro_attributes.description.unwrap_or_default();
+ #[cfg(feature = "2025_03_26")]
+ let some_annotations = macro_attributes.destructive_hint.is_some()
+ || macro_attributes.idempotent_hint.is_some()
+ || macro_attributes.open_world_hint.is_some()
+ || macro_attributes.read_only_hint.is_some()
+ || macro_attributes.title.is_some();
+
+ #[cfg(feature = "2025_03_26")]
+ let annotations = if some_annotations {
+ let destructive_hint = macro_attributes
+ .destructive_hint
+ .map_or(quote! {None}, |v| quote! {Some(#v)});
+
+ let idempotent_hint = macro_attributes
+ .idempotent_hint
+ .map_or(quote! {None}, |v| quote! {Some(#v)});
+ let open_world_hint = macro_attributes
+ .open_world_hint
+ .map_or(quote! {None}, |v| quote! {Some(#v)});
+ let read_only_hint = macro_attributes
+ .read_only_hint
+ .map_or(quote! {None}, |v| quote! {Some(#v)});
+ let title = macro_attributes
+ .title
+ .map_or(quote! {None}, |v| quote! {Some(#v)});
+ quote! {
+ Some(rust_mcp_schema::ToolAnnotations {
+ destructive_hint: #destructive_hint,
+ idempotent_hint: #idempotent_hint,
+ open_world_hint: #open_world_hint,
+ read_only_hint: #read_only_hint,
+ title: #title,
+ }),
+ }
+ } else {
+ quote! {None}
+ };
+
+ #[cfg(feature = "2025_03_26")]
+ let tool_token = quote! {
+ rust_mcp_schema::Tool {
+ name: #tool_name.to_string(),
+ description: Some(#tool_description.to_string()),
+ input_schema: rust_mcp_schema::ToolInputSchema::new(required, properties),
+ annotations: #annotations
+ }
+ };
+ #[cfg(feature = "2024_11_05")]
+ let tool_token = quote! {
+ rust_mcp_schema::Tool {
+ name: #tool_name.to_string(),
+ description: Some(#tool_description.to_string()),
+ input_schema: rust_mcp_schema::ToolInputSchema::new(required, properties),
+ }
+ };
+
let output = quote! {
impl #input_ident {
/// Returns the name of the tool as a string.
@@ -222,54 +361,7 @@ pub fn mcp_tool(attributes: TokenStream, input: TokenStream) -> TokenStream {
.collect()
});
- rust_mcp_schema::Tool {
- name: #tool_name.to_string(),
- description: Some(#tool_description.to_string()),
- input_schema: rust_mcp_schema::ToolInputSchema::new(required, properties),
- }
- }
-
- #[deprecated(since = "0.2.0", note = "Use `tool()` instead.")]
- pub fn get_tool()-> rust_mcp_schema::Tool
- {
- let json_schema = input_ident::json_schema();
-
- let required: Vec<_> = match json_schema.get("required").and_then(|r| r.as_array()) {
- Some(arr) => arr
- .iter()
- .filter_map(|item| item.as_str().map(String::from))
- .collect(),
- None => Vec::new(), // Default to an empty vector if "required" is missing or not an array
- };
-
- let properties: Option<
- std::collections::HashMap>,
- > = json_schema
- .get("properties")
- .and_then(|v| v.as_object()) // Safely extract "properties" as an object.
- .map(|properties| {
- properties
- .iter()
- .filter_map(|(key, value)| {
- serde_json::to_value(value)
- .ok() // If serialization fails, return None.
- .and_then(|v| {
- if let serde_json::Value::Object(obj) = v {
- Some(obj)
- } else {
- None
- }
- })
- .map(|obj| (key.to_string(), obj)) // Return the (key, value) tuple
- })
- .collect()
- });
-
- rust_mcp_schema::Tool {
- name: #tool_name.to_string(),
- description: Some(#tool_description.to_string()),
- input_schema: rust_mcp_schema::ToolInputSchema::new(required, properties),
- }
+ #tool_token
}
}
// Retain the original item (struct definition)
diff --git a/crates/rust-mcp-macros/tests/macro_test.rs b/crates/rust-mcp-macros/tests/macro_test.rs
index 3a23c87..4f24c9e 100644
--- a/crates/rust-mcp-macros/tests/macro_test.rs
+++ b/crates/rust-mcp-macros/tests/macro_test.rs
@@ -31,3 +31,58 @@ fn test_rename() {
let properties = schema.get("properties").unwrap().as_object().unwrap();
assert_eq!(properties.len(), 2);
}
+
+#[test]
+#[cfg(feature = "2025_03_26")]
+fn test_mcp_tool() {
+ #[rust_mcp_macros::mcp_tool(
+ name = "example_tool",
+ description = "An example tool",
+ idempotent_hint = true,
+ destructive_hint = true,
+ open_world_hint = true,
+ read_only_hint = true
+ )]
+ #[derive(rust_mcp_macros::JsonSchema)]
+ #[allow(unused)]
+ struct ExampleTool {
+ field1: String,
+ field2: i32,
+ }
+
+ assert_eq!(ExampleTool::tool_name(), "example_tool");
+ let tool: rust_mcp_schema::Tool = ExampleTool::tool();
+ assert_eq!(tool.name, "example_tool");
+ assert_eq!(tool.description.unwrap(), "An example tool");
+ assert!(tool.annotations.as_ref().unwrap().idempotent_hint.unwrap(),);
+ assert!(tool.annotations.as_ref().unwrap().destructive_hint.unwrap(),);
+ assert!(tool.annotations.as_ref().unwrap().open_world_hint.unwrap(),);
+ assert!(tool.annotations.as_ref().unwrap().read_only_hint.unwrap(),);
+
+ let schema_properties = tool.input_schema.properties.unwrap();
+ assert_eq!(schema_properties.len(), 2);
+ assert!(schema_properties.contains_key("field1"));
+ assert!(schema_properties.contains_key("field2"));
+}
+
+#[test]
+#[cfg(feature = "2024_11_05")]
+fn test_mcp_tool() {
+ #[rust_mcp_macros::mcp_tool(name = "example_tool", description = "An example tool")]
+ #[derive(rust_mcp_macros::JsonSchema)]
+ #[allow(unused)]
+ struct ExampleTool {
+ field1: String,
+ field2: i32,
+ }
+
+ assert_eq!(ExampleTool::tool_name(), "example_tool");
+ let tool: rust_mcp_schema::Tool = ExampleTool::tool();
+ assert_eq!(tool.name, "example_tool");
+ assert_eq!(tool.description.unwrap(), "An example tool");
+
+ let schema_properties = tool.input_schema.properties.unwrap();
+ assert_eq!(schema_properties.len(), 2);
+ assert!(schema_properties.contains_key("field1"));
+ assert!(schema_properties.contains_key("field2"));
+}
diff --git a/crates/rust-mcp-sdk/Cargo.toml b/crates/rust-mcp-sdk/Cargo.toml
index ab88674..b4d7663 100644
--- a/crates/rust-mcp-sdk/Cargo.toml
+++ b/crates/rust-mcp-sdk/Cargo.toml
@@ -32,6 +32,7 @@ tracing.workspace = true
hyper = { version = "1.6.0" }
[dev-dependencies]
+reqwest = { workspace = true, features = ["stream"] }
tracing-subscriber = { workspace = true, features = [
"env-filter",
"std",
diff --git a/crates/rust-mcp-sdk/src/hyper_servers.rs b/crates/rust-mcp-sdk/src/hyper_servers.rs
index ad1e2cd..9a58b04 100644
--- a/crates/rust-mcp-sdk/src/hyper_servers.rs
+++ b/crates/rust-mcp-sdk/src/hyper_servers.rs
@@ -2,6 +2,7 @@ mod app_state;
pub mod error;
pub mod hyper_server;
pub mod hyper_server_core;
+mod middlewares;
mod routes;
mod server;
mod session_store;
diff --git a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs
index 3276802..af572dd 100644
--- a/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs
+++ b/crates/rust-mcp-sdk/src/hyper_servers/app_state.rs
@@ -18,5 +18,6 @@ pub struct AppState {
pub server_details: Arc,
pub handler: Arc,
pub ping_interval: Duration,
+ pub sse_message_endpoint: String,
pub transport_options: Arc,
}
diff --git a/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs b/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs
new file mode 100644
index 0000000..612510e
--- /dev/null
+++ b/crates/rust-mcp-sdk/src/hyper_servers/middlewares.rs
@@ -0,0 +1 @@
+pub(crate) mod session_id_gen;
diff --git a/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs
new file mode 100644
index 0000000..b68b325
--- /dev/null
+++ b/crates/rust-mcp-sdk/src/hyper_servers/middlewares/session_id_gen.rs
@@ -0,0 +1,23 @@
+use std::sync::Arc;
+
+use axum::{
+ extract::{Request, State},
+ middleware::Next,
+ response::Response,
+};
+use hyper::StatusCode;
+use rust_mcp_transport::SessionId;
+
+use crate::hyper_servers::app_state::AppState;
+
+// Middleware to generate and attach a session ID
+pub async fn generate_session_id(
+ State(state): State>,
+ mut request: Request,
+ next: Next,
+) -> Result {
+ let session_id: SessionId = state.id_generator.generate();
+ request.extensions_mut().insert(session_id);
+ // Proceed to the next middleware or handler
+ Ok(next.run(request).await)
+}
diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs
index 10e2eb9..55d15b1 100644
--- a/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs
+++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/messages_routes.rs
@@ -1,6 +1,9 @@
-use crate::hyper_servers::{
- app_state::AppState,
- error::{TransportServerError, TransportServerResult},
+use crate::{
+ hyper_servers::{
+ app_state::AppState,
+ error::{TransportServerError, TransportServerResult},
+ },
+ utils::remove_query_and_hash,
};
use axum::{
extract::{Query, State},
@@ -11,10 +14,11 @@ use axum::{
use std::{collections::HashMap, sync::Arc};
use tokio::io::AsyncWriteExt;
-const SSE_MESSAGES_PATH: &str = "/messages";
-
-pub fn routes(_state: Arc) -> Router> {
- Router::new().route(SSE_MESSAGES_PATH, post(handle_messages))
+pub fn routes(state: Arc) -> Router> {
+ Router::new().route(
+ remove_query_and_hash(&state.sse_message_endpoint).as_str(),
+ post(handle_messages),
+ )
}
pub async fn handle_messages(
diff --git a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs
index 2efe3be..b6e98f0 100644
--- a/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs
+++ b/crates/rust-mcp-sdk/src/hyper_servers/routes/sse_routes.rs
@@ -1,21 +1,25 @@
use crate::{
error::McpSdkError,
- hyper_servers::{app_state::AppState, error::TransportServerResult},
+ hyper_servers::{
+ app_state::AppState, error::TransportServerResult,
+ middlewares::session_id_gen::generate_session_id,
+ },
mcp_server::{server_runtime, ServerRuntime},
mcp_traits::mcp_handler::McpServerHandler,
McpServer,
};
use axum::{
extract::State,
+ middleware,
response::{
sse::{Event, KeepAlive},
IntoResponse, Sse,
},
routing::get,
- Router,
+ Extension, Router,
};
use futures::stream::{self};
-use rust_mcp_transport::{error::TransportError, SseTransport};
+use rust_mcp_transport::{error::TransportError, SessionId, SseTransport};
use std::{convert::Infallible, sync::Arc, time::Duration};
use tokio::{
io::{duplex, AsyncBufReadExt, BufReader},
@@ -23,7 +27,6 @@ use tokio::{
};
use tokio_stream::StreamExt;
-const SSE_MESSAGES_PATH: &str = "/messages";
const CLIENT_PING_TIMEOUT: Duration = Duration::from_secs(2);
const DUPLEX_BUFFER_SIZE: usize = 8192;
@@ -37,10 +40,8 @@ const DUPLEX_BUFFER_SIZE: usize = 8192;
///
/// # Returns
/// * `Result` - The constructed SSE event, infallible
-fn initial_event(session_id: &str) -> Result {
- Ok(Event::default()
- .event("endpoint")
- .data(format!("{SSE_MESSAGES_PATH}?sessionId={session_id}")))
+fn initial_event(endpoint: &str) -> Result {
+ Ok(Event::default().event("endpoint").data(endpoint))
}
/// Configures the SSE routes for the application
@@ -53,8 +54,13 @@ fn initial_event(session_id: &str) -> Result {
///
/// # Returns
/// * `Router>` - An Axum router configured with the SSE route
-pub fn routes(_state: Arc, sse_endpoint: &str) -> Router> {
- Router::new().route(sse_endpoint, get(handle_sse))
+pub fn routes(state: Arc, sse_endpoint: &str) -> Router> {
+ Router::new()
+ .route(sse_endpoint, get(handle_sse))
+ .route_layer(middleware::from_fn_with_state(
+ state.clone(),
+ generate_session_id,
+ ))
}
/// Handles Server-Sent Events (SSE) connections
@@ -68,15 +74,17 @@ pub fn routes(_state: Arc, sse_endpoint: &str) -> Router
/// # Returns
/// * `TransportServerResult` - The SSE response stream or an error
pub async fn handle_sse(
+ Extension(session_id): Extension,
State(state): State>,
) -> TransportServerResult {
+ let messages_endpoint =
+ SseTransport::message_endpoint(&state.sse_message_endpoint, &session_id);
+
// readable stream of string to be used in transport
let (read_tx, read_rx) = duplex(DUPLEX_BUFFER_SIZE);
// writable stream to deliver message to the client
let (write_tx, write_rx) = duplex(DUPLEX_BUFFER_SIZE);
- // generate a session id, and keep it in the server state
- let session_id = state.id_generator.generate();
state
.session_store
.set(session_id.to_owned(), read_tx)
@@ -140,7 +148,7 @@ pub async fn handle_sse(
});
// Initial SSE message to inform the client about the server's endpoint
- let initial_event = stream::once(async move { initial_event(&session_id) });
+ let initial_event = stream::once(async move { initial_event(&messages_endpoint) });
// Construct SSE stream for sending MCP messages to the server
let reader = BufReader::new(write_rx);
diff --git a/crates/rust-mcp-sdk/src/hyper_servers/server.rs b/crates/rust-mcp-sdk/src/hyper_servers/server.rs
index f0770e1..94a867a 100644
--- a/crates/rust-mcp-sdk/src/hyper_servers/server.rs
+++ b/crates/rust-mcp-sdk/src/hyper_servers/server.rs
@@ -1,18 +1,20 @@
use crate::mcp_traits::mcp_handler::McpServerHandler;
#[cfg(feature = "ssl")]
use axum_server::tls_rustls::RustlsConfig;
+use axum_server::Handle;
use std::{
net::{SocketAddr, ToSocketAddrs},
path::Path,
sync::Arc,
time::Duration,
};
+use tokio::signal;
use super::{
app_state::AppState,
error::{TransportServerError, TransportServerResult},
routes::app_routes,
- InMemorySessionStore, UuidGenerator,
+ IdGenerator, InMemorySessionStore, UuidGenerator,
};
use axum::Router;
use rust_mcp_schema::InitializeResult;
@@ -20,20 +22,23 @@ use rust_mcp_transport::TransportOptions;
// Default client ping interval (12 seconds)
const DEFAULT_CLIENT_PING_INTERVAL: Duration = Duration::from_secs(12);
-
+const GRACEFUL_SHUTDOWN_TMEOUT_SECS: u64 = 30;
// Default Server-Sent Events (SSE) endpoint path
const DEFAULT_SSE_ENDPOINT: &str = "/sse";
+// Default MCP Messages endpoint path
+const DEFAULT_MESSAGES_ENDPOINT: &str = "/messages";
/// Configuration struct for the Hyper server
/// Used to configure the HyperServer instance.
pub struct HyperServerOptions {
/// Hostname or IP address the server will bind to (default: "localhost")
pub host: String,
- /// Hostname or IP address the server will bind to (default: "localhost")
+ /// Hostname or IP address the server will bind to (default: "8080")
pub port: u16,
- /// Optional custom path for the Server-Sent Events (SSE) endpoint.
- /// If `None`, the default path `/sse` will be used.
+ /// Optional custom path for the Server-Sent Events (SSE) endpoint (default: `/sse`)
pub custom_sse_endpoint: Option,
+ /// Optional custom path for the MCP messages endpoint (default: `/messages`)
+ pub custom_messages_endpoint: Option,
/// Interval between automatic ping messages sent to clients to detect disconnects
pub ping_interval: Duration,
/// Enables SSL/TLS if set to `true`
@@ -46,6 +51,8 @@ pub struct HyperServerOptions {
pub ssl_key_path: Option,
/// Shared transport configuration used by the server
pub transport_options: Arc,
+ /// Optional thread-safe session id generator to generate unique session IDs.
+ pub session_id_generator: Option>,
}
impl HyperServerOptions {
@@ -121,6 +128,12 @@ impl HyperServerOptions {
.as_deref()
.unwrap_or(DEFAULT_SSE_ENDPOINT)
}
+
+ pub fn sse_messages_endpoint(&self) -> &str {
+ self.custom_messages_endpoint
+ .as_deref()
+ .unwrap_or(DEFAULT_MESSAGES_ENDPOINT)
+ }
}
/// Default implementation for HyperServerOptions
@@ -133,11 +146,13 @@ impl Default for HyperServerOptions {
host: "127.0.0.1".to_string(),
port: 8080,
custom_sse_endpoint: None,
+ custom_messages_endpoint: None,
ping_interval: DEFAULT_CLIENT_PING_INTERVAL,
transport_options: Default::default(),
enable_ssl: false,
ssl_cert_path: None,
ssl_key_path: None,
+ session_id_generator: None,
}
}
}
@@ -147,6 +162,7 @@ pub struct HyperServer {
app: Router,
state: Arc,
options: HyperServerOptions,
+ handle: Handle,
}
impl HyperServer {
@@ -164,14 +180,18 @@ impl HyperServer {
pub(crate) fn new(
server_details: InitializeResult,
handler: Arc,
- server_options: HyperServerOptions,
+ mut server_options: HyperServerOptions,
) -> Self {
let state: Arc = Arc::new(AppState {
session_store: Arc::new(InMemorySessionStore::new()),
- id_generator: Arc::new(UuidGenerator {}),
+ id_generator: server_options
+ .session_id_generator
+ .take()
+ .map_or(Arc::new(UuidGenerator {}), |g| Arc::clone(&g)),
server_details: Arc::new(server_details),
handler,
ping_interval: server_options.ping_interval,
+ sse_message_endpoint: server_options.sse_messages_endpoint().to_owned(),
transport_options: Arc::clone(&server_options.transport_options),
});
let app = app_routes(Arc::clone(&state), &server_options);
@@ -179,6 +199,7 @@ impl HyperServer {
app,
state,
options: server_options,
+ handle: Handle::new(),
}
}
@@ -263,12 +284,25 @@ impl HyperServer {
tracing::info!("{}", self.server_info(Some(addr)).await?);
+ // Spawn a task to trigger shutdown on signal
+ let handle_clone = self.handle.clone();
+ tokio::spawn(async move {
+ shutdown_signal(handle_clone).await;
+ });
+
+ let handle_clone = self.handle.clone();
axum_server::bind_rustls(addr, config)
+ .handle(handle_clone)
.serve(self.app.into_make_service())
.await
.map_err(|err| TransportServerError::ServerStartError(err.to_string()))
}
+ /// Returns server handle that could be used for graceful shutdown
+ pub fn server_handle(&self) -> Handle {
+ self.handle.clone()
+ }
+
/// Starts the server without SSL
///
/// # Arguments
@@ -279,7 +313,15 @@ impl HyperServer {
async fn start_http(self, addr: SocketAddr) -> TransportServerResult<()> {
tracing::info!("{}", self.server_info(Some(addr)).await?);
+ // Spawn a task to trigger shutdown on signal
+ let handle_clone = self.handle.clone();
+ tokio::spawn(async move {
+ shutdown_signal(handle_clone).await;
+ });
+
+ let handle_clone = self.handle.clone();
axum_server::bind(addr)
+ .handle(handle_clone)
.serve(self.app.into_make_service())
.await
.map_err(|err| TransportServerError::ServerStartError(err.to_string()))
@@ -310,3 +352,33 @@ impl HyperServer {
}
}
}
+
+// Shutdown signal handler
+async fn shutdown_signal(handle: Handle) {
+ // Wait for a Ctrl+C or SIGTERM signal
+ let ctrl_c = async {
+ signal::ctrl_c()
+ .await
+ .expect("Failed to install Ctrl+C handler");
+ };
+
+ #[cfg(unix)]
+ let terminate = async {
+ signal::unix::signal(signal::unix::SignalKind::terminate())
+ .expect("Failed to install signal handler")
+ .recv()
+ .await;
+ };
+
+ #[cfg(not(unix))]
+ let terminate = std::future::pending::<()>();
+
+ tokio::select! {
+ _ = ctrl_c => {},
+ _ = terminate => {},
+ }
+
+ tracing::info!("Signal received, starting graceful shutdown");
+ // Trigger graceful shutdown with a timeout
+ handle.graceful_shutdown(Some(Duration::from_secs(GRACEFUL_SHUTDOWN_TMEOUT_SECS)));
+}
diff --git a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs b/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs
index da25000..b0716b8 100644
--- a/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs
+++ b/crates/rust-mcp-sdk/src/hyper_servers/session_store.rs
@@ -3,15 +3,13 @@ use std::sync::Arc;
use async_trait::async_trait;
pub use in_memory::*;
+use rust_mcp_transport::SessionId;
use tokio::{io::DuplexStream, sync::Mutex};
use uuid::Uuid;
// Type alias for the server-side duplex stream used in sessions
pub type TxServer = DuplexStream;
-// Type alias for session identifier, represented as a String
-pub type SessionId = String;
-
/// Trait defining the interface for session storage operations
///
/// This trait provides asynchronous methods for managing session data,
@@ -39,6 +37,10 @@ pub trait SessionStore: Send + Sync {
async fn delete(&self, key: &SessionId);
/// Clears all sessions from the store
async fn clear(&self);
+
+ async fn keys(&self) -> Vec;
+
+ async fn values(&self) -> Vec>>;
}
/// Trait for generating session identifiers
diff --git a/crates/rust-mcp-sdk/src/hyper_servers/session_store/in_memory.rs b/crates/rust-mcp-sdk/src/hyper_servers/session_store/in_memory.rs
index 7c5755d..342d232 100644
--- a/crates/rust-mcp-sdk/src/hyper_servers/session_store/in_memory.rs
+++ b/crates/rust-mcp-sdk/src/hyper_servers/session_store/in_memory.rs
@@ -3,6 +3,7 @@ use super::{SessionStore, TxServer};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
+use tokio::io::DuplexStream;
use tokio::sync::Mutex;
use tokio::sync::RwLock;
@@ -54,4 +55,12 @@ impl SessionStore for InMemorySessionStore {
let mut store = self.store.write().await;
store.clear();
}
+ async fn keys(&self) -> Vec {
+ let store = self.store.read().await;
+ store.keys().cloned().collect::>()
+ }
+ async fn values(&self) -> Vec>> {
+ let store = self.store.read().await;
+ store.values().cloned().collect::>()
+ }
}
diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs
index d325305..5f22a43 100644
--- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs
+++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs
@@ -12,10 +12,10 @@ use std::sync::{Arc, RwLock};
use tokio::io::AsyncWriteExt;
use crate::error::SdkResult;
-#[cfg(feature = "hyper-server")]
-use crate::hyper_servers::SessionId;
use crate::mcp_traits::mcp_handler::McpServerHandler;
use crate::mcp_traits::mcp_server::McpServer;
+#[cfg(feature = "hyper-server")]
+use rust_mcp_transport::SessionId;
/// Struct representing the runtime core of the MCP server, handling transport and client details
pub struct ServerRuntime {
diff --git a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs
index dd9e98f..51eba77 100644
--- a/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs
+++ b/crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime.rs
@@ -12,7 +12,7 @@ use rust_mcp_transport::Transport;
use super::ServerRuntime;
#[cfg(feature = "hyper-server")]
-use crate::hyper_servers::SessionId;
+use rust_mcp_transport::SessionId;
use crate::{
error::SdkResult,
diff --git a/crates/rust-mcp-sdk/src/utils.rs b/crates/rust-mcp-sdk/src/utils.rs
index 85ad72e..13dc579 100644
--- a/crates/rust-mcp-sdk/src/utils.rs
+++ b/crates/rust-mcp-sdk/src/utils.rs
@@ -22,3 +22,48 @@ pub fn format_assertion_message(entity: &str, capability: &str, method_name: &st
entity, capability, method_name
)
}
+
+/// Removes query string and hash fragment from a URL, returning the base path.
+///
+/// # Arguments
+/// * `endpoint` - The URL or endpoint to process (e.g., "/messages?foo=bar#section1")
+///
+/// # Returns
+/// A String containing the base path without query parameters or fragment
+/// ```
+#[allow(unused)]
+pub(crate) fn remove_query_and_hash(endpoint: &str) -> String {
+ // Split off fragment (if any) and take the first part
+ let without_fragment = endpoint.split_once('#').map_or(endpoint, |(path, _)| path);
+
+ // Split off query string (if any) and take the first part
+ let without_query = without_fragment
+ .split_once('?')
+ .map_or(without_fragment, |(path, _)| path);
+
+ // Return the base path
+ if without_query.is_empty() {
+ "/".to_string()
+ } else {
+ without_query.to_string()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ #[test]
+ fn tets_remove_query_and_hash() {
+ assert_eq!(remove_query_and_hash("/messages"), "/messages");
+ assert_eq!(
+ remove_query_and_hash("/messages?foo=bar&baz=qux"),
+ "/messages"
+ );
+ assert_eq!(remove_query_and_hash("/messages#section1"), "/messages");
+ assert_eq!(
+ remove_query_and_hash("/messages?key=value#section2"),
+ "/messages"
+ );
+ assert_eq!(remove_query_and_hash("/"), "/");
+ }
+}
diff --git a/crates/rust-mcp-sdk/tests/common/common.rs b/crates/rust-mcp-sdk/tests/common/common.rs
index d896a56..6746270 100644
--- a/crates/rust-mcp-sdk/tests/common/common.rs
+++ b/crates/rust-mcp-sdk/tests/common/common.rs
@@ -1,8 +1,10 @@
+mod test_server;
use async_trait::async_trait;
use rust_mcp_schema::{
ClientCapabilities, Implementation, InitializeRequestParams, JSONRPC_VERSION,
};
use rust_mcp_sdk::mcp_client::ClientHandler;
+pub use test_server::*;
pub const NPX_SERVER_EVERYTHING: &str = "@modelcontextprotocol/server-everything";
@@ -24,3 +26,11 @@ pub struct TestClientHandler;
#[async_trait]
impl ClientHandler for TestClientHandler {}
+
+pub fn sse_event(sse_raw: &str) -> String {
+ sse_raw.replace("event: ", "")
+}
+
+pub fn sse_data(sse_raw: &str) -> String {
+ sse_raw.replace("data: ", "")
+}
diff --git a/crates/rust-mcp-sdk/tests/common/test_server.rs b/crates/rust-mcp-sdk/tests/common/test_server.rs
new file mode 100644
index 0000000..dcb4e1b
--- /dev/null
+++ b/crates/rust-mcp-sdk/tests/common/test_server.rs
@@ -0,0 +1,118 @@
+use async_trait::async_trait;
+use tokio_stream::StreamExt;
+
+use rust_mcp_schema::{
+ Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools,
+ LATEST_PROTOCOL_VERSION,
+};
+use rust_mcp_sdk::{
+ mcp_server::{hyper_server, HyperServer, HyperServerOptions, IdGenerator, ServerHandler},
+ McpServer, SessionId,
+};
+use std::sync::RwLock;
+use std::time::Duration;
+use tokio::time::timeout;
+
+pub const INITIALIZE_REQUEST: &str = r#"{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2.0","capabilities":{"sampling":{},"roots":{"listChanged":true}},"clientInfo":{"name":"reqwest-test","version":"0.1.0"}}}"#;
+pub const PING_REQUEST: &str = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#;
+
+pub fn test_server_details() -> InitializeResult {
+ InitializeResult {
+ // server name and version
+ server_info: Implementation {
+ name: "Test MCP Server".to_string(),
+ version: "0.1.0".to_string(),
+ },
+ capabilities: ServerCapabilities {
+ // indicates that server support mcp tools
+ tools: Some(ServerCapabilitiesTools { list_changed: None }),
+ ..Default::default() // Using default values for other fields
+ },
+ meta: None,
+ instructions: Some("server instructions...".to_string()),
+ protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
+ }
+}
+
+pub struct TestServerHandler;
+
+#[async_trait]
+impl ServerHandler for TestServerHandler {
+ async fn on_server_started(&self, runtime: &dyn McpServer) {
+ let _ = runtime
+ .stderr_message("Server started successfully".into())
+ .await;
+ }
+}
+
+pub fn create_test_server(options: HyperServerOptions) -> HyperServer {
+ hyper_server::create_server(test_server_details(), TestServerHandler {}, options)
+}
+
+// Tests the session ID generator, ensuring it returns a sequence of predefined session IDs.
+pub struct TestIdGenerator {
+ constant_ids: Vec,
+ generated: RwLock,
+}
+
+impl TestIdGenerator {
+ pub fn new(constant_ids: Vec) -> Self {
+ TestIdGenerator {
+ constant_ids,
+ generated: RwLock::new(0),
+ }
+ }
+}
+
+impl IdGenerator for TestIdGenerator {
+ fn generate(&self) -> SessionId {
+ let mut lock = self.generated.write().unwrap();
+ *lock += 1;
+ if *lock > self.constant_ids.len() {
+ *lock = 1;
+ }
+ self.constant_ids[*lock - 1].to_owned()
+ }
+}
+
+pub async fn collect_sse_lines(
+ response: reqwest::Response,
+ line_count: usize,
+ read_timeout: Duration,
+) -> Result, Box> {
+ let mut collected_lines = Vec::new();
+ let mut stream = response.bytes_stream();
+
+ let result = timeout(read_timeout, async {
+ while let Some(chunk) = stream.next().await {
+ let chunk = chunk.map_err(|e| Box::new(e) as Box)?;
+ let chunk_str = String::from_utf8_lossy(&chunk);
+
+ // Split the chunk into lines
+ let lines: Vec<&str> = chunk_str.lines().collect();
+
+ // Add each line to the collected_lines vector
+ for line in lines {
+ collected_lines.push(line.to_string());
+
+ // Check if we have collected 5 lines
+ if collected_lines.len() >= line_count {
+ return Ok(collected_lines);
+ }
+ }
+ }
+ // If the stream ends before collecting 5 lines, return what we have
+ Ok(collected_lines)
+ })
+ .await;
+
+ // Handle timeout or stream result
+ match result {
+ Ok(Ok(lines)) => Ok(lines),
+ Ok(Err(e)) => Err(e),
+ Err(_) => Err(Box::new(std::io::Error::new(
+ std::io::ErrorKind::TimedOut,
+ "Timed out waiting for 5 lines",
+ ))),
+ }
+}
diff --git a/crates/rust-mcp-sdk/tests/test_client_runtime.rs b/crates/rust-mcp-sdk/tests/test_client_runtime.rs
index c7804e5..c8b3b17 100644
--- a/crates/rust-mcp-sdk/tests/test_client_runtime.rs
+++ b/crates/rust-mcp-sdk/tests/test_client_runtime.rs
@@ -1,8 +1,7 @@
-use common::{test_client_info, TestClientHandler, NPX_SERVER_EVERYTHING};
-use rust_mcp_sdk::{mcp_client::client_runtime, McpClient, StdioTransport, TransportOptions};
-
#[cfg(unix)]
use common::UVX_SERVER_GIT;
+use common::{test_client_info, TestClientHandler, NPX_SERVER_EVERYTHING};
+use rust_mcp_sdk::{mcp_client::client_runtime, McpClient, StdioTransport, TransportOptions};
#[path = "common/common.rs"]
pub mod common;
diff --git a/crates/rust-mcp-sdk/tests/test_server_sse.rs b/crates/rust-mcp-sdk/tests/test_server_sse.rs
new file mode 100644
index 0000000..ba7df51
--- /dev/null
+++ b/crates/rust-mcp-sdk/tests/test_server_sse.rs
@@ -0,0 +1,211 @@
+use std::{sync::Arc, time::Duration};
+
+use common::{
+ collect_sse_lines, create_test_server, sse_data, sse_event, TestIdGenerator, INITIALIZE_REQUEST,
+};
+use reqwest::Client;
+use rust_mcp_schema::{
+ schema_utils::{ResultFromServer, ServerMessage},
+ ServerResult,
+};
+use rust_mcp_sdk::mcp_server::HyperServerOptions;
+use tokio::time::sleep;
+
+#[path = "common/common.rs"]
+pub mod common;
+
+#[tokio::test]
+async fn tets_sse_endpoint_event_default() {
+ let server_options = HyperServerOptions {
+ port: 8081,
+ session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![
+ "AAA-BBB-CCC".to_string()
+ ]))),
+ ..Default::default()
+ };
+
+ let base_url = format!("http://{}:{}", server_options.host, server_options.port);
+
+ let server_endpoint = format!("{}{}", base_url, server_options.sse_endpoint());
+
+ let server = create_test_server(server_options);
+ let handle = server.server_handle();
+ let server_task = tokio::spawn(async move {
+ server.start().await.unwrap();
+ eprintln!("Server 1 is down");
+ });
+
+ sleep(Duration::from_millis(750)).await;
+
+ let client = Client::new();
+ println!("connecting to : {}", server_endpoint);
+ // Act: Connect to the SSE endpoint and read the event stream
+ let response = client
+ .get(server_endpoint)
+ .header("Accept", "text/event-stream")
+ .send()
+ .await
+ .expect("Failed to connect to SSE endpoint");
+
+ assert_eq!(
+ response.headers().get("content-type").map(|v| v.as_bytes()),
+ Some(b"text/event-stream" as &[u8]),
+ "Response content-type should be text/event-stream"
+ );
+
+ let lines = collect_sse_lines(response, 2, Duration::from_secs(5))
+ .await
+ .unwrap();
+
+ assert_eq!(sse_event(&lines[0]), "endpoint");
+ assert_eq!(sse_data(&lines[1]), "/messages?sessionId=AAA-BBB-CCC");
+
+ let message_endpoint = format!("{}{}", base_url, sse_data(&lines[1]));
+ let res = client
+ .post(message_endpoint)
+ .header("Content-Type", "application/json")
+ .body(INITIALIZE_REQUEST.to_string())
+ .send()
+ .await
+ .unwrap();
+ assert!(res.status().is_success());
+ handle.graceful_shutdown(Some(Duration::from_millis(1)));
+ server_task.await.unwrap();
+}
+
+#[tokio::test]
+async fn tets_sse_message_endpoint_query_hash() {
+ let server_options = HyperServerOptions {
+ port: 8082,
+ custom_messages_endpoint: Some(
+ "/custom-msg-endpoint?something=true&otherthing=false#section-59".to_string(),
+ ),
+ session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![
+ "AAA-BBB-CCC".to_string()
+ ]))),
+ ..Default::default()
+ };
+
+ let base_url = format!("http://{}:{}", server_options.host, server_options.port);
+
+ let server_endpoint = format!("{}{}", base_url, server_options.sse_endpoint());
+
+ let server = create_test_server(server_options);
+ let handle = server.server_handle();
+
+ let server_task = tokio::spawn(async move {
+ server.start().await.unwrap();
+ eprintln!("Server 2 is down");
+ });
+
+ sleep(Duration::from_millis(750)).await;
+
+ let client = Client::new();
+ println!("connecting to : {}", server_endpoint);
+ // Act: Connect to the SSE endpoint and read the event stream
+ let response = client
+ .get(server_endpoint)
+ .header("Accept", "text/event-stream")
+ .send()
+ .await
+ .expect("Failed to connect to SSE endpoint");
+
+ assert_eq!(
+ response.headers().get("content-type").map(|v| v.as_bytes()),
+ Some(b"text/event-stream" as &[u8]),
+ "Response content-type should be text/event-stream"
+ );
+
+ let lines = collect_sse_lines(response, 2, Duration::from_secs(5))
+ .await
+ .unwrap();
+
+ assert_eq!(sse_event(&lines[0]), "endpoint");
+ assert_eq!(
+ sse_data(&lines[1]),
+ "/custom-msg-endpoint?something=true&otherthing=false&sessionId=AAA-BBB-CCC#section-59"
+ );
+
+ let message_endpoint = format!("{}{}", base_url, sse_data(&lines[1]));
+ let res = client
+ .post(message_endpoint)
+ .header("Content-Type", "application/json")
+ .body(INITIALIZE_REQUEST.to_string())
+ .send()
+ .await
+ .unwrap();
+ assert!(res.status().is_success());
+ handle.graceful_shutdown(Some(Duration::from_millis(1)));
+ server_task.await.unwrap();
+}
+
+#[tokio::test]
+async fn tets_sse_custom_message_endpoint() {
+ let server_options = HyperServerOptions {
+ port: 8083,
+ custom_messages_endpoint: Some(
+ "/custom-msg-endpoint?something=true&otherthing=false#section-59".to_string(),
+ ),
+ session_id_generator: Some(Arc::new(TestIdGenerator::new(vec![
+ "AAA-BBB-CCC".to_string()
+ ]))),
+ ..Default::default()
+ };
+
+ let base_url = format!("http://{}:{}", server_options.host, server_options.port);
+
+ let server_endpoint = format!("{}{}", base_url, server_options.sse_endpoint());
+
+ let server = create_test_server(server_options);
+ let handle = server.server_handle();
+
+ let server_task = tokio::spawn(async move {
+ server.start().await.unwrap();
+ eprintln!("Server 3 is down");
+ });
+
+ sleep(Duration::from_millis(750)).await;
+
+ let client = Client::new();
+ println!("connecting to : {}", server_endpoint);
+ // Act: Connect to the SSE endpoint and read the event stream
+ let response = client
+ .get(server_endpoint)
+ .header("Accept", "text/event-stream")
+ .send()
+ .await
+ .expect("Failed to connect to SSE endpoint");
+
+ assert_eq!(
+ response.headers().get("content-type").map(|v| v.as_bytes()),
+ Some(b"text/event-stream" as &[u8]),
+ "Response content-type should be text/event-stream"
+ );
+
+ let message_endpoint = format!(
+ "{}{}",
+ base_url,
+ "/custom-msg-endpoint?something=true&otherthing=false&sessionId=AAA-BBB-CCC#section-59"
+ );
+ let res = client
+ .post(message_endpoint)
+ .header("Content-Type", "application/json")
+ .body(INITIALIZE_REQUEST.to_string())
+ .send()
+ .await
+ .unwrap();
+ assert!(res.status().is_success());
+
+ let lines = collect_sse_lines(response, 5, Duration::from_secs(5))
+ .await
+ .unwrap();
+
+ let init_response = sse_data(&lines[3]);
+ let result = serde_json::from_str::(&init_response).unwrap();
+
+ assert!(matches!(result, ServerMessage::Response(response)
+ if matches!(&response.result, ResultFromServer::ServerResult(server_result)
+ if matches!(server_result, ServerResult::InitializeResult(init_result) if init_result.server_info.name == "Test MCP Server"))));
+ handle.graceful_shutdown(Some(Duration::from_millis(1)));
+ server_task.await.unwrap();
+}
diff --git a/crates/rust-mcp-transport/src/lib.rs b/crates/rust-mcp-transport/src/lib.rs
index a2ec12b..31d810d 100644
--- a/crates/rust-mcp-transport/src/lib.rs
+++ b/crates/rust-mcp-transport/src/lib.rs
@@ -16,3 +16,6 @@ pub use message_dispatcher::*;
pub use sse::*;
pub use stdio::*;
pub use transport::*;
+
+// Type alias for session identifier, represented as a String
+pub type SessionId = String;
diff --git a/crates/rust-mcp-transport/src/sse.rs b/crates/rust-mcp-transport/src/sse.rs
index 4d8b100..554826d 100644
--- a/crates/rust-mcp-transport/src/sse.rs
+++ b/crates/rust-mcp-transport/src/sse.rs
@@ -12,8 +12,8 @@ use crate::error::{TransportError, TransportResult};
use crate::mcp_stream::MCPStream;
use crate::message_dispatcher::MessageDispatcher;
use crate::transport::Transport;
-use crate::utils::CancellationTokenSource;
-use crate::{IoStream, McpDispatch, TransportOptions};
+use crate::utils::{endpoint_with_session_id, CancellationTokenSource};
+use crate::{IoStream, McpDispatch, SessionId, TransportOptions};
pub struct SseTransport {
shutdown_source: tokio::sync::RwLock