diff --git a/libs/providers/langchain-google-common/src/embeddings.ts b/libs/providers/langchain-google-common/src/embeddings.ts index de152c3fc8b7..8eb7fdcb05f8 100644 --- a/libs/providers/langchain-google-common/src/embeddings.ts +++ b/libs/providers/langchain-google-common/src/embeddings.ts @@ -22,6 +22,7 @@ import { AIStudioEmbeddingsRequest, GeminiPartText, VertexEmbeddingsRequest, + GoogleEmbeddingsTaskType, } from "./types.js"; class EmbeddingsConnection< @@ -139,6 +140,10 @@ export abstract class BaseGoogleEmbeddings dimensions?: number; + taskType?: GoogleEmbeddingsTaskType; + + title?: string; + private connection: EmbeddingsConnection< BaseGoogleEmbeddingsOptions, AuthOptions @@ -149,6 +154,8 @@ export abstract class BaseGoogleEmbeddings this.model = fields.model; this.dimensions = fields.dimensions ?? fields.outputDimensionality; + this.taskType = fields.taskType; + this.title = fields.title; this.connection = new EmbeddingsConnection( { ...fields, ...this }, @@ -245,9 +252,18 @@ export abstract class BaseGoogleEmbeddings // TODO: Make this configurable const chunkSize = 1; const instanceChunks: VertexEmbeddingsInstance[][] = chunkArray( - documents.map((document) => ({ - content: document, - })), + documents.map((document) => { + const instance: VertexEmbeddingsInstance = { + content: document, + }; + if (this.taskType) { + instance.taskType = this.taskType; + } + if (this.title) { + instance.title = this.title; + } + return instance; + }), chunkSize ); const parameters: VertexEmbeddingsParameters = this.buildParameters(); diff --git a/libs/providers/langchain-google-common/src/types.ts b/libs/providers/langchain-google-common/src/types.ts index 51f2a071e91e..6c7344d64009 100644 --- a/libs/providers/langchain-google-common/src/types.ts +++ b/libs/providers/langchain-google-common/src/types.ts @@ -906,6 +906,25 @@ export interface BaseGoogleEmbeddingsParams * An alias for "dimensions" */ outputDimensionality?: number; + + /** + * The intended downstream application to help the model produce better quality embeddings. + * Available task types: + * - RETRIEVAL_QUERY: Specifies the given text is a query in a search/retrieval setting. + * - RETRIEVAL_DOCUMENT: Specifies the given text is a document in a search/retrieval setting. + * - SEMANTIC_SIMILARITY: Specifies the given text will be used for Semantic Textual Similarity (STS). + * - CLASSIFICATION: Specifies that the embeddings will be used for classification. + * - CLUSTERING: Specifies that the embeddings will be used for clustering. + * - QUESTION_ANSWERING: Specifies that the query embedding will be used for question answering. + * - FACT_VERIFICATION: Specifies that the query embedding will be used for fact verification. + * - CODE_RETRIEVAL_QUERY: Specifies that the query embedding will be used for code retrieval. + */ + taskType?: GoogleEmbeddingsTaskType; + + /** + * An optional title for the text. Only applicable when taskType is RETRIEVAL_DOCUMENT. + */ + title?: string; } /** diff --git a/libs/providers/langchain-google-vertexai/src/tests/embeddings.int.test.ts b/libs/providers/langchain-google-vertexai/src/tests/embeddings.int.test.ts index 81846632b411..e04cad25f8b0 100644 --- a/libs/providers/langchain-google-vertexai/src/tests/embeddings.int.test.ts +++ b/libs/providers/langchain-google-vertexai/src/tests/embeddings.int.test.ts @@ -82,5 +82,31 @@ describe.each(testModels)( expect(typeof res[0]).toBe("number"); expect(res.length).toEqual(testDimensions); }); + + test("taskType", async () => { + const embeddings = new VertexAIEmbeddings({ + model: modelName, + location, + onFailedAttempt, + taskType: "RETRIEVAL_QUERY", + }); + const res = await embeddings.embedQuery("What is the capital of France?"); + expect(typeof res[0]).toBe("number"); + expect(res.length).toEqual(defaultOutputDimensions); + }); + + test("taskType with dimensions", async () => { + const testDimensions: number = 512; + const embeddings = new VertexAIEmbeddings({ + model: modelName, + location, + onFailedAttempt, + taskType: "SEMANTIC_SIMILARITY", + dimensions: testDimensions, + }); + const res = await embeddings.embedQuery("Hello world"); + expect(typeof res[0]).toBe("number"); + expect(res.length).toEqual(testDimensions); + }); } );