Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion datastore/providers/azurecosmosdb_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
AZCOSMOS_CONNSTR = os.environ.get("AZCOSMOS_CONNSTR")
AZCOSMOS_DATABASE_NAME = os.environ.get("AZCOSMOS_DATABASE_NAME")
AZCOSMOS_CONTAINER_NAME = os.environ.get("AZCOSMOS_CONTAINER_NAME")
AZCOSMOS_SIMILARITY = os.environ.get("AZCOSMOS_SIMILARITY", "COS")
AZCOSMOS_NUM_LISTS = os.environ.get("AZCOSMOS_NUM_LISTS", 100)
assert AZCOSMOS_API is not None
assert AZCOSMOS_CONNSTR is not None
assert AZCOSMOS_DATABASE_NAME is not None
Expand Down Expand Up @@ -201,7 +203,7 @@ def __init__(self, cosmosStore: AzureCosmosDBStoreApi):

"""
@staticmethod
async def create(num_lists, similarity) -> DataStore:
async def create(num_lists: int=AZCOSMOS_NUM_LISTS, similarity: str=AZCOSMOS_SIMILARITY) -> DataStore:

# Create underlying data store based on the API definition.
# Right now this only supports Mongo, but set up to support more.
Expand All @@ -211,6 +213,11 @@ async def create(num_lists, similarity) -> DataStore:
apiStore = MongoStoreApi(mongoClient)
else:
raise NotImplementedError
if similarity not in ["COS", "L2", "IP"]:
raise ValueError(
f"Similarity {similarity} is not supported."
"Supported similarity metrics are COS, L2, and IP."
)

await apiStore.ensure(num_lists, similarity)
store = AzureCosmosDBDataStore(apiStore)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def queries() -> List[QueryWithEmbedding]:
async def azurecosmosdb_datastore() -> DataStore:
return await AzureCosmosDBDataStore.create(num_lists=num_lists, similarity=similarity)

@pytest.mark.asyncio
async def test_invalid_similarity() -> None:
with pytest.raises(ValueError):
await AzureCosmosDBDataStore.create(num_lists=num_lists, similarity="INVALID")

@pytest.mark.asyncio
async def test_upsert(
Expand Down