Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions src/common/connectionManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,33 @@ export interface ConnectionState {
connectedAtlasCluster?: AtlasClusterConnectionInfo;
}

export interface ConnectionStateConnected extends ConnectionState {
tag: "connected";
serviceProvider: NodeDriverServiceProvider;
export class ConnectionStateConnected implements ConnectionState {
public tag = "connected" as const;

constructor(
public serviceProvider: NodeDriverServiceProvider,
public connectionStringAuthType?: ConnectionStringAuthType,
public connectedAtlasCluster?: AtlasClusterConnectionInfo
) {}

private _isSearchSupported?: boolean;

public async isSearchSupported(): Promise<boolean> {
if (this._isSearchSupported === undefined) {
try {
const dummyDatabase = `search-index-test-db-${Date.now()}`;
const dummyCollection = `search-index-test-coll-${Date.now()}`;
// If a cluster supports search indexes, the call below will succeed
// with a cursor otherwise will throw an Error
await this.serviceProvider.getSearchIndexes(dummyDatabase, dummyCollection);
this._isSearchSupported = true;
} catch {
this._isSearchSupported = false;
}
}

return this._isSearchSupported;
}
}

export interface ConnectionStateConnecting extends ConnectionState {
Expand Down Expand Up @@ -199,12 +223,10 @@ export class MCPConnectionManager extends ConnectionManager {
});
}

