We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent cc59505 commit 5939aceCopy full SHA for 5939ace
src/diffusers/utils/torch_utils.py
@@ -18,7 +18,7 @@
18
from typing import List, Optional, Tuple, Union
19
20
from . import logging
21
-from .import_utils import is_torch_available, is_torch_version
+from .import_utils import is_torch_available, is_torch_npu_available, is_torch_version
22
23
24
if is_torch_available():
@@ -166,6 +166,8 @@ def get_torch_cuda_device_capability():
166
def get_device():
167
if torch.cuda.is_available():
168
return "cuda"
169
+ elif is_torch_npu_available():
170
+ return "npu"
171
elif hasattr(torch, "xpu") and torch.xpu.is_available():
172
return "xpu"
173
else:
0 commit comments