Skip to content

Commit aee0bdf

Browse files
authored
fix: Add support for numpy ndarrays in Array fields (#3069) (#3070)
This commit fixes issue #3069 where numpy ndarrays were being rejected for Array fields with an error: "Expected list, got ndarray". Fixes #3069 Signed-off-by: zhuwenxing <wenxing.zhu@zilliz.com>
1 parent ce2640c commit aee0bdf

File tree

3 files changed

+9
-1
lines changed

3 files changed

+9
-1
lines changed

pymilvus/bulk_writer/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
DataType.BFLOAT16_VECTOR.name: lambda x, dim: float16_vector_validator(x, dim, True),
6363
DataType.SPARSE_FLOAT_VECTOR.name: lambda x: sparse_vector_validator(x),
6464
DataType.INT8_VECTOR.name: lambda x, dim: int8_vector_validator(x, dim),
65-
DataType.ARRAY.name: lambda x, cap: isinstance(x, list) and len(x) <= cap,
65+
DataType.ARRAY.name: lambda x, cap: (isinstance(x, (list, np.ndarray)) and len(x) <= cap),
6666
}
6767

6868
NUMPY_TYPE_CREATOR = {

pymilvus/client/entity_helper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ def convert_to_array_arr(objs: List[Any], field_info: Any):
246246

247247

248248
def convert_to_array(obj: List[Any], field_info: Any):
249+
# Convert numpy ndarray to list if needed
250+
if isinstance(obj, np.ndarray):
251+
obj = obj.tolist()
252+
249253
field_data = schema_types.ScalarField()
250254
element_type = field_info.get("element_type", None)
251255
if element_type == DataType.BOOL:

pymilvus/client/prepare.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,10 @@ def _process_struct_field(
543543
struct_sub_field_info: Two-level dict [struct_name][field_name] -> field info
544544
struct_sub_fields_data: Two-level dict [struct_name][field_name] -> FieldData
545545
"""
546+
# Convert numpy ndarray to list if needed
547+
if isinstance(values, np.ndarray):
548+
values = values.tolist()
549+
546550
if not isinstance(values, list):
547551
msg = f"Field '{field_name}': Expected list, got {type(values).__name__}"
548552
raise TypeError(msg)

0 commit comments

Comments
 (0)