Skip to content

Commit 075315d

Browse files
committed
add test/prototype/test_quantized_training.py
1 parent fe433de commit 075315d

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

test/prototype/test_parq.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@
5151
torch_version_at_least,
5252
)
5353

54-
_DEVICE = torch.device(torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu")
54+
_DEVICE = torch.device(
55+
torch.accelerator.current_accelerator().type
56+
if torch.accelerator.is_available()
57+
else "cpu"
58+
)
5559

5660

5761
class M(nn.Module):

test/prototype/test_quantized_training.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@
3939
if common_utils.SEED is None:
4040
common_utils.SEED = 1234
4141

42-
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + (["xpu"] if torch.xpu.is_available() else [])
42+
_DEVICES = (
43+
["cpu"]
44+
+ (["cuda"] if torch.cuda.is_available() else [])
45+
+ (["xpu"] if torch.xpu.is_available() else [])
46+
)
4347
_DEVICE = get_current_accelerator_device()
4448

4549

@@ -184,7 +188,9 @@ def test_int8_weight_only_training(self, compile, device):
184188
],
185189
)
186190
@parametrize("module_swap", [False, True])
187-
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
191+
@pytest.mark.skipif(
192+
not torch.accelerator.is_available(), reason="GPU not available"
193+
)
188194
def test_int8_mixed_precision_training(self, compile, config, module_swap):
189195
_reset()
190196
bsize = 64
@@ -223,7 +229,9 @@ def snr(ref, actual):
223229

224230
@pytest.mark.skip("Flaky on CI")
225231
@parametrize("compile", [False, True])
226-
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
232+
@pytest.mark.skipif(
233+
not torch.accelerator.is_available(), reason="GPU not available"
234+
)
227235
def test_bitnet_training(self, compile):
228236
# reference implementation
229237
# https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
@@ -298,7 +306,7 @@ def world_size(self) -> int:
298306
return _FSDP_WORLD_SIZE
299307

300308
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
301-
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
309+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
302310
def test_fsdp2_correctness(self):
303311
mp_policy = MixedPrecisionPolicy()
304312

@@ -389,14 +397,18 @@ def _run_subtest(self, args):
389397
)
390398

391399
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
392-
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
400+
@pytest.mark.skipif(
401+
not torch.accelerator.is_available(), reason="GPU not available"
402+
)
393403
def test_precompute_bitnet_scale(self):
394404
from torchao.prototype.quantized_training.bitnet import (
395405
get_bitnet_scale,
396406
precompute_bitnet_scale_for_fsdp,
397407
)
398408

399-
model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).to(_DEVICE)
409+
model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).to(
410+
_DEVICE
411+
)
400412
model_fsdp = copy.deepcopy(model)
401413
quantize_(model_fsdp, bitnet_training())
402414
fully_shard(model_fsdp)

0 commit comments

Comments
 (0)