Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ members = [

[workspace.dependencies]
# Workspace member crates
rust-mcp-transport = { version = "0.3.1", path = "crates/rust-mcp-transport" }
rust-mcp-transport = { version = "0.3.1", path = "crates/rust-mcp-transport", default-features = false }
rust-mcp-sdk = { path = "crates/rust-mcp-sdk", default-features = false }
rust-mcp-macros = { version = "0.3.0", path = "crates/rust-mcp-macros" }

Expand Down
14 changes: 10 additions & 4 deletions crates/rust-mcp-sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ edition = "2021"

[dependencies]
rust-mcp-schema = { workspace = true }
rust-mcp-transport = { workspace = true }
rust-mcp-transport = { workspace = true, default-features = false, optional = true }
rust-mcp-macros = { workspace = true, optional = true }

tokio.workspace = true
Expand Down Expand Up @@ -47,9 +47,15 @@ default = [
"hyper-server",
"ssl",
] # All features enabled by default
server = [] # Server feature
client = [] # Client feature
hyper-server = ["axum", "axum-server", "uuid", "tokio-stream"]
server = ["rust-mcp-transport/stdio"] # Server feature
client = ["rust-mcp-transport/stdio", "rust-mcp-transport/sse"] # Client feature
hyper-server = [
"axum",
"axum-server",
"uuid",
"tokio-stream",
"rust-mcp-transport/sse",
]
ssl = ["axum-server/tls-rustls"]
macros = ["rust-mcp-macros"]

Expand Down
189 changes: 96 additions & 93 deletions crates/rust-mcp-sdk/tests/common/test_server.rs
Original file line number Diff line number Diff line change
@@ -1,118 +1,121 @@
use async_trait::async_trait;
use tokio_stream::StreamExt;
#[cfg(feature = "hyper-server")]
pub mod test_server_common {
use async_trait::async_trait;
use tokio_stream::StreamExt;

use rust_mcp_schema::{
Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools,
LATEST_PROTOCOL_VERSION,
};
use rust_mcp_sdk::{
mcp_server::{hyper_server, HyperServer, HyperServerOptions, IdGenerator, ServerHandler},
McpServer, SessionId,
};
use std::sync::RwLock;
use std::time::Duration;
use tokio::time::timeout;
use rust_mcp_schema::{
Implementation, InitializeResult, ServerCapabilities, ServerCapabilitiesTools,
LATEST_PROTOCOL_VERSION,
};
use rust_mcp_sdk::{
mcp_server::{hyper_server, HyperServer, HyperServerOptions, IdGenerator, ServerHandler},
McpServer, SessionId,
};
use std::sync::RwLock;
use std::time::Duration;
use tokio::time::timeout;

pub const INITIALIZE_REQUEST: &str = r#"{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2.0","capabilities":{"sampling":{},"roots":{"listChanged":true}},"clientInfo":{"name":"reqwest-test","version":"0.1.0"}}}"#;
pub const PING_REQUEST: &str = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#;
pub const INITIALIZE_REQUEST: &str = r#"{"jsonrpc":"2.0","id":0,"method":"initialize","params":{"protocolVersion":"2.0","capabilities":{"sampling":{},"roots":{"listChanged":true}},"clientInfo":{"name":"reqwest-test","version":"0.1.0"}}}"#;
pub const PING_REQUEST: &str = r#"{"jsonrpc":"2.0","id":1,"method":"ping"}"#;

pub fn test_server_details() -> InitializeResult {
InitializeResult {
// server name and version
server_info: Implementation {
name: "Test MCP Server".to_string(),
version: "0.1.0".to_string(),
},
capabilities: ServerCapabilities {
// indicates that server support mcp tools
tools: Some(ServerCapabilitiesTools { list_changed: None }),
..Default::default() // Using default values for other fields
},
meta: None,
instructions: Some("server instructions...".to_string()),
protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
pub fn test_server_details() -> InitializeResult {
InitializeResult {
// server name and version
server_info: Implementation {
name: "Test MCP Server".to_string(),
version: "0.1.0".to_string(),
},
capabilities: ServerCapabilities {
// indicates that server support mcp tools
tools: Some(ServerCapabilitiesTools { list_changed: None }),
..Default::default() // Using default values for other fields
},
meta: None,
instructions: Some("server instructions...".to_string()),
protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
}
}
}

pub struct TestServerHandler;
pub struct TestServerHandler;

#[async_trait]
impl ServerHandler for TestServerHandler {
async fn on_server_started(&self, runtime: &dyn McpServer) {
let _ = runtime
.stderr_message("Server started successfully".into())
.await;
#[async_trait]
impl ServerHandler for TestServerHandler {
async fn on_server_started(&self, runtime: &dyn McpServer) {
let _ = runtime
.stderr_message("Server started successfully".into())
.await;
}
}
}

pub fn create_test_server(options: HyperServerOptions) -> HyperServer {
hyper_server::create_server(test_server_details(), TestServerHandler {}, options)
}
pub fn create_test_server(options: HyperServerOptions) -> HyperServer {
hyper_server::create_server(test_server_details(), TestServerHandler {}, options)
}

// Tests the session ID generator, ensuring it returns a sequence of predefined session IDs.
pub struct TestIdGenerator {
constant_ids: Vec<SessionId>,
generated: RwLock<usize>,
}
// Tests the session ID generator, ensuring it returns a sequence of predefined session IDs.
pub struct TestIdGenerator {
constant_ids: Vec<SessionId>,
generated: RwLock<usize>,
}

impl TestIdGenerator {
pub fn new(constant_ids: Vec<SessionId>) -> Self {
TestIdGenerator {
constant_ids,
generated: RwLock::new(0),
impl TestIdGenerator {
pub fn new(constant_ids: Vec<SessionId>) -> Self {
TestIdGenerator {
constant_ids,
generated: RwLock::new(0),
}
}
}
}

impl IdGenerator for TestIdGenerator {
fn generate(&self) -> SessionId {
let mut lock = self.generated.write().unwrap();
*lock += 1;
if *lock > self.constant_ids.len() {
*lock = 1;
impl IdGenerator for TestIdGenerator {
fn generate(&self) -> SessionId {
let mut lock = self.generated.write().unwrap();
*lock += 1;
if *lock > self.constant_ids.len() {
*lock = 1;
}
self.constant_ids[*lock - 1].to_owned()
}
self.constant_ids[*lock - 1].to_owned()
}
}

pub async fn collect_sse_lines(
response: reqwest::Response,
line_count: usize,
read_timeout: Duration,
) -> Result<Vec<String>, Box<dyn std::error::Error>> {
let mut collected_lines = Vec::new();
let mut stream = response.bytes_stream();
pub async fn collect_sse_lines(
response: reqwest::Response,
line_count: usize,
read_timeout: Duration,
) -> Result<Vec<String>, Box<dyn std::error::Error>> {
let mut collected_lines = Vec::new();
let mut stream = response.bytes_stream();

let result = timeout(read_timeout, async {
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
let chunk_str = String::from_utf8_lossy(&chunk);
let result = timeout(read_timeout, async {
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
let chunk_str = String::from_utf8_lossy(&chunk);

// Split the chunk into lines
let lines: Vec<&str> = chunk_str.lines().collect();
// Split the chunk into lines
let lines: Vec<&str> = chunk_str.lines().collect();

// Add each line to the collected_lines vector
for line in lines {
collected_lines.push(line.to_string());
// Add each line to the collected_lines vector
for line in lines {
collected_lines.push(line.to_string());

// Check if we have collected 5 lines
if collected_lines.len() >= line_count {
return Ok(collected_lines);
// Check if we have collected 5 lines
if collected_lines.len() >= line_count {
return Ok(collected_lines);
}
}
}
}
// If the stream ends before collecting 5 lines, return what we have
Ok(collected_lines)
})
.await;
// If the stream ends before collecting 5 lines, return what we have
Ok(collected_lines)
})
.await;

// Handle timeout or stream result
match result {
Ok(Ok(lines)) => Ok(lines),
Ok(Err(e)) => Err(e),
Err(_) => Err(Box::new(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"Timed out waiting for 5 lines",
))),
// Handle timeout or stream result
match result {
Ok(Ok(lines)) => Ok(lines),
Ok(Err(e)) => Err(e),
Err(_) => Err(Box::new(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"Timed out waiting for 5 lines",
))),
}
}
}
Loading