Skip to content

Commit d5f7ac3

Browse files
author
Samuel Burbulla
committed
[Cleanup] Add new section to CHANGELOG
1 parent 17e89ab commit d5f7ac3

File tree

4 files changed

+23
-19
lines changed

4 files changed

+23
-19
lines changed

CHANGELOG.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# CHANGELOG
22

3-
## 0.1
3+
## 0.2.0
4+
5+
- Add `Attention` base class, `MultiHeadAttention`, and `ScaledDotProductAttention` classes.
6+
7+
## 0.1.0
48

59
- Move all content of `__init__.py` files to sub-modules.
610
- Add `Trainer` class to replace `operator.fit` method.
@@ -24,7 +28,6 @@
2428
- Add `benchmarks` infrastructure.
2529
- An `Operator` now takes a `device` argument.
2630
- Add `QuantileScaler` class.
27-
- Add `Attention` base class, `MultiHeadAttention`, and `ScaledDotProductAttention` classes.
2831

2932
## 0.0.0 (2024-02-22)
3033

src/continuiti/networks/attention.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
class Attention(nn.Module):
1313
"""Base class for various attention implementations.
1414
15-
Attention assigns different parts of an input varying importance without set kernels. The importance of different
16-
components is designated using "soft" weights. These weights are assigned according to specific algorithms (e.g.
15+
Attention assigns different parts of an input varying importance without set
16+
kernels. The importance of different components is designated using "soft"
17+
weights. These weights are assigned according to specific algorithms (e.g.
1718
scaled-dot-product attention).
18-
1919
"""
2020

2121
def __init__(self):
@@ -26,7 +26,7 @@ def forward(
2626
self,
2727
query: torch.Tensor,
2828
key: torch.Tensor,
29-
value: torch,
29+
value: torch.Tensor,
3030
attn_mask: torch.Tensor = None,
3131
) -> torch.Tensor:
3232
"""Calculates the attention scores.

src/continuiti/networks/multi_head_attention.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
class MultiHeadAttention(Attention):
1515
r"""Multi-Head Attention module.
1616
17-
Module as described in the paper [Attention is All you Need](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)
18-
with optional bias for the projections. This implementation allows to use attention implementations other than the
19-
standard scaled dot product attention implemented by the MultiheadAttention PyTorch module.
17+
Module as described in the paper [Attention is All you
18+
Need](https://proceedings.neurips.cc/paper_files/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf)
19+
with optional bias for the projections. This implementation allows to use
20+
attention implementations other than the standard scaled dot product
21+
attention implemented by the MultiheadAttention PyTorch module.
2022
2123
$$MultiHead(Q,K,V)=Concat(head_1,\dots,head_n)W^O + b^O$$
2224
@@ -67,7 +69,7 @@ def forward(
6769
self,
6870
query: torch.Tensor,
6971
key: torch.Tensor,
70-
value: torch,
72+
value: torch.Tensor,
7173
attn_mask: torch.Tensor = None,
7274
) -> torch.Tensor:
7375
r"""Compute the attention scores.

src/continuiti/networks/scaled_dot_product_attention.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
class ScaledDotProductAttention(Attention):
1313
"""Scaled dot product attention module.
1414
15-
This module is a wrapper for the torch implementation of the scaled dot product attention mechanism as described in
16-
the paper "Attention Is All You Need" by Vaswani et al. (2017). This attention mechanism computes the attention
17-
weights based on the dot product of the query and key matrices, scaled by the square root of the dimension of the
18-
key vectors. The weights are then applied to the value vectors to obtain the final output.
15+
This module is a wrapper for the torch implementation of the scaled dot
16+
product attention mechanism as described in the paper "Attention Is All You
17+
Need" by Vaswani et al. (2017). This attention mechanism computes the
18+
attention weights based on the dot product of the query and key matrices,
19+
scaled by the square root of the dimension of the key vectors. The weights
20+
are then applied to the value vectors to obtain the final output.
1921
"""
2022

2123
def __init__(self, dropout_p: float = 0.0):
@@ -26,13 +28,10 @@ def forward(
2628
self,
2729
query: torch.Tensor,
2830
key: torch.Tensor,
29-
value: torch,
31+
value: torch.Tensor,
3032
attn_mask: torch.Tensor = None,
3133
) -> torch.Tensor:
32-
if self.training:
33-
dropout_p = self.dropout_p
34-
else:
35-
dropout_p = 0.0
34+
dropout_p = self.dropout_p if self.training else 0.0
3635
return scaled_dot_product_attention(
3736
query=query,
3837
key=key,

0 commit comments

Comments
 (0)