Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
99d11e1
add basic sampling code
Jan 13, 2021
8bb3322
Merge remote-tracking branch 'origin/main' into sample
Jan 13, 2021
b9fd039
add prediction input / output fns
Jan 13, 2021
cf76c6c
get sample_autoregressive working
ConnorJL Jan 14, 2021
c7ff6c4
truncate text tokens properly
ConnorJL Jan 14, 2021
4d51fd9
log model params to tensorboard
ConnorJL Jan 14, 2021
346871f
add vae decoding and write to jpeg
kingoflolz Jan 14, 2021
d13c330
unshift image outputs at decode time
kingoflolz Jan 14, 2021
f8a7449
dirty hack to use vae decoder params when training dalle
kingoflolz Jan 14, 2021
ff56d12
Move initialize_vae_weights to after lowering
leogao2 Jan 17, 2021
4c4e0e0
fix vae checkpoint load in training
ConnorJL Jan 17, 2021
2c14bde
fix parameter count logging
ConnorJL Jan 18, 2021
130c26e
fix image vocab size
ConnorJL Jan 18, 2021
4652ef2
revert to separate embeddings for image and text
ConnorJL Jan 19, 2021
67247cf
Fix masking
leogao2 Jan 19, 2021
a0a2828
Ignore text tokens in loss computation
leogao2 Jan 20, 2021
ca23b85
Fix slicing for mtf
leogao2 Jan 20, 2021
e126b79
Fix sampling
leogao2 Jan 20, 2021
ddeb74d
Implement incremental logits mask
leogao2 Jan 20, 2021
e7eb459
Fix typo
leogao2 Jan 20, 2021
29f3006
revert changes to sample.py
ConnorJL Jan 21, 2021
b69ff72
add mask to bias op
ConnorJL Jan 21, 2021
7e3b6ff
update mask. (still not working :( )
ConnorJL Jan 21, 2021
34d5326
mask changes
ConnorJL Jan 31, 2021
05dec26
fix label shifting
ConnorJL Feb 1, 2021
6c39bdf
add weight decay Adam
ConnorJL Feb 1, 2021
69d7ef9
add eval steps
ConnorJL Feb 1, 2021
d564853
add tests
lucidrains Apr 4, 2021
587418d
make sampling work with old <eos> method
lucidrains Apr 5, 2021
fdef8f4
add logits mask back in
Apr 5, 2021
00dbf6b
fix sample call
Apr 5, 2021
106d34b
fix initial positions and tests
lucidrains Apr 5, 2021
4265932
give image axial positional embedding to see if it helps
lucidrains Apr 6, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions

name: Tests

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.7, 3.8]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
pytest -s test.py
10 changes: 5 additions & 5 deletions configs/dalle_coco.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@
},
"train_batch_size": 128,
"eval_batch_size": 128,
"predict_batch_size": 128,
"predict_batch_size": 16,
"steps_per_checkpoint": 5000,
"iterations": 1000,
"train_steps": 100000,
"predict_steps": 0,
"eval_steps": 0,
"n_channels": 3,
"bf_16": false,
"bf_16": true,
"recompute_grad": true,
"lr": 0.0001,
"model_path": "gs://neo-models/dalle_coco/",
"model_path": "gs://neo-models/dalle_coco_sample/",
"mesh_shape": "data:16,model:2",
"layout": "batch_dim:data",
"layout": "batch_dim:data,embed_dim:model",
"n_embd": 1024,
"text_vocab_size": 50258,
"image_vocab_size": 512,
"image_vocab_size": 2048,
"text_seq_len": 256,
"n_layers": 12,
"n_heads": 8,
Expand Down
3 changes: 2 additions & 1 deletion src/dalle_mtf/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .models import DALLE, DiscreteVAE
from .models import DALLE, DiscreteVAE
from .sample import sample_autoregressive
161 changes: 120 additions & 41 deletions src/dalle_mtf/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections import defaultdict
import math

from .ops import pad, exists, get_variable_dtype
from .ops import pad, exists, get_variable_dtype, expand_tile, mask_to_bias
from .layers import gumbel_softmax, mse_loss, norm


