Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions libs/providers/langchain-google-common/src/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
AIStudioEmbeddingsRequest,
GeminiPartText,
VertexEmbeddingsRequest,
GoogleEmbeddingsTaskType,
} from "./types.js";

class EmbeddingsConnection<
Expand Down Expand Up @@ -139,6 +140,10 @@ export abstract class BaseGoogleEmbeddings<AuthOptions>

dimensions?: number;

taskType?: GoogleEmbeddingsTaskType;

title?: string;

private connection: EmbeddingsConnection<
BaseGoogleEmbeddingsOptions,
AuthOptions
Expand All @@ -149,6 +154,8 @@ export abstract class BaseGoogleEmbeddings<AuthOptions>

this.model = fields.model;
this.dimensions = fields.dimensions ?? fields.outputDimensionality;
this.taskType = fields.taskType;
this.title = fields.title;

this.connection = new EmbeddingsConnection(
{ ...fields, ...this },
Expand Down Expand Up @@ -245,9 +252,18 @@ export abstract class BaseGoogleEmbeddings<AuthOptions>
// TODO: Make this configurable
const chunkSize = 1;
const instanceChunks: VertexEmbeddingsInstance[][] = chunkArray(
documents.map((document) => ({
content: document,
})),
documents.map((document) => {
const instance: VertexEmbeddingsInstance = {
content: document,
};
if (this.taskType) {
instance.taskType = this.taskType;
}
if (this.title) {
instance.title = this.title;
}
return instance;
}),
chunkSize
);
const parameters: VertexEmbeddingsParameters = this.buildParameters();
Expand Down
19 changes: 19 additions & 0 deletions libs/providers/langchain-google-common/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,25 @@ export interface BaseGoogleEmbeddingsParams<AuthOptions>
* An alias for "dimensions"
*/
outputDimensionality?: number;

/**
* The intended downstream application to help the model produce better quality embeddings.
* Available task types:
* - RETRIEVAL_QUERY: Specifies the given text is a query in a search/retrieval setting.
* - RETRIEVAL_DOCUMENT: Specifies the given text is a document in a search/retrieval setting.
* - SEMANTIC_SIMILARITY: Specifies the given text will be used for Semantic Textual Similarity (STS).
* - CLASSIFICATION: Specifies that the embeddings will be used for classification.
* - CLUSTERING: Specifies that the embeddings will be used for clustering.
* - QUESTION_ANSWERING: Specifies that the query embedding will be used for question answering.
* - FACT_VERIFICATION: Specifies that the query embedding will be used for fact verification.
* - CODE_RETRIEVAL_QUERY: Specifies that the query embedding will be used for code retrieval.
*/
taskType?: GoogleEmbeddingsTaskType;

/**
* An optional title for the text. Only applicable when taskType is RETRIEVAL_DOCUMENT.
*/
title?: string;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,31 @@ describe.each(testModels)(
expect(typeof res[0]).toBe("number");
expect(res.length).toEqual(testDimensions);
});

test("taskType", async () => {
const embeddings = new VertexAIEmbeddings({
model: modelName,
location,
onFailedAttempt,
taskType: "RETRIEVAL_QUERY",
});
const res = await embeddings.embedQuery("What is the capital of France?");
expect(typeof res[0]).toBe("number");
expect(res.length).toEqual(defaultOutputDimensions);
});

test("taskType with dimensions", async () => {
const testDimensions: number = 512;
const embeddings = new VertexAIEmbeddings({
model: modelName,
location,
onFailedAttempt,
taskType: "SEMANTIC_SIMILARITY",
dimensions: testDimensions,
});
const res = await embeddings.embedQuery("Hello world");
expect(typeof res[0]).toBe("number");
expect(res.length).toEqual(testDimensions);
});
}
);