Skip to content

Commit 9008fbe

Browse files
committed
fused
1 parent 6a77fd0 commit 9008fbe

File tree

4 files changed

+5
-7
lines changed

4 files changed

+5
-7
lines changed

megatron/fused_kernels/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@ def load(args):
66
return
77

88
if torch.version.hip is None:
9-
if torch.distributed.get_rank() == 0:
10-
print("running on CUDA devices")
9+
print("running on CUDA devices")
1110
from megatron.fused_kernels.cuda import load as load_kernels
1211
else:
13-
if torch.distributed.get_rank() == 0:
14-
print("running on ROCm devices")
12+
print("running on ROCm devices")
1513
from megatron.fused_kernels.rocm import load as load_kernels
1614

1715
load_kernels(args)

megatron/initialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
except ModuleNotFoundError:
3232
print('Wandb import failed', flush=True)
3333

34-
import megatron.fused_kernels as fused_kernels
34+
from megatron import fused_kernels
3535
from megatron import get_adlr_autoresume
3636
from megatron import get_args
3737
from megatron import get_tensorboard_writer

tools/checkpoint_loader_megatron.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
import megatron.fused_kernels as fused_kernels
8+
from megatron import fused_kernels
99

1010
def add_arguments(parser):
1111
group = parser.add_argument_group(title='Megatron loader')

tools/checkpoint_saver_megatron.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88

9-
import megatron.fused_kernels as fused_kernels
9+
from megatron import fused_kernels
1010

1111
def add_arguments(parser):
1212
group = parser.add_argument_group(title='Megatron saver')

0 commit comments

Comments
 (0)