Skip to content

Commit 9bc0ffb

Browse files
committed
Update: Improve evaluation flexibility and debug workflows
The primary purpose is to improve evaluation robustness and flexibility when managing model outputs and debug workflows. The primary changes are: - Updated `eval_model` to ensure `forget_data`, `retain_data`, and `holdout_data` are initialized consistently before use. - Replaced hardcoded paths with `os.path.join` using `MUSE_DIR` in `eval_model` for improved path handling. - Added a `kwargs` parameter to both `eval_model` and `load_then_eval_models` to support dynamic control over file creation and loading. - Implemented conditional logic in `eval_model` for managing `privleak` file generation based on `kwargs['create_new_files']`. - Removed unused imports and dynamic import logic from `eval.py`, replacing `importlib` with `sys.path.append` to streamline module loading. - Improved debug visibility in `eval_model` with additional `print` statements for key file paths and parameter values. - Increased `debug_subset_len` from 2 to 50 in `eval_model` for broader test coverage during debug mode. - Updated `exp.ipynb` to align with changes in model handling and evaluation behavior in `eval_model`.
1 parent e1f241b commit 9bc0ffb

File tree

4 files changed

+1298
-1966
lines changed

4 files changed

+1298
-1966
lines changed

MUSE/eval.py

Lines changed: 88 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,25 @@
99
from typing import List, Dict, Literal
1010
from pandas import DataFrame
1111

12-
import importlib.util
13-
import os
12+
import sys
13+
14+
curr_dir = os.path.dirname(__file__)
15+
PROJECT_DIR = os.path.abspath(os.path.join(curr_dir, '..', '..'))
16+
Unlearn_Simple_DIR = os.path.join(PROJECT_DIR, 'Unlearn-Simple')
17+
MUSE_DIR = os.path.join(Unlearn_Simple_DIR, 'MUSE')
18+
19+
sys.path.append(os.path.join(PROJECT_DIR, 'src'))
20+
21+
# print sys paths that are being used for importing
22+
# print("Current sys.path:")
23+
# for path in sys.path:
24+
# print(path)
25+
# sys.path.append(os.path.join(PROJECT_DIR, 'src'))
26+
import input_loss_landscape.utils as input_loss_landscape_utils
1427

15-
input_loss_landscape_utils_path = os.path.abspath(os.path.join(os.getcwd(), '..', '..', 'src', 'input_loss_landscape', 'utils.py'))
16-
spec = importlib.util.spec_from_file_location("input_loss_landscape_utils", input_loss_landscape_utils_path)
17-
input_loss_landscape_utils = importlib.util.module_from_spec(spec)
18-
spec.loader.exec_module(input_loss_landscape_utils)
28+
input_loss_landscape_eval = input_loss_landscape_utils.input_loss_landscape_eval
1929

2030
input_loss_landscape_eval = input_loss_landscape_utils.input_loss_landscape_eval
21-
print(f"Current working directory: {os.getcwd()}") # Ensure the current working directory is set correctly
2231

2332

2433

