@@ -5,43 +5,48 @@ for CSPRNG.
55
66module ChaCha
77
8- using Core. Intrinsics: llvmcall
98using CUDA
9+ using SIMD
1010using StaticArrays
1111
1212# ChaCha block size is 32 * 16 bits = 64 bytes
1313const CHACHA_BLOCK_SIZE_U32 = 16
1414const CHACHA_BLOCK_SIZE = div (32 * 16 , 8 )
1515
1616@inline lrot32 (x, n) = (x << n) | (x >> (32 - n))
17- @inline lrot32 (x:: UInt32 , n:: UInt32 ) = llvmcall (
18- ("""
19- declare i32 @llvm.fshl.i32(i32, i32, i32)
20- define i32 @entry(i32, i32, i32) #0 {
21- 3:
22- %res = call i32 @llvm.fshl.i32(i32 %0, i32 %0, i32 %1)
23- ret i32 %res
24- }
25- attributes #0 = { alwaysinline }
26- """ , " entry" ), UInt32, Tuple{UInt32, UInt32}, x, n)
27-
28- @inline function _QR! (x, a, b, c, d)
29- @inbounds begin
30- x[a] += x[b]; x[d] ⊻= x[a]; x[d] = lrot32 (x[d], UInt32 (16 ))
31- x[c] += x[d]; x[b] ⊻= x[c]; x[b] = lrot32 (x[b], UInt32 (12 ))
32- x[a] += x[b]; x[d] ⊻= x[a]; x[d] = lrot32 (x[d], UInt32 (8 ))
33- x[c] += x[d]; x[b] ⊻= x[c]; x[b] = lrot32 (x[b], UInt32 (7 ))
17+ @inline lrot32 (x:: Union{Vec,UInt32} , n) = bitrotate (x, n)
18+
19+ @inline @generated function rotatevector (x:: Vec{N,T} , :: Val{M} ) where {N,T,M}
20+ rotation = circshift (0 : 3 , M)
21+ rotation = repeat (rotation, N ÷ 4 )
22+ rotation += 4 * ((0 : N- 1 ) .÷ 4 )
23+ rotation = Val (Tuple (rotation))
24+ :(shufflevector (x, $ rotation))
25+ end
26+
27+ macro _QR! (a, b, c, d)
28+ quote
29+ $ (esc (a)) += $ (esc (b)); $ (esc (d)) ⊻= $ (esc (a)); $ (esc (d)) = lrot32 ($ (esc (d)), 16 );
30+ $ (esc (c)) += $ (esc (d)); $ (esc (b)) ⊻= $ (esc (c)); $ (esc (b)) = lrot32 ($ (esc (b)), 12 );
31+ $ (esc (a)) += $ (esc (b)); $ (esc (d)) ⊻= $ (esc (a)); $ (esc (d)) = lrot32 ($ (esc (d)), 8 );
32+ $ (esc (c)) += $ (esc (d)); $ (esc (b)) ⊻= $ (esc (c)); $ (esc (b)) = lrot32 ($ (esc (b)), 7 );
33+
34+ $ (esc (a)), $ (esc (b)), $ (esc (c)), $ (esc (d))
3435 end
3536end
3637
3738@inline function store_u64! (x:: AbstractVector{UInt32} , u:: UInt64 , idx)
38- x[idx] = UInt32 (u & 0xffffffff )
39- x[idx+ 1 ] = UInt32 ((u >> 32 ) & 0xffffffff )
39+ @inbounds begin
40+ x[idx] = UInt32 (u & 0xffffffff )
41+ x[idx+ 1 ] = UInt32 ((u >> 32 ) & 0xffffffff )
42+ end
4043end
4144
4245@inline function add_u64! (x:: AbstractVector{UInt32} , u:: UInt64 , idx)
43- x[idx] += UInt32 (u & 0xffffffff )
44- x[idx+ 1 ] += UInt32 ((u >> 32 ) & 0xffffffff )
46+ @inbounds begin
47+ x[idx] += UInt32 (u & 0xffffffff )
48+ x[idx+ 1 ] += UInt32 ((u >> 32 ) & 0xffffffff )
49+ end
4550end
4651
4752#=
@@ -144,40 +149,89 @@ function chacha_blocks!(
144149 nblocks = 1 ;
145150 doublerounds = 10 ,
146151)
147- for i ∈ 1 : nblocks
148- block_start = CHACHA_BLOCK_SIZE_U32 * (i - 1 ) + 1
149- block_end = block_start + CHACHA_BLOCK_SIZE_U32 - 1
150- state = view (buffer, block_start: block_end)
151-
152- _chacha_set_initial_state! (state, key, nonce, counter, 1 )
153-
154- # Perform alternating rounds of columnar
155- # quarter-rounds and diagonal quarter-rounds
156- for i = 1 : doublerounds
157- # Columnar rounds
158- _QR! (state, 1 , 5 , 9 , 13 )
159- _QR! (state, 2 , 6 , 10 , 14 )
160- _QR! (state, 3 , 7 , 11 , 15 )
161- _QR! (state, 4 , 8 , 12 , 16 )
162-
163- # Diagonal rounds
164- _QR! (state, 1 , 6 , 11 , 16 )
165- _QR! (state, 2 , 7 , 12 , 13 )
166- _QR! (state, 3 , 8 , 9 , 14 )
167- _QR! (state, 4 , 5 , 10 , 15 )
168- end
169-
170- # Finish by adding the initial state back to
171- # the original state, so that the operations
172- # are no longer invertible
173- _chacha_add_initial_state! (state, key, nonce, counter, 1 )
152+ block_start = 1
153+
154+ # We compute as many blocks of output as possible with 512-bit
155+ # SIMD vectorization
156+ for i ∈ 1 : 4 : nblocks- 3
157+ block_start, counter = _chacha_blocks! (
158+ buffer, block_start, key, nonce, counter, doublerounds, Val (4 )
159+ )
160+ end
174161
175- counter += 1
162+ # The remaining blocks are computed with 128-bit vectorization
163+ for i ∈ 1 : (nblocks % 4 )
164+ block_start, counter = _chacha_blocks! (
165+ buffer, block_start, key, nonce, counter, doublerounds, Val (1 )
166+ )
176167 end
177168
178169 counter
179170end
180171
172+ # Compute the ChaCha block function with N * 128-bit SIMD vectorization
173+ #
174+ # Reference: https://eprint.iacr.org/2013/759.pdf
175+ @inline function _chacha_blocks! (
176+ buffer:: AbstractVector{UInt32} , block_start, key, nonce, counter, doublerounds, :: Val{N}
177+ ) where N
178+ block_end = block_start + N * CHACHA_BLOCK_SIZE_U32 - 1
179+ @inbounds state = view (buffer, block_start: block_end)
180+
181+ for i = 0 : N- 1
182+ _chacha_set_initial_state! (state, key, nonce, counter + i, i * CHACHA_BLOCK_SIZE_U32 + 1 )
183+ end
184+
185+ _chacha_rounds! (state, doublerounds, Val (N))
186+
187+ for i = 0 : N- 1
188+ _chacha_add_initial_state! (state, key, nonce, counter + i, i * CHACHA_BLOCK_SIZE_U32 + 1 )
189+ end
190+
191+ block_end + 1 , counter + N
192+ end
193+
194+
195+ @inline @generated function _chacha_rounds! (state, doublerounds, :: Val{N} ) where N
196+ # Perform alternating rounds of columnar
197+ # quarter-rounds and diagonal quarter-rounds
198+ lane = (1 , 2 , 3 , 4 )
199+ lane = repeat (1 : 4 , N)
200+ lane += 16 * ((0 : 4 * N- 1 ) .÷ 4 )
201+ lane = Tuple (lane)
202+
203+ idx0 = Vec (lane)
204+ idx1 = Vec (lane .+ 4 )
205+ idx2 = Vec (lane .+ 8 )
206+ idx3 = Vec (lane .+ 12 )
207+
208+ quote
209+ @inbounds begin
210+ v0 = vgather (state, $ idx0)
211+ v1 = vgather (state, $ idx1)
212+ v2 = vgather (state, $ idx2)
213+ v3 = vgather (state, $ idx3)
214+
215+ for i = 1 : doublerounds
216+ v0, v1, v2, v3 = @_QR! (v0, v1, v2, v3)
217+ v1 = rotatevector (v1, Val (- 1 ))
218+ v2 = rotatevector (v2, Val (- 2 ))
219+ v3 = rotatevector (v3, Val (- 3 ))
220+
221+ v0, v1, v2, v3 = @_QR! (v0, v1, v2, v3)
222+ v1 = rotatevector (v1, Val (1 ))
223+ v2 = rotatevector (v2, Val (2 ))
224+ v3 = rotatevector (v3, Val (3 ))
225+ end
226+
227+ vscatter (v0, state, $ idx0)
228+ vscatter (v1, state, $ idx1)
229+ vscatter (v2, state, $ idx2)
230+ vscatter (v3, state, $ idx3)
231+ end
232+ end
233+ end
234+
181235function chacha_blocks! (
182236 buffer:: CuArray , key, nonce:: UInt64 , counter:: UInt64 , nblocks = 1 ; doublerounds = 10
183237)
@@ -204,7 +258,7 @@ function _cuda_chacha_rounds!(state, doublerounds)
204258
205259 # Only operate on a slice of the state corresponding to
206260 # the thread block
207- state_slice = view (state, block+ 1 : block+ 16 )
261+ slice = view (state, block+ 1 : block+ 16 )
208262
209263 # Pre-compute the indices that this thread will use to
210264 # perform its diagonal rounds
@@ -219,11 +273,11 @@ function _cuda_chacha_rounds!(state, doublerounds)
219273 # Each thread in the same block runs its rounds in parallel
220274 for _ = 1 : doublerounds
221275 # Columnar rounds
222- _QR! (state_slice, i, i + 4 , i + 8 , i + 12 )
276+ @ _QR! (slice[i], slice[i + 4 ], slice[i + 8 ], slice[i + 12 ] )
223277 CUDA. threadfence_block ()
224278
225279 # Diagonal rounds
226- _QR! (state_slice, dgc1, dgc2, dgc3, dgc4)
280+ @ _QR! (slice[ dgc1], slice[ dgc2], slice[ dgc3], slice[ dgc4] )
227281 CUDA. threadfence_block ()
228282 end
229283
0 commit comments