return this.changeState("connection-success", {
tag: "connected",
connectedAtlasCluster: settings.atlas,
serviceProvider: await serviceProvider,
connectionStringAuthType,
});
return this.changeState(
"connection-success",
new ConnectionStateConnected(await serviceProvider, connectionStringAuthType, settings.atlas)
);
} catch (error: unknown) {
const errorReason = error instanceof Error ? error.message : `${error as string}`;
this.changeState("connection-error", {
Expand Down Expand Up @@ -270,11 +292,14 @@ export class MCPConnectionManager extends ConnectionManager {
this.currentConnectionState.tag === "connecting" &&
this.currentConnectionState.connectionStringAuthType?.startsWith("oidc")
) {
this.changeState("connection-success", {
...this.currentConnectionState,
tag: "connected",
serviceProvider: await this.currentConnectionState.serviceProvider,
});
this.changeState(
"connection-success",
new ConnectionStateConnected(
await this.currentConnectionState.serviceProvider,
this.currentConnectionState.connectionStringAuthType,
this.currentConnectionState.connectedAtlasCluster
)
);
}

this.logger.info({
Expand Down
22 changes: 9 additions & 13 deletions src/common/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ export class Session extends EventEmitter<SessionEvents> {
return this.connectionManager.currentConnectionState.tag === "connected";
}

get isConnectedToMongot(): Promise<boolean> {
const state = this.connectionManager.currentConnectionState;
if (state.tag === "connected") {
return state.isSearchSupported();
}

return Promise.resolve(false);
}

get serviceProvider(): NodeDriverServiceProvider {
if (this.isConnectedToMongoDB) {
const state = this.connectionManager.currentConnectionState as ConnectionStateConnected;
Expand All @@ -153,17 +162,4 @@ export class Session extends EventEmitter<SessionEvents> {
get connectedAtlasCluster(): AtlasClusterConnectionInfo | undefined {
return this.connectionManager.currentConnectionState.connectedAtlasCluster;
}

async isSearchIndexSupported(): Promise<boolean> {
try {
const dummyDatabase = `search-index-test-db-${Date.now()}`;
const dummyCollection = `search-index-test-coll-${Date.now()}`;
// If a cluster supports search indexes, the call below will succeed
// with a cursor otherwise will throw an Error
await this.serviceProvider.getSearchIndexes(dummyDatabase, dummyCollection);
return true;
} catch {
return false;
}
}
}
2 changes: 1 addition & 1 deletion src/resources/common/debug.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ export class DebugResource extends ReactiveResource<

switch (this.current.tag) {
case "connected": {
const searchIndexesSupported = await this.session.isSearchIndexSupported();
const searchIndexesSupported = await this.session.isConnectedToMongot;
result += `The user is connected to the MongoDB cluster${searchIndexesSupported ? " with support for search indexes" : " without any support for search indexes"}.`;
break;
}
Expand Down
132 changes: 123 additions & 9 deletions src/tools/mongodb/create/createIndex.ts
Original file line number Diff line number Diff line change
@@ -1,33 +1,147 @@
import { z } from "zod";
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js";
import type { ToolArgs, OperationType } from "../../tool.js";
import type { ToolCategory } from "../../tool.js";
import { type ToolArgs, type OperationType, FeatureFlags } from "../../tool.js";
import type { IndexDirection } from "mongodb";

const vectorSearchIndexDefinition = z.object({
type: z.literal("vectorSearch"),
fields: z
.array(
z.discriminatedUnion("type", [
z
.object({
type: z.literal("filter"),
path: z
.string()
.describe(
"Name of the field to index. For nested fields, use dot notation to specify path to embedded fields"
),
})
.strict()
.describe("Definition for a field that will be used for pre-filtering results."),
z
.object({
type: z.literal("vector"),
path: z
.string()
.describe(
"Name of the field to index. For nested fields, use dot notation to specify path to embedded fields"
),
numDimensions: z
.number()
.min(1)
.max(8192)
.describe(
"Number of vector dimensions that MongoDB Vector Search enforces at index-time and query-time"
),
similarity: z
.enum(["cosine", "euclidean", "dotProduct"])
.default("cosine")
.describe(
"Vector similarity function to use to search for top K-nearest neighbors. You can set this field only for vector-type fields."
),
quantization: z
.enum(["none", "scalar", "binary"])
.optional()
.default("none")
.describe(
"Type of automatic vector quantization for your vectors. Use this setting only if your embeddings are float or double vectors."
),
})
.strict()
.describe("Definition for a field that contains vector embeddings."),
])
)
.nonempty()
.refine((fields) => fields.some((f) => f.type === "vector"), {
message: "At least one vector field must be defined",
})
.describe(
"Definitions for the vector and filter fields to index, one definition per document. You must specify `vector` for fields that contain vector embeddings and `filter` for additional fields to filter on. At least one vector-type field definition is required."
),
});

export class CreateIndexTool extends MongoDBToolBase {
public name = "create-index";
protected description = "Create an index for a collection";
protected argsShape = {
...DbOperationArgs,
keys: z.object({}).catchall(z.custom<IndexDirection>()).describe("The index definition"),
name: z.string().optional().describe("The name of the index"),
definition: z
.array(
z.discriminatedUnion("type", [
z.object({
type: z.literal("classic"),
keys: z.object({}).catchall(z.custom<IndexDirection>()).describe("The index definition"),
}),
...(this.isFeatureFlagEnabled(FeatureFlags.VectorSearch) ? [vectorSearchIndexDefinition] : []),
])
)
.describe(
"The index definition. Use 'classic' for standard indexes and 'vectorSearch' for vector search indexes"
),
};

public operationType: OperationType = "create";

protected async execute({
database,
collection,
keys,
name,
definition: definitions,
}: ToolArgs<typeof this.argsShape>): Promise<CallToolResult> {
const provider = await this.ensureConnected();
const indexes = await provider.createIndexes(database, collection, [
{
key: keys,
name,
},
]);
let indexes: string[] = [];
const definition = definitions[0];
if (!definition) {
throw new Error("Index definition not provided. Expected one of the following: `classic`, `vectorSearch`");
}

switch (definition.type) {
case "classic":
indexes = await provider.createIndexes(database, collection, [
{
key: definition.keys,
name,
},
]);
break;
case "vectorSearch":
{
const isVectorSearchSupported = await this.session.isConnectedToMongot;
if (!isVectorSearchSupported) {
// TODO: remove hacky casts once we merge the local dev tools
const isLocalAtlasAvailable =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are checking if a tool exists in a few places already in other places, maybe we can extract this to a function and refactor?

(this.server?.tools.filter((t) => t.category === ("atlas-local" as unknown as ToolCategory))
.length ?? 0) > 0;

const CTA = isLocalAtlasAvailable ? "`atlas-local` tools" : "Atlas CLI";
return {
content: [
{
text: `The connected MongoDB deployment does not support vector search indexes. Either connect to a MongoDB Atlas cluster or use the ${CTA} to create and manage a local Atlas deployment.`,
type: "text",
},
],
isError: true,
};
}

indexes = await provider.createSearchIndexes(database, collection, [
{
name,
definition: {
fields: definition.fields,
},
type: "vectorSearch",
},
]);
}

break;
}

return {
content: [
Expand Down
2 changes: 1 addition & 1 deletion src/tools/mongodb/mongodbTool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export const DbOperationArgs = {
};

export abstract class MongoDBToolBase extends ToolBase {
private server?: Server;
protected server?: Server;
public category: ToolCategory = "mongodb";

protected async ensureConnected(): Promise<NodeDriverServiceProvider> {
Expand Down
14 changes: 14 additions & 0 deletions src/tools/tool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ export type ToolCallbackArgs<Args extends ZodRawShape> = Parameters<ToolCallback

export type ToolExecutionContext<Args extends ZodRawShape = ZodRawShape> = Parameters<ToolCallback<Args>>[1];

export const enum FeatureFlags {
VectorSearch = "vectorSearch",
}

/**
* The type of operation the tool performs. This is used when evaluating if a tool is allowed to run based on
* the config's `disabledTools` and `readOnly` settings.
Expand Down Expand Up @@ -314,6 +318,16 @@ export abstract class ToolBase {

this.telemetry.emitEvents([event]);
}

// TODO: Move this to a separate file
protected isFeatureFlagEnabled(flag: FeatureFlags): boolean {
switch (flag) {
case FeatureFlags.VectorSearch:
return this.config.voyageApiKey !== "";
default:
return false;
}
}
}

/**
Expand Down
Loading
Loading