|
1 | 1 | from unittest.mock import MagicMock, patch |
| 2 | + |
2 | 3 | import pytest |
| 4 | +from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient, StructFieldSchema |
| 5 | +from pymilvus.grpc_gen import common_pb2, milvus_pb2, schema_pb2 |
3 | 6 |
|
4 | | -from pymilvus import MilvusClient |
5 | 7 | from . import mock_responses |
6 | | -from pymilvus.grpc_gen import common_pb2, milvus_pb2 |
7 | 8 |
|
8 | 9 |
|
9 | | -@pytest.fixture |
10 | | -def mock_search_stub(): |
11 | | - def _mock_search(request, timeout=None, metadata=None): |
12 | | - return mock_responses.create_search_results( |
13 | | - num_queries=1, |
14 | | - top_k=10, |
15 | | - output_fields=["id", "age", "score", "name"] |
16 | | - ) |
17 | | - return _mock_search |
| 10 | +def setup_search_mock(client, mock_fn): |
| 11 | + client._get_connection()._stub.Search = MagicMock(side_effect=mock_fn) |
18 | 12 |
|
19 | 13 |
|
20 | | -@pytest.fixture |
21 | | -def mock_query_stub(): |
22 | | - def _mock_query(request, timeout=None, metadata=None): |
23 | | - return mock_responses.create_query_results( |
24 | | - num_rows=100, |
25 | | - output_fields=["id", "age", "score", "name", "active", "metadata"] |
26 | | - ) |
27 | | - return _mock_query |
| 14 | +def setup_query_mock(client, mock_fn): |
| 15 | + client._get_connection()._stub.Query = MagicMock(side_effect=mock_fn) |
| 16 | + |
| 17 | + |
| 18 | +def setup_hybrid_search_mock(client, mock_fn): |
| 19 | + client._get_connection()._stub.HybridSearch = MagicMock(side_effect=mock_fn) |
| 20 | + |
| 21 | + |
| 22 | +def get_default_test_schema() -> CollectionSchema: |
| 23 | + schema = MilvusClient.create_schema() |
| 24 | + schema.add_field(field_name='id', datatype=DataType.INT64, is_primary=True) |
| 25 | + schema.add_field(field_name='embedding', datatype=DataType.FLOAT_VECTOR, dim=128) |
| 26 | + schema.add_field(field_name='name', datatype=DataType.VARCHAR, max_length=100) |
| 27 | + schema.add_field(field_name='bool_field', datatype=DataType.BOOL) |
| 28 | + schema.add_field(field_name='int8_field', datatype=DataType.INT8) |
| 29 | + schema.add_field(field_name='int16_field', datatype=DataType.INT16) |
| 30 | + schema.add_field(field_name='int32_field', datatype=DataType.INT32) |
| 31 | + schema.add_field(field_name='age', datatype=DataType.INT32) |
| 32 | + schema.add_field(field_name='float_field', datatype=DataType.FLOAT) |
| 33 | + schema.add_field(field_name='score', datatype=DataType.FLOAT) |
| 34 | + schema.add_field(field_name='double_field', datatype=DataType.DOUBLE) |
| 35 | + schema.add_field(field_name='varchar_field', datatype=DataType.VARCHAR, max_length=100) |
| 36 | + schema.add_field(field_name='json_field', datatype=DataType.JSON) |
| 37 | + schema.add_field(field_name='array_field', datatype=DataType.ARRAY, element_type=DataType.INT64, max_capacity=10) |
| 38 | + schema.add_field(field_name='geometry_field', datatype=DataType.GEOMETRY) |
| 39 | + schema.add_field(field_name='timestamptz_field', datatype=DataType.TIMESTAMPTZ) |
| 40 | + schema.add_field(field_name='binary_vector', datatype=DataType.BINARY_VECTOR, dim=128) |
| 41 | + schema.add_field(field_name='float16_vector', datatype=DataType.FLOAT16_VECTOR, dim=128) |
| 42 | + schema.add_field(field_name='bfloat16_vector', datatype=DataType.BFLOAT16_VECTOR, dim=128) |
| 43 | + schema.add_field(field_name='sparse_vector', datatype=DataType.SPARSE_FLOAT_VECTOR) |
| 44 | + schema.add_field(field_name='int8_vector', datatype=DataType.INT8_VECTOR, dim=128) |
| 45 | + |
| 46 | + struct_schema = StructFieldSchema() |
| 47 | + struct_schema.add_field('struct_int', DataType.INT32) |
| 48 | + struct_schema.add_field('struct_str', DataType.VARCHAR, max_length=100) |
| 49 | + schema.add_field(field_name='struct_array_field', datatype=DataType.ARRAY, element_type=DataType.STRUCT, struct_schema=struct_schema, max_capacity=10) |
| 50 | + return schema |
28 | 51 |
|
29 | 52 |
|
30 | 53 | @pytest.fixture |
31 | | -def mocked_milvus_client(mock_search_stub, mock_query_stub): |
| 54 | +def mocked_milvus_client(): |
32 | 55 | with patch('grpc.insecure_channel') as mock_channel_func, \ |
33 | 56 | patch('grpc.secure_channel') as mock_secure_channel_func, \ |
34 | 57 | patch('grpc.channel_ready_future') as mock_ready_future, \ |
35 | 58 | patch('pymilvus.grpc_gen.milvus_pb2_grpc.MilvusServiceStub') as mock_stub_class: |
36 | | - |
| 59 | + |
37 | 60 | mock_channel = MagicMock() |
38 | 61 | mock_channel_func.return_value = mock_channel |
39 | 62 | mock_secure_channel_func.return_value = mock_channel |
40 | | - |
| 63 | + |
41 | 64 | mock_future = MagicMock() |
42 | 65 | mock_future.result = MagicMock(return_value=None) |
43 | 66 | mock_ready_future.return_value = mock_future |
44 | | - |
| 67 | + |
45 | 68 | mock_stub = MagicMock() |
46 | | - |
47 | | - |
| 69 | + |
| 70 | + |
48 | 71 | mock_connect_response = milvus_pb2.ConnectResponse() |
49 | 72 | mock_connect_response.status.error_code = common_pb2.ErrorCode.Success |
50 | 73 | mock_connect_response.status.code = 0 |
51 | 74 | mock_connect_response.identifier = 12345 |
52 | 75 | mock_stub.Connect = MagicMock(return_value=mock_connect_response) |
53 | | - |
54 | | - mock_stub.Search = MagicMock(side_effect=mock_search_stub) |
55 | | - mock_stub.Query = MagicMock(side_effect=mock_query_stub) |
56 | | - mock_stub.HybridSearch = MagicMock(side_effect=mock_search_stub) |
57 | | - mock_stub.DescribeCollection = MagicMock(return_value=_create_describe_collection_response()) |
58 | | - |
| 76 | + |
| 77 | + mock_stub.Search = MagicMock() |
| 78 | + mock_stub.Query = MagicMock() |
| 79 | + mock_stub.HybridSearch = MagicMock() |
| 80 | + |
59 | 81 | mock_stub_class.return_value = mock_stub |
60 | | - |
61 | | - client = MilvusClient(uri="http://localhost:19530") |
62 | | - |
63 | | - yield client |
64 | 82 |
|
| 83 | + client = MilvusClient() |
65 | 84 |
|
66 | | -def _create_describe_collection_response(): |
67 | | - from pymilvus.grpc_gen import milvus_pb2, schema_pb2, common_pb2 |
68 | | - |
69 | | - response = milvus_pb2.DescribeCollectionResponse() |
70 | | - response.status.error_code = common_pb2.ErrorCode.Success |
71 | | - |
72 | | - schema = response.schema |
73 | | - schema.name = "test_collection" |
74 | | - |
75 | | - id_field = schema.fields.add() |
76 | | - id_field.fieldID = 1 |
77 | | - id_field.name = "id" |
78 | | - id_field.data_type = schema_pb2.DataType.Int64 |
79 | | - id_field.is_primary_key = True |
80 | | - |
81 | | - embedding_field = schema.fields.add() |
82 | | - embedding_field.fieldID = 2 |
83 | | - embedding_field.name = "embedding" |
84 | | - embedding_field.data_type = schema_pb2.DataType.FloatVector |
85 | | - |
86 | | - dim_param = embedding_field.type_params.add() |
87 | | - dim_param.key = "dim" |
88 | | - dim_param.value = "128" |
89 | | - |
90 | | - age_field = schema.fields.add() |
91 | | - age_field.fieldID = 3 |
92 | | - age_field.name = "age" |
93 | | - age_field.data_type = schema_pb2.DataType.Int32 |
94 | | - |
95 | | - score_field = schema.fields.add() |
96 | | - score_field.fieldID = 4 |
97 | | - score_field.name = "score" |
98 | | - score_field.data_type = schema_pb2.DataType.Float |
99 | | - |
100 | | - name_field = schema.fields.add() |
101 | | - name_field.fieldID = 5 |
102 | | - name_field.name = "name" |
103 | | - name_field.data_type = schema_pb2.DataType.VarChar |
104 | | - |
105 | | - return response |
| 85 | + yield client |
0 commit comments