Expand Down Expand Up @@ -140,29 +140,34 @@ def forward(self, features, return_recon_loss=False, return_logits=False, hard_g

class DALLE:

def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq_len=256, image_seq_len=1024,
def __init__(self, mesh, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq_len=256, image_seq_len=1024,
n_layers=6, n_heads=8, batch_size=32, bf_16=True, attn_mask=None, mode="train",
is_incremental_inference=False, context=None, loss_fn=None, params=None, eos_token_id=None,
activation_fn=None):

self.mesh = mesh
self.n_embd = n_embd
self.text_vocab_size = text_vocab_size
self.image_vocab_size = image_vocab_size
self.text_seq_len = text_seq_len
self.image_seq_len = image_seq_len
self.total_seq_dim = text_seq_len + image_seq_len
self.total_seq_len = text_seq_len + image_seq_len
self.n_layers = n_layers
self.n_heads = n_heads
self.attn_mask = attn_mask
self.logits_mask = None
self.total_tokens = text_vocab_size + image_vocab_size + 1 # extra for EOS
self.eos_token_id = self.total_tokens - 1 if eos_token_id is None else eos_token_id
self.eos_token_id = self.total_tokens - 1 if eos_token_id is None else eos_token_id
self.dimensions = {"embed_dim": mtf.Dimension("embed_dim", n_embd),
"text_vocab_dim": mtf.Dimension("vocab_dim", text_vocab_size),
"image_vocab_dim": mtf.Dimension("vocab_dim", image_vocab_size),
"final_vocab_dim": mtf.Dimension("vocab_dim", self.total_tokens),
"total_seq_dim": mtf.Dimension("total_seq_dim", self.total_seq_dim),
"embed_seq_dim": mtf.Dimension("embed_seq_dim", self.total_seq_dim),
"memory_len_dim": mtf.Dimension("memory_len_dim", self.total_seq_dim),
"text_sequence_dim": mtf.Dimension("sequence_dim", text_seq_len),
"image_sequence_dim": mtf.Dimension("sequence_dim", image_seq_len),
"total_seq_dim": mtf.Dimension("sequence_dim", self.total_seq_len),
"text_embed_seq_dim": mtf.Dimension("text_embed_seq_dim", text_seq_len),
"image_embed_seq_dim": mtf.Dimension("image_embed_seq_dim", image_seq_len),
"memory_len_dim": mtf.Dimension("memory_len_dim", self.total_seq_len),
"heads_dim": mtf.Dimension("heads", n_heads),
"kv_dim": mtf.Dimension("kv_dim", n_embd // n_heads),
"batch_dim": mtf.Dimension("batch_dim", batch_size)}
Expand All @@ -179,13 +184,17 @@ def __init__(self, n_embd, text_vocab_size=12800, image_vocab_size=512, text_seq
self.activation_fn = activation_fn
if self.is_incremental_inference:
assert self.context is not None, "must have context in incremental inference"
assert self.context['mode'] == 'incremental'
if params is None: # extra params
params = {}
self.params = defaultdict(lambda: None, params)

def embedding(self, x, name):
embd_dim = self.dimensions["embed_dim"]
vocab_dim = self.dimensions["final_vocab_dim"]
if "text" in name:
vocab_dim = self.dimensions["text_vocab_dim"]
else:
vocab_dim = self.dimensions["image_vocab_dim"]
with tf.variable_scope(name):
wte = mtf.get_variable(x.mesh, "wte",
mtf.Shape([vocab_dim, embd_dim]),
Expand All @@ -200,16 +209,61 @@ def embedding(self, x, name):
x = mtf.dropout(x, rate=embed_dropout, name="wte_dropout")
return x

def positional_embedding(self, x, name):
with tf.variable_scope(name):
def axial_positional_embedding(self, mesh):
with tf.variable_scope("axial_emb"):
axial_dim_side = int(sqrt(self.image_seq_len))

embd_dim = self.dimensions["embed_dim"]
axial_dim = mtf.Dimension("axial_dim", self.image_seq_len)

dim_axials = [mtf.Dimension(f"axial_dim_{i}", t) for i, t in enumerate((axial_dim_side, axial_dim_side))]

axial_wpe_1 = mtf.get_variable(mesh, "axial_wpe_1", mtf.Shape([dim_axials[0], embd_dim]),
initializer=tf.random_normal_initializer(stddev=0.01),
master_dtype=self.variable_dtype.master_dtype,
slice_dtype=self.variable_dtype.slice_dtype,
activation_dtype=self.variable_dtype.activation_dtype)

axial_wpe_2 = mtf.get_variable(mesh, "axial_wpe_2", mtf.Shape([dim_axials[1], embd_dim]),
initializer=tf.random_normal_initializer(stddev=0.01),
master_dtype=self.variable_dtype.master_dtype,
slice_dtype=self.variable_dtype.slice_dtype,
activation_dtype=self.variable_dtype.activation_dtype)

axial_wpe_1, axial_wpe_2 = map(lambda t: mtf.broadcast(t, [dim_axials[0], dim_axials[1], embd_dim]),
(axial_wpe_1, axial_wpe_2))
wpe = (axial_wpe_1 + axial_wpe_2) / 2

wpe = mtf.reshape(wpe, [axial_dim, embd_dim])

wpe = mtf.replace_dimensions(wpe, wpe.shape[0], self.dimensions["image_embed_seq_dim"])
return wpe

def image_positional_embedding(self, x):
sequence_dim = self.dimensions["image_sequence_dim"]
with tf.variable_scope("image_pos_emb"):
# Positional embedding
wpe = self.axial_positional_embedding(x.mesh)
position_indices = mtf.range(x.mesh, sequence_dim, tf.int64) if not \
self.is_incremental_inference else (self.context.position - 1)
pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0])
embed_dropout = self.params.get("embed_dropout", 0)
if embed_dropout > 0 and self.mode == "train":
pos_emb = mtf.dropout(pos_emb, rate=embed_dropout, name="wte_dropout")
x += pos_emb
return x

def text_positional_embedding(self, x):
sequence_dim = self.dimensions["text_sequence_dim"]
with tf.variable_scope("text_pos_emb"):
# Positional embedding
wpe = mtf.get_variable(x.mesh, "wpe",
mtf.Shape([self.dimensions["embed_seq_dim"], self.dimensions["embed_dim"]]),
mtf.Shape([self.dimensions["text_embed_seq_dim"], self.dimensions["embed_dim"]]),
initializer=tf.random_normal_initializer(stddev=0.01),
master_dtype=self.variable_dtype.master_dtype,
slice_dtype=self.variable_dtype.slice_dtype,
activation_dtype=self.variable_dtype.activation_dtype)
position_indices = mtf.range(x.mesh, self.dimensions["total_seq_dim"], tf.int64) if not \
position_indices = mtf.range(x.mesh, sequence_dim, tf.int64) if not \
self.is_incremental_inference else (self.context.position - 1)
pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0])
embed_dropout = self.params.get("embed_dropout", 0)
Expand All @@ -225,10 +279,22 @@ def get_attn_mask(self, mesh, nd, ns):
i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j))
self.attn_mask = mtf.cast(mtf.less(i, j), self.variable_dtype.activation_dtype) * -1e10
return self.attn_mask

