Skip to content

Commit 6a77fd0

Browse files
committed
fix
1 parent 5432115 commit 6a77fd0

File tree

4 files changed

+16
-9
lines changed

4 files changed

+16
-9
lines changed

megatron/fused_kernels/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1+
import torch
2+
3+
14
def load(args):
25
if args.use_kernels_from_apex:
36
return
47

5-
if args.device == "cuda":
8+
if torch.version.hip is None:
9+
if torch.distributed.get_rank() == 0:
10+
print("running on CUDA devices")
611
from megatron.fused_kernels.cuda import load as load_kernels
7-
elif args.device == "rocm":
12+
else:
13+
if torch.distributed.get_rank() == 0:
14+
print("running on ROCm devices")
815
from megatron.fused_kernels.rocm import load as load_kernels
916

1017
load_kernels(args)

megatron/initialize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
except ModuleNotFoundError:
3232
print('Wandb import failed', flush=True)
3333

34-
from megatron.fused_kernels import cuda
34+
import megatron.fused_kernels as fused_kernels
3535
from megatron import get_adlr_autoresume
3636
from megatron import get_args
3737
from megatron import get_tensorboard_writer
@@ -198,11 +198,11 @@ def _compile_dependencies():
198198
if torch.distributed.get_rank() == 0:
199199
start_time = time.time()
200200
print('> compiling and loading fused kernels ...', flush=True)
201-
cuda.load(args)
201+
fused_kernels.load(args)
202202
torch.distributed.barrier()
203203
else:
204204
torch.distributed.barrier()
205-
cuda.load(args)
205+
fused_kernels.load(args)
206206
# Simple barrier to make sure all ranks have passed the
207207
# compilation phase successfully before moving on to the
208208
# rest of the program. We think this might ensure that

tools/checkpoint_loader_megatron.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77

8-
from megatron.fused_kernels import cuda
8+
import megatron.fused_kernels as fused_kernels
99

1010
def add_arguments(parser):
1111
group = parser.add_argument_group(title='Megatron loader')
@@ -133,7 +133,7 @@ def get_models(count, dtype, pre_process, post_process):
133133
set_global_variables(margs)
134134
mpu.initialize.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size)
135135
mpu.initialize.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size)
136-
cuda.load(margs)
136+
fused_kernels.load(margs)
137137

138138
# Get true (non-padded) vocab size
139139
if args.true_vocab_size is not None:

tools/checkpoint_saver_megatron.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import torch
88

9-
from megatron.fused_kernels import cuda
9+
import megatron.fused_kernels as fused_kernels
1010

1111
def add_arguments(parser):
1212
group = parser.add_argument_group(title='Megatron saver')
@@ -161,7 +161,7 @@ def get_models(count, dtype, pre_process, post_process):
161161
mpu.initialize.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size)
162162
mpu.initialize.set_tensor_model_parallel_rank(0)
163163
mpu.initialize.set_pipeline_model_parallel_rank(0)
164-
cuda.load(margs)
164+
fused_kernels.load(margs)
165165

166166
# Embeddings
167167
#-----------

0 commit comments

Comments
 (0)