Skip to content

Commit f7079dc

Browse files
authored
Merge pull request #56 from christiangnrd/algoselec
Tweak backend selection
2 parents 8f333b6 + 668c188 commit f7079dc

File tree

8 files changed

+13
-12
lines changed

8 files changed

+13
-12
lines changed

src/accumulate/accumulate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ function _accumulate_impl!(
167167
temp_flags::Union{Nothing, AbstractArray}=nothing,
168168
)
169169
if isnothing(dims)
170-
return if use_KA_algo(v, prefer_threads)
170+
return if use_gpu_algo(backend, prefer_threads)
171171
accumulate_1d_gpu!(
172172
op, v, backend, alg;
173173
init, neutral, inclusive,

src/accumulate/accumulate_nd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function accumulate_nd!(
3535

3636
# Degenerate cases end
3737

38-
if !use_KA_algo(v, prefer_threads)
38+
if !use_gpu_algo(backend, prefer_threads)
3939
_accumulate_nd_cpu_sections!(op, v; init, dims, inclusive, max_tasks, min_elems)
4040
else
4141
# On GPUs we have two parallelisation approaches, based on which dimension has more elements:

src/foreachindex.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ function foreachindex(
130130
# GPU settings
131131
block_size=256,
132132
)
133-
if use_KA_algo(itr, prefer_threads)
133+
if use_gpu_algo(backend, prefer_threads)
134134
_forindices_gpu(f, eachindex(itr), backend; block_size)
135135
else
136136
_forindices_threads(f, eachindex(itr); max_tasks, min_elems)
@@ -232,7 +232,7 @@ function foraxes(
232232
)
233233
end
234234

235-
if use_KA_algo(itr, prefer_threads)
235+
if use_gpu_algo(backend, prefer_threads)
236236
_forindices_gpu(f, axes(itr, dims), backend; block_size)
237237
else
238238
_forindices_threads(f, axes(itr, dims); max_tasks, min_elems)

src/predicates.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ function _any_impl(
119119
# GPU settings
120120
block_size::Int=256,
121121
)
122-
if use_KA_algo(v, prefer_threads)
122+
if use_gpu_algo(backend, prefer_threads)
123123
@argcheck block_size > 0
124124

125125
# Some platforms crash when multiple threads write to the same memory location in a global
@@ -253,7 +253,7 @@ function _all_impl(
253253
# GPU settings
254254
block_size::Int=256,
255255
)
256-
if use_KA_algo(v, prefer_threads)
256+
if use_gpu_algo(backend, prefer_threads)
257257
@argcheck block_size > 0
258258

259259
# Some platforms crash when multiple threads write to the same memory location in a global

src/reduce/mapreduce_nd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ function mapreduce_nd(
114114
end
115115
dst_size = length(dst)
116116

117-
if !use_KA_algo(src, prefer_threads)
117+
if !use_gpu_algo(backend, prefer_threads)
118118
_mapreduce_nd_cpu_sections!(
119119
f, op, dst, src;
120120
init,

src/reduce/reduce.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ function _mapreduce_impl(
183183
switch_below::Int=0,
184184
)
185185
if isnothing(dims)
186-
if use_KA_algo(src, prefer_threads)
186+
if use_gpu_algo(backend, prefer_threads)
187187
mapreduce_1d_gpu(
188188
f, op, src, backend;
189189
init, neutral,

src/sort/sort.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ function _sort_impl!(
9696
# Temporary buffer, same size as `v`
9797
temp::Union{Nothing, AbstractArray}=nothing,
9898
)
99-
if use_KA_algo(v, prefer_threads)
99+
if use_gpu_algo(backend, prefer_threads)
100100
merge_sort!(
101101
v, backend;
102102
lt, by, rev, order,
@@ -207,7 +207,7 @@ function _sortperm_impl!(
207207
# Temporary buffer, same size as `v`
208208
temp::Union{Nothing, AbstractArray}=nothing,
209209
)
210-
if use_KA_algo(v, prefer_threads)
210+
if use_gpu_algo(backend, prefer_threads)
211211
merge_sortperm_lowmem!(
212212
ix, v, backend;
213213
lt, by, rev, order,

src/utils.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ function ispow2(x)
33
end
44

55
# Helper function to check whether the package cpu implementation of an algorithm should be used
6-
@inline function use_KA_algo(output_array, prefer_threads)
7-
return output_array isa AnyGPUArray || !prefer_threads
6+
const CPU_BACKEND = get_backend([])
7+
@inline function use_gpu_algo(backend, prefer_threads)
8+
return backend != CPU_BACKEND || !prefer_threads
89
end
910

1011
"""

0 commit comments

Comments
 (0)