diff --git a/examples/finetune_bigcode_model.slurm b/examples/finetune_bigcode_model.slurm new file mode 100644 index 0000000000..603b7e24fa --- /dev/null +++ b/examples/finetune_bigcode_model.slurm @@ -0,0 +1,144 @@ +#!/bin/bash +#SBATCH --job-name=starcoderpy +#SBATCH --nodes=64 +#SBATCH --ntasks-per-node=1 +#SBATCH --exclusive +#SBATCH --gres=gpu:8 +#SBATCH --partition=production-cluster +#SBATCH --output=/fsx/leandro/logs/starcoderpy/bcs-%x-%j.out + +set -x -e +source /admin/home/leandro/.bashrc + +conda activate megatron + +echo "START TIME: $(date)" + +# File Path setup +SCRIPT_REPO=/fsx/leandro/git/Megatron-LM-BC +pushd $SCRIPT_REPO + +LOG_PATH=$SCRIPT_REPO/main_log.txt + +# Training setup +GPUS_PER_NODE=8 +MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +MASTER_PORT=6000 +NNODES=$SLURM_NNODES +NODE_RANK=$SLURM_PROCID +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +# File path setup +STARCODER_PATH=/fsx/boomcode/starcoder/ +CHECKPOINT_PATH=/fsx/boomcode/starcoderpy/$SLURM_JOB_ID +TOKENIZER_FILE=/fsx/boomcode/tokenizer-starcoder/tokenizer.json +WEIGHTS_TRAIN=/fsx/boomcode/datamix_python/train_data_paths.txt.tmp +WEIGHTS_VALID=/fsx/boomcode/datamix_python/valid_data_paths.txt.tmp +DATA_PATH=/fsx/boomcode/tokenized/python/ +mkdir -p $CHECKPOINT_PATH/tensorboard + +GPT_ARGS="\ + --tensor-model-parallel-size 4 \ + --pipeline-model-parallel-size 4 \ + --sequence-parallel \ + --num-layers 40 \ + --hidden-size 6144 \ + --num-attention-heads 48 \ + --attention-head-type multiquery \ + --init-method-std 0.01275 \ + --seq-length 8192 \ + --max-position-embeddings 8192 \ + --attention-dropout 0.1 \ + --hidden-dropout 0.1 \ + --micro-batch-size 1 \ + --global-batch-size 512 \ + --lr 0.00005 \ + --min-lr 0.000005 \ + --train-iters 258500 \ + --lr-decay-iters 8500 \ + --lr-decay-style cosine \ + --lr-warmup-iters 500 \ + --weight-decay .1 \ + --adam-beta2 .95 \ + --clip-grad 1.0 \ + --bf16 \ + --use-flash-attn \ + --fim-rate 0.5 \ + --log-interval 10 \ + --save-interval 2500 \ + --eval-interval 100 \ + --eval-iters 10 \ + --valid-num-workers 0 \ + --override-opt_param-scheduler \ + --no-load-optim \ + --no-load-rng \ + --finetune \ +" + +# --dataloader-type cyclic\ +TENSORBOARD_ARGS="--tensorboard-dir ${CHECKPOINT_PATH}/tensorboard" + +CMD=" \ + $SCRIPT_REPO/pretrain_gpt.py \ + $GPT_ARGS \ + --tokenizer-type TokenizerFromFile \ + --tokenizer-file $TOKENIZER_FILE \ + --save $CHECKPOINT_PATH \ + --load $STARCODER_PATH \ + --train-weighted-split-paths-path $WEIGHTS_TRAIN \ + --valid-weighted-split-paths-path $WEIGHTS_VALID \ + --structured-logs \ + --structured-logs-dir $CHECKPOINT_PATH/logs \ + $TENSORBOARD_ARGS \ + --wandb-entity-name lvwerra \ + --wandb-project-name starcoder-py \ + " + +# --data-path $DATA_PATH\gpt2-preprocessed_content_document + +export LAUNCHER="python -u -m torch.distributed.run \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --tee 3 \ + " + +echo $CMD + +# hide duplicated errors using this hack - will be properly fixed in pt-1.12 +# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json + +# force crashing on nccl issues like hanging broadcast +export NCCL_ASYNC_ERROR_HANDLING=1 +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=COLL +# export NCCL_SOCKET_NTHREADS=1 +# export NCCL_NSOCKS_PERTHREAD=1 +# export CUDA_LAUNCH_BLOCKING=1 + +# AWS specific +export NCCL_PROTO=simple +export RDMAV_FORK_SAFE=1 +export FI_EFA_FORK_SAFE=1 +export FI_EFA_USE_DEVICE_RDMA=1 +export FI_PROVIDER=efa +export FI_LOG_LEVEL=1 +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_IFNAME=ens + +export CUDA_HOME=/usr/local/cuda-11.6 + +# srun error handling: +# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks +# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code +SRUN_ARGS=" \ + --wait=60 \ + --kill-on-bad-exit=1 \ + " + +# py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD +clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID --role \$SLURMD_NODENAME: $CMD" 2>&1 | tee $LOG_PATH + +echo "END TIME: $(date)" diff --git a/examples/finetune_starcoderplus.slurm b/examples/finetune_starcoderplus.slurm new file mode 100644 index 0000000000..e99c04dde9 --- /dev/null +++ b/examples/finetune_starcoderplus.slurm @@ -0,0 +1,141 @@ +#!/bin/bash +#SBATCH --job-name=starcoderplus +#SBATCH --nodes=64 +#SBATCH --ntasks-per-node=1 +#SBATCH --exclusive +#SBATCH --gres=gpu:8 +#SBATCH --partition=production-cluster +#SBATCH --output=/fsx/leandro/logs/starcoderplus/bcs-%x-%j.out + +set -x -e +source /admin/home/leandro/.bashrc + +conda activate megatron + +echo "START TIME: $(date)" + +# File Path setup +SCRIPT_REPO=/fsx/leandro/git/Megatron-LM-BC +pushd $SCRIPT_REPO + +LOG_PATH=$SCRIPT_REPO/main_log.txt + +# Training setup +GPUS_PER_NODE=8 +MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +MASTER_PORT=6000 +NNODES=$SLURM_NNODES +NODE_RANK=$SLURM_PROCID +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +# File path setup +STARCODER_PATH=/fsx/boomcode/starcoder/ +CHECKPOINT_PATH=/fsx/boomcode/starcoderplus/$SLURM_JOB_ID +TOKENIZER_FILE=/fsx/boomcode/tokenizer-starcoder/tokenizer.json +WEIGHTS_TRAIN=/fsx/boomcode/datamix/train_data_paths.txt.tmp +WEIGHTS_VALID=/fsx/boomcode/datamix/valid_data_paths.txt.tmp + +mkdir -p $CHECKPOINT_PATH/tensorboard + +GPT_ARGS="\ + --tensor-model-parallel-size 4 \ + --pipeline-model-parallel-size 4 \ + --sequence-parallel \ + --num-layers 40 \ + --hidden-size 6144 \ + --num-attention-heads 48 \ + --attention-head-type multiquery \ + --init-method-std 0.01275 \ + --seq-length 8192 \ + --max-position-embeddings 8192 \ + --attention-dropout 0.1 \ + --hidden-dropout 0.1 \ + --micro-batch-size 1 \ + --global-batch-size 512 \ + --lr 0.0001 \ + --min-lr 0.00001 \ + --train-iters 400000 \ + --lr-decay-iters 150000 \ + --lr-decay-style cosine \ + --lr-warmup-iters 1000 \ + --weight-decay .1 \ + --adam-beta2 .95 \ + --clip-grad 1.0 \ + --bf16 \ + --use-flash-attn \ + --fim-rate 0.5 \ + --log-interval 10 \ + --save-interval 2500 \ + --eval-interval 2500 \ + --eval-iters 2 \ + --valid-num-workers 0 \ + --override-opt_param-scheduler \ + --no-load-optim \ + --no-load-rng \ + --finetune \ +" + +TENSORBOARD_ARGS="--tensorboard-dir ${CHECKPOINT_PATH}/tensorboard" + +CMD=" \ + $SCRIPT_REPO/pretrain_gpt.py \ + $GPT_ARGS \ + --tokenizer-type TokenizerFromFile \ + --tokenizer-file $TOKENIZER_FILE \ + --save $CHECKPOINT_PATH \ + --load $STARCODER_PATH \ + --train-weighted-split-paths-path $WEIGHTS_TRAIN \ + --valid-weighted-split-paths-path $WEIGHTS_VALID \ + --structured-logs \ + --structured-logs-dir $CHECKPOINT_PATH/logs \ + $TENSORBOARD_ARGS \ + --wandb-entity-name lvwerra \ + --wandb-project-name starcoder-plus \ + " + +export LAUNCHER="python -u -m torch.distributed.run \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --tee 3 \ + " + +echo $CMD + +# hide duplicated errors using this hack - will be properly fixed in pt-1.12 +# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json + +# force crashing on nccl issues like hanging broadcast +export NCCL_ASYNC_ERROR_HANDLING=1 +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=COLL +# export NCCL_SOCKET_NTHREADS=1 +# export NCCL_NSOCKS_PERTHREAD=1 +# export CUDA_LAUNCH_BLOCKING=1 + +# AWS specific +export NCCL_PROTO=simple +export RDMAV_FORK_SAFE=1 +export FI_EFA_FORK_SAFE=1 +export FI_EFA_USE_DEVICE_RDMA=1 +export FI_PROVIDER=efa +export FI_LOG_LEVEL=1 +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_IFNAME=ens + +export CUDA_HOME=/usr/local/cuda-11.6 + +# srun error handling: +# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks +# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code +SRUN_ARGS=" \ + --wait=60 \ + --kill-on-bad-exit=1 \ + " + +# py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD +clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID --role \$SLURMD_NODENAME: $CMD" 2>&1 | tee $LOG_PATH + +echo "END TIME: $(date)" diff --git a/examples/pretrain_bigcode_1b.slurm b/examples/pretrain_bigcode_1b.slurm new file mode 100644 index 0000000000..c9b850211f --- /dev/null +++ b/examples/pretrain_bigcode_1b.slurm @@ -0,0 +1,142 @@ +#!/bin/bash +#SBATCH --job-name=1b-starcoder +#SBATCH --nodes=16 +#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! +#SBATCH --cpus-per-task=38 +#SBATCH --gres=gpu:8 +#SBATCH --partition=production-cluster +#SBATCH --output=/fsx/bigcode/bigcode-training/logs/1b/%x-%j.out + +set -x -e +source /admin/home/loubna/.bashrc + +conda activate megatron + +echo "START TIME: $(date)" + +# File Path setup +SCRIPT_REPO=/fsx/loubna/code/Megatron-LM +pushd $SCRIPT_REPO + +LOG_PATH=$SCRIPT_REPO/main_log.txt + +# Training setup +GPUS_PER_NODE=8 +# so processes know who to talk to +MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +MASTER_PORT=6000 +NNODES=$SLURM_NNODES +NODE_RANK=$SLURM_PROCID +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +# File path setup +CHECKPOINT_PATH=/fsx/bigcode/experiments/pretraining/1b # Adjust: Directory to store the checkpoints +# Starcoder tokenizer and data paths in /fsx/bigcode +TOKENIZER_FILE=/fsx/loubna/starcoder-tokenizer/15b/tokenizer.json +WEIGHTS_TRAIN=/fsx/bigcode/bigcode-training/code/bigcode-data-mix/data/train_data_paths.txt.tmp +WEIGHTS_VALID=/fsx/bigcode/bigcode-training/code/bigcode-data-mix/data/valid_data_paths.txt.tmp + +mkdir -p $CHECKPOINT_PATH/tensorboard + +GPT_ARGS="\ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --num-layers 24 \ + --hidden-size 2048 \ + --num-attention-heads 16 \ + --attention-head-type multiquery \ + --init-method-std 0.02209 \ + --seq-length 8192 \ + --max-position-embeddings 8192 \ + --attention-dropout 0.1 \ + --hidden-dropout 0.1 \ + --micro-batch-size 1 \ + --global-batch-size 128 \ + --lr 0.0004 \ + --min-lr 0.00004 \ + --train-iters 1000000 \ + --lr-decay-iters 1000000 \ + --lr-decay-style cosine \ + --lr-warmup-iters 2000 \ + --weight-decay .1 \ + --adam-beta2 .95 \ + --clip-grad 1.0 \ + --bf16 \ + --use-flash-attn \ + --fim-rate 0.5 \ + --log-interval 10 \ + --save-interval 10000 \ + --eval-interval 10000 \ + --eval-iters 2 \ + --valid-num-workers 0 \ +" + +TENSORBOARD_ARGS="--tensorboard-dir /fsx/bigcode/experiments/pretraining/1b/tensorboard" + +CMD=" \ + /fsx/loubna/code/Megatron-LM/pretrain_gpt.py \ + $GPT_ARGS \ + --tokenizer-type TokenizerFromFile \ + --tokenizer-file $TOKENIZER_FILE \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --train-weighted-split-paths-path $WEIGHTS_TRAIN \ + --valid-weighted-split-paths-path $WEIGHTS_VALID \ + --structured-logs \ + --structured-logs-dir $CHECKPOINT_PATH/logs \ + $TENSORBOARD_ARGS \ + --wandb-entity-name loubnabnl \ + --wandb-project-name bigcode-pretraining \ + " + +export LAUNCHER="python -u -m torch.distributed.run \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --tee 3 \ + " + +echo $CMD + +# hide duplicated errors using this hack - will be properly fixed in pt-1.12 +# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json + +# force crashing on nccl issues like hanging broadcast +export NCCL_ASYNC_ERROR_HANDLING=1 +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=COLL +# export NCCL_SOCKET_NTHREADS=1 +# export NCCL_NSOCKS_PERTHREAD=1 +# export CUDA_LAUNCH_BLOCKING=1 + +# AWS specific +export NCCL_PROTO=simple +export RDMAV_FORK_SAFE=1 +export FI_EFA_FORK_SAFE=1 +export FI_EFA_USE_DEVICE_RDMA=1 +export FI_PROVIDER=efa +export FI_LOG_LEVEL=1 +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_IFNAME=ens + +export CUDA_HOME=/usr/local/cuda-11.6 +# This is needed for torch1.12.1 otherwise it doesn't link correctly, not sur what the issue was. +#export PATH="/usr/local/cuda-11.6/bin:$PATH" +#export LD_LIBRARY_PATH="/usr/local/cuda-11.6/lib64:$LD_LIBRARY_PATH" +#export LD_PRELOAD=$CUDA_HOME/lib/libnccl.so +#export LD_LIBRARY_PATH=$CUDA_HOME/efa/lib:$CUDA_HOME/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH + +# srun error handling: +# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks +# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code +SRUN_ARGS=" \ + --wait=60 \ + --kill-on-bad-exit=1 \ + " + +# py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD +clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID --role \$SLURMD_NODENAME: $CMD" 2>&1 | tee $LOG_PATH + +echo "END TIME: $(date)" diff --git a/examples/pretrain_bigcode_3b.slurm b/examples/pretrain_bigcode_3b.slurm new file mode 100644 index 0000000000..1d411664e5 --- /dev/null +++ b/examples/pretrain_bigcode_3b.slurm @@ -0,0 +1,143 @@ +#!/bin/bash +#SBATCH --job-name=3b-bigcode +#SBATCH --nodes=32 +#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! +#SBATCH --cpus-per-task=40 +#SBATCH --gres=gpu:8 +#SBATCH --partition=production-cluster +#SBATCH --output=/fsx/bigcode/bigcode-training/logs/3b/%x-%j.out + +set -x -e +source /admin/home/loubna/.bashrc + +conda activate megatron + +echo "START TIME: $(date)" + +# File Path setup +SCRIPT_REPO=/fsx/loubna/code/Megatron-LM +pushd $SCRIPT_REPO + +LOG_PATH=$SCRIPT_REPO/main_log.txt + +# Training setup +GPUS_PER_NODE=8 +# so processes know who to talk to +MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +MASTER_PORT=6000 +NNODES=$SLURM_NNODES +NODE_RANK=$SLURM_PROCID +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +# File path setup +CHECKPOINT_PATH=/fsx/bigcode/experiments/pretraining/3b # Adjust: Directory to store the checkpoints +# Starcoder tokenizer and data paths in /fsx/bigcode +TOKENIZER_FILE=/fsx/bigcode/bigcode-training/tokenizer-starcoder/tokenizer.json +WEIGHTS_TRAIN=/fsx/bigcode/bigcode-training/code/bigcode-data-mix/data/train_data_paths.txt.tmp +WEIGHTS_VALID=/fsx/bigcode/bigcode-training/code/bigcode-data-mix/data/valid_data_paths.txt.tmp + +mkdir -p $CHECKPOINT_PATH/tensorboard + +GPT_ARGS="\ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --num-layers 36 \ + --hidden-size 2816 \ + --num-attention-heads 22 \ + --attention-head-type multiquery \ + --init-method-std 0.01884 \ + --seq-length 8192 \ + --max-position-embeddings 8192 \ + --attention-dropout 0.1 \ + --hidden-dropout 0.1 \ + --micro-batch-size 1 \ + --global-batch-size 256 \ + --lr 0.0005 \ + --min-lr 0.00005 \ + --train-iters 500000 \ + --lr-decay-iters 500000 \ + --lr-decay-style cosine \ + --lr-warmup-iters 2000 \ + --weight-decay .1 \ + --adam-beta2 .95 \ + --clip-grad 1.0 \ + --bf16 \ + --use-flash-attn \ + --fim-rate 0.5 \ + --log-interval 10 \ + --save-interval 5000 \ + --eval-interval 5000 \ + --eval-iters 2 \ + --use-distributed-optimizer \ + --valid-num-workers 0 \ +" + +TENSORBOARD_ARGS="--tensorboard-dir ${CHECKPOINT_PATH}/tensorboard" + +CMD=" \ + /fsx/loubna/code/Megatron-LM/pretrain_gpt.py \ + $GPT_ARGS \ + --tokenizer-type TokenizerFromFile \ + --tokenizer-file $TOKENIZER_FILE \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --train-weighted-split-paths-path $WEIGHTS_TRAIN \ + --valid-weighted-split-paths-path $WEIGHTS_VALID \ + --structured-logs \ + --structured-logs-dir $CHECKPOINT_PATH/logs \ + $TENSORBOARD_ARGS \ + --wandb-entity-name loubnabnl \ + --wandb-project-name bigcode-3b \ + " + +export LAUNCHER="python -u -m torch.distributed.run \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --tee 3 \ + " + +echo $CMD + +# hide duplicated errors using this hack - will be properly fixed in pt-1.12 +# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json + +# force crashing on nccl issues like hanging broadcast +export NCCL_ASYNC_ERROR_HANDLING=1 +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=COLL +# export NCCL_SOCKET_NTHREADS=1 +# export NCCL_NSOCKS_PERTHREAD=1 +# export CUDA_LAUNCH_BLOCKING=1 + +# AWS specific +export NCCL_PROTO=simple +export RDMAV_FORK_SAFE=1 +export FI_EFA_FORK_SAFE=1 +export FI_EFA_USE_DEVICE_RDMA=1 +export FI_PROVIDER=efa +export FI_LOG_LEVEL=1 +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_IFNAME=ens + +export CUDA_HOME=/usr/local/cuda-11.6 +# This is needed for torch1.12.1 otherwise it doesn't link correctly, not sur what the issue was. +#export PATH="/usr/local/cuda-11.6/bin:$PATH" +#export LD_LIBRARY_PATH="/usr/local/cuda-11.6/lib64:$LD_LIBRARY_PATH" +#export LD_PRELOAD=$CUDA_HOME/lib/libnccl.so +#export LD_LIBRARY_PATH=$CUDA_HOME/efa/lib:$CUDA_HOME/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH + +# srun error handling: +# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks +# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code +SRUN_ARGS=" \ + --wait=60 \ + --kill-on-bad-exit=1 \ + " + +# py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD +clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID --role \$SLURMD_NODENAME: $CMD" 2>&1 | tee $LOG_PATH + +echo "END TIME: $(date)" diff --git a/examples/pretrain_bigcode_7b.slurm b/examples/pretrain_bigcode_7b.slurm new file mode 100644 index 0000000000..536ccc0e80 --- /dev/null +++ b/examples/pretrain_bigcode_7b.slurm @@ -0,0 +1,143 @@ +#!/bin/bash +#SBATCH --job-name=7b-starcoder +#SBATCH --nodes=64 +#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! +#SBATCH --cpus-per-task=38 +#SBATCH --gres=gpu:8 +#SBATCH --partition=production-cluster +#SBATCH --output=/fsx/bigcode/bigcode-training/logs/7b/%x-%j.out + +set -x -e +source /admin/home/loubna/.bashrc + +conda activate megatron + +echo "START TIME: $(date)" + +# File Path setup +SCRIPT_REPO=/fsx/loubna/code/Megatron-LM +pushd $SCRIPT_REPO + +LOG_PATH=$SCRIPT_REPO/main_log.txt + +# Training setup +GPUS_PER_NODE=8 +# so processes know who to talk to +MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +MASTER_PORT=6000 +NNODES=$SLURM_NNODES +NODE_RANK=$SLURM_PROCID +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +# File path setup +CHECKPOINT_PATH=/fsx/bigcode/experiments/pretraining/7b-starcoder +# Starcoder tokenizer and data paths in /fsx/bigcode +TOKENIZER_FILE=/fsx/bigcode/bigcode-training/tokenizer-starcoder/tokenizer.json +WEIGHTS_TRAIN=/fsx/bigcode/bigcode-training/code/bigcode-data-mix/data/train_data_paths.txt.tmp +WEIGHTS_VALID=/fsx/bigcode/bigcode-training/code/bigcode-data-mix/data/valid_data_paths.txt.tmp + +mkdir -p $CHECKPOINT_PATH/tensorboard + +GPT_ARGS="\ + --tensor-model-parallel-size 2 \ + --pipeline-model-parallel-size 1 \ + --num-layers 42 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --attention-head-type multiquery \ + --init-method-std 0.015625 \ + --seq-length 8192 \ + --max-position-embeddings 8192 \ + --attention-dropout 0.1 \ + --hidden-dropout 0.1 \ + --micro-batch-size 1 \ + --global-batch-size 512 \ + --lr 0.0003 \ + --min-lr 0.00003 \ + --train-iters 250000 \ + --lr-decay-iters 250000 \ + --lr-decay-style cosine \ + --lr-warmup-iters 2000 \ + --weight-decay .1 \ + --adam-beta2 .95 \ + --clip-grad 1.0 \ + --bf16 \ + --use-flash-attn \ + --fim-rate 0.5 \ + --log-interval 10 \ + --save-interval 2500 \ + --eval-interval 2500 \ + --eval-iters 2 \ + --use-distributed-optimizer \ + --valid-num-workers 0 \ +" + +TENSORBOARD_ARGS="--tensorboard-dir ${CHECKPOINT_PATH}/tensorboard" + +CMD=" \ + /fsx/loubna/code/Megatron-LM/pretrain_gpt.py \ + $GPT_ARGS \ + --tokenizer-type TokenizerFromFile \ + --tokenizer-file $TOKENIZER_FILE \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --train-weighted-split-paths-path $WEIGHTS_TRAIN \ + --valid-weighted-split-paths-path $WEIGHTS_VALID \ + --structured-logs \ + --structured-logs-dir $CHECKPOINT_PATH/logs \ + $TENSORBOARD_ARGS \ + --wandb-entity-name loubnabnl \ + --wandb-project-name bigcode-pretraining \ + " + +export LAUNCHER="python -u -m torch.distributed.run \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --tee 3 \ + " + +echo $CMD + +# hide duplicated errors using this hack - will be properly fixed in pt-1.12 +# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json + +# force crashing on nccl issues like hanging broadcast +export NCCL_ASYNC_ERROR_HANDLING=1 +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=COLL +# export NCCL_SOCKET_NTHREADS=1 +# export NCCL_NSOCKS_PERTHREAD=1 +# export CUDA_LAUNCH_BLOCKING=1 + +# AWS specific +export NCCL_PROTO=simple +export RDMAV_FORK_SAFE=1 +export FI_EFA_FORK_SAFE=1 +export FI_EFA_USE_DEVICE_RDMA=1 +export FI_PROVIDER=efa +export FI_LOG_LEVEL=1 +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_IFNAME=ens + +export CUDA_HOME=/usr/local/cuda-11.6 +# This is needed for torch1.12.1 otherwise it doesn't link correctly, not sur what the issue was. +#export PATH="/usr/local/cuda-11.6/bin:$PATH" +#export LD_LIBRARY_PATH="/usr/local/cuda-11.6/lib64:$LD_LIBRARY_PATH" +#export LD_PRELOAD=$CUDA_HOME/lib/libnccl.so +#export LD_LIBRARY_PATH=$CUDA_HOME/efa/lib:$CUDA_HOME/lib:$CUDA_HOME/lib64:$LD_LIBRARY_PATH + +# srun error handling: +# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks +# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code +SRUN_ARGS=" \ + --wait=60 \ + --kill-on-bad-exit=1 \ + " + +# py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD +clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID --role \$SLURMD_NODENAME: $CMD" 2>&1 | tee $LOG_PATH + +echo "END TIME: $(date)" diff --git a/examples/pretrain_bigcode_model.slurm b/examples/pretrain_bigcode_model.slurm new file mode 100644 index 0000000000..b9f9f19cdd --- /dev/null +++ b/examples/pretrain_bigcode_model.slurm @@ -0,0 +1,138 @@ +#!/bin/bash +#SBATCH --job-name=bigcode-training +#SBATCH --nodes=64 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=96 +#SBATCH --gres=gpu:8 +#SBATCH --exclusive +#SBATCH --partition=production-cluster +#SBATCH --output=/fsx/bigcode/bigcode-training/logs/run-%x-%j.out + +set -x -e +source /admin/home/loubna/.bashrc + +conda activate megatron + +echo "START TIME: $(date)" + +# File Path setup +SCRIPT_REPO=/fsx/loubna/code/Megatron-LM +pushd $SCRIPT_REPO + +LOG_PATH=$SCRIPT_REPO/main_log.txt + +# Training setup +GPUS_PER_NODE=8 +MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +MASTER_PORT=6000 +NNODES=$SLURM_NNODES +NODE_RANK=$SLURM_PROCID +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +# File path setup +CHECKPOINT_PATH=/fsx/bigcode/experiments/pretraining/6672 +TOKENIZER_FILE=/fsx/loubna/data/tokenizer/tokenizer-the-stack-march-sample-v3-no-prefix-spaces/tokenizer.json +WEIGHTS_TRAIN=/fsx/loubna/code/bigcode-data-mix/data/train_data_paths.txt.tmp +WEIGHTS_VALID=/fsx/loubna/code/bigcode-data-mix/data/valid_data_paths.txt.tmp + +mkdir -p $CHECKPOINT_PATH/tensorboard + +GPT_ARGS="\ + --tensor-model-parallel-size 4 \ + --pipeline-model-parallel-size 4 \ + --sequence-parallel \ + --num-layers 40 \ + --hidden-size 6144 \ + --num-attention-heads 48 \ + --attention-head-type multiquery \ + --init-method-std 0.01275 \ + --seq-length 8192 \ + --max-position-embeddings 8192 \ + --attention-dropout 0.1 \ + --hidden-dropout 0.1 \ + --micro-batch-size 1 \ + --global-batch-size 512 \ + --lr 0.0003 \ + --min-lr 0.00003 \ + --train-iters 250000 \ + --lr-decay-iters 250000 \ + --lr-decay-style cosine \ + --lr-warmup-iters 2000 \ + --weight-decay .1 \ + --adam-beta2 .95 \ + --clip-grad 1.0 \ + --bf16 \ + --use-flash-attn \ + --fim-rate 0.5 \ + --log-interval 10 \ + --save-interval 2500 \ + --eval-interval 2500 \ + --eval-iters 2 \ + --use-distributed-optimizer \ + --valid-num-workers 0 \ +" + +TENSORBOARD_ARGS="--tensorboard-dir ${CHECKPOINT_PATH}/tensorboard" + +CMD=" \ + /fsx/loubna/code/Megatron-LM/pretrain_gpt.py \ + $GPT_ARGS \ + --tokenizer-type TokenizerFromFile \ + --tokenizer-file $TOKENIZER_FILE \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --train-weighted-split-paths-path $WEIGHTS_TRAIN \ + --valid-weighted-split-paths-path $WEIGHTS_VALID \ + --structured-logs \ + --structured-logs-dir $CHECKPOINT_PATH/logs \ + $TENSORBOARD_ARGS \ + --wandb-entity-name loubnabnl \ + --wandb-project-name bigcode-pretraining \ + " + +export LAUNCHER="python -u -m torch.distributed.run \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --tee 3 \ + " + +echo $CMD + +# hide duplicated errors using this hack - will be properly fixed in pt-1.12 +# export TORCHELASTIC_ERROR_FILE=/tmp/torch-elastic-error.json + +# force crashing on nccl issues like hanging broadcast +export NCCL_ASYNC_ERROR_HANDLING=1 +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=COLL +# export NCCL_SOCKET_NTHREADS=1 +# export NCCL_NSOCKS_PERTHREAD=1 +# export CUDA_LAUNCH_BLOCKING=1 + +# AWS specific +export NCCL_PROTO=simple +export RDMAV_FORK_SAFE=1 +export FI_EFA_FORK_SAFE=1 +export FI_EFA_USE_DEVICE_RDMA=1 +export FI_PROVIDER=efa +export FI_LOG_LEVEL=1 +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_IFNAME=ens + +export CUDA_HOME=/usr/local/cuda-11.6 + +# srun error handling: +# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks +# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code +SRUN_ARGS=" \ + --wait=60 \ + --kill-on-bad-exit=1 \ + " + +# py-spy top -s -i -n -- $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD +clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --node_rank \$SLURM_PROCID --role \$SLURMD_NODENAME: $CMD" 2>&1 | tee $LOG_PATH + +echo "END TIME: $(date)" diff --git a/examples/pretrain_gpt_1B_santacoder.sh b/examples/pretrain_gpt_1B_santacoder.sh new file mode 100644 index 0000000000..d14f538153 --- /dev/null +++ b/examples/pretrain_gpt_1B_santacoder.sh @@ -0,0 +1,62 @@ +#! /bin/bash + +set -u # stop on unset variables + +# Runs the SantaCoder 1B model + +GPUS_PER_NODE=8 +MASTER_ADDR=${MASTER_NODE} # Adjust +MASTER_PORT=6000 +NNODES=12 # Adjust +# NODE_RANK=0 # Adjust +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +CHECKPOINT_PATH=/my/experiment/path # Adjust: Directory to store the checkpoints +DATA_PATH=/preprocessed/data/path # Adjust: Prefix of the preprocessed dataset. +TOKENIZER_FILE=/tokenizer/path # Adjust + +GPT_ARGS="\ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --recompute-activations \ + --num-layers 24 \ + --hidden-size 2048 \ + --num-attention-heads 16 \ + --attention-head-type multiquery \ + --init-method-std 0.022 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --attention-dropout 0.1 \ + --hidden-dropout 0.1 \ + --micro-batch-size 2 \ + --global-batch-size 192 \ + --lr 0.0002 \ + --train-iters 300000 \ + --lr-decay-iters 600000 \ + --lr-decay-style cosine \ + --lr-warmup-fraction 0.02 \ + --weight-decay .1 \ + --adam-beta2 .95 \ + --clip-grad 1.0 \ + --fp16 \ + --log-interval 10 \ + --save-interval 4000 \ + --eval-interval 200 \ + --eval-iters 10 \ + --initial-loss-scale 65536 \ + --fim-rate 0.5 \ +" + +TENSORBOARD_ARGS="--tensorboard-dir ${CHECKPOINT_PATH}/tensorboard" + +torchrun $DISTRIBUTED_ARGS \ + pretrain_gpt.py \ + $GPT_ARGS \ + --tokenizer-type TokenizerFromFileWithFIM \ + --tokenizer-file $TOKENIZER_FILE \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + $TENSORBOARD_ARGS diff --git a/examples/pretrain_gpt_multilingual.sh b/examples/pretrain_gpt_multilingual.sh new file mode 100644 index 0000000000..5edebe770d --- /dev/null +++ b/examples/pretrain_gpt_multilingual.sh @@ -0,0 +1,65 @@ +#! /bin/bash + +# Runs the "345M" parameter model + +GPUS_PER_NODE=8 +MASTER_ADDR=localhost +MASTER_PORT=6000 +NNODES=1 # Adjust +NODE_RANK=0 # Adjust +WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) + +DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" + +# paths to multilingual preprocessed datasets +DATA_PATH_EN=_text_document +DATA_PATH_AR=_text_document +DATA_PATH_KR=_text_document +DATA_PATH_JP=_text_document + +CHECKPOINT_PATH= + + +torchrun $DISTRIBUTED_ARGS \ + pretrain_gpt.py \ + --num-layers 24 \ + --hidden-size 1024 \ + --num-attention-heads 16 \ + --micro-batch-size 4 \ + --global-batch-size 8 \ + --seq-length 1024 \ + --max-position-embeddings 1024 \ + --train-iters 1000 \ + --lr-decay-iters 320000 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --train-weighted-split-paths "TRAIN: 0.3 0:0.6 $DATA_EN 1 0:0.6 $DATA_AR 1 0:0.6 $DATA_KR 1 0:0.6 $DATA_JP" \ + --valid-weighted-split-paths \ + "VALID_EN: 1 0.6:0.8 $DATA_EN" \ + "VALID_AR: 1 0.6:0.8 $DATA_AR" \ + "VALID_JP: 1 0.6:0.8 $DATA_KR" \ + "VALID_KR: 1 0.6:0.8 $DATA_JP" \ + "VALID_EN-AR-JP-KR_BALANCED: 1 0.6:0.8 $DATA_EN, 1 0.6:0.8 $DATA_AR, 1 0.6:0.8 $DATA_JP, 1 0.6:0.8 $DATA_KR" \ + --test-weighted-split-paths \ + "TEST_EN: 1 0.8:1 $DATA_EN" \ + "TEST_AR: 1 0.8:1 $DATA_AR" \ + "TEST_JP: 1 0.8:1 $DATA_JP" \ + "TEST_KR: 1 0.8:1 $DATA_KR" \ + "TEST_EN-AR-JP-KR_BALANCED: 1 0.8:1 $DATA_EN, 1 0.8:1 $DATA_AR, 1 0.8:1 $DATA_JP, 1 0.8:1 $DATA_KR" \ + --vocab-file gpt2-vocab.json \ + --merge-file gpt2-merges.txt \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + --lr 0.00015 \ + --min-lr 1.0e-5 \ + --lr-decay-style cosine \ + --weight-decay 1e-2 \ + --clip-grad 1.0 \ + --lr-warmup-fraction .01 \ + --checkpoint-activations \ + --log-interval 100 \ + --save-interval 10000 \ + --eval-interval 1000 \ + --eval-iters 10 \ + --fp16 diff --git a/megatron/arguments.py b/megatron/arguments.py index 84a007c026..01040e5023 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -5,6 +5,8 @@ import argparse import json import os +import re + import torch import types @@ -12,6 +14,10 @@ from tools.retro.utils import get_args_path as get_retro_args_path +import megatron +from megatron.model.enums import PositionEmbeddingType + + def parse_args(extra_args_provider=None, ignore_unknown_args=False): """Parse all arguments.""" parser = argparse.ArgumentParser(description='Megatron-LM Arguments', @@ -90,6 +96,30 @@ def validate_args(args, defaults={}): ' to be less than pipeline model parallel size ({})'.format( args.pipeline_model_parallel_size) + # --data-path and --train-weighted-splits-paths + message = "Data loading Mode 1: --data-path and --split "\ + "and Mode 2: --(train|valid|test)-weighted-split-paths"\ + "are mutually exclusive i.e. cannot be set together." + + if args.data_path: + assert args.train_weighted_split_paths is None, message + setattr(args, "valid_weighted_split_names", None) + setattr(args, "valid_weighted_split_weights", None) + setattr(args, "valid_weighted_split_splits", None) + + setattr(args, "test_weighted_split_names", None) + setattr(args, "test_weighted_split_weights", None) + setattr(args, "test_weighted_split_splits", None) + + # args.split default value in the args is None it is set here in order + # to check that it does not to overlap with the 2nd mode of data loading + if args.split is None: + args.split = "969, 30, 1" + + if args.train_weighted_split_paths or args.valid_weighted_split_paths or \ + args.test_weighted_split_paths: + assert args.data_path is None and args.split is None, message + # Deprecated arguments assert args.batch_size is None, '--batch-size argument is no longer ' \ 'valid, use --micro-batch-size instead' @@ -255,6 +285,8 @@ def validate_args(args, defaults={}): # we keep it a multiple of 64, which means the actual tensor size # will be a multiple of 64 / tp_size args.ffn_hidden_size = int((4 * args.hidden_size * 2 / 3) / 64) * 64 + assert not args.glu_activation, \ + "--glu-activation is not compatible with --swiglu. Use only one of the two." if args.kv_channels is None: assert args.hidden_size % args.num_attention_heads == 0 @@ -267,10 +299,21 @@ def validate_args(args, defaults={}): assert args.encoder_seq_length is not None args.seq_length = args.encoder_seq_length - if args.seq_length is not None: - assert args.max_position_embeddings >= args.seq_length - if args.decoder_seq_length is not None: - assert args.max_position_embeddings >= args.decoder_seq_length + # NOTE: this was before integrating alibi + # if args.seq_length is not None: + # assert args.max_position_embeddings >= args.seq_length + # if args.decoder_seq_length is not None: + # assert args.max_position_embeddings >= args.decoder_seq_length + + if args.position_embedding_type == PositionEmbeddingType.absolute or args.position_embedding_type == PositionEmbeddingType.alibi: + assert args.max_position_embeddings is not None + if args.seq_length is not None: + assert args.max_position_embeddings >= args.seq_length + if args.decoder_seq_length is not None: + assert args.max_position_embeddings >= args.decoder_seq_length + else: + assert args.max_position_embeddings is None + if args.lr is not None: assert args.min_lr <= args.lr if args.save is not None: @@ -282,6 +325,10 @@ def validate_args(args, defaults={}): assert args.fp16 or args.bf16, \ 'residual connection in fp32 only supported when using fp16 or bf16.' + # Activation function + if args.glu_activation is not None and args.bias_gelu_fusion: + raise ValueError("if glu-activation is used, please set --no-bias-gelu-fusion") + if args.weight_decay_incr_style == 'constant': assert args.start_weight_decay is None assert args.end_weight_decay is None @@ -316,6 +363,15 @@ def validate_args(args, defaults={}): 'distributed recompute activations are supported for pytorch ' \ 'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \ 'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR) + + # Weights and Biases + if args.wandb_entity_name or args.wandb_project_name: + assert args.wandb_entity_name and args.wandb_project_name, \ + "Both entity and project name must be set in order to report to wandb" + + # Local-rank from environment variable (if using torchrun) + if args.local_rank is None and "LOCAL_RANK" in os.environ: + args.local_rank = int(os.environ["LOCAL_RANK"]) # Tranformer-Engine/FP8 related checking if args.fp8_e4m3 or args.fp8_hybrid: @@ -341,6 +397,10 @@ def validate_args(args, defaults={}): if args.sequence_parallel: args.async_tensor_model_parallel_allreduce = False + if args.use_flash_attn: + assert not args.reset_attention_mask, \ + "Flash Attention doesn't support arbitrary attention masks. Please turn off reset-attention-mask" + if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": if args.sequence_parallel: raise RuntimeError( @@ -512,6 +572,10 @@ def _add_network_size_args(parser): 'attention. This is set to ' ' args.hidden_size // args.num_attention_heads ' 'if not provided.') + group.add_argument('--attention-head-type', type=str, default=None, + choices=['multihead', 'multiquery'], + help='Type of attention heads. `multihead` is the standard multi-head attention.' + '`multiquery` shares the values and keys across attention heads') group.add_argument('--max-position-embeddings', type=int, default=None, help='Maximum number of position embeddings to use. ' 'This is the size of position embedding.') @@ -519,6 +583,8 @@ def _add_network_size_args(parser): help='Use rotary positional embeddings or not') group.add_argument('--rotary-percent', type=float, default=1.0, help='Percent of rotary dimension to use, default 100%') + group.add_argument('--rotary-theta', type=int, default=10000, + help='Theta/frequency value for rotary positional embeddings') group.add_argument('--no-position-embedding', action='store_false', help='Disable position embedding.', @@ -549,6 +615,15 @@ def _add_network_size_args(parser): group.add_argument('--bert-no-binary-head', action='store_false', help='Disable BERT binary head.', dest='bert_binary_head') + group.add_argument('--position-embedding-type', type=lambda x: PositionEmbeddingType[x], + choices=list(PositionEmbeddingType), + default=PositionEmbeddingType.absolute, + help='Define position embedding type ("absolute" | "rotary" | "alibi"). "absolute" by default.' + ) + group.add_argument('--glu-activation', type=str, + choices=megatron.model.glu_activations.GLU_ACTIVATIONS.keys(), + help='GLU activations to use.' + ) group.add_argument('--num-experts', type=int, default=None, help='Number of Experts in Switch Transformer (None means no Switch)') group.add_argument('--untie-embeddings-and-output-weights', action='store_true', @@ -617,6 +692,18 @@ def _add_logging_args(parser): group.add_argument('--log-world-size-to-tensorboard', action='store_true', help='Enable world size logging to tensorboard.') + + group.add_argument('--wandb-entity-name', type=str, default=None, + help="Name of wandb entity for reporting") + group.add_argument('--wandb-project-name', type=str, default=None, + help="Name of wandb project") + group.add_argument('--transformer-timers', action='store_true', + help="If set, activate the timers within the transformer layers." + "Only for debugging, as this slows down the model.") + group.add_argument('--structured-logs', action="store_true", + help='Add timestamp and worker name to stdout and stderr.') + group.add_argument('--structured-logs-dir', type=str, default=None, + help='Directory to save the logs.') return parser @@ -978,6 +1065,9 @@ def _add_distributed_args(parser): 'affects the encoder embedding.)') group.add_argument('--use-distributed-optimizer', action='store_true', help='Use distributed optimizer.') + group.add_argument('--distributed-timeout', default=600, type=float, + help='Timeout for distributed operations, in seconds. ' + 'Should be at least as high as the dataset preprocessing ans checkpoint saving times.') return parser @@ -998,6 +1088,7 @@ def _add_validation_args(parser): def _add_data_args(parser): group = parser.add_argument_group(title='data and dataloader') + # option 1 for data loading (mutually exclusive with option2) group.add_argument('--data-path', nargs='*', default=None, help='Path to the training dataset. Accepted format:' '1) a single data path, 2) multiple datasets in the' @@ -1006,11 +1097,108 @@ def _add_data_args(parser): 'single dataset used for all three: train, valid ' 'and test. It is exclusive to the other ' '--*-data-path args') - group.add_argument('--split', type=str, default='969, 30, 1', + group.add_argument('--split', type=str, default=None, help='Comma-separated list of proportions for training,' ' validation, and test split. For example the split ' '`90,5,5` will use 90%% of data for training, 5%% for ' 'validation and 5%% for test.') + # option 2 for data loading (mutually exclusive with option1) + # see https://github.com/bigscience-workshop/Megatron-DeepSpeed/pull/97/files + + # helper class to parse the --xxx-weighted-split-paths + # note here two args are set: extra valid dataset paths and names + class parse_data_paths(argparse.Action): + def __call__(self, parser, args, values, option_string=None): + + if option_string == "--train-weighted-split-paths": + assert len(values) == 1, 'Only 1 dataset group is allowed to' + 'be passed for the argument --train-weighted-split-paths' + + # make sure string given in the correct format + err_message = 'Each data group should be input on the following format' + '"GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2"' + 'where START < END' + for v in values: + # each prefix consists several datasets separated by commas + prefix = ":".join(v.split(":")[1:]) # remove GIVEN_NAME + datasets = prefix.split(",") + # check if each dataset is formatted like `WEIGHT START:END PATH` + for d in datasets: + assert len(d.split()) == 3, err_message + start, end = d.split()[1].split(":") + assert float(start) < float(end), err_message + + names = [v.split(":")[0] for v in values] + + prefixes = [":".join(v.split(":")[1:]).strip() for v in values] + weights = [[d.split()[0] for d in p.split(",")] for p in prefixes] + splits = [[d.split()[1] for d in p.split(",")] for p in prefixes] + paths = [[d.split()[2] for d in p.split(",")] for p in prefixes] + + # # to keep consistency with Option 1 of data loading (through --data-path) + # # paths will contain strings on the following form + # # "WEIGHTS1 PATH1 WEIGHTS2 PATH2 WEIGHTS3 PATH3" for each dataset group + # # while data will be parsed in additional arguments below + # paths_option1_style = [] + # for p, w in zip(paths, weights): + # paths_option1_style.append(" ".join([f"{w_i} {p_i}" for p_i, w_i in zip(p,w)])) + # setattr(args, self.dest, paths_option1_style) + setattr(args, self.dest, paths) + setattr(args, self.dest.replace("paths", "weights"), weights) + setattr(args, self.dest.replace("paths", "splits"), splits) + setattr(args, self.dest.replace("paths","names"), names) + + + group.add_argument('--train-weighted-split-paths', nargs='*', default=None, + help='Weights, splits and paths to groups of datasets' + 'Accepted format: ONE dataset groups could be' + 'submitted in the following form between double quotes' + '"GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2"' + 'e.g.: "NAME_ABC: 0.6 0:0.6 A, 0.3 0:1 B, 0.1 0:1 C" ' + 'WEIGHT is used to up and down sample each dataset A,B,C in the group' + 'START:END indicates the split portion of the dataset', + action=parse_data_paths) + + group.add_argument('--valid-weighted-split-paths', nargs='*', default=None, + help='Weights, splits and paths to groups of datasets' + 'Accepted format: one or many dataset groups could be' + 'submitted in the following form each between double quotes' + '"GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2"' + 'e.g.: "NAME_ABC: 0.6 0.6:0.8 A, 0.3 0:1 B, 0.1 0:1 C" ' + '"NAME_CDE: 0.6 0.6:0.8 C, 0.3 0:1 D, 0.1 0:1 E" ' + 'validation will be run on each of those groups independently', + action=parse_data_paths) + + group.add_argument('--test-weighted-split-paths', nargs='*', default=None, + help='Weights, splits and paths to groups of datasets' + 'Accepted format: one or many dataset groups could be' + 'submitted in the following form each between double quotes' + '"GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2"' + 'e.g.: "NAME_ABC: 0.6 0.6:0.8 A, 0.3 0:1 B, 0.1 0:1 C" ' + '"NAME_CDE: 0.6 0.6:0.8 C, 0.3 0:1 D, 0.1 0:1 E" ' + 'test will be run on each of those groups independently', + action=parse_data_paths) + + class parse_data_paths_path(argparse.Action): + def __call__(self, parser, args, values, option_string=None): + expected_option_strings = ["--train-weighted-split-paths-path", "--valid-weighted-split-paths-path", "--test-weighted-split-paths-path"] + assert option_string in expected_option_strings, f"Expected {option_string} to be in {expected_option_strings}" + + with open(values, "r") as fi: + lines = fi.readlines() + assert len(lines) == 1, f"Got multiple lines {len(lines)} instead of 1 expected" + assert lines[0][-2:] == "\"\n" and lines[0][0] == "\"", f"Invalid input format, got {lines}" + values = lines[0][1:-2].split("\" \"") + weighted_split_paths_dest = re.sub(r"_path$", "", self.dest) + weighted_split_paths_option = re.sub(r"-path$", "", self.option_strings[0]) + setattr(args, weighted_split_paths_dest, values) + parse_data_paths(option_strings=[weighted_split_paths_option], dest=weighted_split_paths_dest)(parser, args, values, option_string=weighted_split_paths_option) + + # option 2-bis: load x-weighted-split-paths from a file in case this argument is very long + group.add_argument('--train-weighted-split-paths-path', type=str, action=parse_data_paths_path ,default=None) + group.add_argument('--valid-weighted-split-paths-path', type=str, action=parse_data_paths_path, default=None) + group.add_argument('--test-weighted-split-paths-path', type=str, action=parse_data_paths_path, default=None) + group.add_argument('--train-data-path', nargs='*', default=None, help='Path to the training dataset. Accepted format:' '1) a single data path, 2) multiple datasets in the' @@ -1033,6 +1221,8 @@ def _add_data_args(parser): help='Path to the vocab file.') group.add_argument('--merge-file', type=str, default=None, help='Path to the BPE merge file.') + group.add_argument('--tokenizer-file', type=str, default=None, + help='Path to the tokenizer.json file. Used for the TokenizerFromFile[...] tokenizers') group.add_argument('--vocab-extra-ids', type=int, default=0, help='Number of additional vocabulary tokens. ' 'They are used for span masking in the T5 model') @@ -1057,11 +1247,16 @@ def _add_data_args(parser): help='Warm up mmap files.') group.add_argument('--num-workers', type=int, default=2, help="Dataloader number of workers.") + group.add_argument('--valid-num-workers', type=int, default=2, + help="Dataloader number of workers for validation.") group.add_argument('--tokenizer-type', type=str, default=None, choices=['BertWordPieceLowerCase', 'BertWordPieceCase', 'GPT2BPETokenizer', + 'GPT2BPETokenizerWithFIM', + 'TokenizerFromFile', + 'TokenizerFromFileWithFIM', 'SentencePieceTokenizer', 'GPTSentencePieceTokenizer', 'NullTokenizer'], @@ -1078,6 +1273,17 @@ def _add_data_args(parser): 'end-of-document token.') group.add_argument('--eod-mask-loss', action='store_true', help='Mask loss for the end of document tokens.') + group.add_argument('--fim-rate', type=float, default=0., + help='Probability to convert a training sample into a "Fill-in-the-Middle" format. Must be between 0 and 1.') + group.add_argument('--fim-spm-rate', type=float, default=0.5, + help='Probability that the a FIM sample uses the SPM format over the PSM format. ' + 'At 1, exclusively train with SPM. At 0, exclusively train with PSM') + group.add_argument('--fim-split-sample', type=str, default=None, + help='String around which to split the sample for FIM. If None (default), FIM is applied on the sample-level') + group.add_argument('--fragment-fim-rate', type=float, default=0.5, + help='Rate of FIM on each fragment when fim_split_sample is not None.') + group.add_argument('--sanity-check-dataloader-interval', type=int, default=None, + help='Optional interval to print dataloader samples.') return parser diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 41b0535704..5a30619cd8 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -14,6 +14,7 @@ from .global_vars import get_args from .utils import (unwrap_model, print_rank_0) +from megatron.model.enums import PositionEmbeddingType _CHECKPOINT_VERSION = None @@ -52,8 +53,16 @@ def _compare(arg_name, old_arg_name=None): _compare('hidden_size') _compare('num_attention_heads') _compare('add_position_embedding') - if args.vocab_file: + try: + _compare('position_embedding_type') + except AttributeError as e: + print_rank_0(f" Warning, trying to load an old checkpoint: {e}") + assert args.position_embedding_type == PositionEmbeddingType.absolute, \ + f"Checkpoint uses PositionEmbeddingType.absolute, but input argument value was: {args.position_embedding_type}" + # with alibi we can change `max_position_embeddings` + if args.position_embedding_type != PositionEmbeddingType.alibi: _compare('max_position_embeddings') + if args.vocab_file or args.tokenizer_file: _compare('make_vocab_size_divisible_by') _compare('padded_vocab_size') _compare('tokenizer_type') @@ -131,7 +140,7 @@ def find_checkpoint_rank_0(checkpoints_path, iteration, release=False): if os.path.isfile(filename): return filename - return None, None + return None def get_checkpoint_tracker_filename(checkpoints_path): @@ -224,7 +233,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler): checkpoint_name = get_checkpoint_name(args.save, iteration) # Save distributed optimizer's custom parameter state. - if args.use_distributed_optimizer: + if args.use_distributed_optimizer and not args.no_save_optim: optim_checkpoint_name = \ get_distributed_optimizer_checkpoint_name(checkpoint_name) ensure_directory_exists(optim_checkpoint_name) @@ -350,28 +359,29 @@ def fix_query_key_value_ordering(model, checkpoint_version): print_rank_0(" succesfully fixed query-key-values ordering for" " checkpoint version {}".format(checkpoint_version)) - -def _load_base_checkpoint(load_dir, rank0=False): +def _load_base_checkpoint(load_dir, rank0=False, iteration=None, release=None): """ Load the base state_dict from the given directory If rank0 is true, just loads rank 0 checkpoint, ignoring arguments. + If rank0 is true or no_load_optim is true, we do not care about the optimizer, only the model checkpoint. """ # Read the tracker file and set the iteration. - tracker_filename = get_checkpoint_tracker_filename(load_dir) + if iteration is None and release is None: + tracker_filename = get_checkpoint_tracker_filename(load_dir) - # If no tracker file, return nothing - if not os.path.isfile(tracker_filename): - if not rank0: - print_rank_0('WARNING: could not find the metadata file {} '.format( - tracker_filename)) - print_rank_0(' will not load any checkpoints and will start from ' - 'random') - return None, False + # If no tracker file, return nothing + if not os.path.isfile(tracker_filename): + if not rank0: + print_rank_0('WARNING: could not find the metadata file {} '.format( + tracker_filename)) + print_rank_0(' will not load any checkpoints and will start from ' + 'random') + return None, False - # Otherwise, read the tracker file and either set the iteration or - # mark it as a release checkpoint. - iteration, release = read_metadata(tracker_filename) + # Otherwise, read the tracker file and either set the iteration or + # mark it as a release checkpoint. + iteration, release = read_metadata(tracker_filename) # Checkpoint. if rank0: @@ -475,6 +485,7 @@ def _set_arg(arg_name, old_arg_name=None, force=False): _set_arg('apply_layernorm_1p', force=True) _set_arg('tokenizer_type') _set_arg('padded_vocab_size') + _set_arg('attention_head_type') if checkpoint_version < 3.0: _set_arg('tensor_model_parallel_size', 'model_parallel_size') @@ -486,7 +497,7 @@ def _set_arg(arg_name, old_arg_name=None, force=False): return args, checkpoint_args -def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True): +def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True, iteration=None): """Load a model checkpoint and return the iteration. strict (bool): whether to strictly enforce that the keys in :attr:`state_dict` of the checkpoint match the names of @@ -497,7 +508,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri model = unwrap_model(model) - state_dict, release = _load_base_checkpoint(load_dir, rank0=False) + state_dict, release = _load_base_checkpoint(load_dir, rank0=False, iteration=iteration) # Checkpoint not loaded. if state_dict is None: @@ -526,7 +537,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri except KeyError: print_rank_0('A metadata file exists but unable to load ' 'iteration from checkpoint {}, exiting'.format( - checkpoint_name)) + model_checkpoint_name)) sys.exit() # Check arguments. @@ -584,7 +595,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri print_rank_0('Unable to load optimizer from checkpoint {}. ' 'Specify --no-load-optim or --finetune to prevent ' 'attempting to load the optimizer state, ' - 'exiting ...'.format(checkpoint_name)) + 'exiting ...'.format(model_checkpoint_name)) sys.exit() else: if (args.fp16 or args.bf16) and optimizer is not None: @@ -623,14 +634,14 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri print_rank_0('Unable to load rng state from checkpoint {}. ' 'Specify --no-load-rng or --finetune to prevent ' 'attempting to load the rng state, ' - 'exiting ...'.format(checkpoint_name)) + 'exiting ...'.format(model_checkpoint_name)) sys.exit() # Some utilities want to load a checkpoint without distributed being initialized if torch.distributed.is_initialized(): torch.distributed.barrier() - print_rank_0(f' successfully loaded checkpoint from {args.load} ' + print_rank_0(f' successfully loaded checkpoint from {load_dir} ' f'at iteration {iteration}') return iteration @@ -661,7 +672,7 @@ def load_biencoder_checkpoint(model, only_query_model=False, print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) - state_dict = torch.load(model_checkpoint_name, map_location='cpu') + state_dict = torch.load(checkpoint_name, map_location='cpu') ret_state_dict = state_dict['model'] if only_query_model: diff --git a/megatron/data/data_samplers.py b/megatron/data/data_samplers.py index 8dec2c1922..c3ebd87f6d 100644 --- a/megatron/data/data_samplers.py +++ b/megatron/data/data_samplers.py @@ -11,7 +11,7 @@ from megatron.core import mpu -def build_pretraining_data_loader(dataset, consumed_samples): +def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None): """Buld dataloader given an input dataset.""" if dataset is None: @@ -39,11 +39,54 @@ def build_pretraining_data_loader(dataset, consumed_samples): raise Exception('{} dataloader type is not supported.'.format( args.dataloader_type)) + num_workers = args.num_workers if num_workers is None else num_workers # Torch dataloader. - return torch.utils.data.DataLoader(dataset, + dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler, - num_workers=args.num_workers, + num_workers=num_workers, pin_memory=True) + + if args.sanity_check_dataloader_interval is not None: + from transformers import AutoTokenizer + from megatron import is_last_rank + + NUM_BATCHES = 10 + sanity_check_dataloader_interval = args.sanity_check_dataloader_interval + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_file.split("tokenizer.json")[0]) + + if is_last_rank(): + check_step = -1 + + with open("sanity_check.txt", "w") as f: + f.write("") + for i, batch in enumerate(dataloader): + check_step += 1 + if i % sanity_check_dataloader_interval == 0: + with open("sanity_check.txt", "a") as f: + f.write("\n\n") + f.write("*" * 40) + f.write(f"Sanity check {check_step}") + f.write("*" * 40) + print(batch) + + import joblib + + joblib.dump(batch, f"sanity_check_{check_step}.pkl") + texts = tokenizer.batch_decode( + batch["text"], skip_special_tokens=False, clean_up_tokenization_spaces=False + ) + + for j, text in enumerate(texts): + print(f"\n\n>>Batch {i} || Sample {j}<<\n") + print(text[:1000]) + with open("sanity_check.txt", "a", encoding='utf-8') as f: + f.write(f"\n\n>>Batch {i} || Sample {j}<<\n") + f.write(text) + + if i // sanity_check_dataloader_interval == NUM_BATCHES - 1: + break + assert False + return dataloader class MegatronPretrainingSampler: diff --git a/megatron/data/dataset_utils.py b/megatron/data/dataset_utils.py index 2f6f3e2fe9..45f51f6a4f 100644 --- a/megatron/data/dataset_utils.py +++ b/megatron/data/dataset_utils.py @@ -41,8 +41,7 @@ DSET_TYPES = [DSET_TYPE_BERT, DSET_TYPE_ICT, DSET_TYPE_T5] -def get_datasets_weights_and_num_samples(data_prefix, - train_valid_test_num_samples): +def analyze_data_prefix(data_prefix): # The data prefix should be in the format of: # weight-1, data-prefix-1, weight-2, data-prefix-2, .. @@ -59,8 +58,14 @@ def get_datasets_weights_and_num_samples(data_prefix, weight_sum += weight assert weight_sum > 0.0 weights = [weight / weight_sum for weight in weights] + return prefixes, weights + - # Add 0.5% (the 1.005 factor) so in case the bleding dataset does +def get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples): + + prefixes, weights = analyze_data_prefix(data_prefix) + # Add 0.5% (the 1.005 factor) so in case the blending dataset does # not uniformly distribute the number of samples, we still have # samples left to feed to the network. if isinstance(train_valid_test_num_samples, list): @@ -614,6 +619,22 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup): return indexed_dataset +def get_split_by_range_(range_string, size): + """ Get dataset splits based on a range: + range_string is in the form START%:END% for e.g. 0.2:0.8 + outputs an array of two values [start_index, end_index] + """ + # some checks that range is given in the correct form + splits = [float(i) for i in range_string.split(":")] + assert len(splits) == 2, "splits should be passed as start:end" + assert splits[0] <= 1 and splits[1] <= 1 + splits_sum = sum(splits) + assert splits_sum > 0.0 + splits_index = [round(s * float(size)) for s in splits] + assert len(splits_index) == 2 + return splits_index + + def get_train_valid_test_split_(splits_string, size): """ Get dataset splits from comma or '/' separated string list.""" diff --git a/megatron/data/gpt_dataset.py b/megatron/data/gpt_dataset.py index 602e511678..16a29d0d60 100644 --- a/megatron/data/gpt_dataset.py +++ b/megatron/data/gpt_dataset.py @@ -2,18 +2,20 @@ """GPT style dataset.""" +import itertools import os import time import numpy as np import torch -from megatron import print_rank_0 +from megatron import print_rank_0, get_args, get_tokenizer from megatron.core import mpu from megatron.data.blendable_dataset import BlendableDataset from megatron.data.dataset_utils import get_datasets_weights_and_num_samples -from megatron.data.dataset_utils import get_train_valid_test_split_ +from megatron.data.dataset_utils import get_train_valid_test_split_, get_split_by_range_ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset +from megatron.tokenizer.tokenizer import FIM_MIDDLE, FIM_PAD, FIM_PREFIX, FIM_SUFFIX def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, @@ -25,21 +27,20 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, return_doc_ids=False): """Build train, valid, and test datasets.""" - if data_prefix: + # Single dataset. + if data_prefix and len(data_prefix) == 1: print_rank_0("Single data path provided for train, valid & test") - - # Single dataset. - if len(data_prefix) == 1: - return _build_train_valid_test_datasets(data_prefix[0], - data_impl, splits_string, - train_valid_test_num_samples, - seq_length, seed, skip_warmup) - - # Blending dataset. - # Parse the values. + all_train_datasets, all_valid_datasets, all_test_datasets = _build_train_valid_test_datasets(data_prefix[0], + data_impl, splits_string, + train_valid_test_num_samples, + seq_length, seed, skip_warmup) + # Blending dataset. + elif data_prefix: + print_rank_0("Blending dataset for train, valid & test") output = get_datasets_weights_and_num_samples(data_prefix, - train_valid_test_num_samples) + train_valid_test_num_samples) prefixes, weights, datasets_train_valid_test_num_samples = output + train_num_samples, valid_num_samples, test_num_samples = map( sum, zip(*datasets_train_valid_test_num_samples) @@ -61,21 +62,12 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, valid_datasets.append(valid_ds) if test_ds: test_datasets.append(test_ds) - - # Blend. - blending_train_dataset = None - if train_datasets: - blending_train_dataset = BlendableDataset(train_datasets, weights, train_num_samples) - blending_valid_dataset = None - if valid_datasets: - blending_valid_dataset = BlendableDataset(valid_datasets, weights, valid_num_samples) - blending_test_dataset = None - if test_datasets: - blending_test_dataset = BlendableDataset(test_datasets, weights, test_num_samples) - - return (blending_train_dataset, blending_valid_dataset, - blending_test_dataset) - + all_train_datasets = BlendableDataset(train_datasets, weights, train_num_samples) \ + if train_datasets else None + all_valid_datasets = BlendableDataset(valid_datasets, weights, valid_num_samples) \ + if valid_datasets else None + all_test_datasets = BlendableDataset(test_datasets, weights, test_num_samples) \ + if test_datasets else None else: print_rank_0("Separate data paths provided for train, valid & test. Split string will be ignored.") @@ -98,6 +90,107 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, return (train_dataset, valid_dataset, test_dataset) + return all_train_datasets, all_valid_datasets, all_test_datasets + + + +def build_dataset_group(dataset_group_name, paths, weights, splits, data_impl, + train_valid_test_num_samples, + seq_length, seed, skip_warmup, train_valid_test): + ''' + Build a single dataset group corresponding to Option 2 of data loading see arguments.py + a dataset group is passed on the following form + GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT2 START:END PATH2 + or alternatively + GIVEN_NAME PATH1 # for a single dataset to be used fully + ''' + + assert train_valid_test in ["train","valid","test"] + index = ["train","valid","test"].index(train_valid_test) + + # Single dataset. + if len(paths) == 1: + dataset = _build_single_datasets(paths[0], + splits[0], + data_impl, + train_valid_test_num_samples, + seq_length, seed, skip_warmup, + dataset_group_name, train_valid_test) + return dataset + # Blending dataset. + else: + + data_prefix = [] + # data_prefix is on the shape: + # ["WEIGHT1", "PATH1", "WEIGHT2", "PATH2", "WEIGHT3", "PATH3"] + for w,p in zip(weights, paths): + data_prefix += [w,p] + + output = get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + datasets = [] + for i in range(len(prefixes)): + ds = _build_single_datasets(prefixes[i], + splits[i], + data_impl, + datasets_train_valid_test_num_samples[i], + seq_length, + seed, skip_warmup, + dataset_group_name, train_valid_test) + + # ds can be none if the dataset is so small that not a single document + # is present in the split. + assert ds is not None, \ + f"Got an empty split when trying to create dataset: {prefixes[i], splits[i]}" + datasets.append(ds) + all_datasets = BlendableDataset(datasets, weights, train_valid_test_num_samples[index]) + + return all_datasets + +def _build_single_datasets(data_prefix, range_string, data_impl, train_valid_test_num_samples, + seq_length, seed, skip_warmup, dataset_group_name, train_valid_test): + """Build a single dataset""" + + assert train_valid_test in ["train","valid","test"] + index = ["train","valid","test"].index(train_valid_test) + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, + data_impl, + skip_warmup) + + total_num_of_documents = indexed_dataset.sizes.shape[0] + # this corresponds to option2 for data loading on the form + # WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT3 START:END PATH3 + # splits here is an array of size 2 [start_index, end_index] + splits = get_split_by_range_(range_string=range_string, size=total_num_of_documents) + + # Print stats about the splits. + print_rank_0(' > dataset split:') + + print_rank_0(' {}:'.format(dataset_group_name)) + print_rank_0(' document indices in [{}, {}) total of {} ' + 'documents'.format(splits[0], splits[1], + splits[1] - splits[0])) + + def build_dataset(name): + dataset = None + if splits[1] > splits[0]: + documents = np.arange(start=splits[0], stop=splits[1], + step=1, dtype=np.int32) + dataset = GPTDataset(name, data_prefix, + documents, indexed_dataset, + train_valid_test_num_samples[index], + seq_length, seed) + return dataset + + dataset = build_dataset(dataset_group_name) + + return dataset + def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, train_valid_test_num_samples, @@ -111,6 +204,7 @@ def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, skip_warmup) total_num_of_documents = indexed_dataset.sizes.shape[0] + # splits here is an array of size 4 [train_start_index, valid_start_index, test_start_index, test_end_index] splits = get_train_valid_test_split_(splits_string, total_num_of_documents) # Print stats about the splits. @@ -237,7 +331,20 @@ def __init__(self, name, data_prefix, documents, indexed_dataset, _build_index_mappings(self.name, data_prefix, documents, self.indexed_dataset.sizes, num_samples, seq_length, seed) + + self.args = get_args() + self.tokenizer = get_tokenizer() + self.np_rng = np.random.RandomState(seed=seed) # rng state for FIM + + self.fim_rate = self.args.fim_rate + self.fim_spm_rate = self.args.fim_spm_rate + self.fragment_fim_rate = self.args.fragment_fim_rate + self.fim_split_sample = self.tokenizer.vocab[self.args.fim_split_sample] if self.args.fim_split_sample is not None else None + try: + self.suffix_tok_id, self.prefix_tok_id, self.middle_tok_id, self.pad_tok_id = (self.tokenizer.special_tokens[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD]) + except KeyError: + self.suffix_tok_id, self.prefix_tok_id, self.middle_tok_id, self.pad_tok_id = (self.tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD]) def __len__(self): # -1 is due to data structure used to retieve the index: @@ -274,12 +381,99 @@ def __getitem__(self, idx): self.doc_idx[doc_index_l], length=offset_l + 1)) sample = np.concatenate(sample_list) - + + # Code from: https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py#L109 + # TODO(Hailey): can merge the code below this line with code above this line. + # TODO(Hailey), cont: above already iterates through loop, so just add the permuting in there? + sample = np.array(sample, dtype=np.int64) + sample_len = sample.shape[0] + # # print(sample, sample.shape) + # # do FIM here, if enabled + # TODO: Do we handle the following point from FIM paper? + # To transform data in the character space for context-level FIM, the tokenized documents have to be decoded back into strings before FIM augmentation. Depending on the vocabulary, some care has to be given to ensure decoding does not introduce any spurious characters into training. For example, utf-8 characters are encoded as multiple tokens with a BPE vocabulary; they can result in fragments from chunking and fail to decode. To prevent unforeseen errors midway through training, we encourage checking for these fragments at the beginning or end of a context and removing them. + eod = self.tokenizer.eod + segment_breaks = np.argwhere(sample == eod) # split sample by document + + if self.fim_rate == 0: + return sample.astype(np.int64) + + def fim_permute_sequence(sequence, rate): + return permute( + sequence, + self.np_rng, + rate, + self.fim_spm_rate, + self.tokenizer, + truncate_or_pad=False, + suffix_tok_id=self.suffix_tok_id, + prefix_tok_id=self.prefix_tok_id, + middle_tok_id=self.middle_tok_id, + pad_tok_id=self.pad_tok_id, + ) + + def fim_split_and_permute_sequence(sequence): + """ + If self.fim_split_sample is not None, split the sequence. + Then apply FIM on the fragments, or the whole sequence if self.fim_split_sample is None. + """ + if self.fim_split_sample is None: + return fim_permute_sequence(sequence, self.fim_rate) + # fim_split_sample is set: split the sample on this token and permute each fragment separately. + # Typically, if each sample is a repository, then we split again on the file level. + # Each fragment is a file, and we permute the files. + fragment_breaks = np.argwhere(sequence == self.fim_split_sample) + if fragment_breaks.shape == (0, 1): + # no split token in this sample + return fim_permute_sequence(sequence, self.fim_rate) + if not self.np_rng.binomial(1, self.fim_rate): + # don't do FIM preproc + return sequence + # Do FIM on each fragment + curr_start_position = 0 + new_samples = [] + for loc in np.nditer(fragment_breaks): + if loc - curr_start_position > 0: + permuted = fim_permute_sequence(sequence[curr_start_position:loc], self.fragment_fim_rate) + new_samples += [permuted, [self.fim_split_sample]] + curr_start_position = loc + 1 # Jump over the split token + # Permute the segment after the last split token + permuted = fim_permute_sequence(sequence[curr_start_position:], self.fragment_fim_rate) + new_samples.append(permuted) + return np.concatenate(new_samples) + + if segment_breaks.shape != (0, 1): # then there is an EOD token in this example + curr_start_position = 0 + new_samples = [] + for loc in np.nditer(segment_breaks): + # Only permute non-empty segments. + if loc - curr_start_position > 0: + # permute {prefix, suffix, middle} or {suffix, prefix, middle} + permuted = fim_split_and_permute_sequence(sample[curr_start_position:loc]) + new_samples += [permuted, [eod]] + + curr_start_position = loc + 1 # jump over the EOD token + # Permute the segment after the last EOD + permuted = fim_split_and_permute_sequence(sample[curr_start_position:]) + new_samples.append(permuted) + + sample = np.concatenate(new_samples) + else: + sample = fim_split_and_permute_sequence(sample) + + # Truncate or pad sequence to max-length + diff = sample.shape[0] - sample_len + if diff > 0: # too long + sample = sample[:sample_len] + elif diff < 0: # too short + sample = np.concatenate([sample, np.full((-1 * diff), self.pad_tok_id)]) + + assert sample.shape[0] == sample_len + # end FIM-specific code if self.return_doc_ids: # for retro preprocessing - return {'text': np.array(sample, dtype=np.int64), + return {'text': sample, 'doc_ids': np.array(doc_ids, dtype=np.int64)} else: - return {'text': np.array(sample, dtype=np.int64)} + return {'text': sample} def _build_index_mappings(name, data_prefix, documents, sizes, @@ -292,6 +486,7 @@ def _build_index_mappings(name, data_prefix, documents, sizes, """ # Number of tokens in each epoch and number of required epochs. tokens_per_epoch = _num_tokens(documents, sizes) + print_rank_0(f' > Tokens per epoch: {tokens_per_epoch}') num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) # rng state @@ -335,7 +530,10 @@ def _build_index_mappings(name, data_prefix, documents, sizes, assert last_epoch_num_samples >= 0, \ 'last epoch number of samples should be non-negative.' num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length - assert last_epoch_num_samples < (num_samples_per_epoch + 1), \ + # For very small datasets, `last_epoch_num_samples` can be equal to + # (num_samples_per_epoch + 1). + # TODO: check that this is not problematic indeed + assert last_epoch_num_samples <= (num_samples_per_epoch + 1), \ 'last epoch number of samples exceeded max value.' # If we have less than 80% of the samples for the last epoch, # seperate out the epoch and treat it differently. @@ -522,3 +720,68 @@ def _build_shuffle_idx(num_samples, total_size, np_rng): np_rng.shuffle(shuffle_idx_last) return np.concatenate((shuffle_idx_first, shuffle_idx_last)) + + +# From https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py#L339 +def permute(sample, np_rng, fim_rate, fim_spm_rate, tokenizer, truncate_or_pad=True, + suffix_tok_id=None, prefix_tok_id=None, middle_tok_id=None, pad_tok_id=None): + """ + Take in a sample (np array w/ size (0,chunklength)) and perform a FIM transformation on it. + Maintain the same sample length (if transform creates a few extra tokens, drop them). + """ + if np_rng.binomial(1, fim_rate): # sample bernoulli dist + + contents = tokenizer.detokenize(sample) + + try: + # A boundary can be =0 (prefix will be empty) + # a boundary can be =len(contents) (suffix will be empty) + # The two boundaries can be equal (middle will be empty) + boundaries = list(np_rng.randint(low=0, high=len(contents) + 1, size=2)) + boundaries.sort() + except ValueError as e: + print(len(contents), contents) + print(e) + raise e + + prefix = contents[:boundaries[0]] + middle = contents[boundaries[0]:boundaries[1]] + suffix = contents[boundaries[1]:] + + prefix = np.array([*tokenizer.tokenize(prefix)], dtype=np.int64) + middle = np.array([*tokenizer.tokenize(middle)], dtype=np.int64) + suffix = np.array([*tokenizer.tokenize(suffix)], dtype=np.int64) + + # here we truncate each given segment to fit the same length as it was before + # A consequence is that we never reach the end of a file? + # we should rather truncate at the context-level + if truncate_or_pad: + # need to make same length as the input. Take the 3 sentinel tokens into account + new_length = suffix.shape[0] + prefix.shape[0] + middle.shape[0] + 3 + diff = new_length - sample.shape[0] + if diff > 0: # too long + if suffix.shape[0] <= diff: # if there's no space to truncate the suffix: stop and report it. atm i should have stopped this from happening + return sample, np_rng + suffix = suffix[:suffix.shape[0] - diff] + elif diff < 0: # too short + suffix = np.concatenate([suffix, np.full((-1 * diff), pad_tok_id)]) + + if np_rng.binomial(1, fim_spm_rate): + # SPM (variant 2 from FIM paper) + new_sample = np.concatenate([ + [prefix_tok_id, suffix_tok_id], suffix, + [middle_tok_id], prefix, middle + ]) + else: + # PSM + new_sample = np.concatenate([ + [prefix_tok_id], prefix, + [suffix_tok_id], suffix, + [middle_tok_id], middle + ]) + + else: + # don't do FIM preproc + new_sample = sample + + return new_sample diff --git a/megatron/data/indexed_dataset.py b/megatron/data/indexed_dataset.py index 4286e69b45..831c7de6c3 100644 --- a/megatron/data/indexed_dataset.py +++ b/megatron/data/indexed_dataset.py @@ -96,7 +96,7 @@ def write_longs(f, a): 4: np.int32, 5: np.int64, 6: np.float32, - 7: np.double, + 7: np.float64, 8: np.uint16 } @@ -268,8 +268,8 @@ class IndexedDatasetBuilder(object): np.int16: 2, np.int32: 4, np.int64: 8, - np.float: 4, - np.double: 8 + np.float32: 4, + np.float64: 8 } def __init__(self, out_file, dtype=np.int32): diff --git a/megatron/fused_kernels/__init__.py b/megatron/fused_kernels/__init__.py index dcbf24cb3f..74eb94fb69 100644 --- a/megatron/fused_kernels/__init__.py +++ b/megatron/fused_kernels/__init__.py @@ -1,10 +1,6 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - import os -import pathlib -import subprocess +import torch -from torch.utils import cpp_extension # Setting this param to a list has a problem of generating different # compilation commands (with diferent order of architectures) and @@ -15,81 +11,11 @@ def load(args): - - # Check if cuda 11 is installed for compute capability 8.0 - cc_flag = [] - _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( - cpp_extension.CUDA_HOME) - if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') - cc_flag.append('arch=compute_80,code=sm_80') - if int(bare_metal_minor) >= 7: - cc_flag.append('-gencode') - cc_flag.append('arch=compute_90,code=sm_90') - - # Build path - srcpath = pathlib.Path(__file__).parent.absolute() - buildpath = srcpath / 'build' - _create_build_dir(buildpath) - - # Helper function to build the kernels. - def _cpp_extention_load_helper(name, sources, extra_cuda_flags): - return cpp_extension.load( - name=name, - sources=sources, - build_directory=buildpath, - extra_cflags=['-O3',], - extra_cuda_cflags=['-O3', - '-gencode', 'arch=compute_70,code=sm_70', - '--use_fast_math'] + extra_cuda_flags + cc_flag, - verbose=(args.rank == 0) - ) - - # ============== - # Fused softmax. - # ============== - - if args.masked_softmax_fusion: - extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - '--expt-relaxed-constexpr', - '--expt-extended-lambda'] - - # Upper triangular softmax. - sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', - srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'] - scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper( - "scaled_upper_triang_masked_softmax_cuda", - sources, extra_cuda_flags) - - # Masked softmax. - sources=[srcpath / 'scaled_masked_softmax.cpp', - srcpath / 'scaled_masked_softmax_cuda.cu'] - scaled_masked_softmax_cuda = _cpp_extention_load_helper( - "scaled_masked_softmax_cuda", sources, extra_cuda_flags) - - # Softmax - sources=[srcpath / 'scaled_softmax.cpp', - srcpath / 'scaled_softmax_cuda.cu'] - scaled_softmax_cuda = _cpp_extention_load_helper( - "scaled_softmax_cuda", sources, extra_cuda_flags) - - -def _get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], - universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - release = output[release_idx].split(".") - bare_metal_major = release[0] - bare_metal_minor = release[1][0] - - return raw_output, bare_metal_major, bare_metal_minor - - -def _create_build_dir(buildpath): - try: - os.mkdir(buildpath) - except OSError: - if not os.path.isdir(buildpath): - print(f"Creation of the build directory {buildpath} failed") + if torch.version.hip is None: + print("running on CUDA devices") + from megatron.fused_kernels.cuda import load as load_kernels + else: + print("running on ROCm devices") + from megatron.fused_kernels.rocm import load as load_kernels + + load_kernels(args) diff --git a/megatron/fused_kernels/cuda/__init__.py b/megatron/fused_kernels/cuda/__init__.py new file mode 100644 index 0000000000..9bddf7233b --- /dev/null +++ b/megatron/fused_kernels/cuda/__init__.py @@ -0,0 +1,80 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +import pathlib +import subprocess + +from torch.utils import cpp_extension +from megatron.fused_kernels.utils import _create_build_dir + + +def load(args): + + # Check if cuda 11 is installed for compute capability 8.0 + cc_flag = [] + _, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version( + cpp_extension.CUDA_HOME) + if int(bare_metal_major) >= 11: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_80,code=sm_80') + if int(bare_metal_minor) >= 7: + cc_flag.append('-gencode') + cc_flag.append('arch=compute_90,code=sm_90') + + # Build path + srcpath = pathlib.Path(__file__).parent.absolute() + buildpath = srcpath / 'build' + _create_build_dir(buildpath) + + # Helper function to build the kernels. + def _cpp_extention_load_helper(name, sources, extra_cuda_flags): + return cpp_extension.load( + name=name, + sources=sources, + build_directory=buildpath, + extra_cflags=['-O3',], + extra_cuda_cflags=['-O3', + '-gencode', 'arch=compute_70,code=sm_70', + '--use_fast_math'] + extra_cuda_flags + cc_flag, + verbose=(args.rank == 0) + ) + + # ============== + # Fused softmax. + # ============== + + if args.masked_softmax_fusion: + extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', + '--expt-relaxed-constexpr', + '--expt-extended-lambda'] + + # Upper triangular softmax. + sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', + srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'] + scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper( + "scaled_upper_triang_masked_softmax_cuda", + sources, extra_cuda_flags) + + # Masked softmax. + sources=[srcpath / 'scaled_masked_softmax.cpp', + srcpath / 'scaled_masked_softmax_cuda.cu'] + scaled_masked_softmax_cuda = _cpp_extention_load_helper( + "scaled_masked_softmax_cuda", sources, extra_cuda_flags) + + # Softmax + sources=[srcpath / 'scaled_softmax.cpp', + srcpath / 'scaled_softmax_cuda.cu'] + scaled_softmax_cuda = _cpp_extention_load_helper( + "scaled_softmax_cuda", sources, extra_cuda_flags) + + +def _get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], + universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + release = output[release_idx].split(".") + bare_metal_major = release[0] + bare_metal_minor = release[1][0] + + return raw_output, bare_metal_major, bare_metal_minor diff --git a/megatron/fused_kernels/compat.h b/megatron/fused_kernels/cuda/compat.h similarity index 100% rename from megatron/fused_kernels/compat.h rename to megatron/fused_kernels/cuda/compat.h diff --git a/megatron/fused_kernels/scaled_masked_softmax.cpp b/megatron/fused_kernels/cuda/scaled_masked_softmax.cpp similarity index 100% rename from megatron/fused_kernels/scaled_masked_softmax.cpp rename to megatron/fused_kernels/cuda/scaled_masked_softmax.cpp diff --git a/megatron/fused_kernels/scaled_masked_softmax.h b/megatron/fused_kernels/cuda/scaled_masked_softmax.h similarity index 96% rename from megatron/fused_kernels/scaled_masked_softmax.h rename to megatron/fused_kernels/cuda/scaled_masked_softmax.h index 21ebbd5228..ef4f698411 100644 --- a/megatron/fused_kernels/scaled_masked_softmax.h +++ b/megatron/fused_kernels/cuda/scaled_masked_softmax.h @@ -440,7 +440,7 @@ void dispatch_scaled_softmax_forward( int batches, int attn_heads) { - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 ); + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 8192 ); if (key_seq_len == 0) { return; } else { @@ -516,6 +516,10 @@ void dispatch_scaled_softmax_forward( scaled_softmax_warp_forward <<>>(dst, src, scale, batch_count, key_seq_len); break; + case 13: // 8192 + scaled_softmax_warp_forward + <<>>(dst, src, scale, batch_count, key_seq_len); + break; default: break; } @@ -534,7 +538,7 @@ void dispatch_scaled_masked_softmax_forward( int attn_heads, int pad_batches) { - TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096 ); + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 8192 ); if (key_seq_len == 0) { return; } else { @@ -610,6 +614,10 @@ void dispatch_scaled_masked_softmax_forward( scaled_masked_softmax_warp_forward <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; + case 13: // 8192 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; default: break; } @@ -627,7 +635,7 @@ void dispatch_scaled_masked_softmax_backward( int batches, int attn_heads) { - TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 4096 ); + TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 8192 ); if (key_seq_len == 0) { return; } else { @@ -702,6 +710,10 @@ void dispatch_scaled_masked_softmax_backward( scaled_masked_softmax_warp_backward <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); break; + case 13: // 8192 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; default: break; diff --git a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu b/megatron/fused_kernels/cuda/scaled_masked_softmax_cuda.cu similarity index 98% rename from megatron/fused_kernels/scaled_masked_softmax_cuda.cu rename to megatron/fused_kernels/cuda/scaled_masked_softmax_cuda.cu index a8be57c052..3906a9dcc1 100644 --- a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu +++ b/megatron/fused_kernels/cuda/scaled_masked_softmax_cuda.cu @@ -30,7 +30,7 @@ torch::Tensor fwd_cuda( const int attn_heads = input.size(1); const int query_seq_len = input.size(2); const int key_seq_len = input.size(3); - TORCH_INTERNAL_ASSERT(key_seq_len <= 4096); + TORCH_INTERNAL_ASSERT(key_seq_len <= 8192); TORCH_INTERNAL_ASSERT(query_seq_len > 1); TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); TORCH_INTERNAL_ASSERT(mask.size(1) == 1); diff --git a/megatron/fused_kernels/scaled_softmax.cpp b/megatron/fused_kernels/cuda/scaled_softmax.cpp similarity index 100% rename from megatron/fused_kernels/scaled_softmax.cpp rename to megatron/fused_kernels/cuda/scaled_softmax.cpp diff --git a/megatron/fused_kernels/scaled_softmax_cuda.cu b/megatron/fused_kernels/cuda/scaled_softmax_cuda.cu similarity index 98% rename from megatron/fused_kernels/scaled_softmax_cuda.cu rename to megatron/fused_kernels/cuda/scaled_softmax_cuda.cu index ecc6eb06e8..39c94ce108 100644 --- a/megatron/fused_kernels/scaled_softmax_cuda.cu +++ b/megatron/fused_kernels/cuda/scaled_softmax_cuda.cu @@ -23,7 +23,7 @@ torch::Tensor fwd_cuda( const int attn_heads = input.size(1); const int query_seq_len = input.size(2); const int key_seq_len = input.size(3); - TORCH_INTERNAL_ASSERT(key_seq_len <= 4096); + TORCH_INTERNAL_ASSERT(key_seq_len <= 8192); TORCH_INTERNAL_ASSERT(query_seq_len > 1); // Output diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp b/megatron/fused_kernels/cuda/scaled_upper_triang_masked_softmax.cpp similarity index 100% rename from megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp rename to megatron/fused_kernels/cuda/scaled_upper_triang_masked_softmax.cpp diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h b/megatron/fused_kernels/cuda/scaled_upper_triang_masked_softmax.h similarity index 100% rename from megatron/fused_kernels/scaled_upper_triang_masked_softmax.h rename to megatron/fused_kernels/cuda/scaled_upper_triang_masked_softmax.h diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu b/megatron/fused_kernels/cuda/scaled_upper_triang_masked_softmax_cuda.cu similarity index 100% rename from megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu rename to megatron/fused_kernels/cuda/scaled_upper_triang_masked_softmax_cuda.cu diff --git a/megatron/fused_kernels/type_shim.h b/megatron/fused_kernels/cuda/type_shim.h similarity index 100% rename from megatron/fused_kernels/type_shim.h rename to megatron/fused_kernels/cuda/type_shim.h diff --git a/megatron/fused_kernels/rocm/__init__.py b/megatron/fused_kernels/rocm/__init__.py new file mode 100644 index 0000000000..f71a47e961 --- /dev/null +++ b/megatron/fused_kernels/rocm/__init__.py @@ -0,0 +1,61 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pathlib +from torch.utils import cpp_extension +from megatron.fused_kernels.utils import _create_build_dir + + +def load(args): + + # Build path + srcpath = pathlib.Path(__file__).parent.absolute() + buildpath = srcpath / 'build' + _create_build_dir(buildpath) + + # Helper function to build the kernels. + def _cpp_extention_load_helper(name, sources, extra_cuda_flags, extra_include_paths): + return cpp_extension.load( + name=name, + sources=sources, + build_directory=buildpath, + extra_cflags=['-O3'], + extra_cuda_cflags=['-O3'] + extra_cuda_flags, + extra_include_paths=extra_include_paths, + verbose=(args.rank == 0) + ) + + # ============== + # Fused softmax. + # ============== + + extra_include_paths=[os.path.abspath(srcpath)] + + if args.masked_softmax_fusion: + extra_cuda_flags = ['-D__HIP_NO_HALF_OPERATORS__=1', '-D__HIP_NO_HALF_CONVERSIONS__=1'] + + # Upper triangular softmax. + sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', + srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'] + scaled_upper_triang_masked_softmax_cuda = _cpp_extention_load_helper( + "scaled_upper_triang_masked_softmax_cuda", + sources, extra_cuda_flags, extra_include_paths) + + # Masked softmax. + sources=[srcpath / 'scaled_masked_softmax.cpp', + srcpath / 'scaled_masked_softmax_cuda.cu'] + scaled_masked_softmax_cuda = _cpp_extention_load_helper( + "scaled_masked_softmax_cuda", sources, extra_cuda_flags, extra_include_paths) diff --git a/megatron/fused_kernels/rocm/compat.h b/megatron/fused_kernels/rocm/compat.h new file mode 100644 index 0000000000..92e7eb7723 --- /dev/null +++ b/megatron/fused_kernels/rocm/compat.h @@ -0,0 +1,31 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + + + +#ifndef TORCH_CHECK +#define TORCH_CHECK AT_CHECK +#endif + +#ifdef VERSION_GE_1_3 +#define DATA_PTR data_ptr +#else +#define DATA_PTR data +#endif diff --git a/megatron/fused_kernels/rocm/scaled_masked_softmax.cpp b/megatron/fused_kernels/rocm/scaled_masked_softmax.cpp new file mode 100644 index 0000000000..1852aee6fd --- /dev/null +++ b/megatron/fused_kernels/rocm/scaled_masked_softmax.cpp @@ -0,0 +1,97 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +int get_batch_per_block_cuda( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads); + +torch::Tensor fwd( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) { + AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); + + return fwd_cuda(input, mask, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +int get_batch_per_block( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) { + return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); +} + +} // end namespace scaled_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + + m.def("backward", + &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); + + m.def("get_batch_per_block", + &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, + "Return Batch per block size." + ); +} diff --git a/megatron/fused_kernels/rocm/scaled_masked_softmax.h b/megatron/fused_kernels/rocm/scaled_masked_softmax.h new file mode 100644 index 0000000000..835ffe55d4 --- /dev/null +++ b/megatron/fused_kernels/rocm/scaled_masked_softmax.h @@ -0,0 +1,522 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Explicit masking + */ +template +__global__ void scaled_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const acc_t scale, + int micro_batch_size, + int element_count, + int pad_batches) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z))+ threadIdx.y) * WARP_BATCH; + int pad_first_batch = 0; + if (pad_batches != 1) { // bert style + pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) * WARP_BATCH; + } else { // gpt2 style + pad_first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + } + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + uint8_t temp_mask[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + int itr_idx = i*element_count+it*WARP_SIZE; + copy_vector(temp_data, src + itr_idx); + copy_vector(temp_mask, mask + itr_idx); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (temp_mask[element] != 1) { + elements[i][it + element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -10000.0; + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = elements[i][it + element] / sum[i]; + } + copy_vector(dst + i * element_count + it * WARP_SIZE, out); + } else { + break; + } + } + } +} + +template +__global__ void scaled_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = 4; + + // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) + // gridDim/blockIdx = (seq_len, attn_heads, batches) + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count + it * WARP_SIZE, out); + } + } + } +} +} // end of anonymous namespace + +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + + return batches_per_block; +} + +template +void dispatch_scaled_masked_softmax_forward( + output_t *dst, + const input_t *src, + const uint8_t *mask, + const input_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads, + int pad_batches) +{ + TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 8192 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0); + dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 1: // 2 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 2: // 4 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 3: // 8 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 4: // 16 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 5: // 32 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 6: // 64 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 7: // 128 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 8: // 256 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 9: // 512 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 10: // 1024 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 11: // 2048 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 12: // 4096 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + case 13: // 8192 + scaled_masked_softmax_warp_forward + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) +{ + TORCH_INTERNAL_ASSERT( key_seq_len >= 0 && key_seq_len <= 8192 ); + if (key_seq_len == 0) { + return; + } else { + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + int batch_count = batches * attn_heads * query_seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = batch_count/batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 1: // 2 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 2: // 4 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 3: // 8 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 4: // 16 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 5: // 32 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 6: // 64 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 7: // 128 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 8: // 256 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 9: // 512 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 10: // 1024 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 11: // 2048 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 12: // 4096 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + case 13: // 8192 + scaled_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + break; + + default: + break; + } + } +} diff --git a/megatron/fused_kernels/rocm/scaled_masked_softmax_cuda.cu b/megatron/fused_kernels/rocm/scaled_masked_softmax_cuda.cu new file mode 100644 index 0000000000..3b88b9c605 --- /dev/null +++ b/megatron/fused_kernels/rocm/scaled_masked_softmax_cuda.cu @@ -0,0 +1,119 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#ifndef __HIP_PLATFORM_HCC__ +#include +#endif +#include +#include +#include "scaled_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); +} + + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) +{ + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(key_seq_len <= 8192); + TORCH_INTERNAL_ASSERT(query_seq_len > 1); + TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); + TORCH_INTERNAL_ASSERT(mask.size(1) == 1); + TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); + TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* mask_ptr = static_cast(mask.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_masked_softmax_forward", + dispatch_scaled_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(mask_ptr), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads, + pad_batches); + ); + return softmax_results; +} + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.size(0); + const int attn_heads = output_grads.size(1); + const int query_seq_len = output_grads.size(2); + const int key_seq_len = output_grads.size(3); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_masked_softmax_backward", + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} diff --git a/megatron/fused_kernels/rocm/scaled_upper_triang_masked_softmax.cpp b/megatron/fused_kernels/rocm/scaled_upper_triang_masked_softmax.cpp new file mode 100644 index 0000000000..ea283588db --- /dev/null +++ b/megatron/fused_kernels/rocm/scaled_upper_triang_masked_softmax.cpp @@ -0,0 +1,72 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor); + +torch::Tensor fwd(torch::Tensor const& input, float scale_factor) { + AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); + AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || + (input.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return fwd_cuda(input, scale_factor); +} + +torch::Tensor bwd( + torch::Tensor const& output_grads, + torch::Tensor const& softmax_results, + float scale_factor) { + + AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor"); + AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor"); + + AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) || + (output_grads.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) || + (softmax_results.scalar_type() == at::ScalarType::BFloat16), + "Only fp16 and bf16 are supported"); + + return bwd_cuda(output_grads, softmax_results, scale_factor); +} + +} // end namespace scaled_upper_triang_masked_softmax +} // end namespace fused_softmax +} // end namespace multihead_attn + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd, + "Self Multihead Attention scaled, time masked softmax -- Forward."); + m.def("backward", + &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd, + "Self Multihead Attention scaled, time masked softmax -- Backward."); +} diff --git a/megatron/fused_kernels/rocm/scaled_upper_triang_masked_softmax.h b/megatron/fused_kernels/rocm/scaled_upper_triang_masked_softmax.h new file mode 100644 index 0000000000..d4b913d7c0 --- /dev/null +++ b/megatron/fused_kernels/rocm/scaled_upper_triang_masked_softmax.h @@ -0,0 +1,529 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace { + +template +__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) { *dst = *src; } + +template <> +__device__ __inline__ void copy_vector(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } + +template +__device__ __inline__ void copy_zero_vector(Datatype *dst); + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::BFloat16 *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *dst = 0.0; } + +template <> +__device__ __inline__ void copy_zero_vector(c10::Half *dst) { *((float2*) dst) = make_float2(0.0f, 0.0f); } + + +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + +template +__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) +{ +#if CUDA_VERSION >= 9000 + return __shfl_xor_sync(mask, value, laneMask, width); +#else + return __shfl_xor(value, laneMask, width); +#endif +} + +template class ReduceOp> +__device__ __forceinline__ void warp_reduce(acc_t* sum) { + ReduceOp r; + #pragma unroll + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); + sum[i] = r(sum[i], b); + } + } +} + +/* + * Extended softmax (from native aten pytorch) with following additional features + * 1) input scaling + * 2) Implicit time (diagonal masking) + */ +template +__global__ void scaled_upper_triang_masked_softmax_warp_forward( + output_t *dst, + const input_t *src, + const acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_forward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + + // load data from global memory + acc_t elements[WARP_BATCH][WARP_ITERATIONS]; + input_t temp_data[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < batch_element_count) { + copy_vector(temp_data, src + i*element_count*stride + it*WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if ((element_index + element) < batch_element_count) { + elements[i][it+element] = (acc_t)temp_data[element] * scale; + } else { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } else { + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + elements[i][it + element] = -std::numeric_limits::infinity(); + } + } + } + } + + // compute max_value + acc_t max_value[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce(max_value); + + acc_t sum[WARP_BATCH] { 0.0f }; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + if (it < warp_iteration_limit) { + elements[i][it] = std::exp((elements[i][it] - max_value[i])); + sum[i] += elements[i][it]; + } + } + } + warp_reduce(sum); + + // store result + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + + if (element_index < local_seq) { + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < local_seq) { + out[element] = elements[i][it + element] / sum[i]; + } else { + out[element] = 0; + } + } + copy_vector(dst + i * element_count * stride + it * WARP_SIZE, out); + } else if (element_index < element_count) { + copy_zero_vector(dst + i * element_count * stride + it * WARP_SIZE); + } else { + break; + } + } + } +} + +template +__global__ void scaled_upper_triang_masked_softmax_warp_backward( + output_t *gradInput, + input_t *grad, + const input_t *output, + acc_t scale, + int micro_batch_size, + int stride, + int element_count) +{ + // WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and + // warp_size of method warp_softmax_backward_kernel. + constexpr int next_power_of_two = 1 << log2_elements; + constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + constexpr int ELEMENTS_PER_LDG_STG = 4; + + int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; + int local_seq = blockIdx.x + 1; + + // micro_batch_size might not be a multiple of WARP_BATCH. Check how + // many batches have to computed within this WARP. + int local_batches = micro_batch_size - first_batch; + if (local_batches > WARP_BATCH) + local_batches = WARP_BATCH; + + // there might be multiple batches per warp. compute the index within the batch + int local_idx = threadIdx.x; + + // the first element to process by the current thread + int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + acc_t grad_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + acc_t output_reg[WARP_BATCH][WARP_ITERATIONS] { 0.0f }; + input_t temp_grad[ELEMENTS_PER_LDG_STG]; + input_t temp_output[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : local_seq; + + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < batch_element_count) { + copy_vector(temp_grad, grad + i * element_count * stride + it * WARP_SIZE); + copy_vector(temp_output, output + i * element_count * stride + it * WARP_SIZE); + + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + output_reg[i][it + element] = (acc_t)temp_output[element]; + } + } + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + if (element_index + element < batch_element_count) { + grad_reg[i][it + element] = (acc_t)temp_grad[element] * output_reg[i][it + element]; + } + } + } + } + } + + acc_t sum[WARP_BATCH]; + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; + #pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce(sum); + + // store result + #pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) + break; + #pragma unroll + for (int it = 0; it < WARP_ITERATIONS; it+=ELEMENTS_PER_LDG_STG) { + int element_index = ELEMENTS_PER_LDG_STG * local_idx + it * WARP_SIZE; + if (element_index < element_count) { + // compute gradients + output_t out[ELEMENTS_PER_LDG_STG]; + #pragma unroll + for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) { + out[element] = (output_t)(scale * (grad_reg[i][it + element] - output_reg[i][it + element] * sum[i])); + } + copy_vector(gradInput + i * element_count * stride + it * WARP_SIZE, out); + } + } + } +} + +} // end of anonymous namespace + +template +void dispatch_scaled_upper_triang_masked_softmax_forward( + output_t *dst, + const input_t *src, + const input_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 8192 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 12: // 4096 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 13: // 8192 + scaled_upper_triang_masked_softmax_warp_forward + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} + +template +void dispatch_scaled_upper_triang_masked_softmax_backward( + output_t *grad_input, + input_t *grad, + const input_t *output, + const acc_t scale, + int softmax_elements, + int softmax_elements_stride, + int attn_batches) +{ + TORCH_INTERNAL_ASSERT( softmax_elements >= 0 && softmax_elements <= 8192 ); + if (softmax_elements == 0) { + return; + } else { + int log2_elements = log2_ceil(softmax_elements); + const int next_power_of_two = 1 << log2_elements; + int seq_len = softmax_elements; + int batch_count = attn_batches * seq_len; + + // This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward. + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + + // This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward. + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + + int blocks_per_seq = attn_batches / batches_per_block; + dim3 blocks(seq_len, blocks_per_seq, 1); + dim3 threads(warp_size, warps_per_block, 1); + // Launch code would be more elegant if C++ supported FOR CONSTEXPR + switch (log2_elements) { + case 0: // 1 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 1: // 2 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 2: // 4 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 3: // 8 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 4: // 16 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 5: // 32 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 6: // 64 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 7: // 128 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 8: // 256 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 9: // 512 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 10: // 1024 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 11: // 2048 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 12: // 4096 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + case 13: // 8192 + scaled_upper_triang_masked_softmax_warp_backward + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + break; + default: + break; + } + } +} diff --git a/megatron/fused_kernels/rocm/scaled_upper_triang_masked_softmax_cuda.cu b/megatron/fused_kernels/rocm/scaled_upper_triang_masked_softmax_cuda.cu new file mode 100644 index 0000000000..4aa9a702a5 --- /dev/null +++ b/megatron/fused_kernels/rocm/scaled_upper_triang_masked_softmax_cuda.cu @@ -0,0 +1,100 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#ifndef __HIP_PLATFORM_HCC__ +#include +#endif +#include +#include +#include "scaled_upper_triang_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor) +{ + // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = input.size(0); + const int seq_len = input.size(1); + TORCH_INTERNAL_ASSERT(seq_len <= 8192); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({attn_batches, seq_len, seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_forward", + dispatch_scaled_upper_triang_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); + return softmax_results; +} + + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = output_grads.size(0); + const int seq_len = output_grads.size(1); + TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_backward", + dispatch_scaled_upper_triang_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} diff --git a/megatron/fused_kernels/rocm/type_shim.h b/megatron/fused_kernels/rocm/type_shim.h new file mode 100644 index 0000000000..6437dcc7c7 --- /dev/null +++ b/megatron/fused_kernels/rocm/type_shim.h @@ -0,0 +1,91 @@ +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include +#include "compat.h" + + +#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Half: \ + { \ + using scalar_t = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } + + + +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ + switch(TYPEIN) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_in = float; \ + switch(TYPEOUT) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_out = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \ + } \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_in = at::Half; \ + using scalar_t_out = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_in = at::BFloat16; \ + using scalar_t_out = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \ + } + diff --git a/megatron/fused_kernels/tests/test_fused_kernels.py b/megatron/fused_kernels/tests/test_fused_kernels.py index 74024c5020..ce7ba67c76 100644 --- a/megatron/fused_kernels/tests/test_fused_kernels.py +++ b/megatron/fused_kernels/tests/test_fused_kernels.py @@ -22,11 +22,11 @@ def test_load_fused_kernels(): raise e def test_fused_softmax(): - bert = BertModel.from_pretrained("bert-base-cased").cuda().half() + bert = BertModel.from_pretrained("bert-base-cased", max_position_embeddings=8192, ignore_mismatched_sizes=True).cuda().half() tokenizer = BertTokenizer.from_pretrained("bert-base-cased") test_text = ( "Hello. How are you? I am fine thank you and you? yes Good. " - "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 + "hi hi hi hi hi hi hi hi hi hi hi hi hi" * 256 # 32 * 256 ) tokens = tokenizer( @@ -120,11 +120,11 @@ def test_fused_softmax(): def test_fused_upper_triangle_mask_softmax(): - gpt = GPT2Model.from_pretrained("gpt2").cuda().half() + gpt = GPT2Model.from_pretrained("gpt2", max_position_embeddings=8192, ignore_mismatched_sizes=True).cuda().half() tokenizer = GPT2Tokenizer.from_pretrained("gpt2") test_text = ( "Hello. How are you? I am fine thank you and you? yes Good. " - "hi hi hi hi hi hi hi" # 24 + "hi hi hi hi hi hi hi" * 256 # 24 * 256 ) tokens = tokenizer( @@ -220,11 +220,11 @@ def test_fused_upper_triangle_mask_softmax(): def test_layer_norm(): - bert = BertModel.from_pretrained("bert-base-cased").cuda().half() + bert = BertModel.from_pretrained("bert-base-cased", max_position_embeddings=8192, ignore_mismatched_sizes=True).cuda().half() tokenizer = BertTokenizer.from_pretrained("bert-base-cased") test_text = ( "Hello. How are you? I am fine thank you and you? yes Good. " - "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 + "hi hi hi hi hi hi hi hi hi hi hi hi hi" * 256 # 32 ) tokens = tokenizer( @@ -362,6 +362,12 @@ def test_allmasked_softmax_backward(): assert error < 1e-3 +class DummyArgs: + rank: int = 0 + masked_softmax_fusion: bool = True + gradient_accumulation_fusion: bool = True + + if __name__ == "__main__": try: from transformers import BertTokenizer, GPT2Tokenizer @@ -382,6 +388,10 @@ def test_allmasked_softmax_backward(): test_masked_softmax_backward() test_allmasked_softmax_forward() test_allmasked_softmax_backward() + + from megatron.fused_kernels import load + load(DummyArgs()) + test_load_fused_kernels() test_fused_softmax() test_fused_upper_triangle_mask_softmax() diff --git a/megatron/fused_kernels/utils.py b/megatron/fused_kernels/utils.py new file mode 100644 index 0000000000..2425aae9c2 --- /dev/null +++ b/megatron/fused_kernels/utils.py @@ -0,0 +1,9 @@ +import os + + +def _create_build_dir(buildpath): + try: + os.mkdir(buildpath) + except OSError: + if not os.path.isdir(buildpath): + print(f"Creation of the build directory {buildpath} failed") diff --git a/megatron/global_vars.py b/megatron/global_vars.py index e3831167fd..4e0118e10e 100644 --- a/megatron/global_vars.py +++ b/megatron/global_vars.py @@ -80,7 +80,7 @@ def _set_signal_handler(): -def set_global_variables(args): +def set_global_variables(args, build_tokenizer=True): """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers.""" assert args is not None @@ -89,7 +89,8 @@ def set_global_variables(args): set_args(args) _build_num_microbatches_calculator(args) - _ = _build_tokenizer(args) + if build_tokenizer: + _ = _build_tokenizer(args) _set_tensorboard_writer(args) _set_adlr_autoresume(args) _set_timers(args) diff --git a/megatron/initialize.py b/megatron/initialize.py index fdb312068c..e387c4ee78 100644 --- a/megatron/initialize.py +++ b/megatron/initialize.py @@ -2,14 +2,22 @@ """Megatron initialization.""" +import logging +import logging.config import random import os +import sys import time import numpy as np import torch from datetime import timedelta +try: + import wandb +except ModuleNotFoundError: + print('Wandb import failed', flush=True) + from megatron import fused_kernels from megatron import get_adlr_autoresume from megatron import get_args @@ -54,6 +62,7 @@ def finish_mpu_init(): args = get_args() # Pytorch distributed. _initialize_distributed() + _configure_logging() # Random seeds for reproducibility. if args.rank == 0: @@ -85,6 +94,54 @@ def finish_mpu_init(): return None +def _configure_logging(): + args=get_args() + if not args.structured_logs: + return + rank = torch.distributed.get_rank() + + logging_config = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": f"%(asctime)s [Rank {rank}]: %(message)s", + "use_colors": True, + } + }, + "handlers": { + "default": { + "level": "INFO", + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stdout", + } + }, + "loggers": {"default": {"level": "DEBUG", "handlers": ["default"]}}, + "root": {"handlers": ["default"], "level": "INFO"}, + } + if args.structured_logs_dir is not None: + log_dir=args.structured_logs_dir + os.makedirs(log_dir, exist_ok=True) + logging_config["handlers"]["file"] = { + "level": "INFO", + "formatter": "default", + "class": "logging.FileHandler", + "filename": os.path.join(log_dir, f"logs_rank_{rank}.txt"), + } + logging_config["root"]["handlers"].append("file") + logging_config["loggers"]["default"]["handlers"].append("file") + logging.config.dictConfig(logging_config) + + # Add these methods so that stdout can be redirected to logging. + logging.write = lambda msg: logging.info(msg) if msg != '\n' else None + logging.flush = lambda : None + + sys.stdout=logging + sys.stderr=logging + + + def _compile_dependencies(): args = get_args() @@ -101,6 +158,15 @@ def _compile_dependencies(): print('>>> done with dataset index builder. Compilation time: {:.3f} ' 'seconds'.format(time.time() - start_time), flush=True) + try: + # Skip the rest if the kernels are unnecessary or already available (ex. from apex) + if args.use_flash_attn or args.masked_softmax_fusion: + import scaled_upper_triang_masked_softmax_cuda + import scaled_masked_softmax_cuda + return + except ImportError: + pass + # ================== # Load fused kernels # ================== @@ -112,7 +178,7 @@ def _compile_dependencies(): args.micro_batch_size # Constraints on sequence length and attn_batch_size to enable warp based # optimization and upper triangular optimization (for causal mask) - custom_kernel_constraint = seq_len > 16 and seq_len <=4096 and \ + custom_kernel_constraint = seq_len > 16 and seq_len <=8192 and \ seq_len % 4 == 0 and attn_batch_size % 4 == 0 # Print a warning. if not ((args.fp16 or args.bf16) and @@ -170,11 +236,13 @@ def _initialize_distributed(): else: args.local_rank = device torch.cuda.set_device(device) - # Call the init process - torch.distributed.init_process_group( - backend=args.distributed_backend, - world_size=args.world_size, rank=args.rank, - timeout=timedelta(minutes=args.distributed_timeout_minutes)) + # Include this torch.distributed.init_process_group() code in the `else` branch because + # we do not want to reinitialize if torch.distributed.is_initialized() returns True + # Call the init process + torch.distributed.init_process_group( + backend=args.distributed_backend, + world_size=args.world_size, rank=args.rank, + timeout=timedelta(seconds=args.distributed_timeout)) # Set the tensor model-parallel, pipeline model-parallel, and # data-parallel communicators. @@ -228,6 +296,20 @@ def write_args_to_tensorboard(): writer.add_text(arg, str(getattr(args, arg)), global_step=args.iteration) +def init_wandb(): + args = get_args() + if args.rank == (args.world_size - 1): + if not (args.wandb_entity_name and args.wandb_project_name): + print('> Skipping wandb init ...', flush=True) + return + wandb.init( + name=os.path.basename(args.save), + entity=args.wandb_entity_name, + project=args.wandb_project_name, + group="mini_cluster", + config=args + ) + def set_jit_fusion_options(): """Set PyTorch JIT layer fusion options.""" diff --git a/megatron/model/bert_model.py b/megatron/model/bert_model.py index f6dd7ddc4e..94d0576545 100644 --- a/megatron/model/bert_model.py +++ b/megatron/model/bert_model.py @@ -9,7 +9,7 @@ from megatron.model.enums import AttnMaskType from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import get_language_model -from megatron.model import LayerNorm +from .fused_layer_norm import MixedFusedLayerNorm as LayerNorm from megatron.model.utils import openai_gelu, erf_gelu from megatron.model.utils import get_linear_layer from megatron.model.utils import init_method_normal diff --git a/megatron/model/enums.py b/megatron/model/enums.py index bc4e4aa29a..caca365c56 100644 --- a/megatron/model/enums.py +++ b/megatron/model/enums.py @@ -15,7 +15,14 @@ class AttnType(enum.Enum): class AttnMaskType(enum.Enum): padding = 1 - causal = 2 + causal = 2 # Overrides `attention_mask` to be a lower triangular matrix + prefix = 3 + custom = 4 # Forces one to pass an `attention_mask` that's 1 if we need to mask. Tensor that can be broadcast to [micro_batch_size, n_head, seq_length, seq_length] + +class PositionEmbeddingType(enum.Enum): + rotary = 1 # NOTE: this one is not used so far, however for future compatibility the enum left as is + absolute = 2 + alibi = 3 # For backward compatibility with old model checkpoints from megatron.core.enums import ModelType diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py index ed29262acd..d230f81b4e 100644 --- a/megatron/model/fused_softmax.py +++ b/megatron/model/fused_softmax.py @@ -155,12 +155,12 @@ def is_kernel_available(self, mask, b, np, sq, sk): if ( self.scaled_masked_softmax_fusion # user want to fuse and self.input_in_float16 # input must be fp16 - and 16 < sk <= 4096 # sk must be 16 ~ 2048 + and 16 < sk <= 8192 # sk must be 16 ~ 8192 and sq % 4 == 0 # sq must be divisor of 4 and sk % 4 == 0 # sk must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4 ): - if 0 <= sk <= 4096: + if 0 <= sk <= 8192: batch_per_block = self.get_batch_per_block(sq, sk, b, np) if self.attn_mask_type == AttnMaskType.causal: diff --git a/megatron/model/glu_activations.py b/megatron/model/glu_activations.py new file mode 100644 index 0000000000..4fa821d3f1 --- /dev/null +++ b/megatron/model/glu_activations.py @@ -0,0 +1,61 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +# NOTE: logging funcionality commented for now as +# it is not implemented in this version so far + +#from megatron import logging +#from megatron.model.utils import log_debug_usage + +#logger = logging.get_logger(__name__) + +class _GLUBaseModule(nn.Module): + def __init__(self, activation_fn): + super().__init__() + self.activation_fn = activation_fn + + def forward(self, x): + # dim=-1 breaks in jit for pt<1.10 + x1, x2 = x.chunk(2, dim=(x.ndim - 1)) + return x1 * self.activation_fn(x2) + + +class LiGLU(_GLUBaseModule): + def __init__(self): + super().__init__(nn.Identity()) + + +class GEGLU(_GLUBaseModule): + def __init__(self): + super().__init__(F.gelu) + + +class ReGLU(_GLUBaseModule): + def __init__(self): + super().__init__(F.relu) + + +class SwiGLU(_GLUBaseModule): + def __init__(self): + super().__init__(F.silu) + + +#liglu = log_debug_usage(logger, "Using GLU activation: LiGLU.")(torch.jit.script(LiGLU())) +#geglu = log_debug_usage(logger, "Using GLU activation: GELU.")(torch.jit.script(GEGLU())) +#reglu = log_debug_usage(logger, "Using GLU activation: ReGLU.")(torch.jit.script(ReGLU())) +#swiglu = log_debug_usage(logger, "Using GLU activation: SwiGLU.")(torch.jit.script(SwiGLU())) + +liglu = torch.jit.script(LiGLU()) +geglu = torch.jit.script(GEGLU()) +reglu = torch.jit.script(ReGLU()) +swiglu = torch.jit.script(SwiGLU()) + + +GLU_ACTIVATIONS = { + "geglu": geglu, + "liglu": liglu, + "reglu": reglu, + "swiglu": swiglu, +} diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index a9be43722b..8483f12541 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -8,7 +8,7 @@ from megatron.core import tensor_parallel from .module import MegatronModule -from .enums import AttnMaskType +from megatron.model.enums import AttnMaskType from .language_model import parallel_lm_logits from .language_model import get_language_model from .utils import init_method_normal diff --git a/megatron/model/language_model.py b/megatron/model/language_model.py index 61f2501bcb..fc2e8fe348 100644 --- a/megatron/model/language_model.py +++ b/megatron/model/language_model.py @@ -9,7 +9,7 @@ from megatron.core import mpu, tensor_parallel from megatron.core.enums import ModelType -from .enums import AttnMaskType, LayerType +from .enums import AttnMaskType, LayerType, PositionEmbeddingType from .module import MegatronModule from .rotary_pos_embedding import apply_rotary_pos_emb, RotaryEmbedding from .transformer import ParallelTransformer @@ -125,8 +125,6 @@ class Embedding(MegatronModule): Arguments: hidden_size: hidden size vocab_size: vocabulary size - max_sequence_length: maximum size of sequence. This - is used for positional embedding embedding_dropout_prob: dropout probability for embeddings init_method: weight initialization method num_tokentypes: size of the token-type embeddings. 0 value @@ -136,7 +134,6 @@ class Embedding(MegatronModule): def __init__(self, hidden_size, vocab_size, - max_sequence_length, embedding_dropout_prob, init_method, num_tokentypes=0): @@ -160,13 +157,18 @@ def __init__(self, # Position embedding (serial). self.add_position_embedding = args.add_position_embedding - if self.add_position_embedding: + self.position_embedding_type = args.position_embedding_type + if self.add_position_embedding and self.position_embedding_type == PositionEmbeddingType.absolute: + max_position_embeddings = args.max_position_embeddings + assert max_position_embeddings is not None self.position_embeddings = torch.nn.Embedding( - max_sequence_length, self.hidden_size) + max_position_embeddings, self.hidden_size) self._position_embeddings_key = 'position_embeddings' # Initialize the position embeddings. if args.perform_initialization: self.init_method(self.position_embeddings.weight) + else: + self.position_embeddings = None # Token type embedding. # Add this as an optional field that can be added through @@ -191,7 +193,7 @@ def zero_parameters(self): """Zero out all parameters in embedding.""" self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True - if self.add_position_embedding: + if self.add_position_embedding and self.position_embeddings is not None: self.position_embeddings.weight.data.fill_(0) self.position_embeddings.weight.shared = True if self.num_tokentypes > 0: @@ -218,10 +220,12 @@ def add_tokentype_embeddings(self, num_tokentypes): def forward(self, input_ids, position_ids, tokentype_ids=None): # Embeddings. words_embeddings = self.word_embeddings(input_ids) - if self.add_position_embedding: + if self.add_position_embedding and self.position_embedding_type == PositionEmbeddingType.absolute: + assert self.position_embeddings is not None position_embeddings = self.position_embeddings(position_ids) embeddings = words_embeddings + position_embeddings else: + assert self.position_embeddings is None embeddings = words_embeddings if tokentype_ids is not None: @@ -254,7 +258,7 @@ def state_dict_for_save_checkpoint(self, prefix='', keep_vars=False): state_dict_[self._word_embeddings_key] \ = self.word_embeddings.state_dict(prefix=prefix, keep_vars=keep_vars) - if self.add_position_embedding: + if self.add_position_embedding and self.position_embedding_type == PositionEmbeddingType.absolute: state_dict_[self._position_embeddings_key] \ = self.position_embeddings.state_dict(prefix=prefix, keep_vars=keep_vars) @@ -281,7 +285,7 @@ def load_state_dict(self, state_dict, strict=True): self.word_embeddings.load_state_dict(state_dict_, strict=strict) # Position embedding. - if self.add_position_embedding: + if self.add_position_embedding and self.position_embedding_type == PositionEmbeddingType.absolute: if self._position_embeddings_key in state_dict: state_dict_ = state_dict[self._position_embeddings_key] else: @@ -359,7 +363,6 @@ def __init__(self, if self.pre_process: self.embedding = Embedding(self.hidden_size, args.padded_vocab_size, - args.max_position_embeddings, args.hidden_dropout, self.init_method, self.num_tokentypes) @@ -379,7 +382,7 @@ def __init__(self, # partial rotary embeddings, which is better than full rotary # Wang and Komatsuzaki et al # https://github.com/kingoflolz/mesh-transformer-jax/ - self.rotary_pos_emb = RotaryEmbedding(rotary_dim) + self.rotary_pos_emb = RotaryEmbedding(rotary_dim, args.rotary_theta) # Encoder (usually set to True, False if part of an encoder-decoder # architecture and in encoder-only stage). diff --git a/megatron/model/rotary_pos_embedding.py b/megatron/model/rotary_pos_embedding.py index 80c74d62d4..e7f6450513 100644 --- a/megatron/model/rotary_pos_embedding.py +++ b/megatron/model/rotary_pos_embedding.py @@ -12,9 +12,9 @@ __all__ = ['RotaryEmbedding', 'apply_rotary_pos_emb'] class RotaryEmbedding(nn.Module): - def __init__(self, dim): + def __init__(self, dim, theta): super().__init__() - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) if importlib.util.find_spec('einops') is None: raise RuntimeError("einops is required for Rotary Embedding") diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 4d744e7a25..66dd08d412 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -7,17 +7,27 @@ import torch import torch.nn.functional as F from typing import Optional +from packaging.version import Version from megatron import get_timers, get_args, get_retro_args, core, get_num_microbatches +from megatron.utils import print_rank_0 from .module import MegatronModule from megatron.core import mpu, tensor_parallel from megatron.core.enums import ModelType from megatron.model import LayerNorm -from megatron.model.enums import AttnMaskType, LayerType, AttnType +from megatron.model.enums import AttnMaskType, LayerType, AttnType, PositionEmbeddingType from megatron.model.fused_softmax import FusedScaleMaskSoftmax from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.rotary_pos_embedding import apply_rotary_pos_emb -from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu +from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu, get_linear_layer + +from .glu_activations import GLU_ACTIVATIONS + +# flags required to enable jit fusion kernels +torch._C._jit_set_profiling_mode(False) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_override_can_fuse_on_cpu(True) +torch._C._jit_override_can_fuse_on_gpu(True) try: from einops import rearrange @@ -25,9 +35,15 @@ rearrange = None try: - from flash_attn.flash_attn_interface import flash_attn_unpadded_func + import flash_attn as _flash_attn + if Version(getattr(_flash_attn, "__version__", "1")) >= Version("2"): + from flash_attn.flash_attn_interface import flash_attn_func + FLASH_VERSION = 2 + else: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func + FLASH_VERSION = 1 except ImportError: - flash_attn_unpadded_func = None + FLASH_VERSION = None """ We use the following notation throughout this file: @@ -84,7 +100,8 @@ class ParallelMLP(MegatronModule): MLP will take the input with h hidden state, project it to 4*h hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. + state back into h hidden dimension. At the end, dropout is also + applied. """ def __init__(self, init_method, output_layer_init_method): @@ -96,7 +113,8 @@ def __init__(self, init_method, output_layer_init_method): # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear( args.hidden_size, - args.ffn_hidden_size * 2 if args.swiglu else args.ffn_hidden_size, + # GLU is a special activation that divides the dimension by a factor 2. + args.ffn_hidden_size * 2 if (args.swiglu or args.glu_activation) else args.ffn_hidden_size, bias=self.add_bias, gather_output=False, init_method=init_method, @@ -108,7 +126,9 @@ def __init__(self, init_method, output_layer_init_method): self.activation_func = None self.swiglu = args.swiglu - if args.openai_gelu: + if args.glu_activation: + self.activation_func = GLU_ACTIVATIONS[args.glu_activation] + elif args.openai_gelu: self.activation_func = openai_gelu elif args.onnx_safe: self.activation_func = erf_gelu @@ -250,11 +270,12 @@ def __init__(self, layer_number, self.attention_dropout = torch.nn.Dropout(args.attention_dropout) def forward(self, query_layer, key_layer, - value_layer, attention_mask): + value_layer, attention_mask, alibi): # =================================== # Raw attention scores. [b, np, s, s] # =================================== + np = query_layer.size(2) # [b, np, sq, sk] output_size = (query_layer.size(1), @@ -269,17 +290,37 @@ def forward(self, query_layer, key_layer, key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor( - (output_size[0]*output_size[1], output_size[2], output_size[3]), - query_layer.dtype, "mpu") + if alibi is None: + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor( + (output_size[0]*output_size[1], output_size[2], output_size[3]), + query_layer.dtype, "mpu") + else: + # alibi: (batch_size * num_attention_heads, 1, max_seq_len) + matmul_input_buffer = alibi[:output_size[0]*output_size[1], :, :output_size[3]] # Raw attention scores. [b * np, sq, sk] - matmul_result = torch.baddbmm( - matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, alpha=(1.0/self.norm_factor)) + if alibi is None: + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, alpha=(1.0/self.norm_factor)) + else: + if not hasattr(self, "logged_alibi"): + print("Using Alibi.") + self.logged_alibi = True + + if self.apply_query_key_layer_scaling: + beta = 1.0 / self.layer_number + else: + beta = 1.0 + + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=beta, alpha=(1.0 / self.norm_factor)) # change view to [b, np, sq, sk] attention_scores = matmul_result.view(*output_size) @@ -309,7 +350,7 @@ def forward(self, query_layer, key_layer, # context layer shape: [b, np, sq, hn] output_size = (value_layer.size(1), - value_layer.size(2), + np, query_layer.size(0), value_layer.size(3)) @@ -338,6 +379,127 @@ def forward(self, query_layer, key_layer, return context_layer +class MultiQueryCoreAttention(CoreAttention): + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, query_layer, key_layer, value_layer, attention_mask, alibi): + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + sq = query_layer.size(0) + bs = query_layer.size(1) + np = query_layer.size(2) + + sk = key_layer.size(0) + # Only one head for key and values + assert key_layer.size(2) == 1 and value_layer.size(2) == 1 + + # [b, np, sq, sk] + output_size = (query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0)) + + # [sq, b, np, hn] -> [b, np * sq, hn] + query_layer = query_layer.permute([1, 2, 0, 3]).reshape(bs, np * sq, -1) + # [sk, b, 1, hn] -> [b, hn, sk] + key_layer = key_layer.squeeze(2).permute(1, 2, 0) + # [sk, b, 1, hn] -> [sk, b * np, hn] + # key_layer = key_layer.expand(output_size[3], output_size[0], np, -1) + # key_layer = key_layer.reshape(output_size[3], output_size[0] * np, -1) + + if alibi is None: + # preallocting input tensor: [b, np * sq, sk] + matmul_input_buffer = mpu.get_global_memory_buffer().get_tensor( + (bs, np * sq, sk), + query_layer.dtype, "mpu") + else: + # alibi: (batch_size * num_attention_heads, 1, max_seq_len) + # TODO: ideally, alibi would have the shape: (1, num_heads * sq, sk) + matmul_input_buffer = alibi[:bs * np, :, :sk].view(bs, np, sk) + matmul_input_buffer = matmul_input_buffer.unsqueeze(2).expand(bs, np, sq, sk).reshape(bs, np * sq, sk) # [b, np * sq, sk] + + if alibi is None: + # Raw attention scores. [b, np * sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer, # [b, np * sq, hn] + key_layer, # [b, hn, sk] + beta=0.0, alpha=(1.0/self.norm_factor)) + else: + if not hasattr(self, "logged_alibi"): + print("Using Alibi.") + self.logged_alibi = True + + if self.apply_query_key_layer_scaling: + beta = 1.0 / self.layer_number + else: + beta = 1.0 + + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer, + key_layer, + beta=beta, alpha=(1.0 / self.norm_factor)) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(bs, np, sq, sk) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + attention_probs = self.scale_mask_softmax(attention_scores, + attention_mask) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + + if not self.sequence_parallel: + with tensor_parallel.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(attention_probs) + else: + attention_probs = self.attention_dropout(attention_probs) + + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), + np, + query_layer.size(0), + value_layer.size(3)) + + # [sk, b, 1, hn] -> [b, sk, hn] + value_layer = value_layer.squeeze(2).transpose(0, 1) + + # change view [b, np * sq, sk] + attention_probs = attention_probs.view(bs, np * sq, -1) + + # matmul: [b, np * sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(bs, np, sq, -1) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + class FlashSelfAttention(torch.nn.Module): """Implement the scaled dot product attention with softmax. Arguments @@ -351,7 +513,7 @@ class FlashSelfAttention(torch.nn.Module): def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None): super().__init__() - assert flash_attn_unpadded_func is not None, ('Please install FlashAttention first, ' + assert FLASH_VERSION is not None, ('Please install FlashAttention first, ' 'e.g., with pip install flash-attn') assert rearrange is not None, 'Please install einops first, e.g., with pip install einops' self.causal = causal @@ -364,10 +526,31 @@ def forward(self, q, k, v): --------- q, k, v: The tensor containing the query, key, and value. (B, S, H, D) """ - assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q,k,v))) assert all((i.is_cuda for i in (q,k,v))) + if FLASH_VERSION==1: + return self._forward_v1(q,k,v) + + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + + if self.training: + # during training q,k,v always have same seqlen + assert seqlen_k == seqlen_q + is_causal = self.causal + dropout_p = self.dropout_p + else: + # turn off FA causal mask after first inference autoregressive iteration + # only on first autoregressive step q,k,v have same seqlen + is_causal = self.causal and (seqlen_q == seqlen_k) + dropout_p = 0 + + output = flash_attn_func(q, k, v, dropout_p,softmax_scale=self.softmax_scale, causal=is_causal) + + return output + + + def _forward_v1(self, q, k, v): batch_size, seqlen_q = q.shape[0], q.shape[1] seqlen_k = k.shape[1] @@ -381,17 +564,18 @@ def forward(self, q, k, v): is_causal = self.causal cu_seqlens_k = cu_seqlens_q + dropout_p = self.dropout_p else: # turn off FA causal mask after first inference autoregressive iteration # only on first autoregressive step q,k,v have same seqlen - is_causal = seqlen_q == seqlen_k + is_causal = self.causal and (seqlen_q == seqlen_k) cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=q.device) - self.dropout_p = 0 + dropout_p = 0 output = flash_attn_unpadded_func( q, k, v, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen_k, - self.dropout_p, + dropout_p, softmax_scale=self.softmax_scale, causal=is_causal ) @@ -416,21 +600,10 @@ def __init__(self, init_method, self.attention_type = attention_type self.attn_mask_type = attn_mask_type self.params_dtype = args.params_dtype + self.attention_head_type = args.attention_head_type self.sequence_parallel = args.sequence_parallel - self.use_flash_attn = args.use_flash_attn \ - and attention_type == AttnType.self_attn \ - and self.attn_mask_type == AttnMaskType.causal - if self.use_flash_attn: - if flash_attn_unpadded_func is None: - raise ImportError('FlashAttention is not installed, please install with ' - 'pip install flash-attn') - assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports ' - 'self-attention for now') - assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only ' - 'supports causal mask for now') - if rearrange is None: - raise ImportError('einops is not installed, please install with pip install einops') + self.use_flash_attn = args.use_flash_attn projection_size = args.kv_channels * args.num_attention_heads @@ -442,7 +615,7 @@ def __init__(self, init_method, args.num_attention_heads, world_size) # Strided linear layer. - if attention_type == AttnType.self_attn: + if attention_type == AttnType.self_attn and self.attention_head_type == 'multihead': self.query_key_value = tensor_parallel.ColumnParallelLinear( args.hidden_size, 3 * projection_size, @@ -451,7 +624,24 @@ def __init__(self, init_method, init_method=init_method, async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce, **_args_to_kwargs()) - else: + elif attention_type == AttnType.self_attn and self.attention_head_type == 'multiquery': + # TODO: Find a way to merge the query and key-value computations? + self.query = tensor_parallel.ColumnParallelLinear( + args.hidden_size, + projection_size, + gather_output=False, + init_method=init_method, + async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce, + **_args_to_kwargs()) + # In MultiQuery attention, keys and values are shared across heads + # Use args.kv_channels instead of projection_size + # No `.fork()` so the rng tracker is shared across tensor-parallel processes. + # with mpu.get_cuda_rng_tracker(): + self.key_value = get_linear_layer( + args.hidden_size, + 2 * args.kv_channels, + init_method=init_method) + elif attention_type == AttnType.cross_attn and self.attention_head_type == 'multihead': assert attention_type == AttnType.cross_attn self.query = tensor_parallel.ColumnParallelLinear( args.hidden_size, @@ -462,7 +652,6 @@ def __init__(self, init_method, async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce, **_args_to_kwargs()) - self.key_value = tensor_parallel.ColumnParallelLinear( args.hidden_size, 2 * projection_size, @@ -471,12 +660,34 @@ def __init__(self, init_method, init_method=init_method, async_tensor_model_parallel_allreduce=args.async_tensor_model_parallel_allreduce, **_args_to_kwargs()) + elif attention_type == AttnType.cross_attn and self.attention_head_type == 'multiquery': + raise NotImplementedError("Multiquery attention not implemented for cross-attention.") + else: + raise ValueError(f"Invalid attention arguments: {attention_type}, {self.attention_head_type}") - self.core_attention = CoreAttention(self.layer_number, - self.attn_mask_type) + if self.attention_head_type == 'multihead': + self.core_attention = CoreAttention(self.layer_number, + self.attn_mask_type) + else: + self.core_attention = MultiQueryCoreAttention(self.layer_number, self.attn_mask_type) self.checkpoint_core_attention = args.recompute_granularity == 'selective' - + if self.use_flash_attn: + if FLASH_VERSION is None: + raise ImportError('FlashAttention is not installed, please install with ' + 'pip install flash-attn') + assert attention_type == AttnType.self_attn, ('FlashAttention code path only supports ' + 'self-attention for now') + assert self.attn_mask_type == AttnMaskType.causal, ('FlashAttention code path only ' + 'supports causal mask for now') + assert args.position_embedding_type != PositionEmbeddingType.alibi, \ + ('FlashAttention does not support alibi positional embeddings yet') + if rearrange is None: + raise ImportError('einops is not installed, please install with pip install einops') + + if self.checkpoint_core_attention: + print_rank_0(" Warning, using selective recomputation with flash-attn: this is already handled in the " + "flash-attn library and has no effect.") self.core_attention_flash = FlashSelfAttention( causal=True, attention_dropout=args.attention_dropout ) @@ -492,16 +703,16 @@ def __init__(self, init_method, **_args_to_kwargs()) def _checkpointed_attention_forward(self, query_layer, key_layer, - value_layer, attention_mask, - rotary_pos_emb=None): + value_layer, attention_mask, alibi, rotary_pos_emb=None): """Forward method with activation checkpointing.""" def custom_forward(*inputs): query_layer = inputs[0] key_layer = inputs[1] value_layer = inputs[2] attention_mask = inputs[3] + alibi = inputs[4] output_ = self.core_attention(query_layer, key_layer, - value_layer, attention_mask) + value_layer, attention_mask, alibi) return output_ q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None \ @@ -510,7 +721,7 @@ def custom_forward(*inputs): hidden_states = tensor_parallel.checkpoint( custom_forward, False, query_layer, key_layer, value_layer, attention_mask, - q_pos_emb, k_pos_emb) + alibi, q_pos_emb, k_pos_emb) return hidden_states @@ -518,16 +729,15 @@ def _allocate_memory(self, inference_max_sequence_len, batch_size): return torch.empty( inference_max_sequence_len, batch_size, - self.num_attention_heads_per_partition, + self.num_attention_heads_per_partition if self.attention_head_type == "multihead" else 1, self.hidden_size_per_attention_head, dtype=self.params_dtype, device=torch.cuda.current_device()) + def forward(self, hidden_states, attention_mask, - encoder_output=None, inference_params=None, - rotary_pos_emb=None): + encoder_output=None, inference_params=None, alibi=None, rotary_pos_emb=None): # hidden_states: [sq, b, h] - # ================================================= # Pre-allocate memory for key-values for inference. # ================================================= @@ -551,7 +761,7 @@ def forward(self, hidden_states, attention_mask, # Query, Key, and Value # ===================== - if self.attention_type == AttnType.self_attn: + if self.attention_type == AttnType.self_attn and self.attention_head_type == 'multihead': # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] mixed_x_layer, _ = self.query_key_value(hidden_states) @@ -565,6 +775,49 @@ def forward(self, hidden_states, attention_mask, (query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3) + elif self.attention_type == AttnType.self_attn and self.attention_head_type == 'multiquery': + kv_input=hidden_states + # Attention heads [sq, b, h] --> [sq, b, (2 * hn)] + mixed_kv_layer = self.key_value(kv_input) + + # Reduce the KV gradients in the tensor-parallel direction. + # This is different from multi-head attention which reduces the KV input, + # because the sum over attn heads happens in the attn weight gradient instead of the KV layer: + # A [b, n * sq, sk] = Q [b, n * sq, hn] x K^T [b, hn, sk] + # G_K [b, sk, hn] = G_A [b, sk, n * sq] x Q [b, n * sq, hn] + # = sum_p (G_Ap [b, sk, np * sq] x Q_p [b, np * sq, hn]) + if get_args().sequence_parallel: + # We switch to the tensor parallel regime here instead of at the KV input + # so that the KV layer is done in parallel instead of just duplicated. + mixed_kv_layer = tensor_parallel.gather_from_sequence_parallel_region(mixed_kv_layer, tensor_parallel_output_grad=True) + else: + mixed_kv_layer = tensor_parallel.copy_to_tensor_model_parallel_region(mixed_kv_layer) + + # [sq, b, (2 * hn)] --> [sq, b, np (expanded), 2 * hn] + # new_tensor_shape = mixed_kv_layer.size()[:-1] + \ + # (self.num_attention_heads_per_partition, + # 2 * self.hidden_size_per_attention_head) + # mixed_kv_layer = mixed_kv_layer.unsqueeze(2).expand(*new_tensor_shape) + + # [sq, b, (2 * hn)] --> [sq, b, 1, 2 * hn] + new_tensor_shape = mixed_kv_layer.size()[:-1] + \ + (1, + 2 * self.hidden_size_per_attention_head) + mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) + + # [sq, b, np, 2 * hn] --> 2 [sq, b, np, hn] + (key_layer, + value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2) + + # Attention head [sq, b, h] --> [sq, b, np * hn] + query_layer, _ = self.query(hidden_states) + # [sq, b, np * hn] --> [sq, b, np, hn] + new_tensor_shape = query_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + query_layer = query_layer.view(*new_tensor_shape) + + # [sq, b, np, hn] -> [b, np * sq, hn] else: # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv_layer, _ = self.key_value(encoder_output) @@ -640,7 +893,6 @@ def forward(self, hidden_states, attention_mask, # ================================== # core attention computation # ================================== - # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: q_pos_emb, k_pos_emb = rotary_pos_emb @@ -651,27 +903,35 @@ def forward(self, hidden_states, attention_mask, # otherwise, only relative positional embedding takes effect # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) - if not self.use_flash_attn: - if self.checkpoint_core_attention: - context_layer = self._checkpointed_attention_forward( - query_layer, key_layer, value_layer, attention_mask) - else: - context_layer = self.core_attention( - query_layer, key_layer, value_layer, attention_mask) - else: + if self.use_flash_attn: + if self.attention_head_type == "multiquery": + sq, b, np, hn = query_layer.size() + # Expand kv to be compatible with flash-attn implementation + # [sq, b, 1, hn] -> [sq, b, np, hn] + # TODO: This should be skippable for flash 2, but getting illegal memory access. + key_layer = key_layer.expand((sq, b, np, hn)) + value_layer = value_layer.expand((sq, b, np, hn)) q, k, v = [rearrange(x, 's b ... -> b s ...').contiguous() for x in (query_layer, key_layer, value_layer)] - if not self.sequence_parallel: + if self.sequence_parallel: + context_layer = self.core_attention_flash(q, k, v) + else: with tensor_parallel.get_cuda_rng_tracker().fork(): context_layer = self.core_attention_flash(q, k, v) - else: - context_layer = self.core_attention_flash(q, k, v) context_layer = rearrange(context_layer, 'b s h d -> s b (h d)').contiguous() + else: + if self.checkpoint_core_attention: + context_layer = self._checkpointed_attention_forward( + query_layer, key_layer, value_layer, attention_mask, alibi) + else: + context_layer = self.core_attention( + query_layer, key_layer, value_layer, attention_mask, alibi) + + # ================= # Output. [sq, b, h] # ================= - output, bias = self.dense(context_layer) return output, bias @@ -783,13 +1043,23 @@ def __init__(self, init_method, output_layer_init_method, else: self.mlp = ParallelMLP(init_method, output_layer_init_method) - # Set bias+dropout+add fusion grad_enable execution handler. + # Set bias+dropout+add fusion grad_enable execution handler. TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) self.bias_dropout_add_exec_handler = \ nullcontext if use_nvfuser else torch.enable_grad + # Alibi + if args.position_embedding_type == PositionEmbeddingType.alibi: + self.alibi = self._build_alibi_tensor(args.seq_length, args.num_attention_heads, args.micro_batch_size).to(torch.cuda.current_device()) + if args.params_dtype == torch.float16: + self.alibi = self.alibi.to(torch.float16) + elif args.params_dtype == torch.bfloat16: + self.alibi = self.alibi.to(torch.bfloat16) + else: + self.alibi = None + if args.retro_add_retriever: retro_args = get_retro_args() self.retro_num_neighbors = args.retro_num_neighbors @@ -1036,6 +1306,7 @@ def forward(self, hidden_states, attention_mask, layernorm_output, attention_mask, inference_params=inference_params, + alibi=self.alibi, rotary_pos_emb=rotary_pos_emb) # Residual connection. @@ -1149,6 +1420,35 @@ def forward(self, hidden_states, attention_mask, else: return output + @staticmethod + def _build_alibi_tensor(max_seq_len, num_attention_heads, batch_size): + # Based on https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + """Returns tensor shaped (batch_size * num_attention_heads, 1, max_seq_len)""" + + def get_slopes(n): + def get_slopes_power_of_2(n): + start = (2 ** (-2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio ** i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2 * closest_power_of_2)[0::2][ + :n - closest_power_of_2] + + slopes = torch.Tensor(get_slopes(num_attention_heads)) + alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0).expand( + num_attention_heads, -1, -1) + + #Select the part of the tensor that corresponds to our tensor parallel index. + tp_world_size = mpu.get_tensor_model_parallel_world_size() + tp_index = mpu.get_tensor_model_parallel_rank() + alibi = alibi.reshape((tp_world_size, -1, *alibi.shape[1:]))[tp_index] + + alibi = alibi.repeat(batch_size, 1, 1) + return alibi class NoopTransformerLayer(MegatronModule): """A single 'no-op' transformer layer. @@ -1542,6 +1842,10 @@ def forward(self, hidden_states, attention_mask, inference_params=None, rotary_pos_emb=None): # hidden_states: [s, b, h] + timers = get_timers() + args = get_args() + + if args.transformer_timers: timers("Transformer forward").start() # Checks. if inference_params: @@ -1646,4 +1950,6 @@ def forward(self, hidden_states, attention_mask, if self.post_process and self.post_layer_norm: hidden_states = self.final_layernorm(hidden_states) + if args.transformer_timers: timers("Transformer forward").stop() + return hidden_states diff --git a/megatron/model/vision/esvit_swin_backbone.py b/megatron/model/vision/esvit_swin_backbone.py index 70aee3db42..221ccf331f 100644 --- a/megatron/model/vision/esvit_swin_backbone.py +++ b/megatron/model/vision/esvit_swin_backbone.py @@ -17,7 +17,7 @@ from torch.nn.init import trunc_normal_ from megatron.model.transformer import DropPath from megatron import get_args -from megatron.model import LayerNorm +from megatron.model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm import numpy as np from math import sqrt diff --git a/megatron/model/vision/mit_backbone.py b/megatron/model/vision/mit_backbone.py index c67ca2c62b..c68f10e764 100644 --- a/megatron/model/vision/mit_backbone.py +++ b/megatron/model/vision/mit_backbone.py @@ -12,7 +12,7 @@ from functools import partial from torch.nn.init import trunc_normal_ from megatron.model.transformer import DropPath -from megatron.model import LayerNorm +from megatron.model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm class Mlp(nn.Module): diff --git a/megatron/optimizer/distrib_optimizer.py b/megatron/optimizer/distrib_optimizer.py index 96786394ae..1182b37a9d 100644 --- a/megatron/optimizer/distrib_optimizer.py +++ b/megatron/optimizer/distrib_optimizer.py @@ -831,6 +831,16 @@ def reduce_model_grads(self, args, timers): self.allreduce_embedding_grads(args) timers('embedding-grads-all-reduce').stop() + # All-reduce key-value grads if needed. + if ( + args.attention_head_type == "multiquery" + and mpu.get_tensor_model_parallel_world_size() > 1 + and args.sequence_parallel + ): + timers('backward-key-value-all-reduce').start() + self.allreduce_key_value_grads(args) + timers('backward-key-value-all-reduce').stop() + # Reduce-scatter setup. timers('grads-reduce-scatter', log_level=1).start( barrier=args.barrier_with_L1_time) diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py index cc89c95ca2..4a58acfdac 100644 --- a/megatron/optimizer/optimizer.py +++ b/megatron/optimizer/optimizer.py @@ -251,7 +251,31 @@ def allreduce_embedding_grads(self, args): """All-reduce both word and position embeddings.""" self.allreduce_word_embedding_grads(args) self.allreduce_position_embedding_grads(args) - + + def allreduce_key_value_grads(self, args): + """ + Reduce the gradients for the key_value weights and biases for multi-query attention + with sequence parallelism. + Coalesce the bias grads to avoid too many small reductions, + but not the weight grads since it could cause memory issues. + """ + grads=[] + for model_module in self.models: + unwrapped_model = unwrap_model( + model_module, (torchDDP, LocalDDP, Float16Module)) + for layer in unwrapped_model.language_model.encoder.layers: + kv_weight = layer.self_attention.key_value.weight + grad = kv_weight.main_grad if args.DDP_impl == 'local' else kv_weight.grad + torch.distributed.all_reduce(grad, group=mpu.get_tensor_model_parallel_group()) + kv_bias = layer.self_attention.key_value.bias + grads.append(kv_bias.main_grad if args.DDP_impl == 'local' else kv_bias.grad) + if len(grads)>0: + coalesced = _flatten_dense_tensors(grads) + torch.distributed.all_reduce( + coalesced, group=mpu.get_tensor_model_parallel_group()) + for buf, synced in zip(grads, _unflatten_dense_tensors( + coalesced, grads)): + buf.copy_(synced) def allreduce_layernorm_grads(self, args): """All-reduce layernorm grads (for sequence parallelism).""" @@ -299,6 +323,17 @@ def reduce_model_grads(self, args, timers): self.allreduce_embedding_grads(args) timers('embedding-grads-all-reduce').stop() + # All-reduce key-value grads if needed. + if ( + args.attention_head_type == "multiquery" + and mpu.get_tensor_model_parallel_world_size() > 1 + and args.sequence_parallel + ): + timers('backward-key-value-all-reduce').start() + self.allreduce_key_value_grads(args) + timers('backward-key-value-all-reduce').stop() + + class MixedPrecisionOptimizer(MegatronOptimizer): """Base class for both the float-16 and the distributed optimizer. diff --git a/megatron/text_generation/tokenization.py b/megatron/text_generation/tokenization.py index accead319a..727ead28dc 100644 --- a/megatron/text_generation/tokenization.py +++ b/megatron/text_generation/tokenization.py @@ -35,6 +35,8 @@ def detokenize_generations(tokens_gpu_tensor, word = tokenizer.decoder[token] elif args.tokenizer_type == 'NullTokenizer': word = str(token) + elif args.tokenizer_type in ['TokenizerFromFile', 'TokenizerFromFileWithFIM']: + word = tokenizer.detokenize([token]) else: word = tokenizer.tokenizer.decoder[token] word = bytearray( diff --git a/megatron/tokenizer/gpt2_tokenization.py b/megatron/tokenizer/gpt2_tokenization.py index 3f37e44908..ff89504351 100644 --- a/megatron/tokenizer/gpt2_tokenization.py +++ b/megatron/tokenizer/gpt2_tokenization.py @@ -281,7 +281,7 @@ def encode(self, text): return self.convert_tokens_to_ids(self.tokenize(text)) def decode(self, tokens): - text = ''.join([self.decoder[token] for token in tokens]) + text = ''.join(self.convert_ids_to_tokens(tokens)) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) return text diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index 79dab75a04..31e97c135c 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -5,10 +5,18 @@ from abc import ABC from abc import abstractmethod +from transformers import PreTrainedTokenizerFast from .bert_tokenization import FullTokenizer as FullBertTokenizer from .gpt2_tokenization import GPT2Tokenizer +FIM_PREFIX = "" +FIM_MIDDLE = "" +FIM_SUFFIX = "" +FIM_PAD = "" +EOD = "<|endoftext|>" + + def build_tokenizer(args): """Initialize tokenizer.""" if args.rank == 0: @@ -30,6 +38,16 @@ def build_tokenizer(args): assert args.vocab_file is not None assert args.merge_file is not None tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file) + elif args.tokenizer_type == 'GPT2BPETokenizerWithFIM': + assert args.vocab_file is not None + assert args.merge_file is not None + tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file, special_tokens=[FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD]) + elif args.tokenizer_type == "TokenizerFromFile": + assert args.tokenizer_file is not None + tokenizer = _HFTokenizer(args.tokenizer_file, special_tokens=[EOD]) + elif args.tokenizer_type == "TokenizerFromFileWithFIM": + assert args.tokenizer_file is not None + tokenizer = _HFTokenizer(args.tokenizer_file, special_tokens=[EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD]) elif args.tokenizer_type == 'SentencePieceTokenizer': assert args.tokenizer_model is not None tokenizer = _SentencePieceTokenizer(args.tokenizer_model, vocab_extra_ids=args.vocab_extra_ids) @@ -44,6 +62,8 @@ def build_tokenizer(args): 'implemented.'.format(args.tokenizer_type)) # Add vocab size. + # TODO: For most tokenizers, vocab_size does not take special_tokens into account. + # Might cause an issue if vocab_size + len(special_tokens) exceeds padded_vocab_size? args.padded_vocab_size = _vocab_size_with_padding(tokenizer.vocab_size, args) @@ -258,13 +278,15 @@ def additional_special_tokens(self, value): class _GPT2BPETokenizer(AbstractTokenizer): """Original GPT2 BPE tokenizer.""" - def __init__(self, vocab_file, merge_file): + def __init__(self, vocab_file, merge_file, special_tokens=None): name = 'GPT2 BPE' super().__init__(name) + special_tokens = special_tokens if special_tokens is not None else [] self.tokenizer = GPT2Tokenizer(vocab_file, merge_file, errors='replace', - special_tokens=[], max_len=None) + special_tokens=special_tokens, max_len=None) self.eod_id = self.tokenizer.encoder['<|endoftext|>'] + self.special_tokens = self.tokenizer.special_tokens @property def vocab_size(self): @@ -289,6 +311,46 @@ def eod(self): return self.eod_id +class _HFTokenizer(AbstractTokenizer): + """HF Tokenizer.""" + + def __init__(self, tokenizer_file, special_tokens=None): + name = 'HF Tokenizer' + super().__init__(name) + + special_tokens = special_tokens if special_tokens is not None else [] + self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file, errors='replace', max_len=None) + self.tokenizer.add_special_tokens({'additional_special_tokens': special_tokens}) + self.eod_id = self.tokenizer.vocab[EOD] + # Token->id mapping for additional special-tokens + self.special_tokens = { + tok: self.tokenizer.vocab[tok] for tok in special_tokens + } + self._inv_vocab = {v: k for k, v in self.tokenizer.vocab.items()} + + @property + def vocab_size(self): + return len(self.tokenizer) + + @property + def vocab(self): + return self.tokenizer.vocab + + @property + def inv_vocab(self): + return self._inv_vocab + + def tokenize(self, text): + return self.tokenizer.encode(text) + + def detokenize(self, token_ids): + return self.tokenizer.decode(token_ids) + + @property + def eod(self): + return self.eod_id + + class _SentencePieceTokenizer(AbstractTokenizer): """SentencePieceTokenizer-Megatron wrapper""" diff --git a/megatron/training.py b/megatron/training.py index 14bca152f0..35e47c20d6 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -6,6 +6,11 @@ import math import sys import time + +try: + import wandb +except ModuleNotFoundError: + print('Wandb import failed', flush=True) # The earliest we can measure the start time. _TRAIN_START_TIME = time.time() import torch @@ -28,18 +33,19 @@ from megatron.model import GPTModel from megatron.core.enums import ModelType from megatron.optimizer import get_megatron_optimizer -from megatron.initialize import initialize_megatron +from megatron.initialize import init_wandb, initialize_megatron from megatron.initialize import write_args_to_tensorboard from megatron.initialize import set_jit_fusion_options from megatron.optimizer_param_scheduler import OptimizerParamScheduler from megatron.model import DistributedDataParallel as LocalDDP -from megatron.utils import check_adlr_autoresume_termination +from megatron.utils import check_adlr_autoresume_termination, get_tflops from megatron.utils import unwrap_model from megatron.data.data_samplers import build_pretraining_data_loader from megatron.utils import calc_params_l2_norm from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.utils import report_memory from megatron.model.vision.knn_monitor import compute_feature_bank +from megatron.data.dataset_utils import analyze_data_prefix def print_datetime(string): @@ -156,11 +162,13 @@ def pretrain(train_valid_test_dataset_provider, print_datetime('after training is done') if args.do_valid: - prefix = 'the end of training for val data' - evaluate_and_print_results(prefix, forward_step_func, - valid_data_iterator, model, - iteration, process_non_loss_data_func, - False) + names = args.valid_weighted_split_names + names = names if names is not None else ['valid'] * len(valid_data_iterator) + for iterator, name in zip(valid_data_iterator, names): + prefix = 'the end of training for val data' + evaluate_and_print_results(prefix, forward_step_func, + iterator, model, + iteration, process_non_loss_data_func, False, data_group_name=name) if args.save and iteration != 0: save_checkpoint(iteration, model, optimizer, opt_param_scheduler) @@ -168,10 +176,12 @@ def pretrain(train_valid_test_dataset_provider, if args.do_test: # Run on test data. prefix = 'the end of training for test data' - evaluate_and_print_results(prefix, forward_step_func, - test_data_iterator, model, - 0, process_non_loss_data_func, - True) + names = args.test_weighted_split_names + names = names if names is not None else ['test'] * len(test_data_iterator) + for iterator, name in zip(test_data_iterator, names): + evaluate_and_print_results(prefix, forward_step_func, + iterator, model, + 0, process_non_loss_data_func, True, data_group_name=name) def update_train_iters(args): @@ -554,6 +564,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, total_iterations = total_loss_dict[advanced_iters_key] + \ total_loss_dict[skipped_iters_key] + mem_stats = None # Tensorboard values. # Timer requires all the ranks to call. @@ -611,14 +622,22 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, mem_stats["allocation.all.current"], iteration, ) + if iteration % args.log_interval == 0: elapsed_time = timers('interval-time').elapsed(barrier=True) elapsed_time_per_iteration = elapsed_time / total_iterations + + num_gpus = args.data_parallel_size * args.tensor_model_parallel_size * args.pipeline_model_parallel_size + tokens_per_sec_per_gpu = (args.seq_length * batch_size) / num_gpus / elapsed_time_per_iteration + + tflops = get_tflops(batch_size, elapsed_time_per_iteration) if writer: if args.log_timers_to_tensorboard: writer.add_scalar('iteration-time', elapsed_time_per_iteration, iteration) + writer.add_scalar('TFLOPs per gpu (estimated)', + tflops, iteration) log_string = ' iteration {:8d}/{:8d} |'.format( iteration, args.train_iters) log_string += ' consumed samples: {:12d} |'.format( @@ -646,6 +665,10 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, total_loss_dict[skipped_iters_key]) log_string += ' number of nan iterations: {:3d} |'.format( total_loss_dict[nan_iters_key]) + log_string += ' TFLOPs: {:.2f} |'.format(tflops) + log_string += ' tokens-per-second-per-gpu: {:.2f} |'.format(tokens_per_sec_per_gpu) + if args.log_memory_to_tensorboard and mem_stats is not None: + log_string += ' mem-reserved (GB): {:.2f} |'.format(mem_stats["reserved_bytes.all.current"]*1e-9) total_loss_dict[advanced_iters_key] = 0 total_loss_dict[skipped_iters_key] = 0 total_loss_dict[nan_iters_key] = 0 @@ -656,6 +679,18 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, report_memory_flag = False timers.log(timers_to_log, normalizer=args.log_interval) + # Weights and biases reporting + if (iteration % args.log_interval == 0) and is_last_rank() and args.wandb_project_name: + metrics = { + 'learning-rate': learning_rate, + 'samples': args.consumed_train_samples, + 'loss-scale': loss_scale, + 'grad-norm': grad_norm, + 'tflops': tflops, + 'tokens-per-second-per-gpu': tokens_per_sec_per_gpu, + **loss_dict + } + wandb.log(metrics, step=iteration) return report_memory_flag @@ -679,6 +714,9 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, # Write args to tensorboard write_args_to_tensorboard() + # Init Weights and Biases + init_wandb() + # Turn on training mode which enables dropout. for model_module in model: model_module.train() @@ -727,10 +765,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, if args.eval_interval and iteration % args.eval_interval == 0 and \ args.do_valid: prefix = 'iteration {}'.format(iteration) - evaluate_and_print_results(prefix, forward_step_func, - valid_data_iterator, model, - iteration, process_non_loss_data_func, - False) + names = args.valid_weighted_split_names + names = names if names is not None else ['valid'] * len(valid_data_iterator) + for iterator, name in zip(valid_data_iterator, names): + evaluate_and_print_results(prefix, forward_step_func, + iterator, model, + iteration, process_non_loss_data_func, False, data_group_name=name) # Checkpointing saved_checkpoint = False @@ -845,31 +885,42 @@ def evaluate(forward_step_func, def evaluate_and_print_results(prefix, forward_step_func, data_iterator, model, iteration, process_non_loss_data_func, - verbose=False): + verbose=False, data_group_name=None): """Helper function to evaluate and dump results on screen.""" args = get_args() writer = get_tensorboard_writer() + ds_name = data_group_name + # print corresponding dataset name (used for multiple validation datasets) + tf_plot_prefix = f"lm-loss-validation/{ds_name}" if ds_name else "lm-loss-validation" total_loss_dict, collected_non_loss_data = evaluate( forward_step_func, data_iterator, model, process_non_loss_data_func, verbose) - string = ' validation loss at {} | '.format(prefix) + string = '{} loss at {} | '.format(ds_name, prefix) if ds_name is not None\ + else 'validation loss at {} | '.format(prefix) for key in total_loss_dict: string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item()) ppl = math.exp(min(20, total_loss_dict[key].item())) string += '{} PPL: {:.6E} | '.format(key, ppl) if writer: - writer.add_scalar('{} validation'.format(key), + writer.add_scalar(f'{tf_plot_prefix}/{key} validation', total_loss_dict[key].item(), iteration) - writer.add_scalar('{} validation vs samples'.format(key), + writer.add_scalar(f'{tf_plot_prefix}/{key} validation vs samples', total_loss_dict[key].item(), args.consumed_train_samples) if args.log_validation_ppl_to_tensorboard: - writer.add_scalar('{} validation ppl'.format(key), ppl, + writer.add_scalar(f'{tf_plot_prefix}/{key} validation ppl', ppl, iteration) - writer.add_scalar('{} validation ppl vs samples'.format(key), + writer.add_scalar(f'{tf_plot_prefix}/{key} validation ppl vs samples', ppl, args.consumed_train_samples) + + # Weights and biases reporting + if is_last_rank() and args.wandb_project_name: + metrics = { + f'{tf_plot_prefix}/{key} validation': total_loss_dict[key].item() for key in total_loss_dict + } + wandb.log(metrics, step=iteration) if process_non_loss_data_func is not None and writer and is_last_rank(): process_non_loss_data_func(collected_non_loss_data, iteration, writer) @@ -891,7 +942,7 @@ def build_train_valid_test_data_loaders( """XXX""" args = get_args() - (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) + (train_dataloader, valid_dataloaders, test_dataloaders) = (None, None, None) print_rank_0('> building train, validation, and test datasets ...') @@ -927,21 +978,47 @@ def build_train_valid_test_data_loaders( # Build the datasets. train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider( train_val_test_num_samples) + # if dataloading option is not 2 convert to list to allow + # same interface for multiple data groups + # for validation and testing in option 2 + if type(train_ds) != list and train_ds is not None: + train_ds = [train_ds] + if type(valid_ds) != list and valid_ds is not None: + valid_ds = [valid_ds] + if type(test_ds) != list and test_ds is not None: + test_ds = [test_ds] # Build dataloders. + assert len(train_ds) == 1, "only one training dataset group is allowed" + + # train_dataloader is a single item while valid_dataloaders + # and test_dataloaders are arrays train_dataloader = build_pretraining_data_loader( - train_ds, args.consumed_train_samples) - valid_dataloader = build_pretraining_data_loader( - valid_ds, args.consumed_valid_samples) - test_dataloader = build_pretraining_data_loader(test_ds, 0) + train_ds[0], args.consumed_train_samples) + + # We collapse None and empty list as both should mean we don't run validation + # args.consumed_valid_samples accumulates the sum of valid steps for every dataset, which are all equal + # + # XXX: we get a deadlock in the dataloader on multi-dataset eval, after the first dataset, + # possibly due to this bug in pytorch https://github.com/pytorch/pytorch/pull/25158. Using + # num_workers=0 to work around it - the training can't use that since it impacts throughput + # by a few percent + valid_dataloaders = [build_pretraining_data_loader(d, args.consumed_valid_samples // len(valid_ds), num_workers=args.valid_num_workers) + for d in valid_ds] \ + if valid_ds is not None else [] + # We collapse None and empty list as both should mean we don't run test + test_dataloaders = [build_pretraining_data_loader(d, 0) for d in test_ds] \ + if test_ds is not None else [] # Flags to know if we need to do training/validation/testing. do_train = train_dataloader is not None and args.train_iters > 0 - do_valid = valid_dataloader is not None and args.eval_iters > 0 - do_test = test_dataloader is not None and args.eval_iters > 0 + # Need to broadcast num_tokens and num_type_tokens. - flags = torch.cuda.LongTensor( - [int(do_train), int(do_valid), int(do_test)]) + flags = torch.cuda.LongTensor([ + int(do_train), + len(valid_dataloaders) if args.eval_iters > 0 else 0, # eval_iters == 0 is equivalent to having no validation + len(test_dataloaders) if args.eval_iters > 0 else 0, # eval_iters == 0 is equivalent to having no test + ]) else: flags = torch.cuda.LongTensor([0, 0, 0]) @@ -950,10 +1027,14 @@ def build_train_valid_test_data_loaders( mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group()) args.do_train = flags[0].item() - args.do_valid = flags[1].item() - args.do_test = flags[2].item() + args.num_valid_ds = flags[1].item() + args.num_test_ds = flags[2].item() + assert args.num_test_ds >= 0 + assert args.num_valid_ds >= 0 + args.do_valid = args.num_valid_ds > 0 + args.do_test = args.num_test_ds > 0 - return train_dataloader, valid_dataloader, test_dataloader + return train_dataloader, valid_dataloaders, test_dataloaders def build_train_valid_test_data_iterators( @@ -962,7 +1043,7 @@ def build_train_valid_test_data_iterators( args = get_args() # Build loaders. - train_dataloader, valid_dataloader, test_dataloader = \ + train_dataloader, valid_dataloaders, test_dataloaders = \ build_train_valid_test_data_loaders( build_train_valid_test_datasets_provider) @@ -976,16 +1057,18 @@ def build_train_valid_test_data_iterators( else: train_data_iterator = None - if valid_dataloader is not None: - valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \ - else iter(cyclic_iter(valid_dataloader)) + if valid_dataloaders is not None: + valid_data_iterators = [iter(vdl) if dl_type == 'single' \ + else iter(cyclic_iter(valid_dataloaders)) + for vdl in valid_dataloaders] else: - valid_data_iterator = None + valid_data_iterators = [None] * args.num_valid_ds - if test_dataloader is not None: - test_data_iterator = iter(test_dataloader) if dl_type == 'single' \ - else iter(cyclic_iter(test_dataloader)) + if test_dataloaders is not None: + test_data_iterators = [iter(tdl) if dl_type == 'single' \ + else iter(cyclic_iter(test_dataloaders)) + for tdl in test_dataloaders] else: - test_data_iterator = None + test_data_iterators = [None] * args.num_test_ds - return train_data_iterator, valid_data_iterator, test_data_iterator + return train_data_iterator, valid_data_iterators, test_data_iterators diff --git a/megatron/utils.py b/megatron/utils.py index 008f89fa80..412cc12bde 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -6,6 +6,7 @@ import torch from torch.nn.parallel import DistributedDataParallel as torchDDP +from torch.distributed import BarrierOptions, GroupMember from apex.multi_tensor_apply import multi_tensor_applier import amp_C @@ -211,3 +212,49 @@ def print_rank_last(message): print(message, flush=True) else: print(message, flush=True) + + +def get_tflops(batch_size, elapsed_time_per_iteration): + """Get tflop/s/GPU from global-batch-size and elapsed-time""" + args = get_args() + seq_len = args.seq_length + hidden_size = args.hidden_size + num_layers = args.num_layers + vocab_size = args.padded_vocab_size + + # Compute throughput. + samples_per_sec = batch_size / elapsed_time_per_iteration + tokens_per_sec = samples_per_sec * seq_len + + # General TFLOPs formula (borrowed from Equation 3 in Section 5.1 of + # https://arxiv.org/pdf/2104.04473.pdf). + # The factor of 4 is when used with activation check-pointing, + # otherwise it will be 3, but for 200B model, activation check-pointing will always be on. + checkpoint_activations_factor = 4 if args.recompute_granularity == 'full' else 3 + coefficient_h_squared = 24 + # GLU activations double the hidden states in the upscaling feed-forward in each transformer layer + # This leads to 16bsh^2 instead of 8bsh^2 per first feed-forward layer in MLP, thus we increase the coefficient_h_squared by 8. + # Refer to https://github.com/bigscience-workshop/Megatron-DeepSpeed/pull/283#issue-1260805063 for more details. + if args.glu_activation : + coefficient_h_squared += 8 + + # In MultiQuery attention, keys and values are shared across heads + # qkv projection: 6Bsh^2 -> 2Bsh^2 + 4Bshd_kv + # The formula in https://arxiv.org/pdf/2104.04473.pdf becomes: + # 4 * (20 Bsh^2 + 4Bshd_kv + 4Bs^2h) = 4*20*Bsh^2 (1 + (d_kv+s)/5h) + if args.attention_head_type == 'multiquery': + coefficient_h_squared -= 4 # We substract 4 because of shared kv projection + + # Feed-forward and projections + flops_per_iteration = (coefficient_h_squared * checkpoint_activations_factor * batch_size * seq_len * num_layers * (hidden_size**2)) + # Attention-matrix computation + flops_per_iteration += (4 * checkpoint_activations_factor * batch_size * seq_len * num_layers * (hidden_size**2)) * (seq_len / hidden_size) + # LM-head + flops_per_iteration += (6 * batch_size * seq_len * num_layers * (hidden_size**2)) * (vocab_size / (num_layers * hidden_size)) + + if args.attention_head_type == 'multiquery': + d_kv = args.kv_channels + flops_per_iteration += (4 * checkpoint_activations_factor * batch_size * seq_len * num_layers * (hidden_size**2)) * (d_kv / hidden_size) # TODO: maybe tp_size factor missing here + + tflops = flops_per_iteration / (elapsed_time_per_iteration * args.world_size * (10**12)) + return tflops diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 88913e48aa..7b73239271 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -10,7 +10,7 @@ from megatron import get_tokenizer from megatron.core import tensor_parallel from megatron.core.enums import ModelType -from megatron.data.gpt_dataset import build_train_valid_test_datasets +from megatron.data.gpt_dataset import build_train_valid_test_datasets, build_dataset_group from megatron.model import GPTModel from megatron.training import pretrain from megatron.utils import get_ltor_masks_and_position_ids @@ -91,20 +91,54 @@ def forward_step(data_iterator, model): def train_valid_test_datasets_provider(train_val_test_num_samples): """Build train, valid, and test datasets.""" args = get_args() + train_ds, valid_ds, test_ds = None, None, None print_rank_0('> building train, validation, and test datasets ' 'for GPT ...') - train_ds, valid_ds, test_ds = build_train_valid_test_datasets( - data_prefix=args.data_path, - data_impl=args.data_impl, - splits_string=args.split, - train_valid_test_num_samples=train_val_test_num_samples, - seq_length=args.seq_length, - seed=args.seed, - skip_warmup=(not args.mmap_warmup), - train_data_prefix=args.train_data_path, - valid_data_prefix=args.valid_data_path, - test_data_prefix=args.test_data_path) + # Option 1 of data loading using --data-path + if args.data_path: + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=args.data_path, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.seq_length, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + train_data_prefix=args.train_data_path, + valid_data_prefix=args.valid_data_path, + test_data_prefix=args.test_data_path) + # Option 2 of data loading using --(train|valid|test)-weighted-split-paths + elif args.train_weighted_split_paths: + assigned_train_valid_test = [] + if args.train_weighted_split_paths is not None: + train_ds = [] + assigned_train_valid_test.append("train") + if args.valid_weighted_split_paths is not None: + valid_ds = [] + assigned_train_valid_test.append("valid") + if args.test_weighted_split_paths is not None: + test_ds = [] + assigned_train_valid_test.append("test") + + for s in assigned_train_valid_test: + data_groups = zip(eval(f"args.{s}_weighted_split_paths"), + eval(f"args.{s}_weighted_split_weights"), + eval(f"args.{s}_weighted_split_splits"), + eval(f"args.{s}_weighted_split_names")) + for paths, weights, splits, name in data_groups: + d = build_dataset_group(name, paths, weights, splits, + args.data_impl, + train_val_test_num_samples, + args.seq_length, args.seed, + (not args.mmap_warmup), + train_valid_test=s) + assert d is not None, \ + f"Got an empty split when trying to create dataset: {paths, weights, splits, name}" + eval(f"{s}_ds").append(d) + else: + raise NotImplementedError("No dataloading argument passed") + print_rank_0("> finished creating GPT datasets ...") return train_ds, valid_ds, test_ds diff --git a/tasks/vision/segmentation/seg_heads.py b/tasks/vision/segmentation/seg_heads.py index 61b16cdcbd..64c067323b 100644 --- a/tasks/vision/segmentation/seg_heads.py +++ b/tasks/vision/segmentation/seg_heads.py @@ -5,7 +5,7 @@ import apex import torch.nn.functional as F from megatron import get_args -from megatron.model import LayerNorm +from megatron.model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm from megatron.model.module import MegatronModule from megatron.model.vision.utils import resize diff --git a/tools/checkpoint_loader_megatron.py b/tools/checkpoint_loader_megatron.py index 1cd4937152..18e3ddfebe 100644 --- a/tools/checkpoint_loader_megatron.py +++ b/tools/checkpoint_loader_megatron.py @@ -53,6 +53,9 @@ def _load_checkpoint(queue, args): '--no-initialization', '--load', args.load_dir ] + if args.use_distributed_optimizer: + sys.argv.append("--use-distributed-optimizer") + margs = parse_args() margs, checkpoint_args = load_args_from_checkpoint(margs) @@ -87,6 +90,7 @@ def check_for_arg(arg_name, default=None): check_for_arg('bert_binary_head') check_for_arg('disable_bias_linear', False) check_for_arg('params_dtype') + check_for_arg('attention_head_type') check_for_arg('swiglu', False) # Determine how to make our models @@ -148,7 +152,7 @@ def get_models(count, dtype): models[vp_rank].append(model_[vp_rank]) return models - set_global_variables(margs) + set_global_variables(margs, build_tokenizer=False) mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size) mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size) mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size) @@ -177,6 +181,7 @@ def get_models(count, dtype): # metadata md = types.SimpleNamespace() md.model_type = args.model_type + md.attention_head_type = margs.attention_head_type md.num_layers = margs.num_layers md.hidden_size = margs.hidden_size md.seq_length = margs.seq_length @@ -242,22 +247,35 @@ def queue_put(name, msg): if md.linear_bias: message["dense bias"] = layer.self_attention.dense.bias.data message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data + if md.attention_head_type == "multiquery": + # MQA: kv is shared across tp-ranks + message["kv weight"] = layer.self_attention.key_value.weight.data + if md.linear_bias: + message["kv bias"] = layer.self_attention.key_value.bias.data # Grab all parallel tensors for this layer qkv_weight = [] qkv_bias = [] + q_weight = [] + q_bias = [] dense_weight = [] mlp_l0_weight = [] mlp_l0_bias = [] mlp_l1_weight = [] for tp_rank, model in enumerate(models): layer = model.language_model.encoder.layers[layer_num] - qkv_weight.append(layer.self_attention.query_key_value.weight.data) + if md.attention_head_type == "multihead": + qkv_weight.append(layer.self_attention.query_key_value.weight.data) + if md.linear_bias: + qkv_bias.append(layer.self_attention.query_key_value.bias.data) + elif md.attention_head_type == "multiquery": + q_weight.append(layer.self_attention.query.weight.data) + if md.linear_bias: + q_bias.append(layer.self_attention.query.bias.data) dense_weight.append(layer.self_attention.dense.weight.data) mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data) mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data) if md.linear_bias: - qkv_bias.append(layer.self_attention.query_key_value.bias.data) mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data) # Handle gated linear units @@ -271,11 +289,18 @@ def queue_put(name, msg): message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0) # simple concat of the rest - message["qkv weight"] = torch.cat(qkv_weight, dim=0) + if md.attention_head_type == "multihead": + message["qkv weight"] = torch.cat(qkv_weight, dim=0) + if md.linear_bias: + message["qkv bias"] = torch.cat(qkv_bias, dim=0) + elif md.attention_head_type == "multiquery": + message["q weight"] = torch.cat(q_weight, dim=0) + if md.linear_bias: + message["q bias"] = torch.cat(q_bias, dim=0) + message["dense weight"] = torch.cat(dense_weight, dim=1) message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1) if md.linear_bias: - message["qkv bias"] = torch.cat(qkv_bias, dim=0) if md.swiglu: for tp_rank in range(tp_size): mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0) diff --git a/tools/checkpoint_saver_megatron.py b/tools/checkpoint_saver_megatron.py index 0ff8c55b1f..47f1b6c666 100644 --- a/tools/checkpoint_saver_megatron.py +++ b/tools/checkpoint_saver_megatron.py @@ -96,6 +96,7 @@ def check_message(msg): '--seq-length', str(md.seq_length), '--num-attention-heads', str(md.num_attention_heads), '--max-position-embeddings', str(md.max_position_embeddings), + '--attention-head-type', str(md.attention_head_type), '--tokenizer-type', str(md.tokenizer_type), '--tensor-model-parallel-size', str(args.target_tensor_parallel_size), '--pipeline-model-parallel-size', str(args.target_pipeline_parallel_size), @@ -163,7 +164,7 @@ def check_message(msg): validate_args(margs) - set_global_variables(margs) + set_global_variables(margs, build_tokenizer=False) # margs = megatron args margs = get_args() @@ -268,9 +269,20 @@ def get_models(count, dtype, pre_process, post_process): if md.linear_bias: dense_bias = msg.pop("dense bias") mlp_l1_bias = msg.pop("mlp l1 bias") + if md.attention_head_type == "multiquery": + kv_weight = msg.pop("kv weight") + if md.linear_bias: + kv_bias = msg.pop("kv bias") # Split up the parallel tensors - qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0) + if md.attention_head_type == "multihead": + qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0) + if md.linear_bias: + qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0) + elif md.attention_head_type == "multiquery": + q_weight = torch.chunk(msg.pop("q weight"), args.target_tensor_parallel_size, dim=0) + if md.linear_bias: + q_bias = torch.chunk(msg.pop("q bias"), args.target_tensor_parallel_size, dim=0) dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1) mlp_l1_weight = torch.chunk(msg.pop("mlp l1 weight"), args.target_tensor_parallel_size, dim=1) @@ -283,7 +295,6 @@ def get_models(count, dtype, pre_process, post_process): mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0) if md.linear_bias: - qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0) if md.swiglu: mlp_l0_bias_W = torch.chunk(msg.pop("mlp l0 bias W"), args.target_tensor_parallel_size, dim=0) mlp_l0_bias_V = torch.chunk(msg.pop("mlp l0 bias V"), args.target_tensor_parallel_size, dim=0) @@ -296,14 +307,23 @@ def get_models(count, dtype, pre_process, post_process): l = models[tp_rank].language_model.encoder.layers[layer] l.input_layernorm.weight.data.copy_(input_layernorm_weight) l.input_layernorm.bias.data.copy_(input_layernorm_bias) - l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank]) + if md.attention_head_type == "multihead": + l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank]) + if md.linear_bias: + l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank]) + elif md.attention_head_type == "multiquery": + # MQA: key-value are shared across tp-ranks + l.self_attention.key_value.weight.data.copy_(kv_weight) + l.self_attention.query.weight.data.copy_(q_weight[tp_rank]) + if md.linear_bias: + l.self_attention.key_value.bias.data.copy_(kv_bias) + l.self_attention.query.bias.data.copy_(q_bias[tp_rank]) l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank]) l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight) l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias) l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank]) l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank]) if md.linear_bias: - l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank]) l.self_attention.dense.bias.data.copy_(dense_bias) l.mlp.dense_h_to_4h.bias.data.copy_(mlp_l0_bias[tp_rank]) l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias) diff --git a/tools/checkpoint_util.py b/tools/checkpoint_util.py index 628ce47c62..94f2fec3d8 100644 --- a/tools/checkpoint_util.py +++ b/tools/checkpoint_util.py @@ -124,6 +124,9 @@ def main(): parser.add_argument('--no-checking', action='store_false', help='Do not perform checking on the name and ordering of weights', dest='checking') + + parser.add_argument('--use-distributed-optimizer', action='store_true', + help='Loaded checkpoint uses distributed optimizer.') known_args, _ = parser.parse_known_args() loader = load_plugin('loader', known_args.loader) diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index 35781a78e7..7e0406124f 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -18,6 +18,7 @@ except ImportError: nltk_available = False +from datasets import load_dataset from megatron.tokenizer import build_tokenizer from megatron.data import indexed_dataset @@ -63,9 +64,8 @@ def initializer(self): else: Encoder.splitter = IdentitySplitter() - - def encode(self, json_line): - data = json.loads(json_line) + + def _encode_data(self, data): ids = {} for key in self.args.json_keys: text = data[key] @@ -77,13 +77,25 @@ def encode(self, json_line): if len(doc_ids) > 0 and self.args.append_eod: doc_ids[-1].append(Encoder.tokenizer.eod) ids[key] = doc_ids + return ids + + def encode(self, json_line): + data = json.loads(json_line) + ids = self._encode_data(data) return ids, len(json_line) + + def encode_hf(self, sample): + ids = self._encode_data(sample) + return ids, 1 + def get_args(): parser = argparse.ArgumentParser() group = parser.add_argument_group(title='input data') group.add_argument('--input', type=str, required=True, help='Path to input JSON') + group.add_argument('--subset', type=str, default=None, + help='Subset argument when loading input data from a HuggingFace dataset') group.add_argument('--json-keys', nargs='+', default=['text'], help='space separate listed of keys to extract from json') group.add_argument('--split-sentences', action='store_true', @@ -95,12 +107,15 @@ def get_args(): group.add_argument('--tokenizer-type', type=str, required=True, choices=['BertWordPieceLowerCase','BertWordPieceCase', 'GPT2BPETokenizer', 'SentencePieceTokenizer', - 'GPTSentencePieceTokenizer', 'NullTokenizer'], + 'GPTSentencePieceTokenizer', 'NullTokenizer', + 'TokenizerFromFile'], help='What type of tokenizer to use.') group.add_argument('--vocab-file', type=str, default=None, help='Path to the vocab file') group.add_argument('--merge-file', type=str, default=None, help='Path to the BPE merge file (if necessary).') + group.add_argument('--tokenizer-file', type=str, default=None, + help='Path to the tokenizer file') group.add_argument('--append-eod', action='store_true', help='Append an token to the end of a document.') group.add_argument('--lang', type=str, default='english', @@ -143,8 +158,6 @@ def main(): args = get_args() startup_start = time.time() - print("Opening", args.input) - fin = open(args.input, 'r', encoding='utf-8') if nltk_available and args.split_sentences: nltk.download("punkt", quiet=True) @@ -152,8 +165,22 @@ def main(): encoder = Encoder(args) tokenizer = build_tokenizer(args) pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) - encoded_docs = pool.imap(encoder.encode, fin, args.chunk_size) - #encoded_docs = map(encoder.encode, fin) + print("Opening", args.input) + + if args.input.endswith(".jsonl"): + print("Input is a jsonl file") + assert args.subset is None, f"subset argument set to: {args.subset}, but loading a jsonl file." + fin = open(args.input, 'r', encoding='utf-8') + encoded_docs = pool.imap(encoder.encode, fin, args.chunk_size) + #encoded_docs = map(encoder.encode, fin) + else: + # NOTE: this is not recommended for datasets larger than 40-50GB, as iterating through a dataset can be slow. + # Somehow, it seems faster to first dump the dataset to a jsonl file: ds.to_json() and then process the jsonl file. + # NOTE: this will be even slower if the dataset has large objects in other columns. + # In this case, it is recommended to dump as json only the required key: ds = ds.remove_columns(...) then to_json() + print("Input is not a jsonl file, will try to load from HF datasets") + ds = load_dataset(args.input, use_auth_token=True, streaming=True, split="train", data_dir=args.subset) + encoded_docs = pool.imap(encoder.encode_hf, ds, args.chunk_size) level = "document" if args.split_sentences: diff --git a/tools/text_generation_benchmark.py b/tools/text_generation_benchmark.py new file mode 100644 index 0000000000..ee458f377c --- /dev/null +++ b/tools/text_generation_benchmark.py @@ -0,0 +1,163 @@ + +"""Sample Generate GPT""" +import os +import sys +import re +sys.path.append(os.path.abspath(os.path.join( + os.getcwd(), + "Megatron-LM", +))) +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import mpu +from megatron.checkpointing import load_checkpoint +from megatron.initialize import initialize_megatron +from megatron.model import GPTModel +from megatron.training import get_model +from megatron.text_generation import generate_and_post_process +import torch +from human_eval.data import write_jsonl, read_problems +from tqdm import tqdm + + +GENERATE_NUM = 0 + +# End on unindented code +# EOF_STRINGS = ["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif"] + + +BATCH_SIZE = 512 +TOKENS_TO_GENERATE = 128 +PROMPT_LENGTH = 128 +NUM_BATCHES = 8 + + +# NUM_SAMPLES_PER_TASK = 5 +# # Number of human-eval tasks +# NUM_TASKS = 200 + +def send_do_generate(): + choice = torch.cuda.LongTensor([GENERATE_NUM]) + torch.distributed.broadcast(choice, 0) + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building GPT model ...') + model = GPTModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process) + + return model + +def get_batches(prompts, batch_size): + for start_idx in tqdm(range(0, len(prompts), batch_size)): + actual_batch_size = min(batch_size, len(prompts) - start_idx) + yield prompts[start_idx: start_idx + actual_batch_size] + + +def unbatch(d: dict): + return [dict(zip(d.keys(), t)) for t in zip(*d.values())] + + +# Use fixed-length prompts +def load_evaluation_data(args): + # HumanEval data + # problems = read_problems() + + # batches = get_batches( + # [ + # problems[task_id]["prompt"] + # for task_id in problems + # for _ in range(5) + # ], + # BATCH_SIZE + # ) + # return batches + + prompt = " ".join(["one"] * PROMPT_LENGTH) + prompts = [prompt] * (BATCH_SIZE * NUM_BATCHES) + + batches = get_batches(prompts, BATCH_SIZE) + return batches + + +if __name__ == "__main__": + # Initialize Megatron + initialize_megatron(extra_args_provider=None, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer', + 'no_load_rng': True, + 'no_load_optim': True}) + + args = get_args() + timers = get_timers() + + if args.num_layers_per_virtual_pipeline_stage is not None: + print("Interleaved pipeline schedule is not yet supported for text generation.") + exit() + # Setup model and load checkpoint + model = get_model(model_provider, wrap_with_ddp=False) + + if args.load is not None: + iteration = load_checkpoint(model, None, None, iteration=None) + else: + iteration = None + + assert len(model) == 1 + model = model[0] + + def generate(prompts): + response, response_seg, response_logprobs, tokens = \ + generate_and_post_process( + model, + prompts=prompts, + tokens_to_generate=TOKENS_TO_GENERATE, + return_output_log_probs=True, + use_eod_token_for_early_termination=False) + + assert all([r.startswith(p) for r, p in zip(response, prompts)]) + result = { + "response": response, + "response_seg": response_seg, + "raw_completion": [r[len(p):] for r, p in zip(response, prompts)] + } + # The "completion" field contains the string that is actually going to be evaluated by the HumanEval script + # result["completion"] = [post_process_completion(c) for c in result["raw_completion"]] + # Return a list of dicts + return unbatch(result) + + # if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0: + # server = MegatronServer(model) + # server.run("0.0.0.0") + + # while True: + # choice = torch.cuda.LongTensor(1) + # torch.distributed.broadcast(choice, 0) + # if choice[0].item() == 0: + # generate_and_post_process(model) + + + # Evaluation data iterator + batches = load_evaluation_data(args) + + timers('generate').start() + # Generate + samples = [ + generate_dict + for batch in batches + for generate_dict in generate(batch) + ] + timers('generate').stop() + + elapsed = timers.timers['generate'].elapsed(reset=False) + num_tokens = TOKENS_TO_GENERATE * NUM_BATCHES * BATCH_SIZE + print(f"{elapsed * 1000 / (num_tokens)} ms per token") + timers.log(['generate']) + if args.transformer_timers: + timers.log(["Transformer forward"]) + print("DONE") + + # Write results to file + # if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0: + # write_jsonl(args.output_file.format(iteration), samples) +