We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4af4cbf commit ac44a84Copy full SHA for ac44a84
pyvene/models/qwen2/modelings_intervenable_qwen2.py
@@ -36,15 +36,15 @@
36
"mlp_output": ("hidden_size",),
37
"mlp_input": ("hidden_size",),
38
"attention_value_output": ("hidden_size",),
39
- "head_attention_value_output": ("head_dim",),
+ "head_attention_value_output": ("hidden_size/num_attention_heads",),
40
"attention_output": ("hidden_size",),
41
"attention_input": ("hidden_size",),
42
"query_output": ("hidden_size",),
43
"key_output": ("hidden_size",),
44
"value_output": ("hidden_size",),
45
- "head_query_output": ("head_dim",),
46
- "head_key_output": ("head_dim",),
47
- "head_value_output": ("head_dim",),
+ "head_query_output": ("hidden_size/num_attention_heads",),
+ "head_key_output": ("hidden_size/num_attention_heads",),
+ "head_value_output": ("hidden_size/num_attention_heads",),
48
}
49
50
"""qwen2 model with LM head"""
@@ -74,4 +74,4 @@ def create_qwen2(
74
torch_dtype=dtype,
75
)
76
print("loaded model")
77
- return config, tokenizer, model
+ return config, tokenizer, model
0 commit comments