File tree Expand file tree Collapse file tree 4 files changed +5
-7
lines changed
Expand file tree Collapse file tree 4 files changed +5
-7
lines changed Original file line number Diff line number Diff line change @@ -6,12 +6,10 @@ def load(args):
66 return
77
88 if torch .version .hip is None :
9- if torch .distributed .get_rank () == 0 :
10- print ("running on CUDA devices" )
9+ print ("running on CUDA devices" )
1110 from megatron .fused_kernels .cuda import load as load_kernels
1211 else :
13- if torch .distributed .get_rank () == 0 :
14- print ("running on ROCm devices" )
12+ print ("running on ROCm devices" )
1513 from megatron .fused_kernels .rocm import load as load_kernels
1614
1715 load_kernels (args )
Original file line number Diff line number Diff line change 3131except ModuleNotFoundError :
3232 print ('Wandb import failed' , flush = True )
3333
34- import megatron . fused_kernels as fused_kernels
34+ from megatron import fused_kernels
3535from megatron import get_adlr_autoresume
3636from megatron import get_args
3737from megatron import get_tensorboard_writer
Original file line number Diff line number Diff line change 55
66import torch
77
8- import megatron . fused_kernels as fused_kernels
8+ from megatron import fused_kernels
99
1010def add_arguments (parser ):
1111 group = parser .add_argument_group (title = 'Megatron loader' )
Original file line number Diff line number Diff line change 66
77import torch
88
9- import megatron . fused_kernels as fused_kernels
9+ from megatron import fused_kernels
1010
1111def add_arguments (parser ):
1212 group = parser .add_argument_group (title = 'Megatron saver' )
You can’t perform that action at this time.
0 commit comments