Skip to content

Commit a4ff2e4

Browse files
committed
fix: tidy the code
Signed-off-by: yangxuan <xuan.yang@zilliz.com>
1 parent 1425194 commit a4ff2e4

File tree

7 files changed

+203
-1464
lines changed

7 files changed

+203
-1464
lines changed

tests/benchmark/README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ tests/benchmark/
1818
├── conftest.py # Mock gRPC stubs & shared fixtures
1919
├── mock_responses.py # Fake protobuf response builders
2020
├── test_search_bench.py # Search timing benchmarks
21-
├── test_query_bench.py # Query timing benchmarks
22-
├── test_hybrid_bench.py # Hybrid search timing benchmarks
2321
└── scripts/
2422
├── profile_cpu.sh # CPU profiling wrapper
2523
└── profile_memory.sh # Memory profiling wrapper
@@ -41,7 +39,7 @@ pip install -r requirements.txt
4139
pytest tests/benchmark/ --benchmark-only
4240

4341
# Run specific benchmark
44-
pytest tests/benchmark/test_search_bench.py::TestSearchBench::test_search_float32 --benchmark-only
42+
pytest tests/benchmark/test_search_bench.py::TestSearchBench::test_search_float32_varying_output_fields --benchmark-only
4543

4644
# Save baseline for comparison
4745
pytest tests/benchmark/ --benchmark-only --benchmark-save=baseline

tests/benchmark/conftest.py

Lines changed: 55 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,105 +1,85 @@
11
from unittest.mock import MagicMock, patch
2+
23
import pytest
4+
from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusClient, StructFieldSchema
5+
from pymilvus.grpc_gen import common_pb2, milvus_pb2, schema_pb2
36

4-
from pymilvus import MilvusClient
57
from . import mock_responses
6-
from pymilvus.grpc_gen import common_pb2, milvus_pb2
78

89

9-
@pytest.fixture
10-
def mock_search_stub():
11-
def _mock_search(request, timeout=None, metadata=None):
12-
return mock_responses.create_search_results(
13-
num_queries=1,
14-
top_k=10,
15-
output_fields=["id", "age", "score", "name"]
16-
)
17-
return _mock_search
10+
def setup_search_mock(client, mock_fn):
11+
client._get_connection()._stub.Search = MagicMock(side_effect=mock_fn)
1812

1913

20-
@pytest.fixture
21-
def mock_query_stub():
22-
def _mock_query(request, timeout=None, metadata=None):
23-
return mock_responses.create_query_results(
24-
num_rows=100,
25-
output_fields=["id", "age", "score", "name", "active", "metadata"]
26-
)
27-
return _mock_query
14+
def setup_query_mock(client, mock_fn):
15+
client._get_connection()._stub.Query = MagicMock(side_effect=mock_fn)
16+
17+
18+
def setup_hybrid_search_mock(client, mock_fn):
19+
client._get_connection()._stub.HybridSearch = MagicMock(side_effect=mock_fn)
20+
21+
22+
def get_default_test_schema() -> CollectionSchema:
23+
schema = MilvusClient.create_schema()
24+
schema.add_field(field_name='id', datatype=DataType.INT64, is_primary=True)
25+
schema.add_field(field_name='embedding', datatype=DataType.FLOAT_VECTOR, dim=128)
26+
schema.add_field(field_name='name', datatype=DataType.VARCHAR, max_length=100)
27+
schema.add_field(field_name='bool_field', datatype=DataType.BOOL)
28+
schema.add_field(field_name='int8_field', datatype=DataType.INT8)
29+
schema.add_field(field_name='int16_field', datatype=DataType.INT16)
30+
schema.add_field(field_name='int32_field', datatype=DataType.INT32)
31+
schema.add_field(field_name='age', datatype=DataType.INT32)
32+
schema.add_field(field_name='float_field', datatype=DataType.FLOAT)
33+
schema.add_field(field_name='score', datatype=DataType.FLOAT)
34+
schema.add_field(field_name='double_field', datatype=DataType.DOUBLE)
35+
schema.add_field(field_name='varchar_field', datatype=DataType.VARCHAR, max_length=100)
36+
schema.add_field(field_name='json_field', datatype=DataType.JSON)
37+
schema.add_field(field_name='array_field', datatype=DataType.ARRAY, element_type=DataType.INT64, max_capacity=10)
38+
schema.add_field(field_name='geometry_field', datatype=DataType.GEOMETRY)
39+
schema.add_field(field_name='timestamptz_field', datatype=DataType.TIMESTAMPTZ)
40+
schema.add_field(field_name='binary_vector', datatype=DataType.BINARY_VECTOR, dim=128)
41+
schema.add_field(field_name='float16_vector', datatype=DataType.FLOAT16_VECTOR, dim=128)
42+
schema.add_field(field_name='bfloat16_vector', datatype=DataType.BFLOAT16_VECTOR, dim=128)
43+
schema.add_field(field_name='sparse_vector', datatype=DataType.SPARSE_FLOAT_VECTOR)
44+
schema.add_field(field_name='int8_vector', datatype=DataType.INT8_VECTOR, dim=128)
45+
46+
struct_schema = StructFieldSchema()
47+
struct_schema.add_field('struct_int', DataType.INT32)
48+
struct_schema.add_field('struct_str', DataType.VARCHAR, max_length=100)
49+
schema.add_field(field_name='struct_array_field', datatype=DataType.ARRAY, element_type=DataType.STRUCT, struct_schema=struct_schema, max_capacity=10)
50+
return schema
2851

