Skip to content

Commit 221d3df

Browse files
committed
account for codebook scale in bsq
1 parent 4956bf7 commit 221d3df

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "vector-quantize-pytorch"
3-
version = "1.14.38"
3+
version = "1.14.39"
44
description = "Vector Quantization - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "lucidrains@gmail.com" }

tests/test_readme.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,11 @@ def test_rfsq():
201201
assert torch.all(quantized == quantized_out)
202202

203203
@pytest.mark.parametrize('spherical', (True, False))
204-
def test_lfq(spherical):
204+
@pytest.mark.parametrize('codebook_scale', (1., 0.5))
205+
def test_lfq(
206+
spherical,
207+
codebook_scale
208+
):
205209
from vector_quantize_pytorch import LFQ
206210

207211
# you can specify either dim or codebook_size
@@ -212,7 +216,8 @@ def test_lfq(spherical):
212216
dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined
213217
entropy_loss_weight = 0.1, # how much weight to place on entropy loss
214218
diversity_gamma = 1., # within entropy loss, how much weight to give to diversity of codes, taken from https://arxiv.org/abs/1911.05894
215-
spherical = spherical
219+
spherical = spherical,
220+
codebook_scale = codebook_scale
216221
)
217222

218223
image_feats = torch.randn(1, 16, 32, 32)

vector_quantize_pytorch/lookup_free_quantization.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def maybe_distributed_mean(t):
4444
def exists(v):
4545
return v is not None
4646

47+
def identity(t):
48+
return t
49+
4750
def default(*args):
4851
for arg in args:
4952
if exists(arg):
@@ -156,6 +159,7 @@ def __init__(
156159
# whether to use BSQ (binary spherical quantization)
157160

158161
self.spherical = spherical
162+
self.maybe_l2norm = (lambda t: l2norm(t) * self.codebook_scale) if spherical else identity
159163

160164
# entropy aux loss related weights
161165

@@ -220,8 +224,7 @@ def indices_to_codes(
220224

221225
codes = self.bits_to_codes(bits)
222226

223-
if self.spherical:
224-
codes = l2norm(codes)
227+
codes = self.maybe_l2norm(codes)
225228

226229
codes = rearrange(codes, '... c d -> ... (c d)')
227230

@@ -281,8 +284,7 @@ def forward(
281284

282285
# maybe l2norm
283286

284-
if self.spherical:
285-
x = l2norm(x)
287+
x = self.maybe_l2norm(x)
286288

287289
# quantize by eq 3.
288290

@@ -297,8 +299,7 @@ def forward(
297299

298300
# maybe l2norm
299301

300-
if self.spherical:
301-
quantized = l2norm(quantized)
302+
quantized = self.maybe_l2norm(quantized)
302303

303304
# use straight-through gradients (optionally with custom activation fn) if training
304305

@@ -313,8 +314,7 @@ def forward(
313314
if self.training:
314315
codebook = self.codebook
315316

316-
if self.spherical:
317-
codebook = l2norm(codebook)
317+
codebook = self.maybe_l2norm(codebook)
318318

319319
# the same as euclidean distance up to a constant
320320
distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)

0 commit comments

Comments
 (0)