diff --git a/flexgen/dist_flex_opt.py b/flexgen/dist_flex_opt.py index 593bb78f..aac0c086 100644 --- a/flexgen/dist_flex_opt.py +++ b/flexgen/dist_flex_opt.py @@ -219,7 +219,6 @@ def store_hidden(self, b, t, i, j, k): hidden_val = self.hidden[t][i][j][k].pop() else: hidden_val = self.hidden[t][i][j][k].val - ids = hidden_val.data.detach().cpu().numpy() gpu_batch_size = self.policy.gpu_batch_size num_gpu_batches = self.num_gpu_batches @@ -478,7 +477,9 @@ def generation_loop_overlap_one_batch(self): self.load_weight(b, t, i, j+1, 0) self.load_cache(t, i, j+1, 0) self.load_hidden(b, t, i, j, 0) + self.sync() self.compute_layer(t, i, j, 0) + self.sync() self.store_cache(t, i, j-1, 0) self.store_hidden(b, t, i, j, 0) self.sync() @@ -515,7 +516,9 @@ def generation_loop_overlap_multi_batch(self): self.load_weight(b, t, i, j + 1, k) self.load_cache(t, i, j, k + 1) self.load_hidden(b, t, i, j, k) + self.sync() self.compute_layer(t, i, j, k) + self.sync() self.store_cache(t, i, j, k - 1) self.store_hidden(b, t, i, j, k) self.sync() @@ -674,6 +677,10 @@ def add_distributed_parser_arguments(parser): args.world_size = int(os.getenv('OMPI_COMM_WORLD_SIZE')) args.rank = int(os.getenv('OMPI_COMM_WORLD_RANK')) args.local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK')) + else: + args.world_size = int(os.getenv('WORLD_SIZE')) + args.rank = int(os.getenv('RANK')) + args.local_rank = int(os.getenv('LOCAL_RANK')) initialize_distributed(args.head_ip, args.port, args.world_size, args.rank, args.local_rank, args.comm_device) else: