3434 quantize_int8_rowwise ,
3535)
3636from torchao .quantization .quant_api import quantize_
37+ from torchao .utils import get_current_accelerator_device
3738
3839if common_utils .SEED is None :
3940 common_utils .SEED = 1234
4041
41- _DEVICES = ["cpu" ] + (["cuda" ] if torch .cuda .is_available () else [])
42+ _DEVICES = ["cpu" ] + (["cuda" ] if torch .cuda .is_available () else []) + (["xpu" ] if torch .xpu .is_available () else [])
43+ _DEVICE = get_current_accelerator_device ()
4244
4345
4446def _reset ():
@@ -182,12 +184,12 @@ def test_int8_weight_only_training(self, compile, device):
182184 ],
183185 )
184186 @parametrize ("module_swap" , [False , True ])
185- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
187+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
186188 def test_int8_mixed_precision_training (self , compile , config , module_swap ):
187189 _reset ()
188190 bsize = 64
189191 embed_dim = 64
190- device = "cuda"
192+ device = _DEVICE
191193
192194 linear = nn .Linear (embed_dim , embed_dim , device = device )
193195 linear_int8mp = copy .deepcopy (linear )
@@ -221,7 +223,7 @@ def snr(ref, actual):
221223
222224 @pytest .mark .skip ("Flaky on CI" )
223225 @parametrize ("compile" , [False , True ])
224- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
226+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
225227 def test_bitnet_training (self , compile ):
226228 # reference implementation
227229 # https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
@@ -246,7 +248,7 @@ def forward(self, x):
246248 _reset ()
247249 bsize = 4
248250 embed_dim = 32
249- device = "cuda"
251+ device = _DEVICE
250252
251253 # only use 1 matmul shape to reduce triton autotune time
252254 model_ref = nn .Sequential (
@@ -296,7 +298,7 @@ def world_size(self) -> int:
296298 return _FSDP_WORLD_SIZE
297299
298300 @skip_if_lt_x_gpu (_FSDP_WORLD_SIZE )
299- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
301+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
300302 def test_fsdp2_correctness (self ):
301303 mp_policy = MixedPrecisionPolicy ()
302304
@@ -342,7 +344,7 @@ def _run_subtest(self, args):
342344 dropout_p = 0 ,
343345 )
344346 torch .manual_seed (42 )
345- base_model = Transformer (model_args ).cuda ( )
347+ base_model = Transformer (model_args ).to ( _DEVICE )
346348 fsdp_model = copy .deepcopy (base_model )
347349
348350 quantize_ (base_model .layers , quantize_fn )
@@ -362,7 +364,7 @@ def _run_subtest(self, args):
362364
363365 torch .manual_seed (42 + self .rank + 1 )
364366 for iter_idx in range (5 ):
365- inp = torch .randint (0 , vocab_size , (batch_size , seq_len ), device = "cuda" )
367+ inp = torch .randint (0 , vocab_size , (batch_size , seq_len ), device = _DEVICE )
366368 fsdp_optim .zero_grad (set_to_none = (iter_idx % 2 == 0 ))
367369 fsdp_loss = fsdp_model (inp ).sum ()
368370 fsdp_loss .backward ()
@@ -387,14 +389,14 @@ def _run_subtest(self, args):
387389 )
388390
389391 @skip_if_lt_x_gpu (_FSDP_WORLD_SIZE )
390- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
392+ @pytest .mark .skipif (not torch .accelerator .is_available (), reason = "GPU not available" )
391393 def test_precompute_bitnet_scale (self ):
392394 from torchao .prototype .quantized_training .bitnet import (
393395 get_bitnet_scale ,
394396 precompute_bitnet_scale_for_fsdp ,
395397 )
396398
397- model = nn .Sequential (nn .Linear (32 , 64 ), nn .GELU (), nn .Linear (64 , 32 )).cuda ( )
399+ model = nn .Sequential (nn .Linear (32 , 64 ), nn .GELU (), nn .Linear (64 , 32 )).to ( _DEVICE )
398400 model_fsdp = copy .deepcopy (model )
399401 quantize_ (model_fsdp , bitnet_training ())
400402 fully_shard (model_fsdp )
0 commit comments