Skip to content

Commit 1b17354

Browse files
authored
Merge pull request #44 from JuliaGPU/vc/unsafe_atomics
Use UnsafeAtomics to fix race in accumulate
2 parents 2da8696 + 25d035b commit 1b17354

File tree

5 files changed

+18
-55
lines changed

5 files changed

+18
-55
lines changed

Project.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,19 @@ ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
88
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
99
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1010
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
11+
UnsafeAtomics = "013be700-e6cd-48c3-b4a1-df204f14c38f"
1112

1213
[weakdeps]
13-
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
1414
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
1515

1616
[extensions]
17-
AcceleratedKernelsMetalExt = "Metal"
1817
AcceleratedKernelsoneAPIExt = "oneAPI"
1918

2019
[compat]
2120
ArgCheck = "2"
2221
GPUArraysCore = "0.2.0"
2322
KernelAbstractions = "0.9.34"
2423
Markdown = "1"
25-
Metal = "1"
26-
oneAPI = "1, 2"
24+
UnsafeAtomics = "0.3.0"
2725
julia = "1.10"
26+
oneAPI = "1, 2"

ext/AcceleratedKernelsMetalExt.jl

Lines changed: 0 additions & 41 deletions
This file was deleted.

src/AcceleratedKernels.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ module AcceleratedKernels
1414
using ArgCheck: @argcheck
1515
using GPUArraysCore: AbstractGPUArray, @allowscalar
1616
using KernelAbstractions
17+
import UnsafeAtomics
1718

1819

1920
# Exposed functions from upstream packages

src/accumulate/accumulate.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,7 @@ function _accumulate_impl!(
160160
dims::Union{Nothing, Int}=nothing,
161161
inclusive::Bool=true,
162162

163-
# FIXME: Switch back to `DecoupledLookback()` as the default algorithm
164-
# once https://github.com/JuliaGPU/AcceleratedKernels.jl/pull/44 is merged.
165-
alg::AccumulateAlgorithm=ScanPrefixes(),
163+
alg::AccumulateAlgorithm=DecoupledLookback(),
166164

167165
# CPU settings
168166
max_tasks::Int=Threads.nthreads(),

src/accumulate/accumulate_1d_gpu.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ end
169169
running_prefix = prefixes[iblock - 0x1 + 0x1]
170170
inspected_block = signed(typeof(iblock))(iblock) - 0x2
171171
while inspected_block >= 0x0
172-
173172
# Opportunistic: a previous block finished everything
174-
if flags[inspected_block + 0x1] == ACC_FLAG_A
173+
if UnsafeAtomics.load(pointer(flags, inspected_block + 0x1), UnsafeAtomics.monotonic) == ACC_FLAG_A
174+
UnsafeAtomics.fence(UnsafeAtomics.acquire) # (fence before reading from v)
175175
# Previous blocks (except last) always have filled values in v, so index is inbounds
176176
running_prefix = op(running_prefix, v[(inspected_block + 0x1) * block_size * 0x2])
177177
break
@@ -194,11 +194,17 @@ end
194194
end
195195

196196
# Set flag for "aggregate of all prefixes up to this block finished"
197-
@synchronize() # This is needed so that the flag is not set before copying into v, but
198-
# there should be better memory fences to guarantee ordering without
199-
# thread synchronization...
197+
# There are two synchronization concerns here:
198+
# 1. Withing a group we want to ensure that all writed to `v` have occured before setting the flag.
199+
# 2. Between groups we need to use a fence and atomic load/store to ensure that memory operations are not re-ordered
200+
@synchronize() # within-block
201+
# Note: This fence is needed to ensure that the flag is not set before copying into v.
202+
# See https://doc.rust-lang.org/std/sync/atomic/fn.fence.html
203+
# for more details.
204+
# We use the happens-before relation between stores to `v` and the store to `flags`.
205+
UnsafeAtomics.fence(UnsafeAtomics.release)
200206
if ithread == 0x0
201-
flags[iblock + 0x1] = ACC_FLAG_A
207+
UnsafeAtomics.store!(pointer(flags, iblock + 0x1), convert(eltype(flags), ACC_FLAG_A), UnsafeAtomics.monotonic)
202208
end
203209
end
204210

@@ -285,7 +291,7 @@ function accumulate_1d!(
285291
end
286292

287293
if isnothing(temp_flags)
288-
flags = similar(v, Int8, num_blocks)
294+
flags = similar(v, UInt8, num_blocks)
289295
else
290296
@argcheck eltype(temp_flags) <: Integer
291297
@argcheck length(temp_flags) >= num_blocks

0 commit comments

Comments
 (0)