Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/common/connectionErrorHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export type ConnectionErrorHandled = { errorHandled: true; result: CallToolResul

export const connectionErrorHandler: ConnectionErrorHandler = (error, { availableTools, connectionState }) => {
const connectTools = availableTools
.filter((t) => t.operationType === "connect")
.filter((t) => t.operationType === "connect" && t.isEnabled())
.sort((a, b) => a.category.localeCompare(b.category)); // Sort Atlas tools before MongoDB tools

// Find what Atlas connect tools are available and suggest when the LLM should to use each. If no Atlas tools are found, return a suggestion for the MongoDB connect tool.
Expand Down
1 change: 1 addition & 0 deletions src/common/logger.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ export const LogId = {
toolExecute: mongoLogId(1_003_001),
toolExecuteFailure: mongoLogId(1_003_002),
toolDisabled: mongoLogId(1_003_003),
toolMetadataChange: mongoLogId(1_003_004),

mongodbConnectFailure: mongoLogId(1_004_001),
mongodbDisconnectFailure: mongoLogId(1_004_002),
Expand Down
104 changes: 19 additions & 85 deletions src/tools/mongodb/connect/connect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,114 +2,48 @@ import { z } from "zod";
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import { MongoDBToolBase } from "../mongodbTool.js";
import type { ToolArgs, OperationType, ToolConstructorParams } from "../../tool.js";
import assert from "assert";
import type { Server } from "../../../server.js";
import { LogId } from "../../../common/logger.js";

const disconnectedSchema = z
.object({
connectionString: z.string().describe("MongoDB connection string (in the mongodb:// or mongodb+srv:// format)"),
})
.describe("Options for connecting to MongoDB.");

const connectedSchema = z
.object({
connectionString: z
.string()
.optional()
.describe("MongoDB connection string to switch to (in the mongodb:// or mongodb+srv:// format)"),
})
.describe(
"Options for switching the current MongoDB connection. If a connection string is not provided, the connection string from the config will be used."
);

const connectedName = "switch-connection" as const;
const disconnectedName = "connect" as const;

const connectedDescription =
"Switch to a different MongoDB connection. If the user has configured a connection string or has previously called the connect tool, a connection is already established and there's no need to call this tool unless the user has explicitly requested to switch to a new instance.";
const disconnectedDescription =
"Connect to a MongoDB instance. The config resource captures if the server is already connected to a MongoDB cluster. If the user has configured a connection string or has previously called the connect tool, a connection is already established and there's no need to call this tool unless the user has explicitly requested to switch to a new MongoDB cluster.";

export class ConnectTool extends MongoDBToolBase {
public name: typeof connectedName | typeof disconnectedName = disconnectedName;
protected description: typeof connectedDescription | typeof disconnectedDescription = disconnectedDescription;
public override name = "connect";
protected override description =
"Connect to a MongoDB instance. The config resource captures if the server is already connected to a MongoDB cluster. If the user has configured a connection string or has previously called the connect tool, a connection is already established and there's no need to call this tool unless the user has explicitly requested to switch to a new MongoDB cluster.";

// Here the default is empty just to trigger registration, but we're going to override it with the correct
// schema in the register method.
protected argsShape = {
connectionString: z.string().optional(),
protected override argsShape = {
connectionString: z.string().describe("MongoDB connection string (in the mongodb:// or mongodb+srv:// format)"),
};

public operationType: OperationType = "connect";
public override operationType: OperationType = "connect";

constructor({ session, config, telemetry, elicitation }: ToolConstructorParams) {
super({ session, config, telemetry, elicitation });
session.on("connect", () => {
this.updateMetadata();
this.disable();
});

session.on("disconnect", () => {
this.updateMetadata();
this.enable();
});
}

protected async execute({ connectionString }: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
switch (this.name) {
case disconnectedName:
assert(connectionString, "Connection string is required");
break;
case connectedName:
connectionString ??= this.config.connectionString;
assert(
connectionString,
"Cannot switch to a new connection because no connection string was provided and no default connection string is configured."
);
break;
public override register(server: Server): boolean {
const registrationSuccessful = super.register(server);
/**
* When connected to mongodb we want to swap connect with
* switch-connection tool.
*/
if (registrationSuccessful && this.session.isConnectedToMongoDB) {
this.disable();
}
return registrationSuccessful;
}

protected override async execute({ connectionString }: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
await this.session.connectToMongoDB({ connectionString });
this.updateMetadata();

return {
content: [{ type: "text", text: "Successfully connected to MongoDB." }],
};
}

public register(server: Server): boolean {
if (super.register(server)) {
this.updateMetadata();
return true;
}

return false;
}

private updateMetadata(): void {
let name: string;
let description: string;
let inputSchema: z.ZodObject<z.ZodRawShape>;

if (this.session.isConnectedToMongoDB) {
name = connectedName;
description = connectedDescription;
inputSchema = connectedSchema;
} else {
name = disconnectedName;
description = disconnectedDescription;
inputSchema = disconnectedSchema;
}

this.session.logger.info({
id: LogId.updateToolMetadata,
context: "tool",
message: `Updating tool metadata to ${name}`,
});

this.update?.({
name,
description,
inputSchema,
});
}
}
58 changes: 58 additions & 0 deletions src/tools/mongodb/connect/switchConnection.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import z from "zod";
import { type CallToolResult } from "@modelcontextprotocol/sdk/types.js";

import { MongoDBToolBase } from "../mongodbTool.js";
import { type ToolArgs, type OperationType, type ToolConstructorParams } from "../../tool.js";
import type { Server } from "../../../server.js";

export class SwitchConnectionTool extends MongoDBToolBase {
public override name = "switch-connection";
protected override description =
"Switch to a different MongoDB connection. If the user has configured a connection string or has previously called the connect tool, a connection is already established and there's no need to call this tool unless the user has explicitly requested to switch to a new instance.";

protected override argsShape = {
connectionString: z
.string()
.optional()
.describe(
"MongoDB connection string to switch to (in the mongodb:// or mongodb+srv:// format). If a connection string is not provided, the connection string from the config will be used."
),
};

public override operationType: OperationType = "connect";

constructor({ session, config, telemetry, elicitation }: ToolConstructorParams) {
super({ session, config, telemetry, elicitation });
session.on("connect", () => {
this.enable();
});

session.on("disconnect", () => {
this.disable();
});
}

public override register(server: Server): boolean {
const registrationSuccessful = super.register(server);
/**
* When connected to mongodb we want to swap connect with
* switch-connection tool.
*/
if (registrationSuccessful && !this.session.isConnectedToMongoDB) {
this.disable();
}
return registrationSuccessful;
}

protected override async execute({ connectionString }: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
if (typeof connectionString !== "string") {
await this.session.connectToConfiguredConnection();
} else {
await this.session.connectToMongoDB({ connectionString });
}

return {
content: [{ type: "text", text: "Successfully connected to MongoDB." }],
};
}
}
2 changes: 2 additions & 0 deletions src/tools/mongodb/tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ import { CreateCollectionTool } from "./create/createCollection.js";
import { LogsTool } from "./metadata/logs.js";
import { ExportTool } from "./read/export.js";
import { DropIndexTool } from "./delete/dropIndex.js";
import { SwitchConnectionTool } from "./connect/switchConnection.js";

export const MongoDbTools = [
ConnectTool,
SwitchConnectionTool,
ListCollectionsTool,
ListDatabasesTool,
CollectionIndexesTool,
Expand Down
78 changes: 36 additions & 42 deletions src/tools/tool.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { z, AnyZodObject } from "zod";
import type { z } from "zod";
import { type ZodRawShape, type ZodNever } from "zod";
import type { RegisteredTool, ToolCallback } from "@modelcontextprotocol/sdk/server/mcp.js";
import type { CallToolResult, ToolAnnotations } from "@modelcontextprotocol/sdk/types.js";
Expand Down Expand Up @@ -57,6 +57,8 @@ export abstract class ToolBase {

protected abstract argsShape: ZodRawShape;

private registeredTool: RegisteredTool | undefined;

protected get annotations(): ToolAnnotations {
const annotations: ToolAnnotations = {
title: this.name,
Expand Down Expand Up @@ -168,52 +170,44 @@ export abstract class ToolBase {
}
};

server.mcpServer.tool(this.name, this.description, this.argsShape, this.annotations, callback);

// This is very similar to RegisteredTool.update, but without the bugs around the name.
// In the upstream update method, the name is captured in the closure and not updated when
// the tool name changes. This means that you only get one name update before things end up
// in a broken state.
// See https://github.com/modelcontextprotocol/typescript-sdk/issues/414 for more details.
this.update = (updates: { name?: string; description?: string; inputSchema?: AnyZodObject }): void => {
const tools = server.mcpServer["_registeredTools"] as { [toolName: string]: RegisteredTool };
const existingTool = tools[this.name];

if (!existingTool) {
this.session.logger.warning({
id: LogId.toolUpdateFailure,
context: "tool",
message: `Tool ${this.name} not found in update`,
noRedaction: true,
});
return;
}
this.registeredTool = server.mcpServer.tool(
this.name,
this.description,
this.argsShape,
this.annotations,
callback
);

existingTool.annotations = this.annotations;

if (updates.name && updates.name !== this.name) {
existingTool.annotations.title = updates.name;
delete tools[this.name];
this.name = updates.name;
tools[this.name] = existingTool;
}

if (updates.description) {
existingTool.description = updates.description;
this.description = updates.description;
}

if (updates.inputSchema) {
existingTool.inputSchema = updates.inputSchema;
}
return true;
}

server.mcpServer.sendToolListChanged();
};
public isEnabled(): boolean {
return this.registeredTool?.enabled ?? false;
}

return true;
protected disable(): void {
if (!this.registeredTool) {
this.session.logger.warning({
id: LogId.toolMetadataChange,
context: `tool - ${this.name}`,
message: "Requested disabling of tool but it was never registered",
});
return;
}
this.registeredTool.disable();
}

protected update?: (updates: { name?: string; description?: string; inputSchema?: AnyZodObject }) => void;
protected enable(): void {
if (!this.registeredTool) {
this.session.logger.warning({
id: LogId.toolMetadataChange,
context: `tool - ${this.name}`,
message: "Requested enabling of tool but it was never registered",
});
return;
}
this.registeredTool.enable();
}

// Checks if a tool is allowed to run based on the config
protected verifyAllowed(): boolean {
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/tools/mongodb/connect/connect.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ describeWithMongoDB(
[
{
name: "connectionString",
description: "MongoDB connection string to switch to (in the mongodb:// or mongodb+srv:// format)",
description:
"MongoDB connection string to switch to (in the mongodb:// or mongodb+srv:// format). If a connection string is not provided, the connection string from the config will be used.",
type: "string",
required: false,
},
Expand Down
Loading