@@ -45,6 +54,7 @@ def eval_model(
4554
knowmem_retain_qa_icl_file: str | None = None,
4655
temp_dir: str | None = None,
4756
DEBUG: bool = False,
57+
kwargs: dict = {},
4858
): # -> Dict[str, float]:
4959
# Argument sanity check
5060
if not metrics:
@@ -56,6 +66,7 @@ def eval_model(
5666
raise ValueError(f"Invalid corpus. `corpus` should be either 'news' or 'books'.")
5767
if corpus is not None:
5868
verbmem_forget_file = DEFAULT_DATA[corpus]['verbmem_forget_file'] if verbmem_forget_file is None else verbmem_forget_file
69+
print(f"{privleak_forget_file=}, {privleak_retain_file=}, {privleak_holdout_file=}")
5970
privleak_forget_file = DEFAULT_DATA[corpus]['privleak_forget_file'] if privleak_forget_file is None else privleak_forget_file
6071
privleak_retain_file = DEFAULT_DATA[corpus]['privleak_retain_file'] if privleak_retain_file is None else privleak_retain_file
6172
privleak_holdout_file = DEFAULT_DATA[corpus]['privleak_holdout_file'] if privleak_holdout_file is None else privleak_holdout_file
@@ -66,13 +77,13 @@ def eval_model(
6677

6778
out = {}
6879
model = model.to('cuda')
69-
debug_subset_len = 2 if DEBUG else None
80+
debug_subset_len = 50 if DEBUG else None
7081
print(f"{DEBUG=}")
7182
plots = {}
7283

7384
# 1. verbmem_f
7485
if 'verbmem_f' in metrics:
75-
data = read_json(verbmem_forget_file)
86+
data = read_json(os.path.join(MUSE_DIR, verbmem_forget_file))
7687
if DEBUG:
7788
data = data[:debug_subset_len]
7889
agg, log = eval_verbmem(
@@ -88,32 +99,57 @@ def eval_model(
8899

89100
# 2. privleak
90101
if 'privleak' in metrics:
91-
forget_data = read_json(privleak_forget_file)
92-
retain_data = read_json(privleak_retain_file)
93-
holdout_data = read_json(privleak_holdout_file)
102+
forget_data = read_json(os.path.join(MUSE_DIR, privleak_forget_file))
103+
retain_data = read_json(os.path.join(MUSE_DIR, privleak_retain_file))
104+
holdout_data = read_json(os.path.join(MUSE_DIR, privleak_holdout_file))
94105
if DEBUG:
95106
forget_data = forget_data[:debug_subset_len]
96107
retain_data = retain_data[:debug_subset_len]
97108
holdout_data = holdout_data[:debug_subset_len]
98109

99110
privleak_output_dir = os.path.abspath(os.path.join(temp_dir, "privleak") if temp_dir is not None else None)
100-
auc, log, privleak_plots = eval_privleak(
101-
forget_data=forget_data,
102-
retain_data=retain_data,
103-
holdout_data=holdout_data,
104-
model=model, tokenizer=tokenizer,
105-
plot_dir=privleak_output_dir
106-
)
107-
if temp_dir is not None:
108-
write_json(auc, os.path.join(temp_dir, "privleak/auc.json"))
109-
write_json(log, os.path.join(temp_dir, "privleak/log.json"))
111+
create_new_files = kwargs.get('create_new_files', {})
112+
create_new_privleak_files = create_new_files.get('privleak', True)
113+
auc_path = os.path.join(privleak_output_dir, "auc.json")
114+
log_path = os.path.join(privleak_output_dir, "log.json")
115+
plots_dir = os.path.join(privleak_output_dir, "plots")
116+
117+
if create_new_privleak_files:
118+
auc, log, privleak_plots = eval_privleak(
119+
forget_data=forget_data,
120+
retain_data=retain_data,
121+
holdout_data=holdout_data,
122+
model=model, tokenizer=tokenizer,
123+
plot_dir=privleak_output_dir
124+
)
125+
if temp_dir is not None:
126+
write_json(auc, auc_path)
127+
write_json(log, log_path)
128+
# save plots
129+
os.makedirs(plots_dir, exist_ok=True)
130+
for plot_name, plot_obj in privleak_plots.items():
131+
plot_path = os.path.join(plots_dir, f"{plot_name}.png")
132+
plot_obj.savefig(plot_path)
133+
plot_obj.clf()
134+
135+
else:
136+
# load auc, log, privleak_plots
137+
auc = read_json(auc_path) if os.path.exists(auc_path) else {}
138+
log = read_json(log_path) if os.path.exists(log_path) else {}
139+
privleak_plots = {}
140+
if os.path.isdir(plots_dir):
141+
for plot_file in os.listdir(plots_dir):
142+
if plot_file.endswith(".png"):
143+
privleak_plots[os.path.splitext(plot_file)[0]] = os.path.join(plots_dir, plot_file)
144+
145+
110146
out['privleak'] = (auc[privleak_auc_key] - AUC_RETRAIN[corpus][privleak_auc_key]) / AUC_RETRAIN[corpus][privleak_auc_key] * 100
111147
plots['privleak'] = privleak_plots
112148

113149
# 3. knowmem_f
114150
if 'knowmem_f' in metrics:
115-
qa = read_json(knowmem_forget_qa_file)
116-
icl = read_json(knowmem_forget_qa_icl_file)
151+
qa = read_json(os.path.join(MUSE_DIR, knowmem_forget_qa_file))
152+
icl = read_json(os.path.join(MUSE_DIR, knowmem_forget_qa_icl_file))
117153
if DEBUG:
118154
qa = qa[:debug_subset_len]
119155
icl = icl[:debug_subset_len]
@@ -132,8 +168,8 @@ def eval_model(
132168

133169
# 4. knowmem_r
134170
if 'knowmem_r' in metrics:
135-
qa = read_json(knowmem_retain_qa_file)
136-
icl = read_json(knowmem_retain_qa_icl_file)
171+
qa = read_json(os.path.join(MUSE_DIR, knowmem_retain_qa_file))
172+
icl = read_json(os.path.join(MUSE_DIR, knowmem_retain_qa_icl_file))
137173
if DEBUG:
138174
qa = qa[:debug_subset_len]
139175
icl = icl[:debug_subset_len]
@@ -152,9 +188,10 @@ def eval_model(
152188

153189
# 5. loss_landscape
154190
if 'loss_landscape' in metrics:
155-
forget_data = read_json(privleak_forget_file)
156-
retain_data = read_json(privleak_retain_file)
157-
holdout_data = read_json(privleak_holdout_file)
191+
print(f"{os.path.abspath(privleak_forget_file)=}")
192+
forget_data = read_json(os.path.join(MUSE_DIR, privleak_forget_file))
193+
retain_data = read_json(os.path.join(MUSE_DIR, privleak_retain_file))
194+
holdout_data = read_json(os.path.join(MUSE_DIR, privleak_holdout_file))
158195
if DEBUG:
159196
forget_data = forget_data[:debug_subset_len]
160197
retain_data = retain_data[:debug_subset_len]
@@ -192,6 +229,7 @@ def load_then_eval_models(
192229
metrics: List[str] = SUPPORTED_METRICS,
193230
temp_dir: str = "temp",
194231
DEBUG: bool = False,
232+
kwargs: dict = {},
195233
): # -> DataFrame:
196234
print(out_file)
197235
# Argument sanity check
@@ -208,11 +246,28 @@ def load_then_eval_models(
208246
model = load_model(model_dir)
209247
tokenizer = load_tokenizer(tokenizer_dir)
210248

211-
return eval_model(
249+
privleak_files = kwargs.get('privleak_files', {})
250+
if privleak_files:
251+
privleak_forget_file = privleak_files.get('privleak_forget_file', None)
252+
privleak_retain_file = privleak_files.get('privleak_retain_file', None)
253+
privleak_holdout_file = privleak_files.get('privleak_holdout_file', None)
254+
255+
return eval_model(
256+
model, tokenizer, metrics, corpus,
257+
temp_dir=os.path.join(temp_dir, name),
258+
DEBUG=DEBUG,
259+
privleak_forget_file = privleak_forget_file,
260+
privleak_retain_file=privleak_retain_file,
261+
privleak_holdout_file=privleak_holdout_file,
262+
kwargs=kwargs,
263+
)
264+
265+
else:
266+
return eval_model(
212267
model, tokenizer, metrics, corpus,
213268
temp_dir=os.path.join(temp_dir, name),
214269
DEBUG=DEBUG
215-
)
270+
)
216271
# res, plots = eval_model(
217272
# model, tokenizer, metrics, corpus,
218273
# temp_dir=os.path.join(temp_dir, name),
@@ -235,4 +290,5 @@ def load_then_eval_models(
235290
parser.add_argument('--metrics', type=str, nargs='+', default=SUPPORTED_METRICS)
236291
args = parser.parse_args()
237292

238-
load_then_eval_models(**vars(args))
293+
load_then_eval_models(**vars(args))
294+

0 commit comments

Comments
 (0)