Skip to content

Commit 604b7f4

Browse files
jac0626silas.jiang
andauthored
fix:async flush() not waiting for segments to be flushed (#3090)
The async flush() method was returning immediately after the RPC call without waiting for segments to be actually flushed, causing num_entities to return incorrect values (usually 0) right after flush(). This commit: - Adds async get_flush_state() method to check flush status - Adds async _wait_for_flushed() method to wait for flush completion - Updates flush() to wait for all segments to be flushed in parallel using asyncio.gather() - Ensures async flush() behavior matches synchronous flush() (cherry picked from commit bdbdc84) Signed-off-by: silas.jiang <silas.jiang@zilliz.com> Co-authored-by: silas.jiang <silas.jiang@zilliz.com>
1 parent e40aa7f commit 604b7f4

File tree

2 files changed

+347
-0
lines changed

2 files changed

+347
-0
lines changed

pymilvus/client/async_grpc_handler.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1883,13 +1883,77 @@ async def transfer_replica(
18831883
)
18841884
check_status(resp)
18851885

1886+
@retry_on_rpc_failure()
1887+
async def get_flush_state(
1888+
self,
1889+
segment_ids: List[int],
1890+
collection_name: str,
1891+
flush_ts: int,
1892+
timeout: Optional[float] = None,
1893+
**kwargs,
1894+
) -> bool:
1895+
"""Get the flush state for given segments."""
1896+
await self.ensure_channel_ready()
1897+
req = Prepare.get_flush_state_request(segment_ids, collection_name, flush_ts)
1898+
response = await self._async_stub.GetFlushState(
1899+
req, timeout=timeout, metadata=_api_level_md(**kwargs)
1900+
)
1901+
check_status(response.status)
1902+
return response.flushed
1903+
1904+
async def _wait_for_flushed(
1905+
self,
1906+
segment_ids: List[int],
1907+
collection_name: str,
1908+
flush_ts: int,
1909+
timeout: Optional[float] = None,
1910+
**kwargs,
1911+
):
1912+
"""Wait for segments to be flushed."""
1913+
flush_ret = False
1914+
start = time.time()
1915+
while not flush_ret:
1916+
flush_ret = await self.get_flush_state(
1917+
segment_ids, collection_name, flush_ts, timeout, **kwargs
1918+
)
1919+
end = time.time()
1920+
if timeout is not None and end - start > timeout:
1921+
raise MilvusException(
1922+
message=f"wait for flush timeout, collection: {collection_name}, flusht_ts: {flush_ts}"
1923+
)
1924+
1925+
if not flush_ret:
1926+
await asyncio.sleep(0.5)
1927+
18861928
@retry_on_rpc_failure()
18871929
async def flush(self, collection_names: List[str], timeout: Optional[float] = None, **kwargs):
1930+
if collection_names in (None, []) or not isinstance(collection_names, list):
1931+
raise ParamError(message="Collection name list can not be None or empty")
1932+
1933+
check_pass_param(timeout=timeout)
1934+
for name in collection_names:
1935+
check_pass_param(collection_name=name)
1936+
18881937
req = Prepare.flush_param(collection_names)
18891938
response = await self._async_stub.Flush(
18901939
req, timeout=timeout, metadata=_api_level_md(**kwargs)
18911940
)
18921941
check_status(response.status)
1942+
1943+
# Wait for all segments to be flushed in parallel
1944+
if collection_names:
1945+
await asyncio.gather(
1946+
*(
1947+
self._wait_for_flushed(
1948+
response.coll_segIDs[collection_name].data,
1949+
collection_name,
1950+
response.coll_flush_ts[collection_name],
1951+
timeout=timeout,
1952+
)
1953+
for collection_name in collection_names
1954+
)
1955+
)
1956+
18931957
return response
18941958

18951959
@retry_on_rpc_failure()

tests/test_async_flush.py

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
from unittest.mock import AsyncMock, MagicMock, patch
2+
3+
import pytest
4+
from pymilvus.client.async_grpc_handler import AsyncGrpcHandler
5+
from pymilvus.exceptions import MilvusException
6+
7+
8+
class TestAsyncFlush:
9+
"""Test cases for async flush functionality"""
10+
11+
@pytest.mark.asyncio
12+
async def test_flush_waits_for_segments_to_be_flushed(self) -> None:
13+
"""
14+
Test that async flush() waits for all segments to be flushed before returning.
15+
16+
This test verifies the fix for the bug where async flush() would return
17+
immediately after the RPC call, without waiting for segments to be actually flushed.
18+
"""
19+
# Setup mock channel and stub
20+
mock_channel = AsyncMock()
21+
mock_channel.channel_ready = AsyncMock()
22+
mock_channel.close = AsyncMock()
23+
mock_channel._unary_unary_interceptors = []
24+
25+
# Create handler with mocked channel
26+
handler = AsyncGrpcHandler(channel=mock_channel)
27+
handler._is_channel_ready = True
28+
29+
# Mock the async stub
30+
mock_stub = AsyncMock()
31+
handler._async_stub = mock_stub
32+
33+
# Create mock flush response with segment IDs and flush timestamp
34+
mock_flush_response = MagicMock()
35+
mock_flush_status = MagicMock()
36+
mock_flush_status.code = 0
37+
mock_flush_status.error_code = 0
38+
mock_flush_status.reason = ""
39+
mock_flush_response.status = mock_flush_status
40+
41+
# Mock collection segment IDs and flush timestamp
42+
mock_seg_ids = MagicMock()
43+
mock_seg_ids.data = [1, 2, 3] # Segment IDs
44+
mock_flush_ts = 12345 # Flush timestamp
45+
46+
mock_flush_response.coll_segIDs = {"test_collection": mock_seg_ids}
47+
mock_flush_response.coll_flush_ts = {"test_collection": mock_flush_ts}
48+
49+
mock_stub.Flush = AsyncMock(return_value=mock_flush_response)
50+
51+
# Mock get_flush_state to return False first, then True (simulating flush completion)
52+
call_count = 0
53+
async def mock_get_flush_state(*args, **kwargs):
54+
nonlocal call_count
55+
call_count += 1
56+
# Return False first time (not flushed), True second time (flushed)
57+
return call_count > 1
58+
59+
handler.get_flush_state = AsyncMock(side_effect=mock_get_flush_state)
60+
61+
# Mock Prepare.flush_param
62+
with patch('pymilvus.client.async_grpc_handler.Prepare') as mock_prepare, \
63+
patch('pymilvus.client.async_grpc_handler.check_pass_param'), \
64+
patch('pymilvus.client.async_grpc_handler.check_status'), \
65+
patch('pymilvus.client.async_grpc_handler._api_level_md', return_value={}):
66+
mock_prepare.flush_param.return_value = MagicMock()
67+
68+
# Call flush
69+
result = await handler.flush(["test_collection"], timeout=10)
70+
71+
# Verify Flush RPC was called
72+
mock_stub.Flush.assert_called_once()
73+
74+
# Verify get_flush_state was called (waiting for flush to complete)
75+
assert handler.get_flush_state.call_count >= 1, \
76+
"get_flush_state should be called to wait for segments to be flushed"
77+
78+
# Verify the response is returned
79+
assert result == mock_flush_response
80+
81+
@pytest.mark.asyncio
82+
async def test_flush_waits_for_multiple_collections(self) -> None:
83+
"""
84+
Test that async flush() waits for all collections' segments to be flushed.
85+
"""
86+
# Setup mock channel and stub
87+
mock_channel = AsyncMock()
88+
mock_channel.channel_ready = AsyncMock()
89+
mock_channel.close = AsyncMock()
90+
mock_channel._unary_unary_interceptors = []
91+
92+
handler = AsyncGrpcHandler(channel=mock_channel)
93+
handler._is_channel_ready = True
94+
95+
mock_stub = AsyncMock()
96+
handler._async_stub = mock_stub
97+
98+
# Create mock flush response for multiple collections
99+
mock_flush_response = MagicMock()
100+
mock_flush_status = MagicMock()
101+
mock_flush_status.code = 0
102+
mock_flush_status.error_code = 0
103+
mock_flush_status.reason = ""
104+
mock_flush_response.status = mock_flush_status
105+
106+
# Mock segment IDs and flush timestamps for two collections
107+
mock_seg_ids_1 = MagicMock()
108+
mock_seg_ids_1.data = [1, 2]
109+
mock_seg_ids_2 = MagicMock()
110+
mock_seg_ids_2.data = [3, 4]
111+
112+
mock_flush_response.coll_segIDs = {
113+
"collection1": mock_seg_ids_1,
114+
"collection2": mock_seg_ids_2
115+
}
116+
mock_flush_response.coll_flush_ts = {
117+
"collection1": 12345,
118+
"collection2": 12346
119+
}
120+
121+
mock_stub.Flush = AsyncMock(return_value=mock_flush_response)
122+
123+
# Track which collections were checked
124+
checked_collections = []
125+
126+
async def mock_get_flush_state(segment_ids, collection_name, flush_ts, timeout=None, **kwargs):
127+
checked_collections.append(collection_name)
128+
return True # Always return True (already flushed)
129+
130+
handler.get_flush_state = AsyncMock(side_effect=mock_get_flush_state)
131+
132+
with patch('pymilvus.client.async_grpc_handler.Prepare') as mock_prepare, \
133+
patch('pymilvus.client.async_grpc_handler.check_pass_param'), \
134+
patch('pymilvus.client.async_grpc_handler.check_status'), \
135+
patch('pymilvus.client.async_grpc_handler._api_level_md', return_value={}):
136+
mock_prepare.flush_param.return_value = MagicMock()
137+
138+
await handler.flush(["collection1", "collection2"], timeout=10)
139+
140+
# Verify both collections were checked
141+
assert "collection1" in checked_collections, \
142+
"collection1 should be checked for flush completion"
143+
assert "collection2" in checked_collections, \
144+
"collection2 should be checked for flush completion"
145+
assert handler.get_flush_state.call_count == 2, \
146+
"get_flush_state should be called for each collection"
147+
148+
@pytest.mark.asyncio
149+
async def test_flush_timeout(self) -> None:
150+
"""
151+
Test that async flush() raises timeout exception when segments don't flush in time.
152+
"""
153+
# Setup mock channel and stub
154+
mock_channel = AsyncMock()
155+
mock_channel.channel_ready = AsyncMock()
156+
mock_channel.close = AsyncMock()
157+
mock_channel._unary_unary_interceptors = []
158+
159+
handler = AsyncGrpcHandler(channel=mock_channel)
160+
handler._is_channel_ready = True
161+
162+
mock_stub = AsyncMock()
163+
handler._async_stub = mock_stub
164+
165+
# Create mock flush response
166+
mock_flush_response = MagicMock()
167+
mock_flush_status = MagicMock()
168+
mock_flush_status.code = 0
169+
mock_flush_status.error_code = 0
170+
mock_flush_status.reason = ""
171+
mock_flush_response.status = mock_flush_status
172+
173+
mock_seg_ids = MagicMock()
174+
mock_seg_ids.data = [1, 2, 3]
175+
176+
mock_flush_response.coll_segIDs = {"test_collection": mock_seg_ids}
177+
mock_flush_response.coll_flush_ts = {"test_collection": 12345}
178+
179+
mock_stub.Flush = AsyncMock(return_value=mock_flush_response)
180+
181+
# Mock get_flush_state to always return False (never flushed)
182+
handler.get_flush_state = AsyncMock(return_value=False)
183+
184+
# Mock time to simulate timeout
185+
import time
186+
original_time = time.time
187+
start_time = 1000.0
188+
current_time = start_time
189+
190+
def mock_time():
191+
nonlocal current_time
192+
# Increment time by 0.6 seconds each call to exceed timeout
193+
current_time += 0.6
194+
return current_time
195+
196+
with patch('pymilvus.client.async_grpc_handler.Prepare') as mock_prepare, \
197+
patch('pymilvus.client.async_grpc_handler.check_pass_param'), \
198+
patch('pymilvus.client.async_grpc_handler.check_status'), \
199+
patch('pymilvus.client.async_grpc_handler._api_level_md', return_value={}), \
200+
patch('pymilvus.client.async_grpc_handler.time.time', side_effect=mock_time):
201+
mock_prepare.flush_param.return_value = MagicMock()
202+
203+
# Call flush with short timeout
204+
with pytest.raises(MilvusException) as exc_info:
205+
await handler.flush(["test_collection"], timeout=0.5)
206+
207+
# Verify timeout exception was raised
208+
assert "wait for flush timeout" in str(exc_info.value).lower(), \
209+
"Should raise timeout exception when flush takes too long"
210+
211+
@pytest.mark.asyncio
212+
async def test_flush_parameter_validation(self) -> None:
213+
"""
214+
Test that async flush() validates parameters correctly.
215+
"""
216+
# Setup mock channel
217+
mock_channel = AsyncMock()
218+
mock_channel.channel_ready = AsyncMock()
219+
mock_channel.close = AsyncMock()
220+
mock_channel._unary_unary_interceptors = []
221+
222+
handler = AsyncGrpcHandler(channel=mock_channel)
223+
handler._is_channel_ready = True
224+
225+
# Test empty collection names
226+
with pytest.raises(Exception): # Should raise ParamError
227+
await handler.flush([])
228+
229+
# Test None collection names
230+
with pytest.raises(Exception): # Should raise ParamError
231+
await handler.flush(None) # type: ignore
232+
233+
# Test invalid type
234+
with pytest.raises(Exception): # Should raise ParamError
235+
await handler.flush("not_a_list") # type: ignore
236+
237+
@pytest.mark.asyncio
238+
async def test_get_flush_state(self) -> None:
239+
"""
240+
Test the get_flush_state() method.
241+
"""
242+
# Setup mock channel and stub
243+
mock_channel = AsyncMock()
244+
mock_channel.channel_ready = AsyncMock()
245+
mock_channel.close = AsyncMock()
246+
mock_channel._unary_unary_interceptors = []
247+
248+
handler = AsyncGrpcHandler(channel=mock_channel)
249+
handler._is_channel_ready = True
250+
251+
mock_stub = AsyncMock()
252+
handler._async_stub = mock_stub
253+
254+
# Create mock GetFlushState response
255+
mock_response = MagicMock()
256+
mock_status = MagicMock()
257+
mock_status.code = 0
258+
mock_status.error_code = 0
259+
mock_status.reason = ""
260+
mock_response.status = mock_status
261+
mock_response.flushed = True
262+
263+
mock_stub.GetFlushState = AsyncMock(return_value=mock_response)
264+
265+
with patch('pymilvus.client.async_grpc_handler.Prepare') as mock_prepare, \
266+
patch('pymilvus.client.async_grpc_handler.check_status'), \
267+
patch('pymilvus.client.async_grpc_handler._api_level_md', return_value={}):
268+
mock_prepare.get_flush_state_request.return_value = MagicMock()
269+
270+
# Call get_flush_state
271+
result = await handler.get_flush_state(
272+
segment_ids=[1, 2, 3],
273+
collection_name="test_collection",
274+
flush_ts=12345,
275+
timeout=10
276+
)
277+
278+
# Verify GetFlushState RPC was called
279+
mock_stub.GetFlushState.assert_called_once()
280+
281+
# Verify result
282+
assert result is True, "get_flush_state should return the flushed status"
283+

0 commit comments

Comments
 (0)