Skip to content

Commit 3c6cb78

Browse files
authored
Merge pull request #222 from atticusg/main
fix attention in qwen model
2 parents baa3efa + ac44a84 commit 3c6cb78

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pyvene/models/qwen2/modelings_intervenable_qwen2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,15 @@
3636
"mlp_output": ("hidden_size",),
3737
"mlp_input": ("hidden_size",),
3838
"attention_value_output": ("hidden_size",),
39-
"head_attention_value_output": ("head_dim",),
39+
"head_attention_value_output": ("hidden_size/num_attention_heads",),
4040
"attention_output": ("hidden_size",),
4141
"attention_input": ("hidden_size",),
4242
"query_output": ("hidden_size",),
4343
"key_output": ("hidden_size",),
4444
"value_output": ("hidden_size",),
45-
"head_query_output": ("head_dim",),
46-
"head_key_output": ("head_dim",),
47-
"head_value_output": ("head_dim",),
45+
"head_query_output": ("hidden_size/num_attention_heads",),
46+
"head_key_output": ("hidden_size/num_attention_heads",),
47+
"head_value_output": ("hidden_size/num_attention_heads",),
4848
}
4949

5050
"""qwen2 model with LM head"""
@@ -74,4 +74,4 @@ def create_qwen2(
7474
torch_dtype=dtype,
7575
)
7676
print("loaded model")
77-
return config, tokenizer, model
77+
return config, tokenizer, model

0 commit comments

Comments
 (0)