From 09c0ae5a99ee02a93440551ec68bc7015beb177f Mon Sep 17 00:00:00 2001 From: nalalana <3392284556@qq.com> Date: Thu, 15 May 2025 11:30:30 +0800 Subject: [PATCH] Fixed bug where SQL cannot filter ImmuForeignKey condition --- immu_django/sql/getters.py | 115 ++++++++++++++++++++++--------------- 1 file changed, 70 insertions(+), 45 deletions(-) diff --git a/immu_django/sql/getters.py b/immu_django/sql/getters.py index ebfa950..13887f1 100644 --- a/immu_django/sql/getters.py +++ b/immu_django/sql/getters.py @@ -1,4 +1,6 @@ import json +from typing import Union + from immu_django.sql.alter import _TableField from immu_django.sql.models import SQLERROR, SQLModel @@ -49,7 +51,9 @@ def _make_order_str(self, order_by: str = None) -> str: return order_by_str - def _where(self, key: str, org_value: str | int | float, where_str: list[str]): + def _where(self, key: str, org_value: Union[str, int, float], where_str: list[str]): + + if len(where_str) == 0: where_str.append('WHERE ') else: @@ -59,39 +63,48 @@ def _where(self, key: str, org_value: str | int | float, where_str: list[str]): value = f"'{org_value}'" else: value = org_value - + + origin_key = key keys = key.split('__', 1) if len(keys) > 1: key = keys[0] - match keys[-1]: - case 'not': - return where_str.append(f"{key} <> {value}") - case 'in': - in_str = [f"'{org}'" for org in org_value] - return where_str.append(f'{key} IN ({", ".join(in_str)})') - case 'not_in': - in_str = [f"'{org}'" for org in org_value] - return where_str.append(f'{key} NOT IN ({", ".join(in_str)})') - case 'gt': - return where_str.append(f"{key} > {value}") - case 'gte': - return where_str.append(f"{key} >= {value}") - case 'lt': - return where_str.append(f"{key} < {value}") - case 'lte': - return where_str.append(f"{key} <= {value}") - case 'startswith': - return where_str.append(f"{key} LIKE '^{org_value}'") - case 'endswith': - return where_str.append(f"{key} LIKE '{org_value}$'") - case 'contains': - return where_str.append(f"{key} LIKE '.*{org_value}.*'") - case 'not_contains': - return where_str.append(f"{key} NOT LIKE '.*{org_value}.*'") - case 'regex': - return where_str.append(f"{key} LIKE '{org_value}'") - case _: - raise ValueError('__*** in key not allowed') + if keys[-1] == 'not': + return where_str.append(f"{key} <> {value}") + elif keys[-1] == 'in': + in_str = [f"'{org}'" for org in org_value] + return where_str.append(f'{key} IN ({", ".join(in_str)})') + elif keys[-1] == 'not_in': + in_str = [f"'{org}'" for org in org_value] + return where_str.append(f'{key} NOT IN ({", ".join(in_str)})') + elif keys[-1] == 'gt': + return where_str.append(f"{key} > {value}") + elif keys[-1] == 'gte': + return where_str.append(f"{key} >= {value}") + elif keys[-1] == 'lt': + return where_str.append(f"{key} < {value}") + elif keys[-1] == 'lte': + return where_str.append(f"{key} <= {value}") + elif keys[-1] == 'startswith': + return where_str.append(f"{key} LIKE '^{org_value}'") + elif keys[-1] == 'endswith': + return where_str.append(f"{key} LIKE '{org_value}$'") + elif keys[-1] == 'contains': + return where_str.append(f"{key} LIKE '.*{org_value}.*'") + elif keys[-1] == 'not_contains': + return where_str.append(f"{key} NOT LIKE '.*{org_value}.*'") + elif keys[-1] == 'regex': + return where_str.append(f"{key} LIKE '{org_value}'") + elif keys[-1].endswith('fg'): + """ + 这一段代码是因为原本的代码不能处理ImmuForeignKey加上的 + """ + foreign_table = keys[-1].split('__')[-2] # 外键表的名称 + foreign_key = f"{key}_id" + where_str.insert(0, f"JOIN {foreign_table} ON thisisaforeignkeytablenameoccupy.{origin_key} = {foreign_table}.id ") + # print("WHERE_STR:", where_str) + return where_str.append(f"{foreign_table}.id = {value}") + else: + raise ValueError('__*** in key not allowed') return where_str.append(f"{key} = {value}") @@ -149,13 +162,18 @@ def _make_query( limit: int = 1_000, offset: int = 0, order_by: str = None) -> list[tuple]: + # print("VALUES:\n", values, '\n') query_str = f'SELECT * FROM {self.table_name} ' \ f'{self._make_time_travel_str(time_travel)} ' \ f'{self._make_where_str(values)} ' \ f'{self._make_order_str(order_by)}' \ - f'{self._make_offset_str(limit, offset)}' - + f'{self._make_offset_str(limit, offset)}'.replace('thisisaforeignkeytablenameoccupy', self.table_name) + + + values = self.immu_client.sqlQuery(query_str) + + # print("FINAL RESULT:", values) return values @@ -176,18 +194,20 @@ def _get_json_value(self, item: dict, field: str, value: str): def _get_fg_field(self, fg_fields: dict, field: str, value: str): fg = str(field).split('__') + unique_id = f"{fg[0]}__{fg[2]}" - if fg[2] not in fg_fields: - fg_fields[fg[2]] = {} - fg_fields[fg[2]]['name'] = fg[0] - fg_fields[fg[2]]['values'] = {} + if unique_id not in fg_fields: + fg_fields[unique_id] = {} + fg_fields[unique_id]['name'] = fg[0] + fg_fields[unique_id]['values'] = {} + - fg_fields[fg[2]]['values'][fg[1]] = value + fg_fields[unique_id]['values'][fg[1]] = value def _get_fg_objs(self, item: dict, fg_fields: dict): for key, value in fg_fields.items(): - table_name = key + table_name = key.split('__')[1] name = value['name'] getter = GetWhere(self.db, table_name, self.immu_client) @@ -204,12 +224,12 @@ def get( order_by: str = None, time_travel: dict = None, limit: int = 1_000, offset: int = 0, - **kwargs) -> list[dict] | dict: + **kwargs) -> Union[list[dict], dict]: items = [] itens_count = 0 - values = self._make_query(kwargs, time_travel, limit, offset, order_by) - + + # print('values', values) for value in values: if itens_count >= size_limit: break @@ -219,13 +239,18 @@ def get( try: for field, value in zip(self.table_fields_names, value): + # print("FILED, VALUE:", field, value) if str(field).startswith('__json__'): self._get_json_value(item, field, value) elif str(field).endswith('__fg'): - self._get_fg_field(fg_fields, field, value) + if value: + self._get_fg_field(fg_fields, field, value) + else: + # print('虽然是fg,但是是none') + field = field.split('__')[0] + item[field] = None else: item[field] = value - self._get_fg_objs(item, fg_fields) items.append(item) @@ -235,7 +260,7 @@ def get( itens_count += 1 if len(items) <= 0: - raise Exception('Cant find any itens') + raise Exception('Cant find any items') if size_limit <= 1: obj = SQLModel(