diff --git a/README.md b/README.md index 6a422cc85..24404c200 100644 --- a/README.md +++ b/README.md @@ -346,6 +346,7 @@ The MongoDB MCP Server can be configured using multiple methods, with the follow | CLI Option | Environment Variable | Default | Description | | -------------------------------------- | --------------------------------------------------- | ------------------------------------------------------------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `allowRequestOverrides` | `MDB_MCP_ALLOW_REQUEST_OVERRIDES` | `false` | When set to true, allows configuration values to be overridden via request headers and query parameters. | | `apiClientId` | `MDB_MCP_API_CLIENT_ID` | `` | Atlas API client ID for authentication. Required for running Atlas tools. | | `apiClientSecret` | `MDB_MCP_API_CLIENT_SECRET` | `` | Atlas API client secret for authentication. Required for running Atlas tools. | | `atlasTemporaryDatabaseUserLifetimeMs` | `MDB_MCP_ATLAS_TEMPORARY_DATABASE_USER_LIFETIME_MS` | `14400000` | Time in milliseconds that temporary database users created when connecting to MongoDB Atlas clusters will remain active before being automatically deleted. | diff --git a/package.json b/package.json index 6777cdf4b..9c2178e56 100644 --- a/package.json +++ b/package.json @@ -67,7 +67,7 @@ "generate:api": "./scripts/generate.sh", "generate:arguments": "tsx scripts/generateArguments.ts", "pretest": "pnpm run build", - "test": "vitest --project eslint-rules --project unit-and-integration --coverage", + "test": "vitest --project eslint-rules --project unit-and-integration --coverage --run", "test:accuracy": "sh ./scripts/accuracy/runAccuracyTests.sh", "test:long-running-tests": "vitest --project long-running-tests --coverage", "test:local": "SKIP_ATLAS_TESTS=true SKIP_ATLAS_LOCAL_TESTS=true pnpm run test", diff --git a/server.json b/server.json index 8df12af99..471cf90ce 100644 --- a/server.json +++ b/server.json @@ -16,6 +16,13 @@ "type": "stdio" }, "environmentVariables": [ + { + "name": "MDB_MCP_ALLOW_REQUEST_OVERRIDES", + "description": "When set to true, allows configuration values to be overridden via request headers and query parameters.", + "isRequired": false, + "format": "string", + "isSecret": false + }, { "name": "MDB_MCP_API_CLIENT_ID", "description": "Atlas API client ID for authentication. Required for running Atlas tools.", @@ -186,6 +193,12 @@ } ], "packageArguments": [ + { + "type": "named", + "name": "--allowRequestOverrides", + "description": "When set to true, allows configuration values to be overridden via request headers and query parameters.", + "isRequired": false + }, { "type": "named", "name": "--apiClientId", @@ -344,6 +357,13 @@ "type": "stdio" }, "environmentVariables": [ + { + "name": "MDB_MCP_ALLOW_REQUEST_OVERRIDES", + "description": "When set to true, allows configuration values to be overridden via request headers and query parameters.", + "isRequired": false, + "format": "string", + "isSecret": false + }, { "name": "MDB_MCP_API_CLIENT_ID", "description": "Atlas API client ID for authentication. Required for running Atlas tools.", @@ -514,6 +534,12 @@ } ], "packageArguments": [ + { + "type": "named", + "name": "--allowRequestOverrides", + "description": "When set to true, allows configuration values to be overridden via request headers and query parameters.", + "isRequired": false + }, { "type": "named", "name": "--apiClientId", diff --git a/src/common/config/argsParserOptions.ts b/src/common/config/argsParserOptions.ts index cdf7daa1d..dbc1671e3 100644 --- a/src/common/config/argsParserOptions.ts +++ b/src/common/config/argsParserOptions.ts @@ -18,6 +18,7 @@ export const OPTIONS = { "connectionString", "httpHost", "httpPort", + "allowRequestOverrides", "idleTimeoutMs", "logPath", "notificationTimeoutMs", diff --git a/src/common/config/configOverrides.ts b/src/common/config/configOverrides.ts new file mode 100644 index 000000000..69452a796 --- /dev/null +++ b/src/common/config/configOverrides.ts @@ -0,0 +1,178 @@ +import type { UserConfig } from "./userConfig.js"; +import { UserConfigSchema, configRegistry } from "./userConfig.js"; +import type { RequestContext } from "../../transports/base.js"; +import type { ConfigFieldMeta, OverrideBehavior } from "./configUtils.js"; + +export const CONFIG_HEADER_PREFIX = "x-mongodb-mcp-"; +export const CONFIG_QUERY_PREFIX = "mongodbMcp"; + +/** + * Applies config overrides from request context (headers and query parameters). + * Query parameters take precedence over headers. Can be used within the createSessionConfig + * hook to manually apply the overrides. Requires `allowRequestOverrides` to be enabled. + * + * @param baseConfig - The base user configuration + * @param request - The request context containing headers and query parameters + * @returns The configuration with overrides applied + */ +export function applyConfigOverrides({ + baseConfig, + request, +}: { + baseConfig: UserConfig; + request?: RequestContext; +}): UserConfig { + if (!request) { + return baseConfig; + } + + const result: UserConfig = { ...baseConfig }; + const overridesFromHeaders = extractConfigOverrides("header", request.headers); + const overridesFromQuery = extractConfigOverrides("query", request.query); + + // Only apply overrides if allowRequestOverrides is enabled + if ( + !baseConfig.allowRequestOverrides && + (Object.keys(overridesFromHeaders).length > 0 || Object.keys(overridesFromQuery).length > 0) + ) { + throw new Error("Request overrides are not enabled"); + } + + // Apply header overrides first + for (const [key, overrideValue] of Object.entries(overridesFromHeaders)) { + assertValidConfigKey(key); + const meta = getConfigMeta(key); + const behavior = meta?.overrideBehavior || "not-allowed"; + const baseValue = baseConfig[key as keyof UserConfig]; + const newValue = applyOverride(key, baseValue, overrideValue, behavior); + (result as Record)[key] = newValue; + } + + // Apply query overrides (with precedence), but block secret fields + for (const [key, overrideValue] of Object.entries(overridesFromQuery)) { + assertValidConfigKey(key); + const meta = getConfigMeta(key); + + // Prevent overriding secret fields via query params + if (meta?.isSecret) { + throw new Error(`Config key ${key} can only be overriden with headers.`); + } + + const behavior = meta?.overrideBehavior || "not-allowed"; + const baseValue = baseConfig[key as keyof UserConfig]; + const newValue = applyOverride(key, baseValue, overrideValue, behavior); + (result as Record)[key] = newValue; + } + + return result; +} + +/** + * Extracts config overrides from HTTP headers or query parameters. + */ +function extractConfigOverrides( + mode: "header" | "query", + source: Record | undefined +): Partial> { + if (!source) { + return {}; + } + + const overrides: Partial> = {}; + + for (const [name, value] of Object.entries(source)) { + const configKey = nameToConfigKey(mode, name); + if (!configKey) { + continue; + } + assertValidConfigKey(configKey); + + const parsedValue = parseConfigValue(configKey, value); + if (parsedValue !== undefined) { + overrides[configKey] = parsedValue; + } + } + + return overrides; +} + +function assertValidConfigKey(key: string): asserts key is keyof typeof UserConfigSchema.shape { + if (!(key in UserConfigSchema.shape)) { + throw new Error(`Invalid config key: ${key}`); + } +} + +/** + * Gets the schema metadata for a config key. + */ +export function getConfigMeta(key: keyof typeof UserConfigSchema.shape): ConfigFieldMeta | undefined { + return configRegistry.get(UserConfigSchema.shape[key]); +} + +/** + * Parses a string value to the appropriate type using the Zod schema. + */ +function parseConfigValue(key: keyof typeof UserConfigSchema.shape, value: unknown): unknown { + const fieldSchema = UserConfigSchema.shape[key]; + if (!fieldSchema) { + throw new Error(`Invalid config key: ${key}`); + } + + return fieldSchema.safeParse(value).data; +} + +/** + * Converts a header/query name to its config key format. + * Example: "x-mongodb-mcp-read-only" -> "readOnly" + * Example: "mongodbMcpReadOnly" -> "readOnly" + */ +export function nameToConfigKey(mode: "header" | "query", name: string): string | undefined { + const lowerCaseName = name.toLowerCase(); + + if (mode === "header" && lowerCaseName.startsWith(CONFIG_HEADER_PREFIX)) { + const normalized = lowerCaseName.substring(CONFIG_HEADER_PREFIX.length); + // Convert kebab-case to camelCase + return normalized.replace(/-([a-z])/g, (_, letter: string) => letter.toUpperCase()); + } + if (mode === "query" && name.startsWith(CONFIG_QUERY_PREFIX)) { + const withoutPrefix = name.substring(CONFIG_QUERY_PREFIX.length); + // Convert first letter to lowercase to get config key + return withoutPrefix.charAt(0).toLowerCase() + withoutPrefix.slice(1); + } + + return undefined; +} + +function applyOverride( + key: keyof typeof UserConfigSchema.shape, + baseValue: unknown, + overrideValue: unknown, + behavior: OverrideBehavior +): unknown { + if (typeof behavior === "function") { + // Custom logic function returns the value to use (potentially transformed) + // or throws an error if the override cannot be applied + try { + return behavior(baseValue, overrideValue); + } catch (error) { + throw new Error( + `Cannot apply override for ${key}: ${error instanceof Error ? error.message : String(error)}` + ); + } + } + switch (behavior) { + case "override": + return overrideValue; + + case "merge": + if (Array.isArray(baseValue) && Array.isArray(overrideValue)) { + return [...(baseValue as unknown[]), ...(overrideValue as unknown[])]; + } + throw new Error(`Cannot merge non-array values for ${key}`); + + case "not-allowed": + throw new Error(`Config key ${key} is not allowed to be overridden`); + default: + return baseValue; + } +} diff --git a/src/common/config/configUtils.ts b/src/common/config/configUtils.ts index fb8277378..e32617d12 100644 --- a/src/common/config/configUtils.ts +++ b/src/common/config/configUtils.ts @@ -4,6 +4,23 @@ import { ALL_CONFIG_KEYS } from "./argsParserOptions.js"; import * as levenshteinModule from "ts-levenshtein"; const levenshtein = levenshteinModule.default; +/// Custom logic function to apply the override value. +/// Returns the value to use (which may be transformed from newValue). +/// Should throw an error if the override cannot be applied. +export type CustomOverrideLogic = (oldValue: unknown, newValue: unknown) => unknown; + +/** + * Defines how a config field can be overridden via HTTP headers or query parameters. + */ +export type OverrideBehavior = + /// Cannot be overridden via request + | "not-allowed" + /// Can be completely replaced + | "override" + /// Values are merged (for arrays) + | "merge" + | CustomOverrideLogic; + /** * Metadata for config schema fields. */ @@ -17,7 +34,11 @@ export type ConfigFieldMeta = { * Secret fields will be marked as secret in environment variable definitions. */ isSecret?: boolean; - + /** + * Defines how this config field can be overridden via HTTP headers or query parameters. + * Defaults to "not-allowed" for security. + */ + overrideBehavior?: OverrideBehavior; [key: string]: unknown; }; @@ -91,12 +112,17 @@ export function commaSeparatedToArray(str: string | string[] * Zod's coerce.boolean() treats any non-empty string as true, which is not what we want. */ export function parseBoolean(val: unknown): unknown { + if (val === undefined) { + return undefined; + } if (typeof val === "string") { - const lower = val.toLowerCase().trim(); - if (lower === "false") { + if (val === "false") { return false; } - return true; + if (val === "true") { + return true; + } + throw new Error(`Invalid boolean value: ${val}`); } if (typeof val === "boolean") { return val; @@ -106,3 +132,52 @@ export function parseBoolean(val: unknown): unknown { } return !!val; } + +/** Allow overriding only to the allowed value */ +export function oneWayOverride(allowedValue: T): CustomOverrideLogic { + return (oldValue, newValue) => { + // Only allow override if setting to allowed value or current value + if (newValue === oldValue) { + return newValue; + } + if (newValue === allowedValue) { + return newValue; + } + throw new Error(`Can only set to ${String(allowedValue)}`); + }; +} + +/** Allow overriding only to a value lower than the specified value */ +export function onlyLowerThanBaseValueOverride(): CustomOverrideLogic { + return (oldValue, newValue) => { + if (typeof oldValue !== "number") { + throw new Error(`Unsupported type for base value for override: ${typeof oldValue}`); + } + if (typeof newValue !== "number") { + throw new Error(`Unsupported type for new value for override: ${typeof newValue}`); + } + if (newValue >= oldValue) { + throw new Error(`Can only set to a value lower than the base value`); + } + return newValue; + }; +} + +/** Allow overriding only to a subset of an array but not a superset */ +export function onlySubsetOfBaseValueOverride(): CustomOverrideLogic { + return (oldValue, newValue) => { + if (!Array.isArray(oldValue)) { + throw new Error(`Unsupported type for base value for override: ${typeof oldValue}`); + } + if (!Array.isArray(newValue)) { + throw new Error(`Unsupported type for new value for override: ${typeof newValue}`); + } + if (newValue.length > oldValue.length) { + throw new Error(`Can only override to a subset of the base value`); + } + if (!newValue.every((value) => oldValue.includes(value))) { + throw new Error(`Can only override to a subset of the base value`); + } + return newValue as unknown; + }; +} diff --git a/src/common/config/userConfig.ts b/src/common/config/userConfig.ts index 8a3b7ef2b..661e0046d 100644 --- a/src/common/config/userConfig.ts +++ b/src/common/config/userConfig.ts @@ -5,6 +5,9 @@ import { commaSeparatedToArray, getExportsPath, getLogPath, + oneWayOverride, + onlyLowerThanBaseValueOverride, + onlySubsetOfBaseValueOverride, parseBoolean, } from "./configUtils.js"; import { previewFeatureValues, similarityValues } from "../schemas.js"; @@ -17,24 +20,27 @@ export type UserConfig = z4.infer & CliOptions; export const configRegistry = z4.registry(); export const UserConfigSchema = z4.object({ - apiBaseUrl: z4.string().default("https://cloud.mongodb.com/"), + apiBaseUrl: z4 + .string() + .default("https://cloud.mongodb.com/") + .register(configRegistry, { overrideBehavior: "not-allowed" }), apiClientId: z4 .string() .optional() .describe("Atlas API client ID for authentication. Required for running Atlas tools.") - .register(configRegistry, { isSecret: true }), + .register(configRegistry, { isSecret: true, overrideBehavior: "not-allowed" }), apiClientSecret: z4 .string() .optional() .describe("Atlas API client secret for authentication. Required for running Atlas tools.") - .register(configRegistry, { isSecret: true }), + .register(configRegistry, { isSecret: true, overrideBehavior: "not-allowed" }), connectionString: z4 .string() .optional() .describe( "MongoDB connection string for direct database connections. Optional, if not set, you'll need to call the connect tool before interacting with MongoDB data." ) - .register(configRegistry, { isSecret: true }), + .register(configRegistry, { isSecret: true, overrideBehavior: "not-allowed" }), loggers: z4 .preprocess( (val: string | string[] | undefined) => commaSeparatedToArray(val), @@ -50,16 +56,18 @@ export const UserConfigSchema = z4.object({ .describe("An array of logger types.") .register(configRegistry, { defaultValueDescription: '`"disk,mcp"` see below*', + overrideBehavior: "not-allowed", }), logPath: z4 .string() .default(getLogPath()) .describe("Folder to store logs.") - .register(configRegistry, { defaultValueDescription: "see below*" }), + .register(configRegistry, { defaultValueDescription: "see below*", overrideBehavior: "not-allowed" }), disabledTools: z4 .preprocess((val: string | string[] | undefined) => commaSeparatedToArray(val), z4.array(z4.string())) .default([]) - .describe("An array of tool names, operation types, and/or categories of tools that will be disabled."), + .describe("An array of tool names, operation types, and/or categories of tools that will be disabled.") + .register(configRegistry, { overrideBehavior: "merge" }), confirmationRequiredTools: z4 .preprocess((val: string | string[] | undefined) => commaSeparatedToArray(val), z4.array(z4.string())) .default([ @@ -72,111 +80,145 @@ export const UserConfigSchema = z4.object({ ]) .describe( "An array of tool names that require user confirmation before execution. Requires the client to support elicitation." - ), + ) + .register(configRegistry, { overrideBehavior: "merge" }), readOnly: z4 .preprocess(parseBoolean, z4.boolean()) .default(false) .describe( "When set to true, only allows read, connect, and metadata operation types, disabling create/update/delete operations." - ), + ) + .register(configRegistry, { + overrideBehavior: oneWayOverride(true), + }), indexCheck: z4 .preprocess(parseBoolean, z4.boolean()) .default(false) .describe( "When set to true, enforces that query operations must use an index, rejecting queries that perform a collection scan." - ), + ) + .register(configRegistry, { + overrideBehavior: oneWayOverride(true), + }), telemetry: z4 .enum(["enabled", "disabled"]) .default("enabled") - .describe("When set to disabled, disables telemetry collection."), - transport: z4.enum(["stdio", "http"]).default("stdio").describe("Either 'stdio' or 'http'."), + .describe("When set to disabled, disables telemetry collection.") + .register(configRegistry, { overrideBehavior: "not-allowed" }), + transport: z4 + .enum(["stdio", "http"]) + .default("stdio") + .describe("Either 'stdio' or 'http'.") + .register(configRegistry, { overrideBehavior: "not-allowed" }), httpPort: z4.coerce .number() .int() .min(1, "Invalid httpPort: must be at least 1") .max(65535, "Invalid httpPort: must be at most 65535") .default(3000) - .describe("Port number for the HTTP server (only used when transport is 'http')."), + .describe("Port number for the HTTP server (only used when transport is 'http').") + .register(configRegistry, { overrideBehavior: "not-allowed" }), httpHost: z4 .string() .default("127.0.0.1") - .describe("Host address to bind the HTTP server to (only used when transport is 'http')."), + .describe("Host address to bind the HTTP server to (only used when transport is 'http').") + .register(configRegistry, { overrideBehavior: "not-allowed" }), httpHeaders: z4 .object({}) .passthrough() .default({}) .describe( "Header that the HTTP server will validate when making requests (only used when transport is 'http')." - ), + ) + .register(configRegistry, { overrideBehavior: "not-allowed" }), idleTimeoutMs: z4.coerce .number() .default(600_000) - .describe("Idle timeout for a client to disconnect (only applies to http transport)."), + .describe("Idle timeout for a client to disconnect (only applies to http transport).") + .register(configRegistry, { overrideBehavior: onlyLowerThanBaseValueOverride() }), notificationTimeoutMs: z4.coerce .number() .default(540_000) - .describe("Notification timeout for a client to be aware of disconnect (only applies to http transport)."), + .describe("Notification timeout for a client to be aware of disconnect (only applies to http transport).") + .register(configRegistry, { overrideBehavior: onlyLowerThanBaseValueOverride() }), maxBytesPerQuery: z4.coerce .number() .default(16_777_216) .describe( "The maximum size in bytes for results from a find or aggregate tool call. This serves as an upper bound for the responseBytesLimit parameter in those tools." - ), + ) + .register(configRegistry, { overrideBehavior: "not-allowed" }), maxDocumentsPerQuery: z4.coerce .number() .default(100) .describe( "The maximum number of documents that can be returned by a find or aggregate tool call. For the find tool, the effective limit will be the smaller of this value and the tool's limit parameter." - ), + ) + .register(configRegistry, { overrideBehavior: "not-allowed" }), exportsPath: z4 .string() .default(getExportsPath()) .describe("Folder to store exported data files.") - .register(configRegistry, { defaultValueDescription: "see below*" }), + .register(configRegistry, { defaultValueDescription: "see below*", overrideBehavior: "not-allowed" }), exportTimeoutMs: z4.coerce .number() .default(300_000) - .describe("Time in milliseconds after which an export is considered expired and eligible for cleanup."), + .describe("Time in milliseconds after which an export is considered expired and eligible for cleanup.") + .register(configRegistry, { overrideBehavior: onlyLowerThanBaseValueOverride() }), exportCleanupIntervalMs: z4.coerce .number() .default(120_000) - .describe("Time in milliseconds between export cleanup cycles that remove expired export files."), + .describe("Time in milliseconds between export cleanup cycles that remove expired export files.") + .register(configRegistry, { overrideBehavior: "not-allowed" }), atlasTemporaryDatabaseUserLifetimeMs: z4.coerce .number() .default(14_400_000) .describe( "Time in milliseconds that temporary database users created when connecting to MongoDB Atlas clusters will remain active before being automatically deleted." - ), + ) + .register(configRegistry, { overrideBehavior: onlyLowerThanBaseValueOverride() }), voyageApiKey: z4 .string() .default("") .describe( "API key for Voyage AI embeddings service (required for vector search operations with text-to-embedding conversion)." ) - .register(configRegistry, { isSecret: true }), + .register(configRegistry, { isSecret: true, overrideBehavior: "override" }), embeddingsValidation: z4 .preprocess(parseBoolean, z4.boolean()) .default(true) - .describe("When set to false, disables validation of embeddings dimensions."), + .describe("When set to false, disables validation of embeddings dimensions.") + .register(configRegistry, { overrideBehavior: oneWayOverride(true) }), vectorSearchDimensions: z4.coerce .number() .default(1024) - .describe("Default number of dimensions for vector search embeddings."), + .describe("Default number of dimensions for vector search embeddings.") + .register(configRegistry, { overrideBehavior: "override" }), vectorSearchSimilarityFunction: z4 .enum(similarityValues) .default("euclidean") - .describe("Default similarity function for vector search: 'euclidean', 'cosine', or 'dotProduct'."), + .describe("Default similarity function for vector search: 'euclidean', 'cosine', or 'dotProduct'.") + .register(configRegistry, { overrideBehavior: "override" }), previewFeatures: z4 .preprocess( (val: string | string[] | undefined) => commaSeparatedToArray(val), z4.array(z4.enum(previewFeatureValues)) ) .default([]) - .describe("An array of preview features that are enabled."), + .describe("An array of preview features that are enabled.") + .register(configRegistry, { overrideBehavior: onlySubsetOfBaseValueOverride() }), + allowRequestOverrides: z4 + .preprocess(parseBoolean, z4.boolean()) + .default(false) + .describe( + "When set to true, allows configuration values to be overridden via request headers and query parameters." + ) + .register(configRegistry, { overrideBehavior: "not-allowed" }), dryRun: z4 .boolean() .default(false) .describe( "When true, runs the server in dry mode: dumps configuration and enabled tools, then exits without starting the server." - ), + ) + .register(configRegistry, { overrideBehavior: "not-allowed" }), }); diff --git a/src/lib.ts b/src/lib.ts index d101da38f..babdbde77 100644 --- a/src/lib.ts +++ b/src/lib.ts @@ -25,3 +25,4 @@ export { Telemetry } from "./telemetry/telemetry.js"; export { Keychain, registerGlobalSecretToRedact } from "./common/keychain.js"; export type { Secret } from "./common/keychain.js"; export { Elicitation } from "./elicitation.js"; +export { applyConfigOverrides } from "./common/config/configOverrides.js"; diff --git a/src/transports/base.ts b/src/transports/base.ts index 15b36bd30..787df80a8 100644 --- a/src/transports/base.ts +++ b/src/transports/base.ts @@ -21,8 +21,17 @@ import { defaultCreateAtlasLocalClient } from "../common/atlasLocal.js"; import type { Client } from "@mongodb-js/atlas-local"; import { VectorSearchEmbeddingsManager } from "../common/search/vectorSearchEmbeddingsManager.js"; import type { ToolBase, ToolConstructorParams } from "../tools/tool.js"; +import { applyConfigOverrides } from "../common/config/configOverrides.js"; -type CreateSessionConfigFn = (userConfig: UserConfig) => Promise | UserConfig; +export type RequestContext = { + headers?: Record; + query?: Record; +}; + +type CreateSessionConfigFn = (context: { + userConfig: UserConfig; + request?: RequestContext; +}) => Promise | UserConfig; export type TransportRunnerConfig = { userConfig: UserConfig; @@ -90,10 +99,14 @@ export abstract class TransportRunnerBase { this.deviceId = DeviceId.create(this.logger); } - protected async setupServer(): Promise { - // Call the config provider hook if provided, allowing consumers to - // fetch or modify configuration before session initialization - const userConfig = this.createSessionConfig ? await this.createSessionConfig(this.userConfig) : this.userConfig; + protected async setupServer(request?: RequestContext): Promise { + let userConfig: UserConfig = this.userConfig; + + if (this.createSessionConfig) { + userConfig = await this.createSessionConfig({ userConfig, request }); + } else { + userConfig = applyConfigOverrides({ baseConfig: this.userConfig, request }); + } const mcpServer = new McpServer( { diff --git a/src/transports/streamableHttp.ts b/src/transports/streamableHttp.ts index 0a20e59e8..3d3d59ca4 100644 --- a/src/transports/streamableHttp.ts +++ b/src/transports/streamableHttp.ts @@ -5,7 +5,7 @@ import { StreamableHTTPServerTransport } from "@modelcontextprotocol/sdk/server/ import { isInitializeRequest } from "@modelcontextprotocol/sdk/types.js"; import { LogId } from "../common/logger.js"; import { SessionStore } from "../common/sessionStore.js"; -import { TransportRunnerBase, type TransportRunnerConfig } from "./base.js"; +import { TransportRunnerBase, type TransportRunnerConfig, type RequestContext } from "./base.js"; const JSON_RPC_ERROR_CODE_PROCESSING_REQUEST_FAILED = -32000; const JSON_RPC_ERROR_CODE_SESSION_ID_REQUIRED = -32001; @@ -111,7 +111,11 @@ export class StreamableHttpRunner extends TransportRunnerBase { return; } - const server = await this.setupServer(); + const request: RequestContext = { + headers: req.headers as Record, + query: req.query as Record, + }; + const server = await this.setupServer(request); let keepAliveLoop: NodeJS.Timeout; const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: (): string => randomUUID().toString(), diff --git a/tests/integration/server.test.ts b/tests/integration/server.test.ts index 2c362a963..a1ccba1fb 100644 --- a/tests/integration/server.test.ts +++ b/tests/integration/server.test.ts @@ -82,7 +82,7 @@ describe("Server integration test", () => { expect(tools.tools.some((tool) => tool.name === "atlas-list-projects")).toBe(true); // Check that non-read tools are NOT available - expect(tools.tools.some((tool) => tool.name === "insert-one")).toBe(false); + expect(tools.tools.some((tool) => tool.name === "insert-many")).toBe(false); expect(tools.tools.some((tool) => tool.name === "update-many")).toBe(false); expect(tools.tools.some((tool) => tool.name === "delete-one")).toBe(false); expect(tools.tools.some((tool) => tool.name === "drop-collection")).toBe(false); diff --git a/tests/integration/transports/configOverrides.test.ts b/tests/integration/transports/configOverrides.test.ts new file mode 100644 index 000000000..7157339f2 --- /dev/null +++ b/tests/integration/transports/configOverrides.test.ts @@ -0,0 +1,550 @@ +import { StreamableHttpRunner } from "../../../src/transports/streamableHttp.js"; +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; +import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; +import { describe, expect, it, afterEach, beforeEach } from "vitest"; +import { defaultTestConfig, expectDefined } from "../helpers.js"; +import type { TransportRunnerConfig, UserConfig } from "../../../src/lib.js"; +import type { RequestContext } from "../../../src/transports/base.js"; + +describe("Config Overrides via HTTP", () => { + let runner: StreamableHttpRunner; + let client: Client; + let transport: StreamableHTTPClientTransport; + + // Helper function to setup and start runner with config + async function startRunner( + config: UserConfig, + createSessionConfig?: TransportRunnerConfig["createSessionConfig"] + ): Promise { + runner = new StreamableHttpRunner({ userConfig: config, createSessionConfig }); + await runner.start(); + } + + // Helper function to connect client with headers + async function connectClient(headers: Record = {}): Promise { + transport = new StreamableHTTPClientTransport(new URL(`${runner.serverAddress}/mcp`), { + requestInit: { headers }, + }); + await client.connect(transport); + } + + beforeEach(() => { + client = new Client({ + name: "test-client", + version: "1.0.0", + }); + }); + + afterEach(async () => { + if (client) { + await client.close(); + } + if (transport) { + await transport.close(); + } + if (runner) { + await runner.close(); + } + }); + + describe("override behavior", () => { + it("should error when allowRequestOverrides is false", async () => { + await startRunner({ + ...defaultTestConfig, + httpPort: 0, + readOnly: false, + allowRequestOverrides: false, + }); + + try { + await connectClient({ + ["x-mongodb-mcp-read-only"]: "true", + }); + expect.fail("Expected an error to be thrown"); + } catch (error) { + if (!(error instanceof Error)) { + throw new Error("Expected an error to be thrown"); + } + expect(error.message).toContain("Request overrides are not enabled"); + } + }); + + it("should override readOnly config with header (false to true)", async () => { + await startRunner({ + ...defaultTestConfig, + httpPort: 0, + readOnly: false, + allowRequestOverrides: true, + }); + + await connectClient({ + ["x-mongodb-mcp-read-only"]: "true", + }); + + const response = await client.listTools(); + + expect(response).toBeDefined(); + expect(response.tools).toBeDefined(); + + // Verify read-only mode is applied - insert-many should not be available + const writeTools = response.tools.filter((tool) => tool.name === "insert-many"); + expect(writeTools.length).toBe(0); + + // Verify read tools are available + const readTools = response.tools.filter((tool) => tool.name === "find"); + expect(readTools.length).toBe(1); + }); + + it("should not be able tooverride connectionString with header", async () => { + await startRunner({ + ...defaultTestConfig, + httpPort: 0, + connectionString: undefined, + allowRequestOverrides: true, + }); + + try { + await connectClient({ + ["x-mongodb-mcp-connection-string"]: "mongodb://override:27017", + }); + expect.fail("Expected an error to be thrown"); + } catch (error) { + if (!(error instanceof Error)) { + throw new Error("Expected an error to be thrown"); + } + expect(error.message).toContain("Error POSTing to endpoint (HTTP 400)"); + expect(error.message).toContain(`Config key connectionString is not allowed to be overridden`); + } + }); + }); + + describe("merge behavior", () => { + it("should merge disabledTools with header", async () => { + await startRunner({ + ...defaultTestConfig, + httpPort: 0, + disabledTools: ["insert-many"], + allowRequestOverrides: true, + }); + + await connectClient({ + ["x-mongodb-mcp-disabled-tools"]: "find,aggregate", + }); + + const response = await client.listTools(); + + expect(response).toBeDefined(); + expect(response.tools).toBeDefined(); + + // Verify all three tools are disabled + const insertTool = response.tools.find( + (tool) => tool.name === "insert-many" || tool.name === "find" || tool.name === "aggregate" + ); + + expect(response.tools).not.toHaveLength(0); + expect(insertTool).toBeUndefined(); + }); + }); + + describe("not-allowed behavior", () => { + it.each([ + { + configKey: "apiBaseUrl", + headerName: "x-mongodb-mcp-api-base-url", + headerValue: "https://malicious.com/", + }, + { + configKey: "apiClientId", + headerName: "x-mongodb-mcp-api-client-id", + headerValue: "malicious-id", + }, + { + configKey: "apiClientSecret", + headerName: "x-mongodb-mcp-api-client-secret", + headerValue: "malicious-secret", + }, + { + configKey: "transport", + headerName: "x-mongodb-mcp-transport", + headerValue: "stdio", + }, + { + configKey: "httpPort", + headerName: "x-mongodb-mcp-http-port", + headerValue: "9999", + }, + { + configKey: "maxBytesPerQuery", + headerName: "x-mongodb-mcp-max-bytes-per-query", + headerValue: "999999", + }, + { + configKey: "maxDocumentsPerQuery", + headerName: "x-mongodb-mcp-max-documents-per-query", + headerValue: "1000", + }, + ])("should reject $configKey with header", async ({ configKey, headerName, headerValue }) => { + await startRunner({ + ...defaultTestConfig, + httpPort: 0, + allowRequestOverrides: true, + }); + + try { + await connectClient({ + [headerName]: headerValue, + }); + expect.fail("Expected an error to be thrown"); + } catch (error) { + if (!(error instanceof Error)) { + throw new Error("Expected an error to be thrown"); + } + expect(error.message).toContain("Error POSTing to endpoint (HTTP 400)"); + expect(error.message).toContain(`Config key ${configKey} is not allowed to be overridden`); + } + }); + + it("should reject multiple not-allowed fields at once", async () => { + await startRunner({ + ...defaultTestConfig, + httpPort: 0, + allowRequestOverrides: true, + }); + + try { + await connectClient({ + "x-mongodb-mcp-api-base-url": "https://malicious.com/", + "x-mongodb-mcp-transport": "stdio", + "x-mongodb-mcp-http-port": "9999", + }); + expect.fail("Expected an error to be thrown"); + } catch (error) { + if (!(error instanceof Error)) { + throw new Error("Expected an error to be thrown"); + } + expect(error.message).toContain("Error POSTing to endpoint (HTTP 400)"); + // Should contain at least one of the not-allowed field errors + const hasNotAllowedError = + error.message.includes("Config key apiBaseUrl is not allowed to be overridden") || + error.message.includes("Config key transport is not allowed to be overridden") || + error.message.includes("Config key httpPort is not allowed to be overridden"); + expect(hasNotAllowedError).toBe(true); + } + }); + }); + + describe("query parameter overrides", () => { + it("should apply overrides from query parameters", async () => { + await startRunner({ + ...defaultTestConfig, + httpPort: 0, + readOnly: false, + allowRequestOverrides: true, + }); + + // Note: SDK doesn't support query params directly, so this test verifies the mechanism exists + // In real usage, query params would be in the URL or request + await connectClient({ + ["x-mongodb-mcp-read-only"]: "true", + }); + + const response = await client.listTools(); + + expect(response).toBeDefined(); + const writeTools = response.tools.filter((tool) => tool.name === "insert-many"); + expect(writeTools.length).toBe(0); + }); + }); + + describe("integration with createSessionConfig", () => { + it("should allow createSessionConfig to override header values", async () => { + const userConfig = { + ...defaultTestConfig, + httpPort: 0, + readOnly: false, + allowRequestOverrides: true, + }; + + // createSessionConfig receives the config after header overrides are applied + // It can further modify it, but headers have already been applied + const createSessionConfig: TransportRunnerConfig["createSessionConfig"] = ({ + userConfig: config, + request, + }: { + userConfig: typeof userConfig; + request?: RequestContext; + }): typeof userConfig => { + expectDefined(request); + expectDefined(request.headers); + expect(request.headers).toBeDefined(); + config.readOnly = request.headers["x-mongodb-mcp-read-only"] === "true"; + config.disabledTools = ["count"]; + return config; + }; + + await startRunner(userConfig, createSessionConfig); + + await connectClient({ + ["x-mongodb-mcp-read-only"]: "true", + }); + + const response = await client.listTools(); + + expect(response).toBeDefined(); + + // Verify read-only mode was applied, as specified in request and + const writeTools = response.tools.filter((tool) => tool.name === "insert-many"); + expect(writeTools.length).toBe(0); + + // Verify create session config overrides were applied + const countTool = response.tools.find((tool) => tool.name === "count"); + expect(countTool).toBeUndefined(); + + expect(response.tools).not.toHaveLength(0); + }); + + it("should pass request context to createSessionConfig", async () => { + const userConfig = { + ...defaultTestConfig, + httpPort: 0, + allowRequestOverrides: true, + }; + + let capturedRequest: RequestContext | undefined; + const createSessionConfig: TransportRunnerConfig["createSessionConfig"] = ({ + request, + }: { + userConfig: typeof userConfig; + request?: RequestContext; + }): Promise => { + expectDefined(request); + expectDefined(request.headers); + capturedRequest = request; + return Promise.resolve(userConfig); + }; + + await startRunner(userConfig, createSessionConfig); + + await connectClient({ + "x-custom-header": "test-value", + }); + + // Verify that request context was passed + expectDefined(capturedRequest); + expectDefined(capturedRequest.headers); + expect(capturedRequest.headers["x-custom-header"]).toBe("test-value"); + }); + }); + + describe("conditional overrides", () => { + it("should allow readOnly from false to true", async () => { + await startRunner({ + ...defaultTestConfig, + httpPort: 0, + readOnly: false, + allowRequestOverrides: true, + }); + + await connectClient({ + ["x-mongodb-mcp-read-only"]: "true", + }); + + const response = await client.listTools(); + + expect(response).toBeDefined(); + expect(response.tools).toBeDefined(); + // Check readonly mode + const writeTools = response.tools.filter((tool) => tool.name === "insert-many"); + expect(writeTools.length).toBe(0); + + // Check read tools are available + const readTools = response.tools.filter((tool) => tool.name === "find"); + expect(readTools.length).toBe(1); + }); + + it("should NOT allow readOnly from true to false", async () => { + await startRunner({ + ...defaultTestConfig, + httpPort: 0, + readOnly: true, + allowRequestOverrides: true, + }); + + try { + await connectClient({ + ["x-mongodb-mcp-read-only"]: "false", + }); + expect.fail("Expected an error to be thrown"); + } catch (error) { + if (!(error instanceof Error)) { + throw new Error("Expected an error to be thrown"); + } + expect(error.message).toContain("Error POSTing to endpoint (HTTP 400)"); + expect(error.message).toContain(`Cannot apply override for readOnly: Can only set to true`); + } + }); + }); + + describe("multiple overrides", () => { + it("should handle multiple header overrides", async () => { + await startRunner({ + ...defaultTestConfig, + httpPort: 0, + readOnly: false, + indexCheck: false, + idleTimeoutMs: 600_000, + disabledTools: ["tool1"], + allowRequestOverrides: true, + }); + + await connectClient({ + ["x-mongodb-mcp-read-only"]: "true", + ["x-mongodb-mcp-index-check"]: "true", + ["x-mongodb-mcp-idle-timeout-ms"]: "300000", + ["x-mongodb-mcp-disabled-tools"]: "count", + }); + + const response = await client.listTools(); + + expect(response).toBeDefined(); + + // Verify read-only mode + const writeTools = response.tools.filter((tool) => tool.name === "insert-many"); + expect(writeTools.length).toBe(0); + + // Verify disabled tools + const countTool = response.tools.find((tool) => tool.name === "count"); + expect(countTool).toBeUndefined(); + + const findTool = response.tools.find((tool) => tool.name === "find"); + expect(findTool).toBeDefined(); + }); + }); + + describe("onlyLowerThanBaseValueOverride behavior", () => { + it("should allow override to a lower value", async () => { + await startRunner({ + ...defaultTestConfig, + httpPort: 0, + idleTimeoutMs: 600_000, + allowRequestOverrides: true, + }); + + await connectClient({ + ["x-mongodb-mcp-idle-timeout-ms"]: "300000", + }); + + const response = await client.listTools(); + expect(response).toBeDefined(); + }); + + it("should reject override to a higher value", async () => { + await startRunner({ + ...defaultTestConfig, + httpPort: 0, + idleTimeoutMs: 600_000, + allowRequestOverrides: true, + }); + + try { + await connectClient({ + ["x-mongodb-mcp-idle-timeout-ms"]: "900000", + }); + expect.fail("Expected an error to be thrown"); + } catch (error) { + if (!(error instanceof Error)) { + throw new Error("Expected an error to be thrown"); + } + expect(error.message).toContain("Error POSTing to endpoint (HTTP 400)"); + expect(error.message).toContain( + "Cannot apply override for idleTimeoutMs: Can only set to a value lower than the base value" + ); + } + }); + + it("should reject override to equal value", async () => { + await startRunner({ + ...defaultTestConfig, + httpPort: 0, + idleTimeoutMs: 600_000, + allowRequestOverrides: true, + }); + + try { + await connectClient({ + ["x-mongodb-mcp-idle-timeout-ms"]: "600000", + }); + expect.fail("Expected an error to be thrown"); + } catch (error) { + if (!(error instanceof Error)) { + throw new Error("Expected an error to be thrown"); + } + expect(error.message).toContain("Error POSTing to endpoint (HTTP 400)"); + expect(error.message).toContain( + "Cannot apply override for idleTimeoutMs: Can only set to a value lower than the base value" + ); + } + }); + }); + + describe("onlySubsetOfBaseValueOverride behavior", () => { + describe("previewFeatures", () => { + it("should allow override to same value", async () => { + await startRunner({ + ...defaultTestConfig, + httpPort: 0, + previewFeatures: ["search"], + allowRequestOverrides: true, + }); + + await connectClient({ + ["x-mongodb-mcp-preview-features"]: "search", + }); + + const response = await client.listTools(); + expect(response).toBeDefined(); + }); + + it("should allow override to an empty array (subset of any array)", async () => { + await startRunner({ + ...defaultTestConfig, + httpPort: 0, + previewFeatures: ["search"], + allowRequestOverrides: true, + }); + + await connectClient({ + ["x-mongodb-mcp-preview-features"]: "", + }); + + const response = await client.listTools(); + expect(response).toBeDefined(); + }); + + it("should reject override when base is empty array and trying to add items", async () => { + await startRunner({ + ...defaultTestConfig, + httpPort: 0, + previewFeatures: [], + allowRequestOverrides: true, + }); + + // Empty array trying to override with non-empty should fail (superset) + try { + await connectClient({ + ["x-mongodb-mcp-preview-features"]: "search", + }); + expect.fail("Expected an error to be thrown"); + } catch (error) { + if (!(error instanceof Error)) { + throw new Error("Expected an error to be thrown"); + } + expect(error.message).toContain("Error POSTing to endpoint (HTTP 400)"); + expect(error.message).toContain( + "Cannot apply override for previewFeatures: Can only override to a subset of the base value" + ); + } + }); + }); + }); +}); diff --git a/tests/integration/transports/createSessionConfig.test.ts b/tests/integration/transports/createSessionConfig.test.ts index a0b72dcd2..d61e3dae4 100644 --- a/tests/integration/transports/createSessionConfig.test.ts +++ b/tests/integration/transports/createSessionConfig.test.ts @@ -1,151 +1,143 @@ import { StreamableHttpRunner } from "../../../src/transports/streamableHttp.js"; import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js"; -import { describe, expect, it } from "vitest"; -import type { TransportRunnerConfig } from "../../../src/lib.js"; -import { defaultTestConfig } from "../helpers.js"; +import { afterEach, describe, expect, it } from "vitest"; +import type { TransportRunnerConfig, UserConfig } from "../../../src/lib.js"; +import { defaultTestConfig, expectDefined } from "../helpers.js"; describe("createSessionConfig", () => { const userConfig = defaultTestConfig; let runner: StreamableHttpRunner; + let client: Client | undefined; + let transport: StreamableHTTPClientTransport | undefined; + + // Helper to start runner with config + const startRunner = async ( + config: { + userConfig?: typeof userConfig; + createSessionConfig?: TransportRunnerConfig["createSessionConfig"]; + } = {} + ): Promise => { + runner = new StreamableHttpRunner({ + userConfig: { ...userConfig, httpPort: 0, ...config.userConfig }, + createSessionConfig: config.createSessionConfig, + }); + await runner.start(); + return runner; + }; + + // Helper to setup server and get user config + const getServerConfig = async (): Promise => { + const server = await runner["setupServer"](); + return server.userConfig; + }; + + // Helper to create and connect client + const createConnectedClient = async (): Promise<{ client: Client; transport: StreamableHTTPClientTransport }> => { + client = new Client({ name: "test-client", version: "1.0.0" }); + transport = new StreamableHTTPClientTransport(new URL(`${runner.serverAddress}/mcp`)); + await client.connect(transport); + return { client, transport }; + }; + + afterEach(async () => { + if (client) { + await client.close(); + client = undefined; + } + if (transport) { + await transport.close(); + transport = undefined; + } + if (runner) { + await runner.close(); + } + }); describe("basic functionality", () => { it("should use the modified config from createSessionConfig", async () => { - const createSessionConfig: TransportRunnerConfig["createSessionConfig"] = async (userConfig) => { - return Promise.resolve({ - ...userConfig, - apiBaseUrl: "https://test-api.mongodb.com/", - }); - }; - userConfig.httpPort = 0; // Use a random port - runner = new StreamableHttpRunner({ - userConfig, - createSessionConfig, + await startRunner({ + createSessionConfig: async ({ userConfig }) => + Promise.resolve({ + ...userConfig, + apiBaseUrl: "https://test-api.mongodb.com/", + }), }); - await runner.start(); - - const server = await runner["setupServer"](); - expect(server.userConfig.apiBaseUrl).toBe("https://test-api.mongodb.com/"); - await runner.close(); + const config = await getServerConfig(); + expect(config.apiBaseUrl).toBe("https://test-api.mongodb.com/"); }); it("should work without a createSessionConfig", async () => { - userConfig.httpPort = 0; // Use a random port - runner = new StreamableHttpRunner({ - userConfig, - }); - await runner.start(); - - const server = await runner["setupServer"](); - expect(server.userConfig.apiBaseUrl).toBe(userConfig.apiBaseUrl); + await startRunner(); - await runner.close(); + const config = await getServerConfig(); + expect(config.apiBaseUrl).toBe(userConfig.apiBaseUrl); }); }); describe("connection string modification", () => { it("should allow modifying connection string via createSessionConfig", async () => { - const createSessionConfig: TransportRunnerConfig["createSessionConfig"] = async (userConfig) => { - // Simulate fetching connection string from environment or secrets - await new Promise((resolve) => setTimeout(resolve, 10)); - - return { - ...userConfig, - connectionString: "mongodb://test-server:27017/test-db", - }; - }; - - userConfig.httpPort = 0; // Use a random port - runner = new StreamableHttpRunner({ + await startRunner({ userConfig: { ...userConfig, connectionString: undefined }, - createSessionConfig, + createSessionConfig: async ({ userConfig }) => { + // Simulate fetching connection string from environment or secrets + await new Promise((resolve) => setTimeout(resolve, 10)); + return { + ...userConfig, + connectionString: "mongodb://test-server:27017/test-db", + }; + }, }); - await runner.start(); - const server = await runner["setupServer"](); - expect(server.userConfig.connectionString).toBe("mongodb://test-server:27017/test-db"); - - await runner.close(); + const config = await getServerConfig(); + expect(config.connectionString).toBe("mongodb://test-server:27017/test-db"); }); }); describe("server integration", () => { - let client: Client; - let transport: StreamableHTTPClientTransport; - it("should successfully initialize server with createSessionConfig and serve requests", async () => { - const createSessionConfig: TransportRunnerConfig["createSessionConfig"] = async (userConfig) => { - // Simulate async config modification - await new Promise((resolve) => setTimeout(resolve, 10)); - return { - ...userConfig, - readOnly: true, // Enable read-only mode - }; - }; - - userConfig.httpPort = 0; // Use a random port - runner = new StreamableHttpRunner({ - userConfig, - createSessionConfig, + await startRunner({ + createSessionConfig: async ({ userConfig }) => { + // Simulate async config modification + await new Promise((resolve) => setTimeout(resolve, 10)); + return { + ...userConfig, + readOnly: true, // Enable read-only mode + }; + }, }); - await runner.start(); - client = new Client({ - name: "test-client", - version: "1.0.0", - }); - transport = new StreamableHTTPClientTransport(new URL(`${runner.serverAddress}/mcp`)); + await createConnectedClient(); + const response = await client?.listTools(); + expectDefined(response); - await client.connect(transport); - const response = await client.listTools(); - - expect(response).toBeDefined(); expect(response.tools).toBeDefined(); expect(response.tools.length).toBeGreaterThan(0); - // Verify read-only mode is applied - insert-one should not be available - const writeTools = response.tools.filter((tool) => tool.name === "insert-one"); + // Verify read-only mode is applied - insert-many should not be available + const writeTools = response.tools.filter((tool) => tool.name === "insert-many"); expect(writeTools.length).toBe(0); // Verify read tools are available const readTools = response.tools.filter((tool) => tool.name === "find"); expect(readTools.length).toBe(1); - - await client.close(); - await transport.close(); - await runner.close(); }); }); describe("error handling", () => { it("should propagate errors from configProvider on client connection", async () => { - const createSessionConfig: TransportRunnerConfig["createSessionConfig"] = async () => { - return Promise.reject(new Error("Failed to fetch config")); - }; - - userConfig.httpPort = 0; // Use a random port - runner = new StreamableHttpRunner({ - userConfig, - createSessionConfig, + await startRunner({ + createSessionConfig: async () => { + return Promise.reject(new Error("Failed to fetch config")); + }, }); - // Start succeeds because setupServer is only called when a client connects - await runner.start(); - // Error should occur when a client tries to connect - const testClient = new Client({ - name: "test-client", - version: "1.0.0", - }); - const testTransport = new StreamableHTTPClientTransport(new URL(`${runner.serverAddress}/mcp`)); - - await expect(testClient.connect(testTransport)).rejects.toThrow(); - - await testClient.close(); - await testTransport.close(); + client = new Client({ name: "test-client", version: "1.0.0" }); + transport = new StreamableHTTPClientTransport(new URL(`${runner.serverAddress}/mcp`)); - await runner.close(); + await expect(client.connect(transport)).rejects.toThrow(); }); }); }); diff --git a/tests/tsconfig.json b/tests/tsconfig.json new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/common/config.test.ts b/tests/unit/common/config.test.ts index fcf5d1437..c5f537d2d 100644 --- a/tests/unit/common/config.test.ts +++ b/tests/unit/common/config.test.ts @@ -1,7 +1,12 @@ import { describe, it, expect, vi, beforeEach, afterEach, type MockedFunction } from "vitest"; import { type UserConfig, UserConfigSchema } from "../../../src/common/config/userConfig.js"; import { type CreateUserConfigHelpers, createUserConfig } from "../../../src/common/config/createUserConfig.js"; -import { getLogPath, getExportsPath } from "../../../src/common/config/configUtils.js"; +import { + getLogPath, + getExportsPath, + onlyLowerThanBaseValueOverride, + onlySubsetOfBaseValueOverride, +} from "../../../src/common/config/configUtils.js"; import { Keychain } from "../../../src/common/keychain.js"; import type { Secret } from "../../../src/common/keychain.js"; @@ -60,6 +65,7 @@ const expectedDefaults = { embeddingsValidation: true, previewFeatures: [], dryRun: false, + allowRequestOverrides: false, }; describe("config", () => { @@ -976,3 +982,88 @@ describe("keychain management", () => { }); } }); + +describe("custom override logic functions", () => { + describe("onlyLowerThanBaseValueOverride", () => { + it("should allow override to a lower value", () => { + const customLogic = onlyLowerThanBaseValueOverride(); + const result = customLogic(100, 50); + expect(result).toBe(50); + }); + + it("should reject override to a higher value", () => { + const customLogic = onlyLowerThanBaseValueOverride(); + expect(() => customLogic(100, 150)).toThrow("Can only set to a value lower than the base value"); + }); + + it("should reject override to equal value", () => { + const customLogic = onlyLowerThanBaseValueOverride(); + expect(() => customLogic(100, 100)).toThrow("Can only set to a value lower than the base value"); + }); + + it("should throw error if base value is not a number", () => { + const customLogic = onlyLowerThanBaseValueOverride(); + expect(() => customLogic("not a number", 50)).toThrow("Unsupported type for base value for override"); + }); + + it("should throw error if new value is not a number", () => { + const customLogic = onlyLowerThanBaseValueOverride(); + expect(() => customLogic(100, "not a number")).toThrow("Unsupported type for new value for override"); + }); + }); + + describe("onlySubsetOfBaseValueOverride", () => { + it("should allow override to a subset", () => { + const customLogic = onlySubsetOfBaseValueOverride(); + const result = customLogic(["a", "b", "c"], ["a", "b"]); + expect(result).toEqual(["a", "b"]); + }); + + it("should allow override to an empty array", () => { + const customLogic = onlySubsetOfBaseValueOverride(); + const result = customLogic(["a", "b", "c"], []); + expect(result).toEqual([]); + }); + + it("should allow override with same array", () => { + const customLogic = onlySubsetOfBaseValueOverride(); + const result = customLogic(["a", "b"], ["a", "b"]); + expect(result).toEqual(["a", "b"]); + }); + + it("should reject override to a superset", () => { + const customLogic = onlySubsetOfBaseValueOverride(); + expect(() => customLogic(["a", "b"], ["a", "b", "c"])).toThrow( + "Can only override to a subset of the base value" + ); + }); + + it("should reject override with items not in base value", () => { + const customLogic = onlySubsetOfBaseValueOverride(); + expect(() => customLogic(["a", "b"], ["c"])).toThrow("Can only override to a subset of the base value"); + }); + + it("should reject override when base is empty and new is not", () => { + const customLogic = onlySubsetOfBaseValueOverride(); + expect(() => customLogic([], ["a"])).toThrow("Can only override to a subset of the base value"); + }); + + it("should allow override when both arrays are empty", () => { + const customLogic = onlySubsetOfBaseValueOverride(); + const result = customLogic([], []); + expect(result).toEqual([]); + }); + + it("should throw error if base value is not an array", () => { + const customLogic = onlySubsetOfBaseValueOverride(); + expect(() => customLogic("not an array", ["a"])).toThrow("Unsupported type for base value for override"); + }); + + it("should throw error if new value is not an array", () => { + const customLogic = onlySubsetOfBaseValueOverride(); + expect(() => customLogic(["a", "b"], "not an array")).toThrow( + "Unsupported type for new value for override" + ); + }); + }); +}); diff --git a/tests/unit/common/config/configOverrides.test.ts b/tests/unit/common/config/configOverrides.test.ts new file mode 100644 index 000000000..01a4256cd --- /dev/null +++ b/tests/unit/common/config/configOverrides.test.ts @@ -0,0 +1,459 @@ +import { describe, it, expect } from "vitest"; +import { applyConfigOverrides, getConfigMeta, nameToConfigKey } from "../../../../src/common/config/configOverrides.js"; +import { UserConfigSchema, type UserConfig } from "../../../../src/common/config/userConfig.js"; +import type { RequestContext } from "../../../../src/transports/base.js"; + +describe("configOverrides", () => { + const baseConfig: Partial = { + readOnly: false, + indexCheck: false, + idleTimeoutMs: 600_000, + notificationTimeoutMs: 540_000, + disabledTools: ["tool1"], + confirmationRequiredTools: ["drop-database"], + connectionString: "mongodb://localhost:27017", + vectorSearchDimensions: 1024, + vectorSearchSimilarityFunction: "euclidean", + embeddingsValidation: false, + previewFeatures: [], + loggers: ["disk", "mcp"], + exportTimeoutMs: 300_000, + exportCleanupIntervalMs: 120_000, + atlasTemporaryDatabaseUserLifetimeMs: 14_400_000, + allowRequestOverrides: true, + }; + + describe("helper functions", () => { + describe("nameToConfigKey", () => { + it("should convert header name to config key", () => { + expect(nameToConfigKey("header", "x-mongodb-mcp-read-only")).toBe("readOnly"); + expect(nameToConfigKey("header", "x-mongodb-mcp-idle-timeout-ms")).toBe("idleTimeoutMs"); + expect(nameToConfigKey("header", "x-mongodb-mcp-connection-string")).toBe("connectionString"); + }); + + it("should convert query parameter name to config key", () => { + expect(nameToConfigKey("query", "mongodbMcpReadOnly")).toBe("readOnly"); + expect(nameToConfigKey("query", "mongodbMcpIdleTimeoutMs")).toBe("idleTimeoutMs"); + expect(nameToConfigKey("query", "mongodbMcpConnectionString")).toBe("connectionString"); + }); + + it("should not mix up header and query parameter names", () => { + expect(nameToConfigKey("header", "mongodbMcpReadOnly")).toBeUndefined(); + expect(nameToConfigKey("query", "x-mongodb-mcp-read-only")).toBeUndefined(); + }); + + it("should return undefined for non-mcp names", () => { + expect(nameToConfigKey("header", "content-type")).toBeUndefined(); + expect(nameToConfigKey("header", "authorization")).toBeUndefined(); + expect(nameToConfigKey("query", "content")).toBeUndefined(); + }); + }); + + it("should get override behavior for config keys", () => { + expect(getConfigMeta("readOnly")?.overrideBehavior).toEqual(expect.any(Function)); + expect(getConfigMeta("disabledTools")?.overrideBehavior).toBe("merge"); + expect(getConfigMeta("apiBaseUrl")?.overrideBehavior).toBe("not-allowed"); + expect(getConfigMeta("maxBytesPerQuery")?.overrideBehavior).toBe("not-allowed"); + }); + }); + + describe("applyConfigOverrides", () => { + it("should return base config when request is undefined", () => { + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig }); + expect(result).toEqual(baseConfig); + }); + + describe("boolean edge cases", () => { + it("should parse correctly for true value", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-read-only": "true", + }, + }; + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request }); + expect(result.readOnly).toBe(true); + }); + + it("should parse correctly for false value", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-read-only": "false", + }, + }; + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request }); + expect(result.readOnly).toBe(false); + }); + + for (const value of ["True", "False", "TRUE", "FALSE", "0", "1", ""]) { + it(`should throw an error for ${value}`, () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-read-only": value, + }, + }; + expect(() => applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request })).toThrow( + `Invalid boolean value: ${value}` + ); + }); + } + }); + + it("should return base config when request has no headers or query", () => { + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request: {} }); + expect(result).toEqual(baseConfig); + }); + + describe("allowRequestOverrides", () => { + it("should not apply overrides when allowRequestOverrides is false", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-read-only": "true", + "x-mongodb-mcp-idle-timeout-ms": "300000", + }, + }; + const configWithOverridesDisabled = { + ...baseConfig, + allowRequestOverrides: false, + } as UserConfig; + expect(() => applyConfigOverrides({ baseConfig: configWithOverridesDisabled, request })).to.throw( + "Request overrides are not enabled" + ); + }); + + it("should apply overrides when allowRequestOverrides is true", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-read-only": "true", + "x-mongodb-mcp-idle-timeout-ms": "300000", + }, + }; + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request }); + // Config should be overridden + expect(result.readOnly).toBe(true); + expect(result.idleTimeoutMs).toBe(300000); + }); + + it("should not apply overrides by default when allowRequestOverrides is not set", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-read-only": "true", + }, + }; + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const { allowRequestOverrides, ...configWithoutOverridesFlag } = baseConfig; + expect(() => + applyConfigOverrides({ baseConfig: configWithoutOverridesFlag as UserConfig, request }) + ).to.throw("Request overrides are not enabled"); + }); + }); + + describe("override behavior", () => { + it("should override boolean values with override behavior", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-read-only": "true", + }, + }; + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request }); + expect(result.readOnly).toBe(true); + }); + + it("should override string values with override behavior", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-vector-search-similarity-function": "cosine", + }, + }; + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request }); + expect(result.vectorSearchSimilarityFunction).toBe("cosine"); + }); + }); + + describe("merge behavior", () => { + it("should merge array values", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-disabled-tools": "tool2,tool3", + }, + }; + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request }); + expect(result.disabledTools).toEqual(["tool1", "tool2", "tool3"]); + }); + + it("should merge multiple array fields", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-disabled-tools": "tool2", + "x-mongodb-mcp-confirmation-required-tools": "drop-collection", + "x-mongodb-mcp-preview-features": "feature1", + }, + }; + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request }); + expect(result.disabledTools).toEqual(["tool1", "tool2"]); + expect(result.confirmationRequiredTools).toEqual(["drop-database", "drop-collection"]); + // previewFeatures has enum validation - "feature1" isn't a valid value, so it gets rejected + expect(result.previewFeatures).toEqual([]); + }); + + it("should not be able to merge loggers", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-loggers": "stderr", + }, + }; + expect(() => applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request })).toThrow( + "Config key loggers is not allowed to be overridden" + ); + }); + }); + + describe("not-allowed behavior", () => { + it("should have some not-allowed fields", () => { + expect( + Object.keys(UserConfigSchema.shape).filter( + (key) => + getConfigMeta(key as keyof typeof UserConfigSchema.shape)?.overrideBehavior === + "not-allowed" + ) + ).toEqual([ + "apiBaseUrl", + "apiClientId", + "apiClientSecret", + "connectionString", + "loggers", + "logPath", + "telemetry", + "transport", + "httpPort", + "httpHost", + "httpHeaders", + "maxBytesPerQuery", + "maxDocumentsPerQuery", + "exportsPath", + "exportCleanupIntervalMs", + "allowRequestOverrides", + "dryRun", + ]); + }); + + it("should throw an error for not-allowed fields", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-api-base-url": "https://malicious.com/", + "x-mongodb-mcp-max-bytes-per-query": "999999", + "x-mongodb-mcp-max-documents-per-query": "1000", + "x-mongodb-mcp-transport": "stdio", + "x-mongodb-mcp-http-port": "9999", + }, + }; + expect(() => applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request })).toThrow( + "Config key apiBaseUrl is not allowed to be overridden" + ); + }); + }); + + describe("secret fields", () => { + it("should allow overriding secret fields with headers if they have override behavior", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-voyage-api-key": "test", + }, + }; + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request }); + expect(result.voyageApiKey).toBe("test"); + }); + + it("should not allow overriding secret fields via query params", () => { + const request: RequestContext = { + query: { + mongodbMcpVoyageApiKey: "test", + }, + }; + expect(() => applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request })).toThrow( + "Config key voyageApiKey can only be overriden with headers" + ); + }); + }); + + describe("custom overrides", () => { + it("should have certain config keys to be conditionally overridden", () => { + expect( + Object.keys(UserConfigSchema.shape) + .map((key) => [ + key, + getConfigMeta(key as keyof typeof UserConfigSchema.shape)?.overrideBehavior, + ]) + .filter(([, behavior]) => typeof behavior === "function") + .map(([key]) => key) + ).toEqual([ + "readOnly", + "indexCheck", + "idleTimeoutMs", + "notificationTimeoutMs", + "exportTimeoutMs", + "atlasTemporaryDatabaseUserLifetimeMs", + "embeddingsValidation", + "previewFeatures", + ]); + }); + + it("should allow readOnly override from false to true", () => { + const request: RequestContext = { headers: { "x-mongodb-mcp-read-only": "true" } }; + const result = applyConfigOverrides({ + baseConfig: { ...baseConfig, readOnly: false } as UserConfig, + request, + }); + expect(result.readOnly).toBe(true); + }); + + it("should throw when trying to override readOnly from true to false", () => { + const request: RequestContext = { headers: { "x-mongodb-mcp-read-only": "false" } }; + expect(() => + applyConfigOverrides({ baseConfig: { ...baseConfig, readOnly: true } as UserConfig, request }) + ).toThrow("Cannot apply override for readOnly: Can only set to true"); + }); + + it("should allow indexCheck override from false to true", () => { + const request: RequestContext = { headers: { "x-mongodb-mcp-index-check": "true" } }; + const result = applyConfigOverrides({ + baseConfig: { ...baseConfig, indexCheck: false } as UserConfig, + request, + }); + expect(result.indexCheck).toBe(true); + }); + + it("should throw when trying to override indexCheck from true to false", () => { + const request: RequestContext = { headers: { "x-mongodb-mcp-index-check": "false" } }; + expect(() => + applyConfigOverrides({ baseConfig: { ...baseConfig, indexCheck: true } as UserConfig, request }) + ).toThrow("Cannot apply override for indexCheck: Can only set to true"); + }); + + it("should allow disableEmbeddingsValidation override from true to false", () => { + const request: RequestContext = { headers: { "x-mongodb-mcp-embeddings-validation": "true" } }; + const result = applyConfigOverrides({ + baseConfig: { ...baseConfig, embeddingsValidation: true } as UserConfig, + request, + }); + expect(result.embeddingsValidation).toBe(true); + }); + + it("should throw when trying to override embeddingsValidation from false to true", () => { + const request: RequestContext = { headers: { "x-mongodb-mcp-embeddings-validation": "false" } }; + expect(() => + applyConfigOverrides({ + baseConfig: { ...baseConfig, embeddingsValidation: true } as UserConfig, + request, + }) + ).toThrow("Cannot apply override for embeddingsValidation: Can only set to true"); + }); + }); + + describe("query parameter overrides", () => { + it("should apply overrides from query parameters", () => { + const request: RequestContext = { + query: { + mongodbMcpReadOnly: "true", + mongodbMcpIdleTimeoutMs: "400000", + }, + }; + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request }); + expect(result.readOnly).toBe(true); + expect(result.idleTimeoutMs).toBe(400000); + }); + + it("should merge arrays from query parameters", () => { + const request: RequestContext = { + query: { + mongodbMcpDisabledTools: "tool2,tool3", + }, + }; + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request }); + expect(result.disabledTools).toEqual(["tool1", "tool2", "tool3"]); + }); + }); + + describe("precedence", () => { + it("should give query parameters precedence over headers", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-idle-timeout-ms": "300000", + }, + query: { + mongodbMcpIdleTimeoutMs: "500000", + }, + }; + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request }); + expect(result.idleTimeoutMs).toBe(500000); + }); + + it("should merge arrays from both headers and query", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-disabled-tools": "tool2", + }, + query: { + mongodbMcpDisabledTools: "tool3", + }, + }; + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request }); + // Query takes precedence over headers, but base + query result + expect(result.disabledTools).toEqual(["tool1", "tool3"]); + }); + }); + + describe("edge cases", () => { + it("should handle invalid numeric values gracefully", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-idle-timeout-ms": "not-a-number", + }, + }; + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request }); + expect(result.idleTimeoutMs).toBe(baseConfig.idleTimeoutMs); + }); + + it("should handle empty string values for arrays", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-disabled-tools": "", + }, + }; + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request }); + // Empty string gets filtered out by commaSeparatedToArray, resulting in [] + // Merging [] with ["tool1"] gives ["tool1"] + expect(result.disabledTools).toEqual(["tool1"]); + }); + + it("should trim whitespace in array values", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-disabled-tools": " tool2 , tool3 ", + }, + }; + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request }); + expect(result.disabledTools).toEqual(["tool1", "tool2", "tool3"]); + }); + + it("should handle case-insensitive header names", () => { + const request: RequestContext = { + headers: { + "X-MongoDB-MCP-Read-Only": "true", + }, + }; + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request }); + expect(result.readOnly).toBe(true); + }); + + it("should handle array values sent as multiple headers", () => { + const request: RequestContext = { + headers: { + "x-mongodb-mcp-disabled-tools": ["tool2", "tool3"], + }, + }; + const result = applyConfigOverrides({ baseConfig: baseConfig as UserConfig, request }); + expect(result.disabledTools).toEqual(["tool1", "tool2", "tool3"]); + }); + }); + }); +});