Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ result_outputs/
results/
.cache/
backup/
sites/Demo/*
$null
*__pycache__/
.*
Expand Down
82 changes: 53 additions & 29 deletions test/common/capture_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import dataclasses
import functools
from collections.abc import Mapping
from typing import Any, Dict, List

from common.db_utils import write_to_db
Expand Down Expand Up @@ -43,15 +45,64 @@ def post_process(table_name: str, **kwargs) -> List[Dict[str, Any]]:
return []


def _ensure_list(obj):
"""
Ensure the object is returned as a list.
"""
if isinstance(obj, list):
return obj
if isinstance(obj, (str, bytes, Mapping)):
return [obj]
if hasattr(obj, "__iter__") and not hasattr(obj, "__len__"): # 如 generator
return list(obj)
return [obj]


def _to_dict(obj: Any) -> Dict[str, Any]:
"""
Convert various object types to a dictionary for DB writing.
"""
if isinstance(obj, Mapping):
return dict(obj)
if dataclasses.is_dataclass(obj):
return dataclasses.asdict(obj)
if hasattr(obj, "_asdict"): # namedtuple
return obj._asdict()
if hasattr(obj, "__dict__"):
return vars(obj)
raise TypeError(f"Cannot convert {type(obj)} to dict for DB writing")


def proj_process(table_name: str, **kwargs) -> List[Dict[str, Any]]:
if "_proj" not in kwargs:
return []
name = kwargs.get("_name", table_name)
raw_input = kwargs["_proj"]
raw_results = _ensure_list(raw_input)

processed_results = []
for result in raw_results:
try:
dict_result = _to_dict(result)
write_to_db(name, dict_result)
processed_results.append(dict_result)
except Exception as e:
raise ValueError(f"Failed to process item in _proj: {e}") from e

return processed_results


# ---------------- decorator ----------------
def export_vars(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
# If the function returns a dict containing '_data' or 'data', post-process it
# If the function returns a dict containing '_data' or '_proj', post-process it
if isinstance(result, dict):
if "_data" in result or "data" in result:
if "_data" in result:
return post_process(func.__name__, **result)
if "_proj" in result:
return proj_process(func.__name__, **result)
# Otherwise return unchanged
return result

Expand All @@ -65,33 +116,6 @@ def capture():
return {"name": "demo", "_data": {"accuracy": 0.1, "loss": 0.3}}


@export_vars
def capture_list():
"""All lists via '_name' + '_data'"""
return {
"_name": "demo",
"_data": {
"accuracy": [0.1, 0.2, 0.3],
"loss": [0.1, 0.2, 0.3],
},
}


@export_vars
def capture_mix():
"""Mixed single + lists via '_name' + '_data'"""
return {
"_name": "demo",
"_data": {
"length": 10086, # single value
"accuracy": [0.1, 0.2, 0.3], # list
"loss": [0.1, 0.2, 0.3], # list
},
}


# quick test
if __name__ == "__main__":
print("capture(): ", capture())
print("capture_list(): ", capture_list())
print("capture_mix(): ", capture_mix())
Loading