def set_logits_mask(self, tf_mask):
mask_shape = mtf.Shape([self.dimensions['total_seq_dim'], self.dimensions['final_vocab_dim']])
mtf_mask = mtf.import_fully_replicated(self.mesh, tf_mask, mask_shape)
new_shape = mtf.Shape([self.dimensions['batch_dim'], self.dimensions['total_seq_dim'], self.dimensions['final_vocab_dim']])
mtf_mask = mtf.broadcast(mtf_mask, new_shape)
self.logits_mask = mtf_mask

def attention(self, x, n_state, mask, attention_type="global", name="attn"):
# x :: [batch, seq, n_embd]
batch_dim, seq_dim, embd_dim = x_shape = x.shape
if not self.is_incremental_inference:
# x :: [batch, seq, n_embd]
batch_dim, seq_dim, embd_dim = x_shape = x.shape
else:
batch_dim, embd_dim = x_shape = x.shape
seq_dim = self.dimensions['total_seq_dim']

assert n_state.size % self.n_heads == 0, "n_state must be divisible by n_heads"
with tf.variable_scope(name):
# Compute attention inputs
Expand All @@ -254,25 +320,7 @@ def attention(self, x, n_state, mask, attention_type="global", name="attn"):
self.context.record_new_states([k, v])

with tf.variable_scope("attention"):
if attention_type == "local":
# `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights.
radius = self.params.get("local_attention_radius", 256)
if self.is_incremental_inference:
q *= one_hot
a = mtf_transformer.attention.local_attention_1d(
q, k, v,
length_dim=k.shape[1],
key_dim=self.dimensions["kv_dim"],
value_dim=self.dimensions["kv_dim"],
radius=radius,
length_dim_num_splits=1,
fully_autoregressive=True,
attention_kwargs={},
)
if self.is_incremental_inference:
a = mtf.gather(a, self.context.position - 1, seq_dim)

