Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions apis/python/src/tiledb/vector_search/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .image_resnetv2_embedding import ImageResNetV2Embedding
from .langchain_embedding import LangChainEmbedding
from .object_embedding import ObjectEmbedding
from .ollama_embedding import OllamaEmbedding
from .random_embedding import RandomEmbedding
from .sentence_transformers_embedding import SentenceTransformersEmbedding
from .soma_geneptw_embedding import SomaGenePTwEmbedding
Expand All @@ -18,4 +19,5 @@
"LangChainEmbedding",
"SomaScGPTEmbedding",
"SomaSCVIEmbedding",
"OllamaEmbedding",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Dict, Optional, OrderedDict, Sequence, Union

import numpy as np

# from tiledb.vector_search.embeddings import ObjectEmbedding


class OllamaEmbedding:
"""
Embedding functions from Ollama.

This attempts to import the embedding_class from the ollama module.
"""

def __init__(
self,
dimensions: int,
embedding_class: str = "embed", # really it's the method
embedding_kwargs: Optional[Dict] = None,
):
self.dim_num = dimensions
self.embedding_class = embedding_class
self.embedding_kwargs = embedding_kwargs

def init_kwargs(self) -> Dict:
return {
"dimensions": self.dim_num,
"embedding_class": self.embedding_class,
"embedding_kwargs": self.embedding_kwargs,
}

def dimensions(self) -> int:
return self.dim_num

def vector_type(self) -> np.dtype:
return np.float32

def load(self) -> None:
import importlib

try:
embeddings_module = importlib.import_module("ollama")
embedding_method_ = getattr(embeddings_module, self.embedding_class)
self.embedding = embedding_method_(**self.embedding_kwargs)
except ImportError as e:
print(e)

def embed(self, objects: Union[str, Sequence[str]]) -> np.ndarray:
return np.array(self.embedding(input=objects).embeddings, dtype=np.float32)
Loading