Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 70 additions & 45 deletions immu_django/sql/getters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
from typing import Union

from immu_django.sql.alter import _TableField
from immu_django.sql.models import SQLERROR, SQLModel

Expand Down Expand Up @@ -49,7 +51,9 @@ def _make_order_str(self, order_by: str = None) -> str:
return order_by_str


def _where(self, key: str, org_value: str | int | float, where_str: list[str]):
def _where(self, key: str, org_value: Union[str, int, float], where_str: list[str]):


if len(where_str) == 0:
where_str.append('WHERE ')
else:
Expand All @@ -59,39 +63,48 @@ def _where(self, key: str, org_value: str | int | float, where_str: list[str]):
value = f"'{org_value}'"
else:
value = org_value


origin_key = key
keys = key.split('__', 1)
if len(keys) > 1:
key = keys[0]
match keys[-1]:
case 'not':
return where_str.append(f"{key} <> {value}")
case 'in':
in_str = [f"'{org}'" for org in org_value]
return where_str.append(f'{key} IN ({", ".join(in_str)})')
case 'not_in':
in_str = [f"'{org}'" for org in org_value]
return where_str.append(f'{key} NOT IN ({", ".join(in_str)})')
case 'gt':
return where_str.append(f"{key} > {value}")
case 'gte':
return where_str.append(f"{key} >= {value}")
case 'lt':
return where_str.append(f"{key} < {value}")
case 'lte':
return where_str.append(f"{key} <= {value}")
case 'startswith':
return where_str.append(f"{key} LIKE '^{org_value}'")
case 'endswith':
return where_str.append(f"{key} LIKE '{org_value}$'")
case 'contains':
return where_str.append(f"{key} LIKE '.*{org_value}.*'")
case 'not_contains':
return where_str.append(f"{key} NOT LIKE '.*{org_value}.*'")
case 'regex':
return where_str.append(f"{key} LIKE '{org_value}'")
case _:
raise ValueError('__*** in key not allowed')
if keys[-1] == 'not':
return where_str.append(f"{key} <> {value}")
elif keys[-1] == 'in':
in_str = [f"'{org}'" for org in org_value]
return where_str.append(f'{key} IN ({", ".join(in_str)})')
elif keys[-1] == 'not_in':
in_str = [f"'{org}'" for org in org_value]
return where_str.append(f'{key} NOT IN ({", ".join(in_str)})')
elif keys[-1] == 'gt':
return where_str.append(f"{key} > {value}")
elif keys[-1] == 'gte':
return where_str.append(f"{key} >= {value}")
elif keys[-1] == 'lt':
return where_str.append(f"{key} < {value}")
elif keys[-1] == 'lte':
return where_str.append(f"{key} <= {value}")
elif keys[-1] == 'startswith':
return where_str.append(f"{key} LIKE '^{org_value}'")
elif keys[-1] == 'endswith':
return where_str.append(f"{key} LIKE '{org_value}$'")
elif keys[-1] == 'contains':
return where_str.append(f"{key} LIKE '.*{org_value}.*'")
elif keys[-1] == 'not_contains':
return where_str.append(f"{key} NOT LIKE '.*{org_value}.*'")
elif keys[-1] == 'regex':
return where_str.append(f"{key} LIKE '{org_value}'")
elif keys[-1].endswith('fg'):
"""
这一段代码是因为原本的代码不能处理ImmuForeignKey加上的
"""
foreign_table = keys[-1].split('__')[-2] # 外键表的名称
foreign_key = f"{key}_id"
where_str.insert(0, f"JOIN {foreign_table} ON thisisaforeignkeytablenameoccupy.{origin_key} = {foreign_table}.id ")
# print("WHERE_STR:", where_str)
return where_str.append(f"{foreign_table}.id = {value}")
else:
raise ValueError('__*** in key not allowed')

return where_str.append(f"{key} = {value}")

Expand Down Expand Up @@ -149,13 +162,18 @@ def _make_query(
limit: int = 1_000,
offset: int = 0,
order_by: str = None) -> list[tuple]:
# print("VALUES:\n", values, '\n')
query_str = f'SELECT * FROM {self.table_name} ' \
f'{self._make_time_travel_str(time_travel)} ' \
f'{self._make_where_str(values)} ' \
f'{self._make_order_str(order_by)}' \
f'{self._make_offset_str(limit, offset)}'

f'{self._make_offset_str(limit, offset)}'.replace('thisisaforeignkeytablenameoccupy', self.table_name)



values = self.immu_client.sqlQuery(query_str)

# print("FINAL RESULT:", values)

return values

Expand All @@ -176,18 +194,20 @@ def _get_json_value(self, item: dict, field: str, value: str):

def _get_fg_field(self, fg_fields: dict, field: str, value: str):
fg = str(field).split('__')
unique_id = f"{fg[0]}__{fg[2]}"

if fg[2] not in fg_fields:
fg_fields[fg[2]] = {}
fg_fields[fg[2]]['name'] = fg[0]
fg_fields[fg[2]]['values'] = {}
if unique_id not in fg_fields:
fg_fields[unique_id] = {}
fg_fields[unique_id]['name'] = fg[0]
fg_fields[unique_id]['values'] = {}


fg_fields[fg[2]]['values'][fg[1]] = value
fg_fields[unique_id]['values'][fg[1]] = value


def _get_fg_objs(self, item: dict, fg_fields: dict):
for key, value in fg_fields.items():
table_name = key
table_name = key.split('__')[1]
name = value['name']

getter = GetWhere(self.db, table_name, self.immu_client)
Expand All @@ -204,12 +224,12 @@ def get(
order_by: str = None,
time_travel: dict = None,
limit: int = 1_000, offset: int = 0,
**kwargs) -> list[dict] | dict:
**kwargs) -> Union[list[dict], dict]:
items = []
itens_count = 0

values = self._make_query(kwargs, time_travel, limit, offset, order_by)


# print('values', values)
for value in values:
if itens_count >= size_limit:
break
Expand All @@ -219,13 +239,18 @@ def get(

try:
for field, value in zip(self.table_fields_names, value):
# print("FILED, VALUE:", field, value)
if str(field).startswith('__json__'):
self._get_json_value(item, field, value)
elif str(field).endswith('__fg'):
self._get_fg_field(fg_fields, field, value)
if value:
self._get_fg_field(fg_fields, field, value)
else:
# print('虽然是fg,但是是none')
field = field.split('__')[0]
item[field] = None
else:
item[field] = value

self._get_fg_objs(item, fg_fields)

items.append(item)
Expand All @@ -235,7 +260,7 @@ def get(
itens_count += 1

if len(items) <= 0:
raise Exception('Cant find any itens')
raise Exception('Cant find any items')

if size_limit <= 1:
obj = SQLModel(
Expand Down