Skip to content

Commit fe433de

Browse files
committed
add test/prototype/test_quantized_training.py
1 parent 346a339 commit fe433de

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

test/prototype/test_quantized_training.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,13 @@
3434
quantize_int8_rowwise,
3535
)
3636
from torchao.quantization.quant_api import quantize_
37+
from torchao.utils import get_current_accelerator_device
3738

3839
if common_utils.SEED is None:
3940
common_utils.SEED = 1234
4041

41-
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
42+
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + (["xpu"] if torch.xpu.is_available() else [])
43+
_DEVICE = get_current_accelerator_device()
4244

4345

4446
def _reset():
@@ -182,12 +184,12 @@ def test_int8_weight_only_training(self, compile, device):
182184
],
183185
)
184186
@parametrize("module_swap", [False, True])
185-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
187+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
186188
def test_int8_mixed_precision_training(self, compile, config, module_swap):
187189
_reset()
188190
bsize = 64
189191
embed_dim = 64
190-
device = "cuda"
192+
device = _DEVICE
191193

192194
linear = nn.Linear(embed_dim, embed_dim, device=device)
193195
linear_int8mp = copy.deepcopy(linear)
@@ -221,7 +223,7 @@ def snr(ref, actual):
221223

222224
@pytest.mark.skip("Flaky on CI")
223225
@parametrize("compile", [False, True])
224-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
226+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
225227
def test_bitnet_training(self, compile):
226228
# reference implementation
227229
# https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
@@ -246,7 +248,7 @@ def forward(self, x):
246248
_reset()
247249
bsize = 4
248250
embed_dim = 32
249-
device = "cuda"
251+
device = _DEVICE
250252

251253
# only use 1 matmul shape to reduce triton autotune time
252254
model_ref = nn.Sequential(
@@ -296,7 +298,7 @@ def world_size(self) -> int:
296298
return _FSDP_WORLD_SIZE
297299

298300
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
299-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
301+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
300302
def test_fsdp2_correctness(self):
301303
mp_policy = MixedPrecisionPolicy()
302304

@@ -342,7 +344,7 @@ def _run_subtest(self, args):
342344
dropout_p=0,
343345
)
344346
torch.manual_seed(42)
345-
base_model = Transformer(model_args).cuda()
347+
base_model = Transformer(model_args).to(_DEVICE)
346348
fsdp_model = copy.deepcopy(base_model)
347349

348350
quantize_(base_model.layers, quantize_fn)
@@ -362,7 +364,7 @@ def _run_subtest(self, args):
362364

363365
torch.manual_seed(42 + self.rank + 1)
364366
for iter_idx in range(5):
365-
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
367+
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device=_DEVICE)
366368
fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
367369
fsdp_loss = fsdp_model(inp).sum()
368370
fsdp_loss.backward()
@@ -387,14 +389,14 @@ def _run_subtest(self, args):
387389
)
388390

389391
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
390-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
392+
@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available")
391393
def test_precompute_bitnet_scale(self):
392394
from torchao.prototype.quantized_training.bitnet import (
393395
get_bitnet_scale,
394396
precompute_bitnet_scale_for_fsdp,
395397
)
396398

397-
model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).cuda()
399+
model = nn.Sequential(nn.Linear(32, 64), nn.GELU(), nn.Linear(64, 32)).to(_DEVICE)
398400
model_fsdp = copy.deepcopy(model)
399401
quantize_(model_fsdp, bitnet_training())
400402
fully_shard(model_fsdp)

0 commit comments

Comments
 (0)