diff --git a/flux.hpp b/flux.hpp index 95927f8b..bea12a11 100644 --- a/flux.hpp +++ b/flux.hpp @@ -1266,7 +1266,8 @@ namespace Flux { set_backend_tensor_data(mod_index_arange, mod_index_arange_vec.data()); } y = to_backend(y); - + float current_timestep = ggml_get_f32_1d(timesteps, 0); + LOG_DEBUG("current_timestep %f", current_timestep); timesteps = to_backend(timesteps); if (flux_params.guidance_embed || flux_params.is_chroma) { guidance = to_backend(guidance); @@ -1275,6 +1276,30 @@ namespace Flux { ref_latents[i] = to_backend(ref_latents[i]); } + // get use_yarn, use_ntk and use_dype from env for now (TODO: add args) + // Env value could be one of yarn, dy_yarn, ntk or dy_ntk, (anything else means disabled) + const char* env_value = getenv("FLUX_ROPE"); + bool use_yarn = false; + bool use_dype = false; + bool use_ntk = false; + if (env_value != nullptr) { + if (strcmp(env_value, "YARN") == 0) { + LOG_DEBUG("Using YARN RoPE"); + use_yarn = true; + } else if (strcmp(env_value, "DY_YARN") == 0) { + LOG_DEBUG("Using DY YARN RoPE"); + use_yarn = true; + use_dype = true; + } else if (strcmp(env_value, "NTK") == 0) { + LOG_DEBUG("Using NTK RoPE"); + use_ntk = true; + } else if (strcmp(env_value, "DY_NTK") == 0) { + LOG_DEBUG("Using DY NTK RoPE"); + use_ntk = true; + use_dype = true; + } + } + pe_vec = Rope::gen_flux_pe(x->ne[1], x->ne[0], flux_params.patch_size, @@ -1283,7 +1308,11 @@ namespace Flux { ref_latents, increase_ref_index, flux_params.theta, - flux_params.axes_dim); + flux_params.axes_dim, + use_yarn, + use_dype, + use_ntk, + current_timestep); int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; // LOG_DEBUG("pos_len %d", pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); diff --git a/rope.hpp b/rope.hpp index bd1dfad5..7d91d6b5 100644 --- a/rope.hpp +++ b/rope.hpp @@ -71,6 +71,115 @@ namespace Rope { return result; } + float find_correction_factor(float num_rotations, int dim, float base, float max_position_embeddings) { + return (dim * std::log(max_position_embeddings / (num_rotations * 2 * 3.14159265358979323846))) / (2 * std::log(base)); + } + + std::pair find_correction_range(float low_ratio, float high_ratio, int dim, float base, float ori_max_pe_len) { + float low = std::floor(find_correction_factor(low_ratio, dim, base, ori_max_pe_len)); + float high = std::ceil(find_correction_factor(high_ratio, dim, base, ori_max_pe_len)); + return {std::max(0, static_cast(low)), std::min(dim / 2, static_cast(high))}; + } + + std::vector linear_ramp_mask(int min, int max, int dim) { + if (min == max) { + max += 0.001f; // Prevent singularity + } + std::vector ramp(dim); + for (int i = 0; i < dim; ++i) { + ramp[i] = std::max(0.0f, std::min(1.0f, static_cast(i - min) / (max - min))); + } + return ramp; + } + + __STATIC_INLINE__ std::vector> rope_ext( + const std::vector& pos, + int dim, + float theta = 10000.0f, + bool use_real = false, + float linear_factor = 1.0f, + float ntk_factor = 1.0f, + bool repeat_interleave_real = true, + bool yarn = false, + int max_pe_len = -1, + int ori_max_pe_len = 64, + bool dype = false, + float current_timestep = 1.0f) { + assert(dim % 2 == 0); + int half_dim = dim / 2; + + // Compute frequencies + std::vector freqs_base(half_dim); + std::vector freqs_linear(half_dim); + std::vector freqs_ntk(half_dim); + std::vector freqs(half_dim); + + if (yarn && max_pe_len > ori_max_pe_len) { + float beta_0 = 1.25f; + float beta_1 = 0.75f; + float gamma_0 = 16.0f; + float gamma_1 = 2.0f; + + float scale = std::max(1.0f, static_cast(max_pe_len) / ori_max_pe_len); + // d,t,s + float new_base = theta * std::pow(scale, half_dim / (half_dim - 1)); + for (int i = 0; i < half_dim; ++i) { + float exponent = static_cast(i) / half_dim; + freqs_base[i] = 1.0f / std::pow(theta, exponent); + freqs_linear[i] = 1.0f / (scale * std::pow(theta, exponent)); + freqs_ntk[i] = 1.0f / std::pow(new_base, exponent); + } + + if (dype) { + beta_0 = std::pow(beta_0, 2.0f * current_timestep * current_timestep); + beta_1 = std::pow(beta_1, 2.0f * current_timestep * current_timestep); + gamma_0 = std::pow(gamma_0, 2.0f * current_timestep * current_timestep); + gamma_1 = std::pow(gamma_1, 2.0f * current_timestep * current_timestep); + } + + // Apply correction range and linear ramp mask + auto [low, high] = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len); + auto mask = linear_ramp_mask(low, high, half_dim); + for (int i = 0; i < half_dim; ++i) { + freqs[i] = freqs_linear[i] * mask[i] + freqs_ntk[i] * (1.0f - mask[i]); + } + + // Apply gamma correction + auto [low_gamma, high_gamma] = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len); + auto mask_gamma = linear_ramp_mask(low_gamma, high_gamma, half_dim); + for (int i = 0; i < half_dim; ++i) { + freqs[i] = freqs[i] * mask_gamma[i] + freqs_base[i] * (1.0f - mask_gamma[i]); + } + } else { + float theta_ntk = theta * ntk_factor; + for (int i = 0; i < half_dim; ++i) { + float exponent = static_cast(i) / half_dim; + freqs[i] = 1.0f / std::pow(theta_ntk, exponent) / linear_factor; + } + } + + // Outer product of pos and freqs + std::vector> freqs_outer(pos.size(), std::vector(half_dim)); + for (size_t i = 0; i < pos.size(); ++i) { + for (int j = 0; j < half_dim; ++j) { + freqs_outer[i][j] = pos[i] * freqs[j]; + } + } + + std::vector> result; + result.resize(pos.size(), std::vector(half_dim * 4)); + for (size_t i = 0; i < pos.size(); ++i) { + for (int j = 0; j < half_dim; ++j) { + result[i][4 * j] = std::cos(freqs_outer[i][j]); // cos + result[i][4 * j + 1] = -std::sin(freqs_outer[i][j]); // -sin + result[i][4 * j + 2] = std::sin(freqs_outer[i][j]); // sin + result[i][4 * j + 3] = std::cos(freqs_outer[i][j]); // cos + } + } + + return result; + } + // Generate IDs for image patches and text __STATIC_INLINE__ std::vector> gen_txt_ids(int bs, int context_len) { return std::vector>(bs * context_len, std::vector(3, 0.0)); @@ -151,6 +260,53 @@ namespace Rope { return flatten(emb); } + std::vector embed_nd_ext( + const std::vector>& ids, + int bs, + float theta, + const std::vector& axes_dim, + bool yarn = false, + std::vector max_pe_len = {}, + std::vector ori_max_pe_len = {64, 64, 64}, + bool dype = false, + float current_timestep = 1.0f, + std::vector ntk_factors = {}) { + std::vector> trans_ids = transpose(ids); + size_t pos_len = ids.size() / bs; + int num_axes = axes_dim.size(); + + if (ntk_factors.size() == 0) { + ntk_factors = std::vector(num_axes, 1.0f); + } + if (max_pe_len.size() == 0) { + max_pe_len = std::vector(num_axes, -1); + } + + int emb_dim = 0; + for (int d : axes_dim) { + emb_dim += d; + } + + std::vector> emb(bs * pos_len, std::vector(emb_dim * 2, 0.0f)); + int offset = 0; + + for (int i = 0; i < num_axes; ++i) { + std::vector> rope_emb = rope_ext( + trans_ids[i], axes_dim[i], theta, false, 1.0f, ntk_factors[i], true, yarn, max_pe_len[i], ori_max_pe_len[i], dype, current_timestep); + + for (int b = 0; b < bs; ++b) { + for (size_t j = 0; j < pos_len; ++j) { + for (size_t k = 0; k < rope_emb[j].size(); ++k) { + emb[b * pos_len + j][offset + k] = rope_emb[j][k]; + } + } + } + offset += static_cast(axes_dim[i] * 2); + } + + return flatten(emb); + } + __STATIC_INLINE__ std::vector> gen_refs_ids(int patch_size, int bs, const std::vector& ref_latents, @@ -210,9 +366,63 @@ namespace Rope { const std::vector& ref_latents, bool increase_ref_index, int theta, - const std::vector& axes_dim) { + const std::vector& axes_dim, + bool use_yarn = false, + bool use_dype = false, + bool use_ntk = false, + float current_timestep = 1.0f) { + int base_resolution = 1024; + int base_patches_H = -1; + int base_patches_W = -1; + + // set it via environment variable for now (TODO: arg) + // could be either a single integer, or WxH + const char* env_base_resolution = getenv("FLUX_DYPE_BASE_RESOLUTION"); + if (env_base_resolution != nullptr) { + if (strchr(env_base_resolution, 'x') != nullptr) { + const char* x_pos = strchr(env_base_resolution, 'x'); + base_patches_H = atoi(x_pos + 1) / 16; + base_patches_W = atoi(env_base_resolution) / 16; + } else { + base_resolution = atoi(env_base_resolution); + } + } + // preserve aspect ratio of the input image + // base_patches_W = k*w, base_patches_H = k*h, base_patches_W*base_patches_H = base_resolution^2 + // => k = base_resolution / sqrt(w*h) + if (base_patches_H == -1) + base_patches_H = (base_resolution * h * sqrt(1.0f / (w * h))) / 16; + if (base_patches_W == -1) + base_patches_W = (base_resolution * w * sqrt(1.0f / (w * h))) / 16; + + // First dim is ref image, should not need any weird rope modifications since the max pos should stay very low. 1024 is a lot + std::vector base_patches = {1024, base_patches_H, base_patches_W}; std::vector> ids = gen_flux_ids(h, w, patch_size, bs, context_len, ref_latents, increase_ref_index); - return embed_nd(ids, bs, theta, axes_dim); + std::vector max_pos_vec = {}; + std::vector ntk_factor_vec = {}; + for (int i = 0; i < axes_dim.size(); i++) { + float max_pos_f = 0.0f; + for (const auto& row : ids) { + float val = row[i]; + if (val > max_pos_f) { + max_pos_f = val; + } + } + int max_pos = static_cast(max_pos_f) + 1; + max_pos_vec.push_back(max_pos); + float ntk_factor = 1.0f; + if (use_ntk) { + float base_ntk = pow((float)max_pos / base_patches[i], (float)axes_dim[i] / (axes_dim[i] - 2)); + ntk_factor = use_dype ? pow(base_ntk, 2.0f * current_timestep * current_timestep) : base_ntk; + ntk_factor = std::max(1.0f, ntk_factor); + } + ntk_factor_vec.push_back(ntk_factor); + } + if (use_yarn || use_ntk) { + return embed_nd_ext(ids, bs, theta, axes_dim, use_yarn, max_pos_vec, base_patches, use_dype, current_timestep, ntk_factor_vec); + } else { + return embed_nd(ids, bs, theta, axes_dim); + } } __STATIC_INLINE__ std::vector> gen_qwen_image_ids(int h,