1- from metrics .verbmem import eval as eval_verbmem
1+ from metrics .verbmem import eval as eval_ve rbmem
22from metrics .privleak import eval as eval_privleak
33from metrics .knowmem import eval as eval_knowmem
44from utils import load_model , load_tokenizer , write_csv , read_json , write_json
@@ -29,6 +29,7 @@ def eval_model(
2929 knowmem_retain_qa_file : str | None = None ,
3030 knowmem_retain_qa_icl_file : str | None = None ,
3131 temp_dir : str | None = None ,
32+ DEBUG : bool = False ,
3233) -> Dict [str , float ]:
3334 # Argument sanity check
3435 if not metrics :
@@ -50,10 +51,13 @@ def eval_model(
5051
5152 out = {}
5253 model = model .to ('cuda' )
54+ debug_subset_len = 3 if DEBUG else None
5355
5456 # 1. verbmem_f
5557 if 'verbmem_f' in metrics :
5658 data = read_json (verbmem_forget_file )
59+ if DEBUG :
60+ data = data [:debug_subset_len ]
5761 agg , log = eval_verbmem (
5862 prompts = [d ['prompt' ] for d in data ],
5963 gts = [d ['gt' ] for d in data ],
@@ -67,10 +71,17 @@ def eval_model(
6771
6872 # 2. privleak
6973 if 'privleak' in metrics :
74+ forget_data = read_json (privleak_forget_file )
75+ retain_data = read_json (privleak_retain_file )
76+ holdout_data = read_json (privleak_holdout_file )
77+ if DEBUG :
78+ forget_data = forget_data [:debug_subset_len ]
79+ retain_data = retain_data [:debug_subset_len ]
80+ holdout_data = holdout_data [:debug_subset_len ]
7081 auc , log = eval_privleak (
71- forget_data = read_json ( privleak_forget_file ) ,
72- retain_data = read_json ( privleak_retain_file ) ,
73- holdout_data = read_json ( privleak_holdout_file ) ,
82+ forget_data = forget_data ,
83+ retain_data = retain_data ,
84+ holdout_data = holdout_data ,
7485 model = model , tokenizer = tokenizer
7586 )
7687 if temp_dir is not None :
@@ -82,6 +93,9 @@ def eval_model(
8293 if 'knowmem_f' in metrics :
8394 qa = read_json (knowmem_forget_qa_file )
8495 icl = read_json (knowmem_forget_qa_icl_file )
96+ if DEBUG :
97+ qa = qa [:debug_subset_len ]
98+ icl = icl [:debug_subset_len ]
8599 agg , log = eval_knowmem (
86100 questions = [d ['question' ] for d in qa ],
87101 answers = [d ['answer' ] for d in qa ],
@@ -99,6 +113,9 @@ def eval_model(
99113 if 'knowmem_r' in metrics :
100114 qa = read_json (knowmem_retain_qa_file )
101115 icl = read_json (knowmem_retain_qa_icl_file )
116+ if DEBUG :
117+ qa = qa [:debug_subset_len ]
118+ icl = icl [:debug_subset_len ]
102119 agg , log = eval_knowmem (
103120 questions = [d ['question' ] for d in qa ],
104121 answers = [d ['answer' ] for d in qa ],
@@ -122,7 +139,8 @@ def load_then_eval_models(
122139 tokenizer_dir : str = LLAMA_DIR ,
123140 out_file : str | None = None ,
124141 metrics : List [str ] = SUPPORTED_METRICS ,
125- temp_dir : str = "temp"
142+ temp_dir : str = "temp" ,
143+ DEBUG : bool = False ,
126144) -> DataFrame :
127145 print (out_file )
128146 # Argument sanity check
@@ -140,7 +158,8 @@ def load_then_eval_models(
140158 tokenizer = load_tokenizer (tokenizer_dir )
141159 res = eval_model (
142160 model , tokenizer , metrics , corpus ,
143- temp_dir = os .path .join (temp_dir , name )
161+ temp_dir = os .path .join (temp_dir , name ),
162+ DEBUG = DEBUG
144163 )
145164 out .append ({'name' : name } | res )
146165 if out_file is not None : write_csv (out , out_file )
0 commit comments