1+ import dataclasses
12import functools
3+ from collections .abc import Mapping
24from typing import Any , Dict , List
35
46from common .db_utils import write_to_db
@@ -43,15 +45,64 @@ def post_process(table_name: str, **kwargs) -> List[Dict[str, Any]]:
4345 return []
4446
4547
48+ def _ensure_list (obj ):
49+ """
50+ Ensure the object is returned as a list.
51+ """
52+ if isinstance (obj , list ):
53+ return obj
54+ if isinstance (obj , (str , bytes , Mapping )):
55+ return [obj ]
56+ if hasattr (obj , "__iter__" ) and not hasattr (obj , "__len__" ): # 如 generator
57+ return list (obj )
58+ return [obj ]
59+
60+
61+ def _to_dict (obj : Any ) -> Dict [str , Any ]:
62+ """
63+ Convert various object types to a dictionary for DB writing.
64+ """
65+ if isinstance (obj , Mapping ):
66+ return dict (obj )
67+ if dataclasses .is_dataclass (obj ):
68+ return dataclasses .asdict (obj )
69+ if hasattr (obj , "_asdict" ): # namedtuple
70+ return obj ._asdict ()
71+ if hasattr (obj , "__dict__" ):
72+ return vars (obj )
73+ raise TypeError (f"Cannot convert { type (obj )} to dict for DB writing" )
74+
75+
76+ def proj_process (table_name : str , ** kwargs ) -> List [Dict [str , Any ]]:
77+ if "_proj" not in kwargs :
78+ return []
79+ name = kwargs .get ("_name" , table_name )
80+ raw_input = kwargs ["_proj" ]
81+ raw_results = _ensure_list (raw_input )
82+
83+ processed_results = []
84+ for result in raw_results :
85+ try :
86+ dict_result = _to_dict (result )
87+ write_to_db (name , dict_result )
88+ processed_results .append (dict_result )
89+ except Exception as e :
90+ raise ValueError (f"Failed to process item in _proj: { e } " ) from e
91+
92+ return processed_results
93+
94+
4695# ---------------- decorator ----------------
4796def export_vars (func ):
4897 @functools .wraps (func )
4998 def wrapper (* args , ** kwargs ):
5099 result = func (* args , ** kwargs )
51- # If the function returns a dict containing '_data' or 'data ', post-process it
100+ # If the function returns a dict containing '_data' or '_proj ', post-process it
52101 if isinstance (result , dict ):
53- if "_data" in result or "data" in result :
102+ if "_data" in result :
54103 return post_process (func .__name__ , ** result )
104+ if "_proj" in result :
105+ return proj_process (func .__name__ , ** result )
55106 # Otherwise return unchanged
56107 return result
57108
@@ -65,33 +116,6 @@ def capture():
65116 return {"name" : "demo" , "_data" : {"accuracy" : 0.1 , "loss" : 0.3 }}
66117
67118
68- @export_vars
69- def capture_list ():
70- """All lists via '_name' + '_data'"""
71- return {
72- "_name" : "demo" ,
73- "_data" : {
74- "accuracy" : [0.1 , 0.2 , 0.3 ],
75- "loss" : [0.1 , 0.2 , 0.3 ],
76- },
77- }
78-
79-
80- @export_vars
81- def capture_mix ():
82- """Mixed single + lists via '_name' + '_data'"""
83- return {
84- "_name" : "demo" ,
85- "_data" : {
86- "length" : 10086 , # single value
87- "accuracy" : [0.1 , 0.2 , 0.3 ], # list
88- "loss" : [0.1 , 0.2 , 0.3 ], # list
89- },
90- }
91-
92-
93119# quick test
94120if __name__ == "__main__" :
95121 print ("capture(): " , capture ())
96- print ("capture_list(): " , capture_list ())
97- print ("capture_mix(): " , capture_mix ())
0 commit comments