Skip to content

Commit 9016609

Browse files
authored
Support struct field for BulkWriter (#3057)
issue: #3050 Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent afa03fb commit 9016609

File tree

8 files changed

+460
-18
lines changed

8 files changed

+460
-18
lines changed
Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
import datetime
2+
import pytz
3+
import time
4+
import numpy as np
5+
from typing import List
6+
7+
from pymilvus import (
8+
MilvusClient,
9+
CollectionSchema, DataType,
10+
)
11+
12+
from pymilvus.bulk_writer import (
13+
RemoteBulkWriter, LocalBulkWriter,
14+
BulkFileType,
15+
bulk_import,
16+
get_import_progress,
17+
)
18+
19+
# minio
20+
MINIO_ADDRESS = "0.0.0.0:9000"
21+
MINIO_SECRET_KEY = "minioadmin"
22+
MINIO_ACCESS_KEY = "minioadmin"
23+
24+
# milvus
25+
HOST = '127.0.0.1'
26+
PORT = '19530'
27+
28+
COLLECTION_NAME = "for_bulkwriter"
29+
DIM = 16 # must >= 8
30+
ROW_COUNT = 10
31+
32+
client = MilvusClient(uri="http://localhost:19530", user="root", password="Milvus")
33+
print(client.get_server_version())
34+
35+
36+
def gen_float_vector(i):
37+
return [i / 4 for _ in range(DIM)]
38+
39+
40+
def gen_binary_vector(i, to_numpy_arr: bool):
41+
raw_vector = [(i + k) % 2 for k in range(DIM)]
42+
if to_numpy_arr:
43+
return np.packbits(raw_vector, axis=-1)
44+
return raw_vector
45+
46+
47+
def gen_sparse_vector(i, indices_values: bool):
48+
raw_vector = {}
49+
dim = 3
50+
if indices_values:
51+
raw_vector["indices"] = [i + k for k in range(dim)]
52+
raw_vector["values"] = [(i + k) / 8 for k in range(dim)]
53+
else:
54+
for k in range(dim):
55+
raw_vector[i + k] = (i + k) / 8
56+
return raw_vector
57+
58+
59+
def build_schema():
60+
schema = MilvusClient.create_schema(enable_dynamic_field=False)
61+
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
62+
schema.add_field(field_name="bool", datatype=DataType.BOOL)
63+
schema.add_field(field_name="int8", datatype=DataType.INT8)
64+
schema.add_field(field_name="int16", datatype=DataType.INT16)
65+
schema.add_field(field_name="int32", datatype=DataType.INT32)
66+
schema.add_field(field_name="int64", datatype=DataType.INT64)
67+
schema.add_field(field_name="float", datatype=DataType.FLOAT)
68+
schema.add_field(field_name="double", datatype=DataType.DOUBLE)
69+
schema.add_field(field_name="varchar", datatype=DataType.VARCHAR, max_length=100)
70+
schema.add_field(field_name="json", datatype=DataType.JSON)
71+
schema.add_field(field_name="timestamp", datatype=DataType.TIMESTAMPTZ)
72+
schema.add_field(field_name="geometry", datatype=DataType.GEOMETRY)
73+
74+
schema.add_field(field_name="array_bool", datatype=DataType.ARRAY, element_type=DataType.BOOL, max_capacity=10)
75+
schema.add_field(field_name="array_int8", datatype=DataType.ARRAY, element_type=DataType.INT8, max_capacity=10)
76+
schema.add_field(field_name="array_int16", datatype=DataType.ARRAY, element_type=DataType.INT16, max_capacity=10)
77+
schema.add_field(field_name="array_int32", datatype=DataType.ARRAY, element_type=DataType.INT32, max_capacity=10)
78+
schema.add_field(field_name="array_int64", datatype=DataType.ARRAY, element_type=DataType.INT64, max_capacity=10)
79+
schema.add_field(field_name="array_float", datatype=DataType.ARRAY, element_type=DataType.FLOAT, max_capacity=10)
80+
schema.add_field(field_name="array_double", datatype=DataType.ARRAY, element_type=DataType.DOUBLE, max_capacity=10)
81+
schema.add_field(field_name="array_varchar", datatype=DataType.ARRAY, element_type=DataType.VARCHAR,
82+
max_capacity=10, max_length=100)
83+
84+
schema.add_field(field_name="float_vector", datatype=DataType.FLOAT_VECTOR, dim=DIM)
85+
schema.add_field(field_name="sparse_vector", datatype=DataType.SPARSE_FLOAT_VECTOR)
86+
schema.add_field(field_name="binary_vector", datatype=DataType.BINARY_VECTOR, dim=DIM)
87+
88+
struct_schema = MilvusClient.create_struct_field_schema()
89+
struct_schema.add_field("struct_bool", DataType.BOOL)
90+
struct_schema.add_field("struct_int8", DataType.INT8)
91+
struct_schema.add_field("struct_int16", DataType.INT16)
92+
struct_schema.add_field("struct_int32", DataType.INT32)
93+
struct_schema.add_field("struct_int64", DataType.INT64)
94+
struct_schema.add_field("struct_float", DataType.FLOAT)
95+
struct_schema.add_field("struct_double", DataType.DOUBLE)
96+
struct_schema.add_field("struct_varchar", DataType.VARCHAR, max_length=100)
97+
struct_schema.add_field("struct_float_vec", DataType.FLOAT_VECTOR, dim=DIM)
98+
schema.add_field("struct_field", datatype=DataType.ARRAY, element_type=DataType.STRUCT,
99+
struct_schema=struct_schema, max_capacity=1000)
100+
schema.verify()
101+
return schema
102+
103+
104+
def build_collection(schema: CollectionSchema):
105+
index_params = client.prepare_index_params()
106+
for field in schema.fields:
107+
if (field.dtype in [DataType.FLOAT_VECTOR, DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR]):
108+
index_params.add_index(field_name=field.name,
109+
index_type="AUTOINDEX",
110+
metric_type="L2")
111+
elif field.dtype == DataType.BINARY_VECTOR:
112+
index_params.add_index(field_name=field.name,
113+
index_type="AUTOINDEX",
114+
metric_type="HAMMING")
115+
elif field.dtype == DataType.SPARSE_FLOAT_VECTOR:
116+
index_params.add_index(field_name=field.name,
117+
index_type="SPARSE_INVERTED_INDEX",
118+
metric_type="IP")
119+
120+
for struct_field in schema.struct_fields:
121+
for field in struct_field.fields:
122+
if (field.dtype == DataType.FLOAT_VECTOR):
123+
index_params.add_index(field_name=f"{struct_field.name}[{field.name}]",
124+
index_name=f"{struct_field.name}_{field.name}",
125+
index_type="HNSW",
126+
metric_type="MAX_SIM_COSINE")
127+
128+
print(f"Drop collection: {COLLECTION_NAME}")
129+
client.drop_collection(collection_name=COLLECTION_NAME)
130+
client.create_collection(
131+
collection_name=COLLECTION_NAME,
132+
schema=schema,
133+
index_params=index_params,
134+
consistency_level="Strong",
135+
)
136+
print(f"Collection created: {COLLECTION_NAME}")
137+
print(client.describe_collection(collection_name=COLLECTION_NAME))
138+
139+
140+
def gen_row(i):
141+
shanghai_tz = pytz.timezone("Asia/Shanghai")
142+
row = {
143+
"id": i,
144+
"float_vector": gen_float_vector(i),
145+
"sparse_vector": gen_sparse_vector(i, False if i % 2 == 0 else True),
146+
"binary_vector": gen_binary_vector(i, False if i % 2 == 0 else True),
147+
"bool": True,
148+
"int8": i % 128,
149+
"int16": i % 32768,
150+
"int32": i,
151+
"int64": i,
152+
"float": i / 4,
153+
"double": i / 3,
154+
"varchar": f"varchar_{i}",
155+
"json": {"dummy": i},
156+
"timestamp": shanghai_tz.localize(
157+
datetime.datetime(2025, 1, 1, 0, 0, 0) + datetime.timedelta(days=i)
158+
).isoformat(),
159+
"geometry": f"POINT ({i} {i})",
160+
161+
"array_bool": [True if (i + k) % 2 == 0 else False for k in range(4)],
162+
"array_int8": [(i + k) % 128 for k in range(4)],
163+
"array_int16": [(i + k) % 32768 for k in range(4)],
164+
"array_int32": [(i + k) + 1000 for k in range(4)],
165+
"array_int64": [(i + k) + 100 for k in range(5)],
166+
"array_float": [(i + k) / 4 for k in range(5)],
167+
"array_double": [(i + k) / 3 for k in range(5)],
168+
"array_varchar": [f"element_{i + k}" for k in range(5)],
169+
170+
"struct_field": [
171+
{
172+
"struct_bool": True,
173+
"struct_int8": i % 128,
174+
"struct_int16": i % 32768,
175+
"struct_int32": i,
176+
"struct_int64": i,
177+
"struct_float": i / 4,
178+
"struct_double": i / 3,
179+
"struct_varchar": f"aaa_{i}",
180+
"struct_float_vec": gen_float_vector(i)
181+
},
182+
{
183+
"struct_bool": False,
184+
"struct_int8": -(i % 128),
185+
"struct_int16": -(i % 32768),
186+
"struct_int32": -i,
187+
"struct_int64": -i,
188+
"struct_float": -i / 4,
189+
"struct_double": -i / 3,
190+
"struct_varchar": f"aaa_{i * 1000}",
191+
"struct_float_vec": gen_float_vector(i)
192+
},
193+
],
194+
}
195+
return row
196+
197+
198+
def bulk_writer(writer):
199+
for i in range(ROW_COUNT):
200+
row = gen_row(i)
201+
print(row)
202+
writer.append_row(row)
203+
if ((i + 1) % 1000 == 0) or (i == ROW_COUNT - 1):
204+
print(f"{i + 1} rows appends")
205+
206+
print(f"{writer.total_row_count} rows appends")
207+
print(f"{writer.buffer_row_count} rows in buffer not flushed")
208+
writer.commit()
209+
batch_files = writer.batch_files
210+
print(f"Remote writer done! output remote files: {batch_files}")
211+
return batch_files
212+
213+
214+
def remote_writer(schema: CollectionSchema, file_type: BulkFileType):
215+
print(f"\n===================== remote writer ({file_type.name}) ====================")
216+
with RemoteBulkWriter(
217+
schema=schema,
218+
remote_path="bulk_data",
219+
connect_param=RemoteBulkWriter.S3ConnectParam(
220+
endpoint=MINIO_ADDRESS,
221+
access_key=MINIO_ACCESS_KEY,
222+
secret_key=MINIO_SECRET_KEY,
223+
bucket_name="a-bucket",
224+
),
225+
segment_size=512 * 1024 * 1024,
226+
file_type=file_type,
227+
) as writer:
228+
return bulk_writer(writer)
229+
230+
231+
def call_bulk_import(batch_files: List[List[str]]):
232+
url = f"http://{HOST}:{PORT}"
233+
234+
print(f"\n===================== import files to milvus ====================")
235+
resp = bulk_import(
236+
url=url,
237+
collection_name=COLLECTION_NAME,
238+
files=batch_files,
239+
)
240+
print(resp.json())
241+
job_id = resp.json()['data']['jobId']
242+
print(f"Create a bulk_import job, job id: {job_id}")
243+
244+
while True:
245+
print("Wait 2 second to check bulk_import job state...")
246+
time.sleep(2)
247+
248+
resp = get_import_progress(
249+
url=url,
250+
job_id=job_id,
251+
)
252+
253+
state = resp.json()['data']['state']
254+
progress = resp.json()['data']['progress']
255+
if state == "Importing":
256+
print(f"The job {job_id} is importing... {progress}%")
257+
continue
258+
if state == "Failed":
259+
reason = resp.json()['data']['reason']
260+
print(f"The job {job_id} failed, reason: {reason}")
261+
break
262+
if state == "Completed" and progress == 100:
263+
print(f"The job {job_id} completed")
264+
break
265+
266+
267+
def local_writer(schema: CollectionSchema, file_type: BulkFileType):
268+
print(f"\n===================== local writer ({file_type.name}) ====================")
269+
writer = LocalBulkWriter(
270+
schema=schema,
271+
local_path="./" + file_type.name,
272+
chunk_size=16 * 1024 * 1024,
273+
file_type=file_type
274+
)
275+
return bulk_writer(writer)
276+
277+
278+
def verify_imported_data():
279+
# refresh_load() ensure the import data is loaded
280+
client.refresh_load(collection_name=COLLECTION_NAME)
281+
res = client.query(collection_name=COLLECTION_NAME, filter="", output_fields=["count(*)"],
282+
consistency_level="Strong")
283+
print(f'row count: {res[0]["count(*)"]}')
284+
results = client.query(collection_name=COLLECTION_NAME,
285+
filter="id >= 0",
286+
output_fields=["*"])
287+
print(f"\n===================== query results ====================")
288+
for item in results:
289+
print(item)
290+
id = item["id"]
291+
original_row = gen_row(id)
292+
for key in original_row.keys():
293+
if key not in item:
294+
raise Exception(f"{key} is missed in query result")
295+
if key == "binary_vector":
296+
# returned binary vector is wrapped by a list, this is a bug
297+
original_row[key] = [bytes(gen_binary_vector(id, True).tolist())]
298+
elif key == "sparse_vector":
299+
# returned sparse vector is id-pair format
300+
original_row[key] = gen_sparse_vector(id, False)
301+
elif key == "timestamp":
302+
# TODO: compare the timestamp values
303+
continue
304+
if item[key] != original_row[key]:
305+
raise Exception(f"value of {key} is unequal, original value: {original_row[key]}, query value: {item[key]}")
306+
print(f"Query result of id={id} is correct")
307+
308+
309+
def test_file_type(file_type: BulkFileType):
310+
print(f"\n########################## {file_type.name} ##################################")
311+
schema = build_schema()
312+
batch_files = local_writer(schema=schema, file_type=file_type)
313+
build_collection(schema)
314+
batch_files = remote_writer(schema=schema, file_type=file_type)
315+
call_bulk_import(batch_files=batch_files)
316+
verify_imported_data()
317+
318+
319+
if __name__ == '__main__':
320+
file_types = [
321+
BulkFileType.PARQUET,
322+
BulkFileType.JSON,
323+
BulkFileType.CSV,
324+
]
325+
for file_type in file_types:
326+
test_file_type(file_type)

examples/orm_deprecated/bulk_import/example_bulkwriter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import numpy as np
1919
import pandas as pd
2020

21-
from examples.bulk_import.data_gengerator import *
21+
from examples.orm_deprecated.bulk_import.data_gengerator import *
2222

2323
logging.basicConfig(level=logging.INFO)
2424

examples/orm_deprecated/bulk_import/example_bulkwriter_with_nullable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
from typing import List
66

7-
from examples.bulk_import.data_gengerator import *
7+
from examples.orm_deprecated.bulk_import.data_gengerator import *
88

99
logging.basicConfig(level=logging.INFO)
1010

0 commit comments

Comments
 (0)