Skip to content

Commit b6c1837

Browse files
authored
Fix rope ref implementation.
Differential Revision: D88780674 Pull Request resolved: #16168
1 parent 9eaea4a commit b6c1837

File tree

3 files changed

+47
-33
lines changed

3 files changed

+47
-33
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1701,31 +1701,38 @@ def rope(
17011701
input_tensor.shape[0], input_tensor.shape[1], input_tensor.shape[2], -1
17021702
)
17031703

1704-
_, s, h, hd = input_tensor.shape
1704+
_, seq, _, hd = input_tensor.shape
17051705

17061706
if hd % 2:
17071707
raise ValueError("Hidden dimension must be divisible by 2")
17081708

1709-
if sin_tensor.shape != (s, hd // 2) or cos_tensor.shape != (s, hd // 2):
1709+
if (
1710+
sin_tensor.size(-1) * 2 != hd
1711+
or cos_tensor.size(-1) * 2 != hd
1712+
or sin_tensor.size(0) < seq
1713+
or cos_tensor.size(0) < seq
1714+
):
17101715
raise ValueError(
1711-
f"sin_tensor and cos_tensor must have shape {s, hd // 2}. Got {sin_tensor.shape} and {cos_tensor.shape}"
1716+
f"sin_tensor and cos_tensor must have shape <kvseq (> {seq}) x {hd // 2}>. Got {sin_tensor.shape} and {cos_tensor.shape}"
17121717
)
17131718

17141719
if pos is not None:
1715-
if pos.shape != (input_tensor.shape[1],):
1720+
if pos.shape != (seq,):
17161721
raise ValueError(
17171722
f"pos must have shape {input_tensor.shape[1]}. Got {pos.shape}"
17181723
)
17191724
sin_tensor = sin_tensor[pos]
17201725
cos_tensor = cos_tensor[pos]
17211726

1727+
# seq x 1 x hd
17221728
sin_tensor = sin_tensor.unsqueeze(1)
17231729
cos_tensor = cos_tensor.unsqueeze(1)
17241730

1731+
# batch x seq x num_heads x head_dim_by_two
17251732
x0, x1 = input_tensor[..., ::2], input_tensor[..., 1::2]
1726-
rotated = torch.cat(
1727-
[x0 * cos_tensor - x1 * sin_tensor, x0 * sin_tensor + x1 * cos_tensor], dim=-1
1728-
)
1733+
o0 = x0 * cos_tensor - x1 * sin_tensor
1734+
o1 = x0 * sin_tensor + x1 * cos_tensor
1735+
rotated = torch.cat([o0.view(-1, 1), o1.view(-1, 1)], dim=-1)
17291736
return rotated.view(original_shape)
17301737

17311738

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,7 +1458,7 @@ def test_where_Scalar(self) -> None:
14581458
torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]], dtype=torch.float32),
14591459
torch.tensor([[0.0, 0.0]], dtype=torch.float32),
14601460
torch.tensor([[1.0, 1.0]], dtype=torch.float32),
1461-
torch.tensor([[[[1.0, 3.0, 2.0, 4.0]]]], dtype=torch.float32),
1461+
torch.tensor([[[[1.0, 2.0, 3.0, 4.0]]]], dtype=torch.float32),
14621462
),
14631463
(
14641464
"h2xhd4",
@@ -1469,7 +1469,7 @@ def test_where_Scalar(self) -> None:
14691469
torch.tensor([[0.0, 1.0]], dtype=torch.float32),
14701470
torch.tensor([[1.0, 0.0]], dtype=torch.float32),
14711471
torch.tensor(
1472-
[[[[1.0, -4.0, 2.0, 3.0], [5, -8.0, 6.0, 7.0]]]],
1472+
[[[[1.0, 2.0, -4.0, 3.0], [5, 6.0, -8.0, 7.0]]]],
14731473
dtype=torch.float32,
14741474
),
14751475
),
@@ -1489,8 +1489,8 @@ def test_where_Scalar(self) -> None:
14891489
torch.tensor(
14901490
[
14911491
[
1492-
[[1.0, -4.0, 2.0, 3.0], [5.0, -8.0, 6.0, 7.0]],
1493-
[[9.0, -12.0, 10.0, 11.0], [13.0, -16.0, 14.0, 15.0]],
1492+
[[1.0, 2.0, -4.0, 3.0], [5.0, 6.0, -8.0, 7.0]],
1493+
[[9.0, 10.0, -12.0, 11.0], [13.0, 14.0, -16.0, 15.0]],
14941494
]
14951495
],
14961496
dtype=torch.float32,
@@ -1512,8 +1512,8 @@ def test_where_Scalar(self) -> None:
15121512
torch.tensor(
15131513
[
15141514
[
1515-
[[1.0, -4.0, 2.0, 3.0], [5.0, -8.0, 6.0, 7.0]],
1516-
[[-10.0, 11.0, 9.0, 12.0], [-14.0, 15.0, 13.0, 16.0]],
1515+
[[1.0, 2.0, -4.0, 3.0], [5.0, 6.0, -8.0, 7.0]],
1516+
[[-10.0, 9.0, 11.0, 12.0], [-14.0, 13.0, 15.0, 16.0]],
15171517
]
15181518
],
15191519
dtype=torch.float32,

backends/cadence/generic/operators/op_rope.cpp

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,20 @@ Tensor& rope_out(
2323
const optional<Tensor>& pos,
2424
Tensor& out) {
2525
// Input shape is [1, seq, h, hd / 2, 2] or [1, seq, h, hd]
26-
const auto kSeq = input.size(1);
27-
const auto kH = input.size(2);
28-
const auto kHd = input.numel() / (kSeq * kH);
29-
for (int32_t s = 0; s < kSeq; ++s) {
30-
for (int32_t h = 0; h < kH; ++h) {
31-
for (int32_t hd_o = 0; hd_o < kHd / 2; ++hd_o) {
32-
float x_0 =
33-
input.const_data_ptr<float>()[s * kH * kHd + h * kHd + hd_o * 2];
34-
float x_1 =
35-
input
36-
.const_data_ptr<float>()[s * kH * kHd + h * kHd + hd_o * 2 + 1];
26+
const ssize_t seq_length = input.size(1);
27+
const ssize_t num_heads = input.size(2);
28+
const ssize_t head_dimension = input.numel() / (seq_length * num_heads);
29+
const ssize_t head_dimension_by_two = head_dimension / 2;
30+
for (int32_t s = 0; s < seq_length; ++s) {
31+
for (int32_t h = 0; h < num_heads; ++h) {
32+
for (int32_t hd_o = 0; hd_o < head_dimension_by_two; ++hd_o) {
33+
// Process 2 elements in head dimension at a time.
34+
const float x_0 = input.const_data_ptr<float>()
35+
[s * num_heads * head_dimension +
36+
h * head_dimension + hd_o * 2];
37+
const float x_1 = input.const_data_ptr<float>()
38+
[s * num_heads * head_dimension +
39+
h * head_dimension + hd_o * 2 + 1];
3740
int64_t token_id = s;
3841
if (pos.has_value()) {
3942
if (pos->scalar_type() == ::executorch::aten::ScalarType::Int) {
@@ -42,17 +45,21 @@ Tensor& rope_out(
4245
token_id = pos.has_value() ? pos->const_data_ptr<int64_t>()[s] : s;
4346
}
4447
}
45-
float sin =
46-
sin_tensor.const_data_ptr<float>()[token_id * kHd / 2 + hd_o];
47-
float cos =
48-
cos_tensor.const_data_ptr<float>()[token_id * kHd / 2 + hd_o];
4948

50-
float out_0 = x_0 * cos - x_1 * sin;
51-
float out_1 = x_0 * sin + x_1 * cos;
52-
out.mutable_data_ptr<float>()[s * kH * kHd + h * kHd + hd_o * 2] =
49+
const float sin = sin_tensor.const_data_ptr<
50+
float>()[token_id * head_dimension_by_two + hd_o];
51+
const float cos = cos_tensor.const_data_ptr<
52+
float>()[token_id * head_dimension_by_two + hd_o];
53+
54+
const float out_0 = x_0 * cos - x_1 * sin;
55+
out.mutable_data_ptr<float>()
56+
[s * num_heads * head_dimension + h * head_dimension + hd_o * 2] =
5357
out_0;
54-
out.mutable_data_ptr<float>()[s * kH * kHd + h * kHd + hd_o * 2 + 1] =
55-
out_1;
58+
59+
const float out_1 = x_0 * sin + x_1 * cos;
60+
out.mutable_data_ptr<float>()
61+
[s * num_heads * head_dimension + h * head_dimension + hd_o * 2 +
62+
1] = out_1;
5663
}
5764
}
5865
}

0 commit comments

Comments
 (0)