Skip to content

Commit b0f3cfb

Browse files
committed
Merge branch 'multi-query-attention' into remove_hf_transformers
2 parents d47f623 + 654d0d8 commit b0f3cfb

File tree

9 files changed

+537
-92
lines changed

9 files changed

+537
-92
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#! /bin/bash
2+
3+
set -u # stop on unset variables
4+
5+
# Runs the SantaCoder 1B model
6+
7+
GPUS_PER_NODE=8
8+
MASTER_ADDR=${MASTER_NODE} # Adjust
9+
MASTER_PORT=6000
10+
NNODES=12 # Adjust
11+
# NODE_RANK=0 # Adjust
12+
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
13+
14+
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
15+
16+
CHECKPOINT_PATH=/my/experiment/path # Adjust: Directory to store the checkpoints
17+
DATA_PATH=/preprocessed/data/path # Adjust: Prefix of the preprocessed dataset.
18+
TOKENIZER_FILE=/tokenizer/path # Adjust
19+
20+
GPT_ARGS="\
21+
--tensor-model-parallel-size 1 \
22+
--pipeline-model-parallel-size 1 \
23+
--recompute-activations \
24+
--num-layers 24 \
25+
--hidden-size 2048 \
26+
--num-attention-heads 16 \
27+
--attention-head-type multiquery \
28+
--init-method-std 0.022 \
29+
--seq-length 2048 \
30+
--max-position-embeddings 2048 \
31+
--attention-dropout 0.1 \
32+
--hidden-dropout 0.1 \
33+
--micro-batch-size 2 \
34+
--global-batch-size 192 \
35+
--lr 0.0002 \
36+
--train-iters 3000 \
37+
--lr-decay-iters 600000 \
38+
--lr-decay-style cosine \
39+
--lr-warmup-fraction 0.02 \
40+
--weight-decay .1 \
41+
--adam-beta2 .95 \
42+
--clip-grad 1.0 \
43+
--fp16 \
44+
--log-interval 10 \
45+
--save-interval 4000 \
46+
--eval-interval 200 \
47+
--eval-iters 10 \
48+
--initial-loss-scale 65536 \
49+
--fim-rate 0.5 \
50+
"
51+
52+
TENSORBOARD_ARGS="--tensorboard-dir ${CHECKPOINT_PATH}/tensorboard"
53+
54+
torchrun $DISTRIBUTED_ARGS \
55+
pretrain_gpt.py \
56+
$GPT_ARGS \
57+
--tokenizer-type TokenizerFromFileWithFIM \
58+
--tokenizer-file $TOKENIZER_FILE \
59+
--save $CHECKPOINT_PATH \
60+
--load $CHECKPOINT_PATH \
61+
--data-path $DATA_PATH \
62+
$TENSORBOARD_ARGS
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#! /bin/bash
2+
3+
# Runs the "345M" parameter model
4+
5+
GPUS_PER_NODE=8
6+
MASTER_ADDR=localhost
7+
MASTER_PORT=6000
8+
NNODES=1 # Adjust
9+
NODE_RANK=0 # Adjust
10+
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
11+
12+
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
13+
14+
# paths to multilingual preprocessed datasets
15+
DATA_PATH_EN=<Specify path and file prefix>_text_document
16+
DATA_PATH_AR=<Specify path and file prefix>_text_document
17+
DATA_PATH_KR=<Specify path and file prefix>_text_document
18+
DATA_PATH_JP=<Specify path and file prefix>_text_document
19+
20+
CHECKPOINT_PATH=<Specify path>
21+
22+
23+
torchrun $DISTRIBUTED_ARGS \
24+
pretrain_gpt.py \
25+
--num-layers 24 \
26+
--hidden-size 1024 \
27+
--num-attention-heads 16 \
28+
--micro-batch-size 4 \
29+
--global-batch-size 8 \
30+
--seq-length 1024 \
31+
--max-position-embeddings 1024 \
32+
--train-iters 1000 \
33+
--lr-decay-iters 320000 \
34+
--save $CHECKPOINT_PATH \
35+
--load $CHECKPOINT_PATH \
36+
--train-weighted-split-paths "TRAIN: 0.3 0:0.6 $DATA_EN 1 0:0.6 $DATA_AR 1 0:0.6 $DATA_KR 1 0:0.6 $DATA_JP" \
37+
--valid-weighted-split-paths \
38+
"VALID_EN: 1 0.6:0.8 $DATA_EN" \
39+
"VALID_AR: 1 0.6:0.8 $DATA_AR" \
40+
"VALID_JP: 1 0.6:0.8 $DATA_KR" \
41+
"VALID_KR: 1 0.6:0.8 $DATA_JP" \
42+
"VALID_EN-AR-JP-KR_BALANCED: 1 0.6:0.8 $DATA_EN, 1 0.6:0.8 $DATA_AR, 1 0.6:0.8 $DATA_JP, 1 0.6:0.8 $DATA_KR" \
43+
--test-weighted-split-paths \
44+
"TEST_EN: 1 0.8:1 $DATA_EN" \
45+
"TEST_AR: 1 0.8:1 $DATA_AR" \
46+
"TEST_JP: 1 0.8:1 $DATA_JP" \
47+
"TEST_KR: 1 0.8:1 $DATA_KR" \
48+
"TEST_EN-AR-JP-KR_BALANCED: 1 0.8:1 $DATA_EN, 1 0.8:1 $DATA_AR, 1 0.8:1 $DATA_JP, 1 0.8:1 $DATA_KR" \
49+
--vocab-file gpt2-vocab.json \
50+
--merge-file gpt2-merges.txt \
51+
--data-impl mmap \
52+
--split 949,50,1 \
53+
--distributed-backend nccl \
54+
--lr 0.00015 \
55+
--min-lr 1.0e-5 \
56+
--lr-decay-style cosine \
57+
--weight-decay 1e-2 \
58+
--clip-grad 1.0 \
59+
--lr-warmup-fraction .01 \
60+
--checkpoint-activations \
61+
--log-interval 100 \
62+
--save-interval 10000 \
63+
--eval-interval 1000 \
64+
--eval-iters 10 \
65+
--fp16

