99from typing import List , Dict , Literal
1010from pandas import DataFrame
1111
12- import importlib .util
13- import os
12+ import sys
13+
14+ curr_dir = os .path .dirname (__file__ )
15+ PROJECT_DIR = os .path .abspath (os .path .join (curr_dir , '..' , '..' ))
16+ Unlearn_Simple_DIR = os .path .join (PROJECT_DIR , 'Unlearn-Simple' )
17+ MUSE_DIR = os .path .join (Unlearn_Simple_DIR , 'MUSE' )
18+
19+ sys .path .append (os .path .join (PROJECT_DIR , 'src' ))
20+
21+ # print sys paths that are being used for importing
22+ # print("Current sys.path:")
23+ # for path in sys.path:
24+ # print(path)
25+ # sys.path.append(os.path.join(PROJECT_DIR, 'src'))
26+ import input_loss_landscape .utils as input_loss_landscape_utils
1427
15- input_loss_landscape_utils_path = os .path .abspath (os .path .join (os .getcwd (), '..' , '..' , 'src' , 'input_loss_landscape' , 'utils.py' ))
16- spec = importlib .util .spec_from_file_location ("input_loss_landscape_utils" , input_loss_landscape_utils_path )
17- input_loss_landscape_utils = importlib .util .module_from_spec (spec )
18- spec .loader .exec_module (input_loss_landscape_utils )
28+ input_loss_landscape_eval = input_loss_landscape_utils .input_loss_landscape_eval
1929
2030input_loss_landscape_eval = input_loss_landscape_utils .input_loss_landscape_eval
21- print (f"Current working directory: { os .getcwd ()} " ) # Ensure the current working directory is set correctly
2231
2332
2433
@@ -45,6 +54,7 @@ def eval_model(
4554 knowmem_retain_qa_icl_file : str | None = None ,
4655 temp_dir : str | None = None ,
4756 DEBUG : bool = False ,
57+ kwargs : dict = {},
4858): # -> Dict[str, float]:
4959 # Argument sanity check
5060 if not metrics :
@@ -56,6 +66,7 @@ def eval_model(
5666 raise ValueError (f"Invalid corpus. `corpus` should be either 'news' or 'books'." )
5767 if corpus is not None :
5868 verbmem_forget_file = DEFAULT_DATA [corpus ]['verbmem_forget_file' ] if verbmem_forget_file is None else verbmem_forget_file
69+ print (f"{ privleak_forget_file = } , { privleak_retain_file = } , { privleak_holdout_file = } " )
5970 privleak_forget_file = DEFAULT_DATA [corpus ]['privleak_forget_file' ] if privleak_forget_file is None else privleak_forget_file
6071 privleak_retain_file = DEFAULT_DATA [corpus ]['privleak_retain_file' ] if privleak_retain_file is None else privleak_retain_file
6172 privleak_holdout_file = DEFAULT_DATA [corpus ]['privleak_holdout_file' ] if privleak_holdout_file is None else privleak_holdout_file
@@ -66,13 +77,13 @@ def eval_model(
6677
6778 out = {}
6879 model = model .to ('cuda' )
69- debug_subset_len = 2 if DEBUG else None
80+ debug_subset_len = 50 if DEBUG else None
7081 print (f"{ DEBUG = } " )
7182 plots = {}
7283
7384 # 1. verbmem_f
7485 if 'verbmem_f' in metrics :
75- data = read_json (verbmem_forget_file )
86+ data = read_json (os . path . join ( MUSE_DIR , verbmem_forget_file ) )
7687 if DEBUG :
7788 data = data [:debug_subset_len ]
7889 agg , log = eval_verbmem (
@@ -88,32 +99,57 @@ def eval_model(
8899
89100 # 2. privleak
90101 if 'privleak' in metrics :
91- forget_data = read_json (privleak_forget_file )
92- retain_data = read_json (privleak_retain_file )
93- holdout_data = read_json (privleak_holdout_file )
102+ forget_data = read_json (os . path . join ( MUSE_DIR , privleak_forget_file ) )
103+ retain_data = read_json (os . path . join ( MUSE_DIR , privleak_retain_file ) )
104+ holdout_data = read_json (os . path . join ( MUSE_DIR , privleak_holdout_file ) )
94105 if DEBUG :
95106 forget_data = forget_data [:debug_subset_len ]
96107 retain_data = retain_data [:debug_subset_len ]
97108 holdout_data = holdout_data [:debug_subset_len ]
98109
99110 privleak_output_dir = os .path .abspath (os .path .join (temp_dir , "privleak" ) if temp_dir is not None else None )
100- auc , log , privleak_plots = eval_privleak (
101- forget_data = forget_data ,
102- retain_data = retain_data ,
103- holdout_data = holdout_data ,
104- model = model , tokenizer = tokenizer ,
105- plot_dir = privleak_output_dir
106- )
107- if temp_dir is not None :
108- write_json (auc , os .path .join (temp_dir , "privleak/auc.json" ))
109- write_json (log , os .path .join (temp_dir , "privleak/log.json" ))
111+ create_new_files = kwargs .get ('create_new_files' , {})
112+ create_new_privleak_files = create_new_files .get ('privleak' , True )
113+ auc_path = os .path .join (privleak_output_dir , "auc.json" )
114+ log_path = os .path .join (privleak_output_dir , "log.json" )
115+ plots_dir = os .path .join (privleak_output_dir , "plots" )
116+
117+ if create_new_privleak_files :
118+ auc , log , privleak_plots = eval_privleak (
119+ forget_data = forget_data ,
120+ retain_data = retain_data ,
121+ holdout_data = holdout_data ,
122+ model = model , tokenizer = tokenizer ,
123+ plot_dir = privleak_output_dir
124+ )
125+ if temp_dir is not None :
126+ write_json (auc , auc_path )
127+ write_json (log , log_path )
128+ # save plots
129+ os .makedirs (plots_dir , exist_ok = True )
130+ for plot_name , plot_obj in privleak_plots .items ():
131+ plot_path = os .path .join (plots_dir , f"{ plot_name } .png" )
132+ plot_obj .savefig (plot_path )
133+ plot_obj .clf ()
134+
135+ else :
136+ # load auc, log, privleak_plots
137+ auc = read_json (auc_path ) if os .path .exists (auc_path ) else {}
138+ log = read_json (log_path ) if os .path .exists (log_path ) else {}
139+ privleak_plots = {}
140+ if os .path .isdir (plots_dir ):
141+ for plot_file in os .listdir (plots_dir ):
142+ if plot_file .endswith (".png" ):
143+ privleak_plots [os .path .splitext (plot_file )[0 ]] = os .path .join (plots_dir , plot_file )
144+
145+
110146 out ['privleak' ] = (auc [privleak_auc_key ] - AUC_RETRAIN [corpus ][privleak_auc_key ]) / AUC_RETRAIN [corpus ][privleak_auc_key ] * 100
111147 plots ['privleak' ] = privleak_plots
112148
113149 # 3. knowmem_f
114150 if 'knowmem_f' in metrics :
115- qa = read_json (knowmem_forget_qa_file )
116- icl = read_json (knowmem_forget_qa_icl_file )
151+ qa = read_json (os . path . join ( MUSE_DIR , knowmem_forget_qa_file ) )
152+ icl = read_json (os . path . join ( MUSE_DIR , knowmem_forget_qa_icl_file ) )
117153 if DEBUG :
118154 qa = qa [:debug_subset_len ]
119155 icl = icl [:debug_subset_len ]
@@ -132,8 +168,8 @@ def eval_model(
132168
133169 # 4. knowmem_r
134170 if 'knowmem_r' in metrics :
135- qa = read_json (knowmem_retain_qa_file )
136- icl = read_json (knowmem_retain_qa_icl_file )
171+ qa = read_json (os . path . join ( MUSE_DIR , knowmem_retain_qa_file ) )
172+ icl = read_json (os . path . join ( MUSE_DIR , knowmem_retain_qa_icl_file ) )
137173 if DEBUG :
138174 qa = qa [:debug_subset_len ]
139175 icl = icl [:debug_subset_len ]
@@ -152,9 +188,10 @@ def eval_model(
152188
153189 # 5. loss_landscape
154190 if 'loss_landscape' in metrics :
155- forget_data = read_json (privleak_forget_file )
156- retain_data = read_json (privleak_retain_file )
157- holdout_data = read_json (privleak_holdout_file )
191+ print (f"{ os .path .abspath (privleak_forget_file )= } " )
192+ forget_data = read_json (os .path .join (MUSE_DIR , privleak_forget_file ))
193+ retain_data = read_json (os .path .join (MUSE_DIR , privleak_retain_file ))
194+ holdout_data = read_json (os .path .join (MUSE_DIR , privleak_holdout_file ))
158195 if DEBUG :
159196 forget_data = forget_data [:debug_subset_len ]
160197 retain_data = retain_data [:debug_subset_len ]
@@ -192,6 +229,7 @@ def load_then_eval_models(
192229 metrics : List [str ] = SUPPORTED_METRICS ,
193230 temp_dir : str = "temp" ,
194231 DEBUG : bool = False ,
232+ kwargs : dict = {},
195233): # -> DataFrame:
196234 print (out_file )
197235 # Argument sanity check
@@ -208,11 +246,28 @@ def load_then_eval_models(
208246 model = load_model (model_dir )
209247 tokenizer = load_tokenizer (tokenizer_dir )
210248
211- return eval_model (
249+ privleak_files = kwargs .get ('privleak_files' , {})
250+ if privleak_files :
251+ privleak_forget_file = privleak_files .get ('privleak_forget_file' , None )
252+ privleak_retain_file = privleak_files .get ('privleak_retain_file' , None )
253+ privleak_holdout_file = privleak_files .get ('privleak_holdout_file' , None )
254+
255+ return eval_model (
256+ model , tokenizer , metrics , corpus ,
257+ temp_dir = os .path .join (temp_dir , name ),
258+ DEBUG = DEBUG ,
259+ privleak_forget_file = privleak_forget_file ,
260+ privleak_retain_file = privleak_retain_file ,
261+ privleak_holdout_file = privleak_holdout_file ,
262+ kwargs = kwargs ,
263+ )
264+
265+ else :
266+ return eval_model (
212267 model , tokenizer , metrics , corpus ,
213268 temp_dir = os .path .join (temp_dir , name ),
214269 DEBUG = DEBUG
215- )
270+ )
216271 # res, plots = eval_model(
217272 # model, tokenizer, metrics, corpus,
218273 # temp_dir=os.path.join(temp_dir, name),
@@ -235,4 +290,5 @@ def load_then_eval_models(
235290 parser .add_argument ('--metrics' , type = str , nargs = '+' , default = SUPPORTED_METRICS )
236291 args = parser .parse_args ()
237292
238- load_then_eval_models (** vars (args ))
293+ load_then_eval_models (** vars (args ))
294+
0 commit comments