Skip to content

Commit 8e8c16d

Browse files
committed
Fix HF warnings (attention_mask, pad_token_id) and add DEBUG mode for fast eval.
files changed: eval.py, metrics/knowmem.py, metrics/privleak.py, metrics/verbmem.py
1 parent 47abb2a commit 8e8c16d

File tree

4 files changed

+54
-20
lines changed

4 files changed

+54
-20
lines changed

MUSE/eval.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from metrics.verbmem import eval as eval_verbmem
1+
from metrics.verbmem import eval as eval_ve rbmem
22
from metrics.privleak import eval as eval_privleak
33
from metrics.knowmem import eval as eval_knowmem
44
from utils import load_model, load_tokenizer, write_csv, read_json, write_json
@@ -29,6 +29,7 @@ def eval_model(
2929
knowmem_retain_qa_file: str | None = None,
3030
knowmem_retain_qa_icl_file: str | None = None,
3131
temp_dir: str | None = None,
32+
DEBUG: bool = False,
3233
) -> Dict[str, float]:
3334
# Argument sanity check
3435
if not metrics:
@@ -50,10 +51,13 @@ def eval_model(
5051

5152
out = {}
5253
model = model.to('cuda')
54+
debug_subset_len = 3 if DEBUG else None
5355

5456
# 1. verbmem_f
5557
if 'verbmem_f' in metrics:
5658
data = read_json(verbmem_forget_file)
59+
if DEBUG:
60+
data = data[:debug_subset_len]
5761
agg, log = eval_verbmem(
5862
prompts=[d['prompt'] for d in data],
5963
gts=[d['gt'] for d in data],
@@ -67,10 +71,17 @@ def eval_model(
6771

6872
# 2. privleak
6973
if 'privleak' in metrics:
74+
forget_data = read_json(privleak_forget_file)
75+
retain_data = read_json(privleak_retain_file)
76+
holdout_data = read_json(privleak_holdout_file)
77+
if DEBUG:
78+
forget_data = forget_data[:debug_subset_len]
79+
retain_data = retain_data[:debug_subset_len]
80+
holdout_data = holdout_data[:debug_subset_len]
7081
auc, log = eval_privleak(
71-
forget_data=read_json(privleak_forget_file),
72-
retain_data=read_json(privleak_retain_file),
73-
holdout_data=read_json(privleak_holdout_file),
82+
forget_data=forget_data,
83+
retain_data=retain_data,
84+
holdout_data=holdout_data,
7485
model=model, tokenizer=tokenizer
7586
)
7687
if temp_dir is not None:
@@ -82,6 +93,9 @@ def eval_model(
8293
if 'knowmem_f' in metrics:
8394
qa = read_json(knowmem_forget_qa_file)
8495
icl = read_json(knowmem_forget_qa_icl_file)
96+
if DEBUG:
97+
qa = qa[:debug_subset_len]
98+
icl = icl[:debug_subset_len]
8599
agg, log = eval_knowmem(
86100
questions=[d['question'] for d in qa],
87101
answers=[d['answer'] for d in qa],
@@ -99,6 +113,9 @@ def eval_model(
99113
if 'knowmem_r' in metrics:
100114
qa = read_json(knowmem_retain_qa_file)
101115
icl = read_json(knowmem_retain_qa_icl_file)
116+
if DEBUG:
117+
qa = qa[:debug_subset_len]
118+
icl = icl[:debug_subset_len]
102119
agg, log = eval_knowmem(
103120
questions=[d['question'] for d in qa],
104121
answers=[d['answer'] for d in qa],
@@ -122,7 +139,8 @@ def load_then_eval_models(
122139
tokenizer_dir: str = LLAMA_DIR,
123140
out_file: str | None = None,
124141
metrics: List[str] = SUPPORTED_METRICS,
125-
temp_dir: str = "temp"
142+
temp_dir: str = "temp",
143+
DEBUG: bool = False,
126144
) -> DataFrame:
127145
print(out_file)
128146
# Argument sanity check
@@ -140,7 +158,8 @@ def load_then_eval_models(
140158
tokenizer = load_tokenizer(tokenizer_dir)
141159
res = eval_model(
142160
model, tokenizer, metrics, corpus,
143-
temp_dir=os.path.join(temp_dir, name)
161+
temp_dir=os.path.join(temp_dir, name),
162+
DEBUG=DEBUG
144163
)
145164
out.append({'name': name} | res)
146165
if out_file is not None: write_csv(out, out_file)

MUSE/metrics/knowmem.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,23 @@ def eval(
2828
for question, answer in tzip(questions, answers):
2929
prompt = general_prompt + f"Question: {question}\nAnswer: "
3030

31-
# Encode the `prompt` into `input_ids`
32-
input_ids = tokenizer(
31+
# Encode the `prompt` into `input_ids` and `attention_mask`
32+
inputs = tokenizer(
3333
prompt,
3434
return_tensors='pt',
35-
add_special_tokens=True).input_ids
35+
add_special_tokens=True
36+
)
37+
input_ids = inputs.input_ids
38+
attention_mask = inputs.attention_mask
3639

3740
# Use the `model` to generate the continuation of the `input_ids`.
3841
output_ids = model.generate(
39-
input_ids.to(model.device),
42+
input_ids=input_ids.to(model.device),
43+
attention_mask=attention_mask.to(model.device),
4044
max_new_tokens=max_new_tokens,
4145
do_sample=False,
42-
pad_token_id=tokenizer.pad_token_id)
46+
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
47+
)
4348
output_ids = output_ids[:, len(input_ids[0]):]
4449

4550
output = tokenizer.batch_decode(

MUSE/metrics/privleak.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,21 @@
1111

1212

1313
def compute_ppl(text: str, model, tokenizer, device='cuda'):
14-
input_ids = torch.tensor(tokenizer.encode(text)).unsqueeze(0)
15-
input_ids = input_ids.to(device)
14+
# Tokenize with attention_mask and padding
15+
inputs = tokenizer(
16+
text,
17+
return_tensors='pt',
18+
add_special_tokens=True
19+
)
20+
input_ids = inputs['input_ids'].to(device)
21+
attention_mask = inputs['attention_mask'].to(device)
1622
with torch.no_grad():
17-
outputs = model(input_ids, labels=input_ids)
23+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
1824
loss, logits = outputs[:2]
1925

2026
probabilities = torch.nn.functional.log_softmax(logits, dim=-1)
21-
all_prob = []
2227
input_ids_processed = input_ids[0][1:]
28+
all_prob = []
2329
for i, token_id in enumerate(input_ids_processed):
2430
probability = probabilities[0, i, token_id].item()
2531
all_prob.append(probability)

MUSE/metrics/verbmem.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,26 @@ def eval(
1111
):
1212
logger = RougeEvalLogger()
1313
for prompt, gt in tzip(prompts, gts):
14-
# Encode the `prompt` into `input_ids`
15-
input_ids = tokenizer(
14+
# Encode the `prompt` into `input_ids` and `attention_mask`
15+
inputs = tokenizer(
1616
prompt,
1717
return_tensors='pt',
1818
add_special_tokens=True
19-
).input_ids
19+
)
20+
input_ids = inputs.input_ids
21+
attention_mask = inputs.attention_mask
2022
assert len(input_ids) == 1
2123

2224
gt_ids = tokenizer(gt, return_tensors='pt', add_special_tokens=True).input_ids[:, :max_new_tokens]
2325

2426
# Use the `model` to generate the continuation of the `input_ids`.
2527
output_ids = model.generate(
26-
input_ids.to(model.device),
28+
input_ids=input_ids.to(model.device),
29+
attention_mask=attention_mask.to(model.device),
2730
max_new_tokens=max_new_tokens,
2831
do_sample=False,
29-
pad_token_id=tokenizer.pad_token_id)
32+
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
33+
)
3034
output_ids = output_ids[:, len(input_ids[0]):]
3135
output = tokenizer.batch_decode(
3236
output_ids,

0 commit comments

Comments
 (0)