megatron/arguments.py

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import argparse
1919
import os
20+
import re
2021

2122
import torch
2223

@@ -100,6 +101,30 @@ def validate_args(args, defaults={}):
100101
' to be less than pipeline model parallel size ({})'.format(
101102
args.pipeline_model_parallel_size)
102103

104+
# --data-path and --train-weighted-splits-paths
105+
message = "Data loading Mode 1: --data-path and --split "\
106+
"and Mode 2: --(train|valid|test)-weighted-split-paths"\
107+
"are mutually exclusive i.e. cannot be set together."
108+
109+
if args.data_path:
110+
assert args.train_weighted_split_paths is None, message
111+
setattr(args, "valid_weighted_split_names", None)
112+
setattr(args, "valid_weighted_split_weights", None)
113+
setattr(args, "valid_weighted_split_splits", None)
114+
115+
setattr(args, "test_weighted_split_names", None)
116+
setattr(args, "test_weighted_split_weights", None)
117+
setattr(args, "test_weighted_split_splits", None)
118+
119+
# args.split default value in the args is None it is set here in order
120+
# to check that it does not to overlap with the 2nd mode of data loading
121+
if args.split is None:
122+
args.split = "969, 30, 1"
123+
124+
if args.train_weighted_split_paths or args.valid_weighted_split_paths or \
125+
args.test_weighted_split_paths:
126+
assert args.data_path is None and args.split is None, message
127+
103128
# Deprecated arguments
104129
assert args.batch_size is None, '--batch-size argument is no longer ' \
105130
'valid, use --micro-batch-size instead'
@@ -863,16 +888,114 @@ def _add_validation_args(parser):
863888
def _add_data_args(parser):
864889
group = parser.add_argument_group(title='data and dataloader')
865890

