Skip to content

Commit d7075c9

Browse files
committed
Removes dropout-related fields from Flash_fwd_params
Cleans up the parameter structure by removing unused dropout probability fields, scaling factors, random state management, and rotary interleaving flag. Moves softcap field to improve struct organization and readability.
1 parent a5b1b49 commit d7075c9

File tree

1 file changed

+1
-19
lines changed

1 file changed

+1
-19
lines changed

csrc/src/flash.h

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par
9393
// The scaling factors for the kernel.
9494
float scale_softmax;
9595
float scale_softmax_log2;
96+
float softcap;
9697

9798
// array of length b+1 holding starting offset of each sequence.
9899
int * __restrict__ cu_seqlens_q;
@@ -128,32 +129,13 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par
128129
index_t block_table_batch_stride;
129130
int page_block_size;
130131

131-
// The dropout probability (probability of keeping an activation).
132-
float p_dropout;
133-
// uint32_t p_dropout_in_uint;
134-
// uint16_t p_dropout_in_uint16_t;
135-
uint8_t p_dropout_in_uint8_t;
136-
137-
// Scale factor of 1 / (1 - p_dropout).
138-
float rp_dropout;
139-
float scale_softmax_rp_dropout;
140-
float softcap;
141-
142-
// Random state.
143-
at::PhiloxCudaState philox_args;
144-
145-
// Pointer to the RNG seed (idx 0) and offset (idx 1).
146-
uint64_t * rng_state;
147-
148132
bool is_bf16;
149133
bool is_causal;
150134

151135
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
152136
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
153137
bool is_seqlens_k_cumulative;
154138

155-
bool is_rotary_interleaved;
156-
157139
int num_splits; // For split-KV version
158140

159141
bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].

0 commit comments

Comments
 (0)