Skip to content

Commit ba6a714

Browse files
Support function edit sdk
Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
1 parent afff26c commit ba6a714

File tree

7 files changed

+370
-11
lines changed

7 files changed

+370
-11
lines changed

examples/function_edit.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from pymilvus import (
2+
MilvusClient,
3+
Function, DataType, FunctionType,
4+
)
5+
6+
collection_name = "text_embedding"
7+
8+
milvus_client = MilvusClient("http://localhost:19530")
9+
10+
has_collection = milvus_client.has_collection(collection_name, timeout=5)
11+
if has_collection:
12+
milvus_client.drop_collection(collection_name)
13+
14+
schema = milvus_client.create_schema()
15+
schema.add_field("id", DataType.INT64, is_primary=True, auto_id=False)
16+
schema.add_field("document", DataType.VARCHAR, max_length=9000)
17+
schema.add_field("dense", DataType.FLOAT_VECTOR, dim=1536)
18+
19+
text_embedding_function = Function(
20+
name="openai",
21+
function_type=FunctionType.TEXTEMBEDDING,
22+
input_field_names=["document"],
23+
output_field_names="dense",
24+
params={
25+
"provider": "openai",
26+
"model_name": "text-embedding-3-small",
27+
}
28+
)
29+
30+
schema.add_function(text_embedding_function)
31+
32+
index_params = milvus_client.prepare_index_params()
33+
index_params.add_index(
34+
field_name="dense",
35+
index_name="dense_index",
36+
index_type="AUTOINDEX",
37+
metric_type="IP",
38+
)
39+
40+
ret = milvus_client.create_collection(collection_name, schema=schema, index_params=index_params, consistency_level="Strong")
41+
42+
ret = milvus_client.describe_collection(collection_name)
43+
print(ret["functions"][0])
44+
45+
text_embedding_function.params["user"] = "user123"
46+
47+
milvus_client.alter_collection_function(collection_name, "openai", text_embedding_function)
48+
49+
ret = milvus_client.describe_collection(collection_name)
50+
print(ret["functions"][0])
51+
52+
milvus_client.drop_collection_function(collection_name, "openai")
53+
54+
ret = milvus_client.describe_collection(collection_name)
55+
print(ret["functions"])
56+
57+
text_embedding_function.params["user"] = "user1234"
58+
59+
milvus_client.add_collection_function(collection_name, text_embedding_function)
60+
61+
ret = milvus_client.describe_collection(collection_name)
62+
print(ret["functions"][0])

pymilvus/client/async_grpc_handler.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,60 @@ async def add_collection_field(
13101310
)
13111311
check_status(status)
13121312

1313+
@retry_on_rpc_failure()
1314+
async def drop_collection_function(
1315+
self,
1316+
collection_name: str,
1317+
function_name: str,
1318+
timeout: Optional[float] = None,
1319+
**kwargs,
1320+
):
1321+
await self.ensure_channel_ready()
1322+
check_pass_param(collection_name=collection_name, timeout=timeout)
1323+
request = Prepare.drop_collection_function_request(collection_name, function_name)
1324+
1325+
status = await self._async_stub.DropCollectionFunction(
1326+
request, timeout=timeout, metadata=_api_level_md(**kwargs)
1327+
)
1328+
check_status(status)
1329+
1330+
@retry_on_rpc_failure()
1331+
async def add_collection_function(
1332+
self,
1333+
collection_name: str,
1334+
function: Function,
1335+
timeout: Optional[float] = None,
1336+
**kwargs,
1337+
):
1338+
await self.ensure_channel_ready()
1339+
check_pass_param(collection_name=collection_name, timeout=timeout)
1340+
request = Prepare.add_collection_function_request(collection_name, function)
1341+
1342+
status = await self._async_stub.AddCollectionFunction(
1343+
request, timeout=timeout, metadata=_api_level_md(**kwargs)
1344+
)
1345+
check_status(status)
1346+
1347+
@retry_on_rpc_failure()
1348+
async def alter_collection_function(
1349+
self,
1350+
collection_name: str,
1351+
function_name: str,
1352+
function: Function,
1353+
timeout: Optional[float] = None,
1354+
**kwargs,
1355+
):
1356+
await self.ensure_channel_ready()
1357+
check_pass_param(collection_name=collection_name, timeout=timeout)
1358+
request = Prepare.alter_collection_function_request(
1359+
collection_name, function_name, function
1360+
)
1361+
1362+
status = await self._async_stub.AlterCollectionFunction(
1363+
request, timeout=timeout, metadata=_api_level_md(**kwargs)
1364+
)
1365+
check_status(status)
1366+
13131367
@retry_on_rpc_failure()
13141368
async def list_indexes(self, collection_name: str, timeout: Optional[float] = None, **kwargs):
13151369
await self.ensure_channel_ready()

