Skip to content

Commit 0571aab

Browse files
add tests
1 parent 468b120 commit 0571aab

File tree

2 files changed

+213
-0
lines changed

2 files changed

+213
-0
lines changed

tests/networks/test_multi_head.py

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import pytest
2+
import torch
3+
import torch.nn as nn
4+
5+
from continuiti.networks import MultiHeadAttention, ScaledDotProductAttention
6+
7+
8+
@pytest.fixture(scope="session")
9+
def some_multi_head_attn():
10+
return MultiHeadAttention(
11+
hidden_dim=32,
12+
n_heads=4,
13+
attention=ScaledDotProductAttention(dropout_p=0.25),
14+
bias=True,
15+
)
16+
17+
18+
@pytest.fixture(scope="class")
19+
def random_qkv():
20+
batch_size = 3
21+
target_length = 5
22+
source_length = 7
23+
embedding_dim = 8
24+
25+
q = torch.rand(batch_size, target_length, embedding_dim)
26+
k = torch.rand(batch_size, source_length, embedding_dim)
27+
v = torch.rand(batch_size, source_length, embedding_dim)
28+
return q, k, v
29+
30+
31+
class TestMultiHeadAttention:
32+
def test_can_initialize(self, some_multi_head_attn):
33+
assert isinstance(some_multi_head_attn, MultiHeadAttention)
34+
35+
def test_output_shape(self, some_multi_head_attn):
36+
batch_size = 3
37+
query_size = 5
38+
key_val_size = 7
39+
40+
query = torch.rand(batch_size, query_size, some_multi_head_attn.hidden_dim)
41+
key = torch.rand(batch_size, key_val_size, some_multi_head_attn.hidden_dim)
42+
val = torch.rand(batch_size, key_val_size, some_multi_head_attn.hidden_dim)
43+
44+
out = some_multi_head_attn(query, key, val)
45+
46+
gt_attn = nn.MultiheadAttention(
47+
embed_dim=some_multi_head_attn.hidden_dim,
48+
num_heads=some_multi_head_attn.n_heads,
49+
batch_first=True,
50+
bias=True,
51+
)
52+
correct_out, _ = gt_attn(query, key, val)
53+
54+
assert out.shape == correct_out.shape
55+
56+
def test_zero_value(self, random_qkv):
57+
"""Edge case testing for correctness."""
58+
q, k, v = random_qkv
59+
v = torch.zeros(v.shape)
60+
61+
m_attn = MultiHeadAttention(q.size(-1), 4, bias=False)
62+
63+
# V = 0 -> attn score == 0
64+
out = m_attn(q, k, v)
65+
assert torch.allclose(out, torch.zeros(out.shape))
66+
67+
def test_gradient_flow(self, some_multi_head_attn):
68+
hidden_size = 32
69+
some_multi_head_attn.eval() # Turn off dropout or other stochastic processes
70+
query = key = value = torch.rand((10, 5, hidden_size), requires_grad=True)
71+
output = some_multi_head_attn(
72+
value,
73+
key,
74+
query,
75+
)
76+
output.sum().backward()
77+
78+
assert query.grad is not None, "Gradients not flowing to query"
79+
assert key.grad is not None, "Gradients not flowing to key"
80+
assert value.grad is not None, "Gradients not flowing to value"
81+
82+
def test_equal_to_torch(self, random_qkv):
83+
q, k, v = random_qkv
84+
mask = torch.rand(q.size(0), q.size(1), k.size(1)) < 0.2
85+
86+
heads = 2
87+
embedding_dim = q.size(-1)
88+
89+
gt_attn = nn.MultiheadAttention(q.size(-1), heads, batch_first=True)
90+
attn = MultiHeadAttention(
91+
hidden_dim=q.size(-1),
92+
n_heads=heads,
93+
attention=ScaledDotProductAttention(dropout_p=0.0),
94+
bias=True,
95+
)
96+
97+
# align in projection
98+
attn.key_project.weight = nn.Parameter(
99+
gt_attn.in_proj_weight[embedding_dim : 2 * embedding_dim, :]
100+
)
101+
attn.key_project.bias = nn.Parameter(
102+
gt_attn.in_proj_bias[embedding_dim : 2 * embedding_dim]
103+
)
104+
105+
attn.value_project.weight = nn.Parameter(
106+
gt_attn.in_proj_weight[2 * embedding_dim :, :]
107+
)
108+
attn.value_project.bias = nn.Parameter(
109+
gt_attn.in_proj_bias[2 * embedding_dim :]
110+
)
111+
112+
attn.query_project.weight = nn.Parameter(
113+
gt_attn.in_proj_weight[:embedding_dim, :]
114+
)
115+
attn.query_project.bias = nn.Parameter(gt_attn.in_proj_bias[:embedding_dim])
116+
117+
# align out projection
118+
attn.out_project.weight = nn.Parameter(gt_attn.out_proj.weight)
119+
attn.out_project.bias = nn.Parameter(gt_attn.out_proj.bias)
120+
121+
# forward pass
122+
out = attn(q, k, v, attn_mask=mask)
123+
124+
# torch applies masks differently to scaled-dot-product and multi-head attention (inversed).
125+
gt_mask = torch.repeat_interleave(mask, heads, 0).logical_not()
126+
ground_truth, _ = gt_attn(q, k, v, need_weights=False, attn_mask=gt_mask)
127+
128+
assert torch.allclose(
129+
out[~torch.isnan(out)], ground_truth[~torch.isnan(ground_truth)]
130+
)
131+
132+
def test_full_mask_identical_to_none(self, random_qkv):
133+
heads = 2
134+
q, k, v = random_qkv
135+
136+
mask = torch.ones(q.size(0), q.size(1), k.size(1))
137+
138+
attn = MultiHeadAttention(
139+
hidden_dim=q.size(-1),
140+
n_heads=heads,
141+
attention=ScaledDotProductAttention(dropout_p=0.0),
142+
bias=True,
143+
)
144+
145+
# forward pass
146+
out_masked = attn(q, k, v, attn_mask=mask)
147+
out_none = attn(q, k, v)
148+
149+
assert torch.allclose(out_masked, out_none)
150+
151+
def test_mask_all_but_one(self, random_qkv):
152+
q, k, v = random_qkv
153+
q.requires_grad = True
154+
k.requires_grad = True
155+
v.requires_grad = True
156+
157+
# Masks out the last kvs
158+
mask = torch.ones(q.size(0), q.size(1), k.size(1), dtype=torch.bool)
159+
mask[:, :, -1] = 0
160+
161+
attn = MultiHeadAttention(
162+
hidden_dim=q.size(-1),
163+
n_heads=2,
164+
attention=ScaledDotProductAttention(dropout_p=0.0),
165+
bias=True,
166+
)
167+
out = attn(q, k, v, attn_mask=mask)
168+
169+
eq = torch.sum(out)
170+
eq.backward()
171+
172+
assert not torch.any(torch.isnan(q.grad))
173+
assert not torch.any(
174+
torch.isclose(q.grad, torch.zeros(q.shape))
175+
) # all queries have a non-zero gradient
176+
177+
assert not torch.any(torch.isnan(v.grad))
178+
unmasked_rows = v.grad[:, :-1, :] # gradient on unmasked values is non-zero
179+
assert not torch.any(
180+
torch.isclose(unmasked_rows, torch.zeros(unmasked_rows.shape))
181+
)
182+
masked_row = v.grad[:, -1, :] # gradient on masked value is zero
183+
assert torch.allclose(masked_row, torch.zeros(masked_row.shape))
184+
185+
assert not torch.any(torch.isnan(k.grad))
186+
unmasked_rows = k.grad[:, :-1, :] # gradient on unmasked keys is non-zero
187+
assert not torch.any(
188+
torch.isclose(unmasked_rows, torch.zeros(unmasked_rows.shape))
189+
)
190+
masked_row = k.grad[:, -1, :] # gradient on masked key is zero
191+
assert torch.allclose(masked_row, torch.zeros(masked_row.shape))

tests/networks/test_scaled_dot.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
from torch.nn.functional import scaled_dot_product_attention
3+
4+
from continuiti.networks import ScaledDotProductAttention
5+
6+
7+
def test_forward_correct():
8+
batch_size = 3
9+
query_size = 5
10+
key_val_size = 7
11+
hidden_dim = 11
12+
13+
query = torch.rand(batch_size, query_size, hidden_dim)
14+
key = torch.rand(batch_size, key_val_size, hidden_dim)
15+
value = torch.rand(batch_size, key_val_size, hidden_dim)
16+
17+
attn = ScaledDotProductAttention()
18+
19+
out = attn(query, key, value)
20+
gt_out = scaled_dot_product_attention(query, key, value)
21+
22+
assert torch.allclose(out, gt_out)

0 commit comments

Comments
 (0)