|
17 | 17 |
|
18 | 18 | import argparse |
19 | 19 | import os |
| 20 | +import re |
20 | 21 |
|
21 | 22 | import torch |
22 | 23 |
|
@@ -100,6 +101,30 @@ def validate_args(args, defaults={}): |
100 | 101 | ' to be less than pipeline model parallel size ({})'.format( |
101 | 102 | args.pipeline_model_parallel_size) |
102 | 103 |
|
| 104 | + # --data-path and --train-weighted-splits-paths |
| 105 | + message = "Data loading Mode 1: --data-path and --split "\ |
| 106 | + "and Mode 2: --(train|valid|test)-weighted-split-paths"\ |
| 107 | + "are mutually exclusive i.e. cannot be set together." |
| 108 | + |
| 109 | + if args.data_path: |
| 110 | + assert args.train_weighted_split_paths is None, message |
| 111 | + setattr(args, "valid_weighted_split_names", None) |
| 112 | + setattr(args, "valid_weighted_split_weights", None) |
| 113 | + setattr(args, "valid_weighted_split_splits", None) |
| 114 | + |
| 115 | + setattr(args, "test_weighted_split_names", None) |
| 116 | + setattr(args, "test_weighted_split_weights", None) |
| 117 | + setattr(args, "test_weighted_split_splits", None) |
| 118 | + |
| 119 | + # args.split default value in the args is None it is set here in order |
| 120 | + # to check that it does not to overlap with the 2nd mode of data loading |
| 121 | + if args.split is None: |
| 122 | + args.split = "969, 30, 1" |
| 123 | + |
| 124 | + if args.train_weighted_split_paths or args.valid_weighted_split_paths or \ |
| 125 | + args.test_weighted_split_paths: |
| 126 | + assert args.data_path is None and args.split is None, message |
| 127 | + |
103 | 128 | # Deprecated arguments |
104 | 129 | assert args.batch_size is None, '--batch-size argument is no longer ' \ |
105 | 130 | 'valid, use --micro-batch-size instead' |
@@ -863,16 +888,114 @@ def _add_validation_args(parser): |
863 | 888 | def _add_data_args(parser): |
864 | 889 | group = parser.add_argument_group(title='data and dataloader') |
865 | 890 |
|
| 891 | + # option 1 for data loading (mutually exclusive with option2) |
866 | 892 | group.add_argument('--data-path', nargs='*', default=None, |
867 | 893 | help='Path to the training dataset. Accepted format:' |
868 | 894 | '1) a single data path, 2) multiple datasets in the' |
869 | 895 | 'form: dataset1-weight dataset1-path dataset2-weight ' |
870 | 896 | 'dataset2-path ...') |
871 | | - group.add_argument('--split', type=str, default='969, 30, 1', |
| 897 | + group.add_argument('--split', type=str, default=None, |
872 | 898 | help='Comma-separated list of proportions for training,' |
873 | 899 | ' validation, and test split. For example the split ' |
874 | 900 | '`90,5,5` will use 90%% of data for training, 5%% for ' |
875 | 901 | 'validation and 5%% for test.') |
| 902 | + # option 2 for data loading (mutually exclusive with option1) |
| 903 | + # see https://github.com/bigscience-workshop/Megatron-DeepSpeed/pull/97/files |
| 904 | + |
| 905 | + # helper class to parse the --xxx-weighted-split-paths |
| 906 | + # note here two args are set: extra valid dataset paths and names |
| 907 | + class parse_data_paths(argparse.Action): |
| 908 | + def __call__(self, parser, args, values, option_string=None): |
| 909 | + |
| 910 | + if option_string == "--train-weighted-split-paths": |
| 911 | + assert len(values) == 1, 'Only 1 dataset group is allowed to' |
| 912 | + 'be passed for the argument --train-weighted-split-paths' |
| 913 | + |
| 914 | + # make sure string given in the correct format |
| 915 | + err_message = 'Each data group should be input on the following format' |
| 916 | + '"GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2"' |
| 917 | + 'where START < END' |
| 918 | + for v in values: |
| 919 | + # each prefix consists several datasets separated by commas |
| 920 | + prefix = ":".join(v.split(":")[1:]) # remove GIVEN_NAME |
| 921 | + datasets = prefix.split(",") |
| 922 | + # check if each dataset is formatted like `WEIGHT START:END PATH` |
| 923 | + for d in datasets: |
| 924 | + assert len(d.split()) == 3, err_message |
| 925 | + start, end = d.split()[1].split(":") |
| 926 | + assert float(start) < float(end), err_message |
| 927 | + |
| 928 | + names = [v.split(":")[0] for v in values] |
| 929 | + |
| 930 | + prefixes = [":".join(v.split(":")[1:]).strip() for v in values] |
| 931 | + weights = [[d.split()[0] for d in p.split(",")] for p in prefixes] |
| 932 | + splits = [[d.split()[1] for d in p.split(",")] for p in prefixes] |
| 933 | + paths = [[d.split()[2] for d in p.split(",")] for p in prefixes] |
| 934 | + |
| 935 | + # # to keep consistency with Option 1 of data loading (through --data-path) |
| 936 | + # # paths will contain strings on the following form |
| 937 | + # # "WEIGHTS1 PATH1 WEIGHTS2 PATH2 WEIGHTS3 PATH3" for each dataset group |
| 938 | + # # while data will be parsed in additional arguments below |
| 939 | + # paths_option1_style = [] |
| 940 | + # for p, w in zip(paths, weights): |
| 941 | + # paths_option1_style.append(" ".join([f"{w_i} {p_i}" for p_i, w_i in zip(p,w)])) |
| 942 | + # setattr(args, self.dest, paths_option1_style) |
| 943 | + setattr(args, self.dest, paths) |
| 944 | + setattr(args, self.dest.replace("paths", "weights"), weights) |
| 945 | + setattr(args, self.dest.replace("paths", "splits"), splits) |
| 946 | + setattr(args, self.dest.replace("paths","names"), names) |
| 947 | + |
| 948 | + |
| 949 | + group.add_argument('--train-weighted-split-paths', nargs='*', default=None, |
| 950 | + help='Weights, splits and paths to groups of datasets' |
| 951 | + 'Accepted format: ONE dataset groups could be' |
| 952 | + 'submitted in the following form between double quotes' |
| 953 | + '"GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2"' |
| 954 | + 'e.g.: "NAME_ABC: 0.6 0:0.6 A, 0.3 0:1 B, 0.1 0:1 C" ' |
| 955 | + 'WEIGHT is used to up and down sample each dataset A,B,C in the group' |
| 956 | + 'START:END indicates the split portion of the dataset', |
| 957 | + action=parse_data_paths) |
| 958 | + |
| 959 | + group.add_argument('--valid-weighted-split-paths', nargs='*', default=None, |
| 960 | + help='Weights, splits and paths to groups of datasets' |
| 961 | + 'Accepted format: one or many dataset groups could be' |
| 962 | + 'submitted in the following form each between double quotes' |
| 963 | + '"GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2"' |
| 964 | + 'e.g.: "NAME_ABC: 0.6 0.6:0.8 A, 0.3 0:1 B, 0.1 0:1 C" ' |
| 965 | + '"NAME_CDE: 0.6 0.6:0.8 C, 0.3 0:1 D, 0.1 0:1 E" ' |
| 966 | + 'validation will be run on each of those groups independently', |
| 967 | + action=parse_data_paths) |
| 968 | + |
| 969 | + group.add_argument('--test-weighted-split-paths', nargs='*', default=None, |
| 970 | + help='Weights, splits and paths to groups of datasets' |
| 971 | + 'Accepted format: one or many dataset groups could be' |
| 972 | + 'submitted in the following form each between double quotes' |
| 973 | + '"GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2"' |
| 974 | + 'e.g.: "NAME_ABC: 0.6 0.6:0.8 A, 0.3 0:1 B, 0.1 0:1 C" ' |
| 975 | + '"NAME_CDE: 0.6 0.6:0.8 C, 0.3 0:1 D, 0.1 0:1 E" ' |
| 976 | + 'test will be run on each of those groups independently', |
| 977 | + action=parse_data_paths) |
| 978 | + |
| 979 | + class parse_data_paths_path(argparse.Action): |
| 980 | + def __call__(self, parser, args, values, option_string=None): |
| 981 | + expected_option_strings = ["--train-weighted-split-paths-path", "--valid-weighted-split-paths-path", "--test-weighted-split-paths-path"] |
| 982 | + assert option_string in expected_option_strings, f"Expected {option_string} to be in {expected_option_strings}" |
| 983 | + |
| 984 | + with open(values, "r") as fi: |
| 985 | + lines = fi.readlines() |
| 986 | + assert len(lines) == 1, f"Got multiple lines {len(lines)} instead of 1 expected" |
| 987 | + assert lines[0][-2:] == "\"\n" and lines[0][0] == "\"", f"Invalid input format, got {lines}" |
| 988 | + values = lines[0][1:-2].split("\" \"") |
| 989 | + weighted_split_paths_dest = re.sub(r"_path$", "", self.dest) |
| 990 | + weighted_split_paths_option = re.sub(r"-path$", "", self.option_strings[0]) |
| 991 | + setattr(args, weighted_split_paths_dest, values) |
| 992 | + parse_data_paths(option_strings=[weighted_split_paths_option], dest=weighted_split_paths_dest)(parser, args, values, option_string=weighted_split_paths_option) |
| 993 | + |
| 994 | + # option 2-bis: load x-weighted-split-paths from a file in case this argument is very long |
| 995 | + group.add_argument('--train-weighted-split-paths-path', type=str, action=parse_data_paths_path ,default=None) |
| 996 | + group.add_argument('--valid-weighted-split-paths-path', type=str, action=parse_data_paths_path, default=None) |
| 997 | + group.add_argument('--test-weighted-split-paths-path', type=str, action=parse_data_paths_path, default=None) |
| 998 | + |
876 | 999 | group.add_argument('--vocab-file', type=str, default=None, |
877 | 1000 | help='Path to the vocab file.') |
878 | 1001 | group.add_argument('--merge-file', type=str, default=None, |
@@ -903,6 +1026,8 @@ def _add_data_args(parser): |
903 | 1026 | help='Warm up mmap files.') |
904 | 1027 | group.add_argument('--num-workers', type=int, default=2, |
905 | 1028 | help="Dataloader number of workers.") |
| 1029 | + group.add_argument('--valid-num-workers', type=int, default=2, |
| 1030 | + help="Dataloader number of workers for validation.") |
906 | 1031 | group.add_argument('--tokenizer-type', type=str, |
907 | 1032 | default=None, |
908 | 1033 | choices=['BertWordPieceLowerCase', |
|
0 commit comments