pymilvus/client/grpc_handler.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,57 @@ def add_collection_field(
346346
)
347347
check_status(status)
348348

349+
@retry_on_rpc_failure()
350+
def drop_collection_function(
351+
self,
352+
collection_name: str,
353+
function_name: str,
354+
timeout: Optional[float] = None,
355+
**kwargs,
356+
):
357+
check_pass_param(collection_name=collection_name, timeout=timeout)
358+
request = Prepare.drop_collection_function_request(collection_name, function_name)
359+
360+
status = self._stub.DropCollectionFunction(
361+
request, timeout=timeout, metadata=_api_level_md(**kwargs)
362+
)
363+
check_status(status)
364+
365+
@retry_on_rpc_failure()
366+
def add_collection_function(
367+
self,
368+
collection_name: str,
369+
function: Function,
370+
timeout: Optional[float] = None,
371+
**kwargs,
372+
):
373+
check_pass_param(collection_name=collection_name, timeout=timeout)
374+
request = Prepare.add_collection_function_request(collection_name, function)
375+
376+
status = self._stub.AddCollectionFunction(
377+
request, timeout=timeout, metadata=_api_level_md(**kwargs)
378+
)
379+
check_status(status)
380+
381+
@retry_on_rpc_failure()
382+
def alter_collection_function(
383+
self,
384+
collection_name: str,
385+
function_name: str,
386+
function: Function,
387+
timeout: Optional[float] = None,
388+
**kwargs,
389+
):
390+
check_pass_param(collection_name=collection_name, timeout=timeout)
391+
request = Prepare.alter_collection_function_request(
392+
collection_name, function_name, function
393+
)
394+
395+
status = self._stub.AlterCollectionFunction(
396+
request, timeout=timeout, metadata=_api_level_md(**kwargs)
397+
)
398+
check_status(status)
399+
349400
@retry_on_rpc_failure()
350401
def alter_collection_properties(
351402
self, collection_name: str, properties: List, timeout: Optional[float] = None, **kwargs

pymilvus/client/prepare.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -241,16 +241,7 @@ def get_schema_from_collection_schema(
241241
schema.struct_array_fields.append(struct_schema)
242242

243243
for f in fields.functions:
244-
function_schema = schema_types.FunctionSchema(
245-
name=f.name,
246-
description=f.description,
247-
type=f.type,
248-
input_field_names=f.input_field_names,
249-
output_field_names=f.output_field_names,
250-
)
251-
for k, v in f.params.items():
252-
kv_pair = common_types.KeyValuePair(key=str(k), value=str(v))
253-
function_schema.params.append(kv_pair)
244+
function_schema = cls.convert_function_to_function_schema(f)
254245
schema.functions.append(function_schema)
255246

256247
return schema
@@ -363,6 +354,34 @@ def get_schema(
363354
def drop_collection_request(cls, collection_name: str) -> milvus_types.DropCollectionRequest:
364355
return milvus_types.DropCollectionRequest(collection_name=collection_name)
365356

357+
@classmethod
358+
def drop_collection_function_request(
359+
cls, collection_name: str, function_name: str
360+
) -> milvus_types.DropCollectionFunctionRequest:
361+
return milvus_types.DropCollectionFunctionRequest(
362+
collection_name=collection_name, function_name=function_name
363+
)
364+
365+
@classmethod
366+
def add_collection_function_request(
367+
cls, collection_name: str, f: Function
368+
) -> milvus_types.AddCollectionFunctionRequest:
369+
function_schema = cls.convert_function_to_function_schema(f)
370+
return milvus_types.AddCollectionFunctionRequest(
371+
collection_name=collection_name, functionSchema=function_schema
372+
)
373+
374+
@classmethod
375+
def alter_collection_function_request(
376+
cls, collection_name: str, function_name: str, f: Function
377+
) -> milvus_types.AlterCollectionFunctionRequest:
378+
function_schema = cls.convert_function_to_function_schema(f)
379+
return milvus_types.AlterCollectionFunctionRequest(
380+
collection_name=collection_name,
381+
function_name=function_name,
382+
functionSchema=function_schema,
383+
)
384+
366385
@classmethod
367386
def add_collection_field_request(
368387
cls,
@@ -2424,3 +2443,17 @@ def update_replicate_configuration_request(
24242443
return milvus_types.UpdateReplicateConfigurationRequest(
24252444
replicate_configuration=replicate_configuration
24262445
)
2446+
2447+
@staticmethod
2448+
def convert_function_to_function_schema(f: Function) -> schema_types.FunctionSchema:
2449+
function_schema = schema_types.FunctionSchema(
2450+
name=f.name,
2451+
description=f.description,
2452+
type=f.type,
2453+
input_field_names=f.input_field_names,
2454+
output_field_names=f.output_field_names,
2455+
)
2456+
for k, v in f.params.items():
2457+
kv_pair = common_types.KeyValuePair(key=str(k), value=str(v))
2458+
function_schema.params.append(kv_pair)
2459+
return function_schema

pymilvus/client/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,3 +536,4 @@ def validate_iso_timestamp(s: str) -> bool:
536536
return False
537537
else:
538538
return True
539+

pymilvus/milvus_client/async_milvus_client.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
PrimaryKeyException,
1616
)
1717
from pymilvus.orm import utility
18-
from pymilvus.orm.collection import CollectionSchema
18+
from pymilvus.orm.collection import CollectionSchema, Function
1919
from pymilvus.orm.connections import connections
2020
from pymilvus.orm.schema import FieldSchema
2121
from pymilvus.orm.types import DataType
@@ -683,6 +683,85 @@ async def add_collection_field(
683683
**kwargs,
684684
)
685685

686+
async def add_collection_function(
687+
self, collection_name: str, function: Function, timeout: Optional[float] = None, **kwargs
688+
):
689+
"""Add a new function to the collection.
690+
691+
Args:
692+
collection_name(``string``): The name of collection.
693+
function(``Function``): The function schema.
694+
timeout (``float``, optional): A duration of time in seconds to allow for the RPC.
695+
If timeout is set to None, the client keeps waiting until the server
696+
responds or an error occurs.
697+
**kwargs (``dict``): Optional field params
698+
699+
Raises:
700+
MilvusException: If anything goes wrong
701+
"""
702+
conn = self._get_connection()
703+
await conn.add_collection_function(
704+
collection_name,
705+
function,
706+
timeout=timeout,
707+
**kwargs,
708+
)
709+
710+
async def alter_collection_function(
711+
self,
712+
collection_name: str,
713+
function_name: str,
714+
function: Function,
715+
timeout: Optional[float] = None,
716+
**kwargs,
717+
):
718+
"""Alter a function in the collection.
719+
720+
Args:
721+
collection_name(``string``): The name of collection.
722+
function_name(``string``): The function name that needs to be modified
723+
function(``Function``): The function schema.
724+
timeout (``float``, optional): A duration of time in seconds to allow for the RPC.
725+
If timeout is set to None, the client keeps waiting until the server
726+
responds or an error occurs.
727+
**kwargs (``dict``): Optional field params
728+
729+
Raises:
730+
MilvusException: If anything goes wrong
731+
"""
732+
conn = self._get_connection()
733+
await conn.alter_collection_function(
734+
collection_name,
735+
function_name,
736+
function,
737+
timeout=timeout,
738+
**kwargs,
739+
)
740+
741+
async def drop_collection_function(
742+
self, collection_name: str, function_name: str, timeout: Optional[float] = None, **kwargs
743+
):
744+
"""Drop a function from the collection.
745+
746+
Args:
747+
collection_name(``string``): The name of collection.
748+
function_name(``string``): The function name that needs to be dropped
749+
timeout (``float``, optional): A duration of time in seconds to allow for the RPC.
750+
If timeout is set to None, the client keeps waiting until the server
751+
responds or an error occurs.
752+
**kwargs (``dict``): Optional field params
753+
754+
Raises:
755+
MilvusException: If anything goes wrong
756+
"""
757+
conn = self._get_connection()
758+
await conn.drop_collection_function(
759+
collection_name,
760+
function_name,
761+
timeout=timeout,
762+
**kwargs,
763+
)
764+
686765
@classmethod
687766
def create_schema(cls, **kwargs):
688767
kwargs["check_fields"] = False # do not check fields for now

0 commit comments

Comments
 (0)