Skip to content

Commit bd77529

Browse files
committed
feat(test): Add PostgreSQL support and optimize database write logic
1 parent 010844e commit bd77529

File tree

8 files changed

+227
-184
lines changed

8 files changed

+227
-184
lines changed

test/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ result_outputs/
55
results/
66
.cache/
77
backup/
8+
sites/Demo/*
89
$null
910
*__pycache__/
1011
.*

test/common/capture_utils.py

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import dataclasses
12
import functools
3+
from collections.abc import Mapping
24
from typing import Any, Dict, List
35

46
from common.db_utils import write_to_db
@@ -43,15 +45,64 @@ def post_process(table_name: str, **kwargs) -> List[Dict[str, Any]]:
4345
return []
4446

4547

48+
def _ensure_list(obj):
49+
"""
50+
Ensure the object is returned as a list.
51+
"""
52+
if isinstance(obj, list):
53+
return obj
54+
if isinstance(obj, (str, bytes, Mapping)):
55+
return [obj]
56+
if hasattr(obj, "__iter__") and not hasattr(obj, "__len__"): # 如 generator
57+
return list(obj)
58+
return [obj]
59+
60+
61+
def _to_dict(obj: Any) -> Dict[str, Any]:
62+
"""
63+
Convert various object types to a dictionary for DB writing.
64+
"""
65+
if isinstance(obj, Mapping):
66+
return dict(obj)
67+
if dataclasses.is_dataclass(obj):
68+
return dataclasses.asdict(obj)
69+
if hasattr(obj, "_asdict"): # namedtuple
70+
return obj._asdict()
71+
if hasattr(obj, "__dict__"):
72+
return vars(obj)
73+
raise TypeError(f"Cannot convert {type(obj)} to dict for DB writing")
74+
75+
76+
def proj_process(table_name: str, **kwargs) -> List[Dict[str, Any]]:
77+
if "_proj" not in kwargs:
78+
return []
79+
name = kwargs.get("_name", table_name)
80+
raw_input = kwargs["_proj"]
81+
raw_results = _ensure_list(raw_input)
82+
83+
processed_results = []
84+
for result in raw_results:
85+
try:
86+
dict_result = _to_dict(result)
87+
write_to_db(name, dict_result)
88+
processed_results.append(dict_result)
89+
except Exception as e:
90+
raise ValueError(f"Failed to process item in _proj: {e}") from e
91+
92+
return processed_results
93+
94+
4695
# ---------------- decorator ----------------
4796
def export_vars(func):
4897
@functools.wraps(func)
4998
def wrapper(*args, **kwargs):
5099
result = func(*args, **kwargs)
51-
# If the function returns a dict containing '_data' or 'data', post-process it
100+
# If the function returns a dict containing '_data' or '_proj', post-process it
52101
if isinstance(result, dict):
53-
if "_data" in result or "data" in result:
102+
if "_data" in result:
54103
return post_process(func.__name__, **result)
104+
if "_proj" in result:
105+
return proj_process(func.__name__, **result)
55106
# Otherwise return unchanged
56107
return result
57108

@@ -65,33 +116,6 @@ def capture():
65116
return {"name": "demo", "_data": {"accuracy": 0.1, "loss": 0.3}}
66117

67118

68-
@export_vars
69-
def capture_list():
70-
"""All lists via '_name' + '_data'"""
71-
return {
72-
"_name": "demo",
73-
"_data": {
74-
"accuracy": [0.1, 0.2, 0.3],
75-
"loss": [0.1, 0.2, 0.3],
76-
},
77-
}
78-
79-
80-
@export_vars
81-
def capture_mix():
82-
"""Mixed single + lists via '_name' + '_data'"""
83-
return {
84-
"_name": "demo",
85-
"_data": {
86-
"length": 10086, # single value
87-
"accuracy": [0.1, 0.2, 0.3], # list
88-
"loss": [0.1, 0.2, 0.3], # list
89-
},
90-
}
91-
92-
93119
# quick test
94120
if __name__ == "__main__":
95121
print("capture(): ", capture())
96-
print("capture_list(): ", capture_list())
97-
print("capture_mix(): ", capture_mix())

0 commit comments

Comments
 (0)