elif attention_type == "global":
if attention_type == "global":
if exists(mask):
if not self.is_incremental_inference:
broadcasted_mask = mtf.broadcast(mask,
Expand Down Expand Up @@ -347,7 +395,8 @@ def transformer(self, x, mask):

def _loss(self, logits, labels):
with tf.variable_scope("loss_final"):
loss_batch = self.loss_fn(logits=logits, targets=labels,
loss_batch = self.loss_fn(logits =mtf.slice(logits, begin=self.text_seq_len, size=self.image_seq_len, slice_dim_name="sequence_dim"),
targets=mtf.slice(labels, begin=self.text_seq_len, size=self.image_seq_len, slice_dim_name="sequence_dim"),
vocab_dim=logits.shape[-1], z_loss=0.0)

with tf.variable_scope("reduce_mean_final"):
Expand Down Expand Up @@ -392,22 +441,52 @@ def to_logits(self, x):
with tf.variable_scope("to_logits"):
logits = self.linear(self.layer_norm(x), self.dimensions["final_vocab_dim"], name="linear_out")
# Go to full precision for the logits
if self.is_incremental_inference:
# add seq dim in inference mode
logits = expand_tile(logits, mtf.Dimension("sequence_dim", 1), axis=1)
return mtf.cast(logits, tf.float32)


def shift_labels(self, labels):
labels = pad(labels, [0, 1], dim_name="sequence_dim", pad_value=self.eos_token_id)
labels = mtf.slice(labels, 1, self.total_seq_len, "sequence_dim")
return labels


def forward(self, features, return_loss=True, return_logits=False):
inputs = features["tokens"]
tokens = self.positional_embedding(self.embedding(inputs, "embedding"), "positional_embedding")
if features.get('text_inputs') is not None:
text = features["text_inputs"]
text_emb = self.text_positional_embedding(self.embedding(text, "text_embd"))
else:
assert self.is_incremental_inference
image = features.get("image_inputs", None)
if not self.is_incremental_inference:
image_emb = self.image_positional_embedding(self.embedding(image, "image_embd"))
tokens = mtf.concat([text_emb, image_emb], concat_dim_name="sequence_dim") # [batch, seq, n_embd]
else:
# reshape inputs if in inference mode
image = mtf.gather(image, self.context.position - 1, self.dimensions["image_sequence_dim"])
image = mtf.reshape(image, [self.dimensions["batch_dim"]])
tokens = self.image_positional_embedding(self.embedding(image, "image_embd"))

mask = self.get_attn_mask(tokens.mesh, tokens.shape[1], self.dimensions["memory_len_dim"])
mask = self.get_attn_mask(tokens.mesh, self.dimensions["total_seq_dim"], self.dimensions["memory_len_dim"])
out = self.transformer(tokens, mask=mask)
logits = self.to_logits(out)
if self.is_incremental_inference:
logits_mask = mtf.gather(self.logits_mask, self.context.position + self.text_seq_len - 1, self.logits_mask.shape[1])
logits_mask = expand_tile(logits_mask, mtf.Dimension("sequence_dim", 1), axis=1)
else:
logits_mask = self.logits_mask
logits += mtf.cast(logits_mask, logits.dtype)

if not return_loss:
logits = mtf.cast(logits, self.variable_dtype.master_dtype)
return logits

labels = pad(inputs, [0, 1], dim_name="total_seq_dim", pad_value=self.eos_token_id)
indices = mtf.range(labels.mesh, mtf.Dimension("range", labels.shape[1].size - 1), tf.int32, name="labels_indices") + 1
labels = mtf.gather(labels, indices, dim=labels.shape[1])
labels = mtf.rename_dimension(labels, "range", "total_seq_dim")
assert exists(image), 'when training, image must be supplied'
offset_image = image + self.text_vocab_size
labels = mtf.concat([text, offset_image], concat_dim_name="sequence_dim")
labels = self.shift_labels(labels)
loss, loss_batch = self._loss(logits, labels)
if return_logits and return_loss:
# Cast back to checkpoint dtype
Expand Down
18 changes: 18 additions & 0 deletions src/dalle_mtf/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,21 @@ def get_variable_dtype(bf_16=True):
return mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32, activation_dtype=tf.bfloat16)
else:
return mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32)

def expand_tile(value, newdim, axis=0):
"""Add a new axis of given size."""
new_shape = value.shape.dims
new_shape.insert(axis, newdim)
return mtf.broadcast(value, new_shape) # shape.dims gets us a list which we need in order to concat

def mask_to_bias(visible, dtype):
"""Convert a boolean visibility mask to an attention bias.
The returned Tensor has large negative values in positions where
visible=False.
Args:
visible: a boolean Tensor
dtype: a dtype
Returns:
a Tensor with the given dtype and the same shape as "visible"
"""
return mtf.cast(mtf.logical_not(visible), dtype) * -1e9
Loading