169169 running_prefix = prefixes[iblock - 0x1 + 0x1 ]
170170 inspected_block = signed (typeof (iblock))(iblock) - 0x2
171171 while inspected_block >= 0x0
172-
173172 # Opportunistic: a previous block finished everything
174- if flags[inspected_block + 0x1 ] == ACC_FLAG_A
173+ if UnsafeAtomics. load (pointer (flags, inspected_block + 0x1 ), UnsafeAtomics. monotonic) == ACC_FLAG_A
174+ UnsafeAtomics. fence (UnsafeAtomics. acquire) # (fence before reading from v)
175175 # Previous blocks (except last) always have filled values in v, so index is inbounds
176176 running_prefix = op (running_prefix, v[(inspected_block + 0x1 ) * block_size * 0x2 ])
177177 break
@@ -194,11 +194,17 @@ end
194194 end
195195
196196 # Set flag for "aggregate of all prefixes up to this block finished"
197- @synchronize () # This is needed so that the flag is not set before copying into v, but
198- # there should be better memory fences to guarantee ordering without
199- # thread synchronization...
197+ # There are two synchronization concerns here:
198+ # 1. Withing a group we want to ensure that all writed to `v` have occured before setting the flag.
199+ # 2. Between groups we need to use a fence and atomic load/store to ensure that memory operations are not re-ordered
200+ @synchronize () # within-block
201+ # Note: This fence is needed to ensure that the flag is not set before copying into v.
202+ # See https://doc.rust-lang.org/std/sync/atomic/fn.fence.html
203+ # for more details.
204+ # We use the happens-before relation between stores to `v` and the store to `flags`.
205+ UnsafeAtomics. fence (UnsafeAtomics. release)
200206 if ithread == 0x0
201- flags[ iblock + 0x1 ] = ACC_FLAG_A
207+ UnsafeAtomics . store! ( pointer ( flags, iblock + 0x1 ), convert ( eltype (flags), ACC_FLAG_A), UnsafeAtomics . monotonic)
202208 end
203209end
204210
@@ -285,7 +291,7 @@ function accumulate_1d!(
285291 end
286292
287293 if isnothing (temp_flags)
288- flags = similar (v, Int8 , num_blocks)
294+ flags = similar (v, UInt8 , num_blocks)
289295 else
290296 @argcheck eltype (temp_flags) <: Integer
291297 @argcheck length (temp_flags) >= num_blocks
0 commit comments