2952

3053
@pytest.fixture
31-
def mocked_milvus_client(mock_search_stub, mock_query_stub):
54+
def mocked_milvus_client():
3255
with patch('grpc.insecure_channel') as mock_channel_func, \
3356
patch('grpc.secure_channel') as mock_secure_channel_func, \
3457
patch('grpc.channel_ready_future') as mock_ready_future, \
3558
patch('pymilvus.grpc_gen.milvus_pb2_grpc.MilvusServiceStub') as mock_stub_class:
36-
59+
3760
mock_channel = MagicMock()
3861
mock_channel_func.return_value = mock_channel
3962
mock_secure_channel_func.return_value = mock_channel
40-
63+
4164
mock_future = MagicMock()
4265
mock_future.result = MagicMock(return_value=None)
4366
mock_ready_future.return_value = mock_future
44-
67+
4568
mock_stub = MagicMock()
46-
47-
69+
70+
4871
mock_connect_response = milvus_pb2.ConnectResponse()
4972
mock_connect_response.status.error_code = common_pb2.ErrorCode.Success
5073
mock_connect_response.status.code = 0
5174
mock_connect_response.identifier = 12345
5275
mock_stub.Connect = MagicMock(return_value=mock_connect_response)
53-
54-
mock_stub.Search = MagicMock(side_effect=mock_search_stub)
55-
mock_stub.Query = MagicMock(side_effect=mock_query_stub)
56-
mock_stub.HybridSearch = MagicMock(side_effect=mock_search_stub)
57-
mock_stub.DescribeCollection = MagicMock(return_value=_create_describe_collection_response())
58-
76+
77+
mock_stub.Search = MagicMock()
78+
mock_stub.Query = MagicMock()
79+
mock_stub.HybridSearch = MagicMock()
80+
5981
mock_stub_class.return_value = mock_stub
60-
61-
client = MilvusClient(uri="http://localhost:19530")
62-
63-
yield client
6482

83+
client = MilvusClient()
6584

66-
def _create_describe_collection_response():
67-
from pymilvus.grpc_gen import milvus_pb2, schema_pb2, common_pb2
68-
69-
response = milvus_pb2.DescribeCollectionResponse()
70-
response.status.error_code = common_pb2.ErrorCode.Success
71-
72-
schema = response.schema
73-
schema.name = "test_collection"
74-
75-
id_field = schema.fields.add()
76-
id_field.fieldID = 1
77-
id_field.name = "id"
78-
id_field.data_type = schema_pb2.DataType.Int64
79-
id_field.is_primary_key = True
80-
81-
embedding_field = schema.fields.add()
82-
embedding_field.fieldID = 2
83-
embedding_field.name = "embedding"
84-
embedding_field.data_type = schema_pb2.DataType.FloatVector
85-
86-
dim_param = embedding_field.type_params.add()
87-
dim_param.key = "dim"
88-
dim_param.value = "128"
89-
90-
age_field = schema.fields.add()
91-
age_field.fieldID = 3
92-
age_field.name = "age"
93-
age_field.data_type = schema_pb2.DataType.Int32
94-
95-
score_field = schema.fields.add()
96-
score_field.fieldID = 4
97-
score_field.name = "score"
98-
score_field.data_type = schema_pb2.DataType.Float
99-
100-
name_field = schema.fields.add()
101-
name_field.fieldID = 5
102-
name_field.name = "name"
103-
name_field.data_type = schema_pb2.DataType.VarChar
104-
105-
return response
85+
yield client

0 commit comments

Comments
 (0)