Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion packages/loro-websocket/src/server/simple-server.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { WebSocketServer, WebSocket } from "ws";
import { randomBytes } from "node:crypto";
import type { RawData } from "ws";
import type { IncomingMessage } from "http";
// no direct CRDT imports here; handled by CrdtDoc implementations
import {
encode,
Expand Down Expand Up @@ -47,6 +48,13 @@ export interface SimpleServerConfig {
crdtType: CrdtType,
auth: Uint8Array
) => Promise<Permission | null>;
/**
* Optional handshake auth: called during WS HTTP upgrade.
* Return true to accept, false to reject.
*/
handshakeAuth?: (
req: IncomingMessage
) => boolean | Promise<boolean>;
}

interface RoomDocument {
Expand Down Expand Up @@ -86,12 +94,28 @@ export class SimpleServer {

start(): Promise<void> {
return new Promise(resolve => {
const options: { port: number; host?: string } = {
const options: { port: number; host?: string; verifyClient?: any } = {
port: this.config.port,
};
if (this.config.host) {
options.host = this.config.host;
}
if (this.config.handshakeAuth) {
options.verifyClient = (
info: { origin: string; secure: boolean; req: IncomingMessage },
cb: (res: boolean, code?: number, message?: string) => void
) => {
Promise.resolve(this.config.handshakeAuth!(info.req))
.then(allowed => {
if (allowed) cb(true);
else cb(false, 401, "Unauthorized");
})
.catch(err => {
console.error("Handshake auth error", err);
cb(false, 500, "Internal Server Error");
});
};
}
this.wss = new WebSocketServer(options);

this.wss.on("connection", ws => {
Expand Down
70 changes: 70 additions & 0 deletions packages/loro-websocket/tests/handshake-auth.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import { describe, it, expect, beforeAll, afterAll } from "vitest";
import { WebSocket } from "ws";
import getPort from "get-port";
import { SimpleServer } from "../src/server/simple-server";

// Make WebSocket available globally for the client
Object.defineProperty(globalThis, "WebSocket", {
value: WebSocket,
configurable: true,
writable: true,
});

describe("Handshake Auth", () => {
let server: SimpleServer;
let port: number;

beforeAll(async () => {
port = await getPort();
server = new SimpleServer({
port,
handshakeAuth: req => {
const cookie = req.headers.cookie;
return cookie === "session=valid";
},
});
await server.start();
});

afterAll(async () => {
await server.stop();
}, 10000);

it("should accept connection with valid cookie", async () => {
const ws = new WebSocket(`ws://localhost:${port}`, {
headers: {
Cookie: "session=valid",
},
});

await new Promise<void>((resolve, reject) => {
ws.onopen = () => resolve();
ws.onerror = err => reject(err);
});
ws.close();
});

it("should reject connection with invalid cookie", async () => {
const ws = new WebSocket(`ws://localhost:${port}`, {
headers: {
Cookie: "session=invalid",
},
});

await new Promise<void>((resolve, reject) => {
ws.onopen = () => reject(new Error("Should have failed"));
ws.onerror = err => {
resolve();
};
});
});

it("should reject connection with missing cookie", async () => {
const ws = new WebSocket(`ws://localhost:${port}`);

await new Promise<void>((resolve, reject) => {
ws.onopen = () => reject(new Error("Should have failed"));
ws.onerror = () => resolve();
});
});
});
63 changes: 63 additions & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions rust/loro-websocket-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ tokio-tungstenite = "0.27"
futures-util = { version = "0.3", default-features = false, features = ["sink"] }
loro = "1"
tracing = "0.1"
cookie = "0.18.1"

[dev-dependencies]
loro-websocket-client = { version = "0.1.0", path = "../loro-websocket-client" }
Expand Down
45 changes: 40 additions & 5 deletions rust/loro-websocket-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,28 @@ type LoadFuture<DocCtx> =
type SaveFuture = Pin<Box<dyn Future<Output = Result<(), String>> + Send + 'static>>;
type LoadFn<DocCtx> = Arc<dyn Fn(LoadDocArgs) -> LoadFuture<DocCtx> + Send + Sync>;
type SaveFn<DocCtx> = Arc<dyn Fn(SaveDocArgs<DocCtx>) -> SaveFuture + Send + Sync>;

/// Arguments provided to `authenticate`.
pub struct AuthArgs {
pub room: String,
pub crdt: CrdtType,
pub auth: Vec<u8>,
pub conn_id: u64,
}

type AuthFuture =
Pin<Box<dyn Future<Output = Result<Option<Permission>, String>> + Send + 'static>>;
type AuthFn = Arc<dyn Fn(String, CrdtType, Vec<u8>) -> AuthFuture + Send + Sync>;
type AuthFn = Arc<dyn Fn(AuthArgs) -> AuthFuture + Send + Sync>;

/// Arguments provided to `handshake_auth`.
pub struct HandshakeAuthArgs<'a> {
pub workspace: &'a str,
pub token: Option<&'a str>,
pub request: &'a tungstenite::handshake::server::Request,
pub conn_id: u64,
}

type HandshakeAuthFn = dyn Fn(&str, Option<&str>) -> bool + Send + Sync;
type HandshakeAuthFn = dyn Fn(HandshakeAuthArgs) -> bool + Send + Sync;

#[derive(Clone)]
pub struct ServerConfig<DocCtx = ()> {
Expand All @@ -122,6 +139,8 @@ pub struct ServerConfig<DocCtx = ()> {
/// Parameters:
/// - `workspace_id`: extracted from request path `/{workspace}` (empty if missing)
/// - `token`: `token` query parameter if present
/// - `request`: the full HTTP request (headers, uri, etc)
/// - `conn_id`: the connection id
///
/// Return true to accept, false to reject with 401.
pub handshake_auth: Option<Arc<HandshakeAuthFn>>,
Expand Down Expand Up @@ -884,12 +903,17 @@ async fn handle_conn<DocCtx>(
where
DocCtx: Clone + Send + Sync + 'static,
{

// Generate a connection id
let conn_id = NEXT_ID.fetch_add(1, Ordering::Relaxed);

// Capture config outside of non-async closure
let handshake_auth = registry.config.handshake_auth.clone();
let workspace_holder: Arc<std::sync::Mutex<Option<String>>> =
Arc::new(std::sync::Mutex::new(None));
let workspace_holder_c = workspace_holder.clone();


let ws = accept_hdr_async(
stream,
move |req: &tungstenite::handshake::server::Request,
Expand Down Expand Up @@ -925,7 +949,12 @@ where
None
});

let allowed = (check)(workspace_id, token);
let allowed = (check)(HandshakeAuthArgs {
workspace: workspace_id,
token,
request: req,
conn_id,
});
if !allowed {
warn!(workspace=%workspace_id, token=?token, "handshake auth denied");
// Build a 401 Unauthorized response
Expand Down Expand Up @@ -971,7 +1000,6 @@ where
}
});

let conn_id = NEXT_ID.fetch_add(1, Ordering::Relaxed);
let mut joined_rooms: HashSet<RoomKey> = HashSet::new();

while let Some(msg) = stream.next().await {
Expand Down Expand Up @@ -1001,7 +1029,14 @@ where
let mut permission = h.config.default_permission;
if let Some(auth_fn) = &h.config.authenticate {
let room_str = room.room.clone();
match (auth_fn)(room_str, room.crdt, auth.clone()).await {
match (auth_fn)(AuthArgs {
room: room_str,
crdt: room.crdt,
auth: auth.clone(),
conn_id,
})
.await
{
Ok(Some(p)) => {
permission = p;
}
Expand Down
8 changes: 4 additions & 4 deletions rust/loro-websocket-server/tests/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async fn e2e_sync_two_clients_docupdate_roundtrip() {
let addr = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let cfg: Cfg = server::ServerConfig {
handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))),
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
..Default::default()
};
server::serve_incoming_with_config(listener, cfg)
Expand Down Expand Up @@ -65,7 +65,7 @@ async fn workspaces_are_isolated() {
let addr = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let cfg: Cfg = server::ServerConfig {
handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))),
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
..Default::default()
};
server::serve_incoming_with_config(listener, cfg)
Expand Down Expand Up @@ -104,7 +104,7 @@ async fn e2e_sync_two_clients_loro_adaptor_roundtrip() {
let addr = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let cfg: Cfg = server::ServerConfig {
handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))),
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
..Default::default()
};
server::serve_incoming_with_config(listener, cfg)
Expand Down Expand Up @@ -154,7 +154,7 @@ async fn e2e_sync_two_clients_elo_adaptor_roundtrip() {
let addr = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let cfg: Cfg = server::ServerConfig {
handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))),
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
..Default::default()
};
server::serve_incoming_with_config(listener, cfg)
Expand Down
2 changes: 1 addition & 1 deletion rust/loro-websocket-server/tests/elo_accept_broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ async fn elo_accepts_join_and_broadcasts_updates() {
let addr = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let cfg: server::ServerConfig<()> = server::ServerConfig {
handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))),
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
..Default::default()
};
server::serve_incoming_with_config(listener, cfg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async fn elo_fragment_reassembly_broadcasts_original_frames() {
let addr = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let cfg: server::ServerConfig<()> = server::ServerConfig {
handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))),
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
..Default::default()
};
server::serve_incoming_with_config(listener, cfg)
Expand Down
2 changes: 1 addition & 1 deletion rust/loro-websocket-server/tests/handshake_auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ async fn handshake_rejects_invalid_token_with_401() {
let addr = listener.local_addr().unwrap();
let server_task = tokio::spawn(async move {
let cfg: server::ServerConfig<()> = server::ServerConfig {
handshake_auth: Some(Arc::new(|_ws, token| token == Some("secret"))),
handshake_auth: Some(Arc::new(|args| args.token == Some("secret"))),
..Default::default()
};
server::serve_incoming_with_config(listener, cfg)
Expand Down
Loading