Skip to content

Commit 9394057

Browse files
committed
Allowing multiple load and apply statements
1 parent a03bc04 commit 9394057

File tree

3 files changed

+73
-42
lines changed

3 files changed

+73
-42
lines changed

docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
---
22
# image tag 8.0-RC2-pre is the one matching the 8.0 GA release
33
x-client-libs-stack-image: &client-libs-stack-image
4-
image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-8.2}"
4+
image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_STACK_IMAGE_TAG:-8.4-RC1-pre.2}"
55

66
x-client-libs-image: &client-libs-image
77
image: "redislabs/client-libs-test:${CLIENT_LIBS_TEST_IMAGE_TAG:-8.4-RC1-pre.2}"

redis/commands/search/hybrid_query.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -228,18 +228,21 @@ def __init__(self) -> None:
228228
"""
229229
Create a new hybrid post processing configuration object.
230230
"""
231-
self._load_fields = []
232-
self._groupby = []
233-
self._apply = []
231+
self._load_statements = []
232+
self._apply_statements = []
233+
self._groupby_statements = []
234234
self._sortby_fields = []
235235
self._filter = None
236236
self._limit = None
237237

238238
def load(self, *fields: str) -> Self:
239239
"""
240-
Add load parameters to the query.
240+
Add load statement parameters to the query.
241241
"""
242-
self._load_fields = fields
242+
if fields:
243+
fields_str = " ".join(fields)
244+
fields_list = fields_str.split(" ")
245+
self._load_statements.extend(("LOAD", len(fields_list), *fields_list))
243246
return self
244247

245248
def group_by(self, fields: List[str], *reducers: Reducer) -> Self:
@@ -262,7 +265,7 @@ def group_by(self, fields: List[str], *reducers: Reducer) -> Self:
262265
if reducer._alias is not None:
263266
ret.extend(("AS", reducer._alias))
264267

265-
self._groupby.extend(ret)
268+
self._groupby_statements.extend(ret)
266269
return self
267270

268271
def apply(self, **kwexpr) -> Self:
@@ -274,11 +277,14 @@ def apply(self, **kwexpr) -> Self:
274277
the alias for the projection, and the value is the projection
275278
expression itself, for example `apply(square_root="sqrt(@foo)")`.
276279
"""
280+
apply_args = []
277281
for alias, expr in kwexpr.items():
278282
ret = ["APPLY", expr]
279283
if alias is not None:
280284
ret.extend(("AS", alias))
281-
self._apply.extend(ret)
285+
apply_args.extend(ret)
286+
287+
self._apply_statements.extend(apply_args)
282288

283289
return self
284290

@@ -310,14 +316,12 @@ def limit(self, offset: int, num: int) -> Self:
310316

311317
def build_args(self) -> List[str]:
312318
args = []
313-
if self._load_fields:
314-
fields_str = " ".join(self._load_fields)
315-
fields = fields_str.split(" ")
316-
args.extend(("LOAD", len(fields), *fields))
317-
if self._groupby:
318-
args.extend(self._groupby)
319-
if self._apply:
320-
args.extend(self._apply)
319+
if self._load_statements:
320+
args.extend(self._load_statements)
321+
if self._groupby_statements:
322+
args.extend(self._groupby_statements)
323+
if self._apply_statements:
324+
args.extend(self._apply_statements)
321325
if self._sortby_fields:
322326
sortby_args = []
323327
for f in self._sortby_fields:

tests/test_search.py

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4095,32 +4095,6 @@ def compare_list_of_dicts(actual, expected):
40954095
f"All expected:{expected}"
40964096
)
40974097

4098-
@pytest.mark.redismod
4099-
@skip_if_server_version_lt("8.3.224")
4100-
def test_review_feedback_hybrid_search(self, client):
4101-
# Create index and add data
4102-
self._create_hybrid_search_index(client)
4103-
self._add_data_for_hybrid_search(client, items_sets=5)
4104-
4105-
# set search query
4106-
search_query = HybridSearchQuery("@color:{red} @color:{green}")
4107-
search_query.scorer("TFIDF")
4108-
4109-
vsim_query = HybridVsimQuery(
4110-
vector_field_name="@embedding",
4111-
vector_data=np.array([-100, -200, -200, -300], dtype=np.float32).tobytes(),
4112-
)
4113-
4114-
hybrid_query = HybridQuery(search_query, vsim_query)
4115-
4116-
res = client.ft().hybrid_search(query=hybrid_query)
4117-
if is_resp2_connection(client):
4118-
assert len(res.results) > 0
4119-
assert res.warnings == []
4120-
else:
4121-
assert len(res["results"]) > 0
4122-
assert res["warnings"] == []
4123-
41244098
@pytest.mark.redismod
41254099
@skip_if_server_version_lt("8.3.224")
41264100
def test_basic_hybrid_search(self, client):
@@ -5289,3 +5263,56 @@ def test_hybrid_search_query_with_cursor(self, client):
52895263
assert len(vsim_res_from_cursor.rows) == 5
52905264
else:
52915265
assert len(vsim_res_from_cursor[0]["results"]) == 5
5266+
5267+
@pytest.mark.redismod
5268+
@skip_if_server_version_lt("8.3.224")
5269+
def test_hybrid_search_query_with_multiple_loads_and_applies(self, client):
5270+
# Create index and add data
5271+
self._create_hybrid_search_index(client)
5272+
self._add_data_for_hybrid_search(client, items_sets=1)
5273+
5274+
# set search query
5275+
search_query = HybridSearchQuery("@color:{red|green}")
5276+
5277+
vsim_query = HybridVsimQuery(
5278+
vector_field_name="@embedding",
5279+
vector_data=np.array([1, 2, 7, 6], dtype=np.float32).tobytes(),
5280+
)
5281+
5282+
hybrid_query = HybridQuery(search_query, vsim_query)
5283+
5284+
postprocessing_config = HybridPostProcessingConfig()
5285+
postprocessing_config.load("@color", "@price")
5286+
postprocessing_config.load("@description")
5287+
postprocessing_config.apply(discount_10_percents="@price - (@price * 0.1)")
5288+
postprocessing_config.apply(
5289+
additional_discount="@discount_10_percents - (@discount_10_percents * 0.1)"
5290+
)
5291+
postprocessing_config.filter(HybridFilter('@price=="15"'))
5292+
postprocessing_config.load("@description")
5293+
postprocessing_config.sort_by(
5294+
SortbyField("@discount_10_percents", asc=False),
5295+
SortbyField("@color", asc=True),
5296+
)
5297+
postprocessing_config.limit(0, 5)
5298+
5299+
res = client.ft().hybrid_search(
5300+
query=hybrid_query, post_processing=postprocessing_config, timeout=10
5301+
)
5302+
print(res)
5303+
if is_resp2_connection(client):
5304+
assert len(res.results) == 2
5305+
for item in res.results:
5306+
assert item["color"] is not None
5307+
assert item["price"] is not None
5308+
assert item["description"] is not None
5309+
assert item["discount_10_percents"] is not None
5310+
assert item["additional_discount"] is not None
5311+
else:
5312+
assert len(res["results"]) == 2
5313+
for item in res["results"]:
5314+
assert item["color"] is not None
5315+
assert item["price"] is not None
5316+
assert item["description"] is not None
5317+
assert item["discount_10_percents"] is not None
5318+
assert item["additional_discount"] is not None

0 commit comments

Comments
 (0)