diff --git a/packages/loro-websocket/src/server/simple-server.ts b/packages/loro-websocket/src/server/simple-server.ts index ebdf7dd..767142a 100644 --- a/packages/loro-websocket/src/server/simple-server.ts +++ b/packages/loro-websocket/src/server/simple-server.ts @@ -1,6 +1,7 @@ import { WebSocketServer, WebSocket } from "ws"; import { randomBytes } from "node:crypto"; import type { RawData } from "ws"; +import type { IncomingMessage } from "http"; // no direct CRDT imports here; handled by CrdtDoc implementations import { encode, @@ -47,6 +48,13 @@ export interface SimpleServerConfig { crdtType: CrdtType, auth: Uint8Array ) => Promise; + /** + * Optional handshake auth: called during WS HTTP upgrade. + * Return true to accept, false to reject. + */ + handshakeAuth?: ( + req: IncomingMessage + ) => boolean | Promise; } interface RoomDocument { @@ -86,12 +94,28 @@ export class SimpleServer { start(): Promise { return new Promise(resolve => { - const options: { port: number; host?: string } = { + const options: { port: number; host?: string; verifyClient?: any } = { port: this.config.port, }; if (this.config.host) { options.host = this.config.host; } + if (this.config.handshakeAuth) { + options.verifyClient = ( + info: { origin: string; secure: boolean; req: IncomingMessage }, + cb: (res: boolean, code?: number, message?: string) => void + ) => { + Promise.resolve(this.config.handshakeAuth!(info.req)) + .then(allowed => { + if (allowed) cb(true); + else cb(false, 401, "Unauthorized"); + }) + .catch(err => { + console.error("Handshake auth error", err); + cb(false, 500, "Internal Server Error"); + }); + }; + } this.wss = new WebSocketServer(options); this.wss.on("connection", ws => { diff --git a/packages/loro-websocket/tests/handshake-auth.test.ts b/packages/loro-websocket/tests/handshake-auth.test.ts new file mode 100644 index 0000000..910857d --- /dev/null +++ b/packages/loro-websocket/tests/handshake-auth.test.ts @@ -0,0 +1,70 @@ +import { describe, it, expect, beforeAll, afterAll } from "vitest"; +import { WebSocket } from "ws"; +import getPort from "get-port"; +import { SimpleServer } from "../src/server/simple-server"; + +// Make WebSocket available globally for the client +Object.defineProperty(globalThis, "WebSocket", { + value: WebSocket, + configurable: true, + writable: true, +}); + +describe("Handshake Auth", () => { + let server: SimpleServer; + let port: number; + + beforeAll(async () => { + port = await getPort(); + server = new SimpleServer({ + port, + handshakeAuth: req => { + const cookie = req.headers.cookie; + return cookie === "session=valid"; + }, + }); + await server.start(); + }); + + afterAll(async () => { + await server.stop(); + }, 10000); + + it("should accept connection with valid cookie", async () => { + const ws = new WebSocket(`ws://localhost:${port}`, { + headers: { + Cookie: "session=valid", + }, + }); + + await new Promise((resolve, reject) => { + ws.onopen = () => resolve(); + ws.onerror = err => reject(err); + }); + ws.close(); + }); + + it("should reject connection with invalid cookie", async () => { + const ws = new WebSocket(`ws://localhost:${port}`, { + headers: { + Cookie: "session=invalid", + }, + }); + + await new Promise((resolve, reject) => { + ws.onopen = () => reject(new Error("Should have failed")); + ws.onerror = err => { + resolve(); + }; + }); + }); + + it("should reject connection with missing cookie", async () => { + const ws = new WebSocket(`ws://localhost:${port}`); + + await new Promise((resolve, reject) => { + ws.onopen = () => reject(new Error("Should have failed")); + ws.onerror = () => resolve(); + }); + }); +}); diff --git a/rust/Cargo.lock b/rust/Cargo.lock index d418930..6f25d15 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -314,6 +314,16 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "time", + "version_check", +] + [[package]] name = "cpufeatures" version = "0.2.17" @@ -390,6 +400,15 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" +[[package]] +name = "deranged" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ececcb659e7ba858fb4f10388c250a7252eb0a27373f1a72b8748afdd248e587" +dependencies = [ + "powerfmt", +] + [[package]] name = "derive_arbitrary" version = "1.4.2" @@ -1003,6 +1022,7 @@ name = "loro-websocket-server" version = "0.1.0" dependencies = [ "clap", + "cookie", "futures-util", "loro", "loro-protocol", @@ -1124,6 +1144,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-integer" version = "0.1.46" @@ -1263,6 +1289,12 @@ dependencies = [ "serde", ] +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.21" @@ -1723,6 +1755,37 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "time" +version = "0.3.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e7d9e3bb61134e77bde20dd4825b97c010155709965fedf0f49bb138e52a9d" +dependencies = [ + "deranged", + "itoa", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40868e7c1d2f0b8d73e4a8c7f0ff63af4f6d19be117e90bd73eb1d62cf831c6b" + +[[package]] +name = "time-macros" +version = "0.2.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30cfb0125f12d9c277f35663a0a33f8c30190f4e4574868a330595412d34ebf3" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "tokio" version = "1.47.1" diff --git a/rust/loro-websocket-server/Cargo.toml b/rust/loro-websocket-server/Cargo.toml index 297c313..6f60c92 100644 --- a/rust/loro-websocket-server/Cargo.toml +++ b/rust/loro-websocket-server/Cargo.toml @@ -20,6 +20,7 @@ tokio-tungstenite = "0.27" futures-util = { version = "0.3", default-features = false, features = ["sink"] } loro = "1" tracing = "0.1" +cookie = "0.18.1" [dev-dependencies] loro-websocket-client = { version = "0.1.0", path = "../loro-websocket-client" } diff --git a/rust/loro-websocket-server/src/lib.rs b/rust/loro-websocket-server/src/lib.rs index a56aff7..00be5ba 100644 --- a/rust/loro-websocket-server/src/lib.rs +++ b/rust/loro-websocket-server/src/lib.rs @@ -104,11 +104,28 @@ type LoadFuture = type SaveFuture = Pin> + Send + 'static>>; type LoadFn = Arc LoadFuture + Send + Sync>; type SaveFn = Arc) -> SaveFuture + Send + Sync>; + +/// Arguments provided to `authenticate`. +pub struct AuthArgs { + pub room: String, + pub crdt: CrdtType, + pub auth: Vec, + pub conn_id: u64, +} + type AuthFuture = Pin, String>> + Send + 'static>>; -type AuthFn = Arc) -> AuthFuture + Send + Sync>; +type AuthFn = Arc AuthFuture + Send + Sync>; + +/// Arguments provided to `handshake_auth`. +pub struct HandshakeAuthArgs<'a> { + pub workspace: &'a str, + pub token: Option<&'a str>, + pub request: &'a tungstenite::handshake::server::Request, + pub conn_id: u64, +} -type HandshakeAuthFn = dyn Fn(&str, Option<&str>) -> bool + Send + Sync; +type HandshakeAuthFn = dyn Fn(HandshakeAuthArgs) -> bool + Send + Sync; #[derive(Clone)] pub struct ServerConfig { @@ -122,6 +139,8 @@ pub struct ServerConfig { /// Parameters: /// - `workspace_id`: extracted from request path `/{workspace}` (empty if missing) /// - `token`: `token` query parameter if present + /// - `request`: the full HTTP request (headers, uri, etc) + /// - `conn_id`: the connection id /// /// Return true to accept, false to reject with 401. pub handshake_auth: Option>, @@ -884,12 +903,17 @@ async fn handle_conn( where DocCtx: Clone + Send + Sync + 'static, { + + // Generate a connection id + let conn_id = NEXT_ID.fetch_add(1, Ordering::Relaxed); + // Capture config outside of non-async closure let handshake_auth = registry.config.handshake_auth.clone(); let workspace_holder: Arc>> = Arc::new(std::sync::Mutex::new(None)); let workspace_holder_c = workspace_holder.clone(); + let ws = accept_hdr_async( stream, move |req: &tungstenite::handshake::server::Request, @@ -925,7 +949,12 @@ where None }); - let allowed = (check)(workspace_id, token); + let allowed = (check)(HandshakeAuthArgs { + workspace: workspace_id, + token, + request: req, + conn_id, + }); if !allowed { warn!(workspace=%workspace_id, token=?token, "handshake auth denied"); // Build a 401 Unauthorized response @@ -971,7 +1000,6 @@ where } }); - let conn_id = NEXT_ID.fetch_add(1, Ordering::Relaxed); let mut joined_rooms: HashSet = HashSet::new(); while let Some(msg) = stream.next().await { @@ -1001,7 +1029,14 @@ where let mut permission = h.config.default_permission; if let Some(auth_fn) = &h.config.authenticate { let room_str = room.room.clone(); - match (auth_fn)(room_str, room.crdt, auth.clone()).await { + match (auth_fn)(AuthArgs { + room: room_str, + crdt: room.crdt, + auth: auth.clone(), + conn_id, + }) + .await + { Ok(Some(p)) => { permission = p; } diff --git a/rust/loro-websocket-server/tests/e2e.rs b/rust/loro-websocket-server/tests/e2e.rs index d230f60..4646204 100644 --- a/rust/loro-websocket-server/tests/e2e.rs +++ b/rust/loro-websocket-server/tests/e2e.rs @@ -14,7 +14,7 @@ async fn e2e_sync_two_clients_docupdate_roundtrip() { let addr = listener.local_addr().unwrap(); let server_task = tokio::spawn(async move { let cfg: Cfg = server::ServerConfig { - handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))), + handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))), ..Default::default() }; server::serve_incoming_with_config(listener, cfg) @@ -65,7 +65,7 @@ async fn workspaces_are_isolated() { let addr = listener.local_addr().unwrap(); let server_task = tokio::spawn(async move { let cfg: Cfg = server::ServerConfig { - handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))), + handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))), ..Default::default() }; server::serve_incoming_with_config(listener, cfg) @@ -104,7 +104,7 @@ async fn e2e_sync_two_clients_loro_adaptor_roundtrip() { let addr = listener.local_addr().unwrap(); let server_task = tokio::spawn(async move { let cfg: Cfg = server::ServerConfig { - handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))), + handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))), ..Default::default() }; server::serve_incoming_with_config(listener, cfg) @@ -154,7 +154,7 @@ async fn e2e_sync_two_clients_elo_adaptor_roundtrip() { let addr = listener.local_addr().unwrap(); let server_task = tokio::spawn(async move { let cfg: Cfg = server::ServerConfig { - handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))), + handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))), ..Default::default() }; server::serve_incoming_with_config(listener, cfg) diff --git a/rust/loro-websocket-server/tests/elo_accept_broadcast.rs b/rust/loro-websocket-server/tests/elo_accept_broadcast.rs index 7e2d3b8..c0605fa 100644 --- a/rust/loro-websocket-server/tests/elo_accept_broadcast.rs +++ b/rust/loro-websocket-server/tests/elo_accept_broadcast.rs @@ -11,7 +11,7 @@ async fn elo_accepts_join_and_broadcasts_updates() { let addr = listener.local_addr().unwrap(); let server_task = tokio::spawn(async move { let cfg: server::ServerConfig<()> = server::ServerConfig { - handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))), + handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))), ..Default::default() }; server::serve_incoming_with_config(listener, cfg) diff --git a/rust/loro-websocket-server/tests/elo_fragment_reassembly.rs b/rust/loro-websocket-server/tests/elo_fragment_reassembly.rs index a9ef6a4..e86c21d 100644 --- a/rust/loro-websocket-server/tests/elo_fragment_reassembly.rs +++ b/rust/loro-websocket-server/tests/elo_fragment_reassembly.rs @@ -19,7 +19,7 @@ async fn elo_fragment_reassembly_broadcasts_original_frames() { let addr = listener.local_addr().unwrap(); let server_task = tokio::spawn(async move { let cfg: server::ServerConfig<()> = server::ServerConfig { - handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))), + handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))), ..Default::default() }; server::serve_incoming_with_config(listener, cfg) diff --git a/rust/loro-websocket-server/tests/handshake_auth.rs b/rust/loro-websocket-server/tests/handshake_auth.rs index 6206180..fde5939 100644 --- a/rust/loro-websocket-server/tests/handshake_auth.rs +++ b/rust/loro-websocket-server/tests/handshake_auth.rs @@ -9,7 +9,7 @@ async fn handshake_rejects_invalid_token_with_401() { let addr = listener.local_addr().unwrap(); let server_task = tokio::spawn(async move { let cfg: server::ServerConfig<()> = server::ServerConfig { - handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))), + handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))), ..Default::default() }; server::serve_incoming_with_config(listener, cfg) diff --git a/rust/loro-websocket-server/tests/handshake_cookies.rs b/rust/loro-websocket-server/tests/handshake_cookies.rs new file mode 100644 index 0000000..a66af00 --- /dev/null +++ b/rust/loro-websocket-server/tests/handshake_cookies.rs @@ -0,0 +1,77 @@ +use loro_websocket_server as server; +use std::sync::Arc; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tokio_tungstenite::tungstenite::http::HeaderValue; + +#[tokio::test(flavor = "current_thread")] +async fn handshake_auth_can_read_cookies() { + // Start server requiring cookie "session=valid" + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let server_task = tokio::spawn(async move { + let cfg: server::ServerConfig<()> = server::ServerConfig { + handshake_auth: Some(Arc::new(|args| { + if let Some(header) = args.request.headers().get("Cookie") { + if let Ok(s) = header.to_str() { + for cookie in cookie::Cookie::split_parse(s) { + if let Ok(c) = cookie { + if c.name() == "session" && c.value() == "valid" { + return true; + } + } + } + } + } + false + })), + ..Default::default() + }; + server::serve_incoming_with_config(listener, cfg) + .await + .unwrap(); + }); + + let url = format!("ws://{}/ws1", addr); + + // 1. Test valid cookie + { + let mut req = url.clone().into_client_request().unwrap(); + req.headers_mut().insert( + "Cookie", + HeaderValue::from_static("session=valid; other=stuff"), + ); + match tokio_tungstenite::connect_async(req).await { + Ok(_) => {} // success + Err(e) => panic!("valid cookie should be accepted: {}", e), + } + } + + // 2. Test missing cookie + { + let req = url.clone().into_client_request().unwrap(); + // no cookie header + match tokio_tungstenite::connect_async(req).await { + Ok(_) => panic!("missing cookie should be rejected"), + Err(tokio_tungstenite::tungstenite::Error::Http(resp)) => { + assert_eq!(resp.status(), 401); + } + Err(e) => panic!("unexpected error for missing cookie: {}", e), + } + } + + // 3. Test invalid cookie value + { + let mut req = url.clone().into_client_request().unwrap(); + req.headers_mut() + .insert("Cookie", HeaderValue::from_static("session=invalid")); + match tokio_tungstenite::connect_async(req).await { + Ok(_) => panic!("invalid cookie should be rejected"), + Err(tokio_tungstenite::tungstenite::Error::Http(resp)) => { + assert_eq!(resp.status(), 401); + } + Err(e) => panic!("unexpected error for invalid cookie: {}", e), + } + } + + server_task.abort(); +} diff --git a/rust/loro-websocket-server/tests/join_denied.rs b/rust/loro-websocket-server/tests/join_denied.rs index 34450a3..bab1e09 100644 --- a/rust/loro-websocket-server/tests/join_denied.rs +++ b/rust/loro-websocket-server/tests/join_denied.rs @@ -11,9 +11,9 @@ async fn join_denied_returns_error() { // Server with auth that always denies let cfg: server::ServerConfig<()> = server::ServerConfig { - authenticate: Some(Arc::new(|_room, _crdt, _auth| Box::pin(async { Ok(None) }))), + authenticate: Some(Arc::new(|_args| Box::pin(async { Ok(None) }))), default_permission: Permission::Write, - handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))), + handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))), ..Default::default() }; let server_task = tokio::spawn(async move { diff --git a/rust/loro-websocket-server/tests/join_snapshot_load.rs b/rust/loro-websocket-server/tests/join_snapshot_load.rs index 7a2ac79..163648b 100644 --- a/rust/loro-websocket-server/tests/join_snapshot_load.rs +++ b/rust/loro-websocket-server/tests/join_snapshot_load.rs @@ -24,7 +24,7 @@ async fn join_sends_snapshot_from_loader() { }) }) })), - handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))), + handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))), ..Default::default() }; let server_task = tokio::spawn(async move { diff --git a/rust/loro-websocket-server/tests/readonly_receive.rs b/rust/loro-websocket-server/tests/readonly_receive.rs index 62106a2..e03789e 100644 --- a/rust/loro-websocket-server/tests/readonly_receive.rs +++ b/rust/loro-websocket-server/tests/readonly_receive.rs @@ -12,11 +12,11 @@ async fn readonly_receives_updates_writer_sends() { let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); let cfg: server::ServerConfig<()> = server::ServerConfig { - authenticate: Some(Arc::new(|_room, _crdt, auth| { + authenticate: Some(Arc::new(|args| { Box::pin(async move { - if auth == b"writer" { + if args.auth == b"writer" { Ok(Some(Permission::Write)) - } else if auth == b"reader" { + } else if args.auth == b"reader" { Ok(Some(Permission::Read)) } else { Ok(None) @@ -24,7 +24,7 @@ async fn readonly_receives_updates_writer_sends() { }) })), default_permission: Permission::Write, - handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))), + handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))), ..Default::default() }; let server_task = tokio::spawn(async move { diff --git a/rust/loro-websocket-server/tests/reject_update_without_join.rs b/rust/loro-websocket-server/tests/reject_update_without_join.rs index 852322f..1bdf417 100644 --- a/rust/loro-websocket-server/tests/reject_update_without_join.rs +++ b/rust/loro-websocket-server/tests/reject_update_without_join.rs @@ -9,7 +9,7 @@ async fn reject_update_without_join() { let addr = listener.local_addr().unwrap(); let server_task = tokio::spawn(async move { let cfg: server::ServerConfig<()> = server::ServerConfig { - handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))), + handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))), ..Default::default() }; server::serve_incoming_with_config(listener, cfg)