891+
# option 1 for data loading (mutually exclusive with option2)
866892
group.add_argument('--data-path', nargs='*', default=None,
867893
help='Path to the training dataset. Accepted format:'
868894
'1) a single data path, 2) multiple datasets in the'
869895
'form: dataset1-weight dataset1-path dataset2-weight '
870896
'dataset2-path ...')
871-
group.add_argument('--split', type=str, default='969, 30, 1',
897+
group.add_argument('--split', type=str, default=None,
872898
help='Comma-separated list of proportions for training,'
873899
' validation, and test split. For example the split '
874900
'`90,5,5` will use 90%% of data for training, 5%% for '
875901
'validation and 5%% for test.')
902+
# option 2 for data loading (mutually exclusive with option1)
903+
# see https://github.com/bigscience-workshop/Megatron-DeepSpeed/pull/97/files
904+
905+
# helper class to parse the --xxx-weighted-split-paths
906+
# note here two args are set: extra valid dataset paths and names
907+
class parse_data_paths(argparse.Action):
908+
def __call__(self, parser, args, values, option_string=None):
909+
910+
if option_string == "--train-weighted-split-paths":
911+
assert len(values) == 1, 'Only 1 dataset group is allowed to'
912+
'be passed for the argument --train-weighted-split-paths'
913+
914+
# make sure string given in the correct format
915+
err_message = 'Each data group should be input on the following format'
916+
'"GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2"'
917+
'where START < END'
918+
for v in values:
919+
# each prefix consists several datasets separated by commas
920+
prefix = ":".join(v.split(":")[1:]) # remove GIVEN_NAME
921+
datasets = prefix.split(",")
922+
# check if each dataset is formatted like `WEIGHT START:END PATH`
923+
for d in datasets:
924+
assert len(d.split()) == 3, err_message
925+
start, end = d.split()[1].split(":")
926+
assert float(start) < float(end), err_message
927+
928+
names = [v.split(":")[0] for v in values]
929+
930+
prefixes = [":".join(v.split(":")[1:]).strip() for v in values]
931+
weights = [[d.split()[0] for d in p.split(",")] for p in prefixes]
932+
splits = [[d.split()[1] for d in p.split(",")] for p in prefixes]
933+
paths = [[d.split()[2] for d in p.split(",")] for p in prefixes]
934+
935+
# # to keep consistency with Option 1 of data loading (through --data-path)
936+
# # paths will contain strings on the following form
937+
# # "WEIGHTS1 PATH1 WEIGHTS2 PATH2 WEIGHTS3 PATH3" for each dataset group
938+
# # while data will be parsed in additional arguments below
939+
# paths_option1_style = []
940+
# for p, w in zip(paths, weights):
941+
# paths_option1_style.append(" ".join([f"{w_i} {p_i}" for p_i, w_i in zip(p,w)]))
942+
# setattr(args, self.dest, paths_option1_style)
943+
setattr(args, self.dest, paths)
944+
setattr(args, self.dest.replace("paths", "weights"), weights)
945+
setattr(args, self.dest.replace("paths", "splits"), splits)
946+
setattr(args, self.dest.replace("paths","names"), names)
947+
948+
949+
group.add_argument('--train-weighted-split-paths', nargs='*', default=None,
950+
help='Weights, splits and paths to groups of datasets'
951+
'Accepted format: ONE dataset groups could be'
952+
'submitted in the following form between double quotes'
953+
'"GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2"'
954+
'e.g.: "NAME_ABC: 0.6 0:0.6 A, 0.3 0:1 B, 0.1 0:1 C" '
955+
'WEIGHT is used to up and down sample each dataset A,B,C in the group'
956+
'START:END indicates the split portion of the dataset',
957+
action=parse_data_paths)
958+
959+
group.add_argument('--valid-weighted-split-paths', nargs='*', default=None,
960+
help='Weights, splits and paths to groups of datasets'
961+
'Accepted format: one or many dataset groups could be'
962+
'submitted in the following form each between double quotes'
963+
'"GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2"'
964+
'e.g.: "NAME_ABC: 0.6 0.6:0.8 A, 0.3 0:1 B, 0.1 0:1 C" '
965+
'"NAME_CDE: 0.6 0.6:0.8 C, 0.3 0:1 D, 0.1 0:1 E" '
966+
'validation will be run on each of those groups independently',
967+
action=parse_data_paths)
968+
969+
group.add_argument('--test-weighted-split-paths', nargs='*', default=None,
970+
help='Weights, splits and paths to groups of datasets'
971+
'Accepted format: one or many dataset groups could be'
972+
'submitted in the following form each between double quotes'
973+
'"GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2"'
974+
'e.g.: "NAME_ABC: 0.6 0.6:0.8 A, 0.3 0:1 B, 0.1 0:1 C" '
975+
'"NAME_CDE: 0.6 0.6:0.8 C, 0.3 0:1 D, 0.1 0:1 E" '
976+
'test will be run on each of those groups independently',
977+
action=parse_data_paths)
978+
979+
class parse_data_paths_path(argparse.Action):
980+
def __call__(self, parser, args, values, option_string=None):
981+
expected_option_strings = ["--train-weighted-split-paths-path", "--valid-weighted-split-paths-path", "--test-weighted-split-paths-path"]
982+
assert option_string in expected_option_strings, f"Expected {option_string} to be in {expected_option_strings}"
983+
984+
with open(values, "r") as fi:
985+
lines = fi.readlines()
986+
assert len(lines) == 1, f"Got multiple lines {len(lines)} instead of 1 expected"
987+
assert lines[0][-2:] == "\"\n" and lines[0][0] == "\"", f"Invalid input format, got {lines}"
988+
values = lines[0][1:-2].split("\" \"")
989+
weighted_split_paths_dest = re.sub(r"_path$", "", self.dest)
990+
weighted_split_paths_option = re.sub(r"-path$", "", self.option_strings[0])
991+
setattr(args, weighted_split_paths_dest, values)
992+
parse_data_paths(option_strings=[weighted_split_paths_option], dest=weighted_split_paths_dest)(parser, args, values, option_string=weighted_split_paths_option)
993+
994+
# option 2-bis: load x-weighted-split-paths from a file in case this argument is very long
995+
group.add_argument('--train-weighted-split-paths-path', type=str, action=parse_data_paths_path ,default=None)
996+
group.add_argument('--valid-weighted-split-paths-path', type=str, action=parse_data_paths_path, default=None)
997+
group.add_argument('--test-weighted-split-paths-path', type=str, action=parse_data_paths_path, default=None)
998+
876999
group.add_argument('--vocab-file', type=str, default=None,
8771000
help='Path to the vocab file.')
8781001
group.add_argument('--merge-file', type=str, default=None,
@@ -903,6 +1026,8 @@ def _add_data_args(parser):
9031026
help='Warm up mmap files.')
9041027
group.add_argument('--num-workers', type=int, default=2,
9051028
help="Dataloader number of workers.")
1029+
group.add_argument('--valid-num-workers', type=int, default=2,
1030+
help="Dataloader number of workers for validation.")
9061031
group.add_argument('--tokenizer-type', type=str,
9071032
default=None,
9081033
choices=['BertWordPieceLowerCase',

megatron/data/data_samplers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from megatron import mpu
2525

2626

27-
def build_pretraining_data_loader(dataset, consumed_samples):
27+
def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
2828
"""Buld dataloader given an input dataset."""
2929

3030
if dataset is None:
@@ -52,10 +52,11 @@ def build_pretraining_data_loader(dataset, consumed_samples):
5252
raise Exception('{} dataloader type is not supported.'.format(
5353
args.dataloader_type))
5454

55+
num_workers = args.num_workers if num_workers is None else num_workers
5556
# Torch dataloader.
5657
return torch.utils.data.DataLoader(dataset,
5758
batch_sampler=batch_sampler,
58-
num_workers=args.num_workers,
59+
num_workers=num_workers,
5960
pin_memory=True)
6061

6162
class MegatronPretrainingSampler:

megatron/data/dataset_utils.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@
4141
DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5]
4242

4343

44-
def get_datasets_weights_and_num_samples(data_prefix,
45-
train_valid_test_num_samples):
44+
def analyze_data_prefix(data_prefix):
4645

4746
# The data prefix should be in the format of:
4847
# weight-1, data-prefix-1, weight-2, data-prefix-2, ..
@@ -59,10 +58,16 @@ def get_datasets_weights_and_num_samples(data_prefix,
5958
weight_sum += weight
6059
assert weight_sum > 0.0
6160
weights = [weight / weight_sum for weight in weights]
61+
return prefixes, weights
62+
63+
64+
def get_datasets_weights_and_num_samples(data_prefix,
65+
train_valid_test_num_samples):
6266

63-
# Add 0.5% (the 1.005 factor) so in case the bleding dataset does
67+
# Add 0.5% (the 1.005 factor) so in case the blending dataset does
6468
# not uniformly distribute the number of samples, we still have
6569
# samples left to feed to the network.
70+
prefixes, weights = analyze_data_prefix(data_prefix)
6671
datasets_train_valid_test_num_samples = []
6772
for weight in weights:
6873
datasets_train_valid_test_num_samples.append(
@@ -603,6 +608,22 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
603608
return indexed_dataset
604609

605610

611+
def get_split_by_range_(range_string, size):
612+
""" Get dataset splits based on a range:
613+
range_string is in the form START%:END% for e.g. 0.2:0.8
614+
outputs an array of two values [start_index, end_index]
615+
"""
616+
# some checks that range is given in the correct form
617+
splits = [float(i) for i in range_string.split(":")]
618+
assert len(splits) == 2, "splits should be passed as start:end"
619+
assert splits[0] <= 1 and splits[1] <= 1
620+
splits_sum = sum(splits)
621+
assert splits_sum > 0.0
622+
splits_index = [round(s * float(size)) for s in splits]
623+
assert len(splits_index) == 2
624+
return splits_index
625+
626+
606627
def get_train_valid_test_split_(splits_string, size):
607628
""" Get dataset splits from comma or '/' separated string list."""
608629

0 commit comments

Comments
 (0)