Skip to content

Commit b1d4a5e

Browse files
KernelIntrinsics API (#635)
Co-authored-by: Valentin Churavy <v.churavy@gmail.com>
1 parent 0ece57c commit b1d4a5e

File tree

7 files changed

+656
-91
lines changed

7 files changed

+656
-91
lines changed

examples/histogram.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# INCLUDE ROCM
22
using KernelAbstractions, Test
33
using KernelAbstractions: @atomic, @atomicswap, @atomicreplace
4+
import KernelAbstractions.KernelIntrinsics as KI
5+
46
include(joinpath(dirname(pathof(KernelAbstractions)), "../examples/utils.jl")) # Load backend
57

68
# Function to use as a baseline for CPU metrics
@@ -12,16 +14,15 @@ function create_histogram(input)
1214
return histogram_output
1315
end
1416

15-
# This a 1D histogram kernel where the histogramming happens on shmem
16-
@kernel unsafe_indices = true function histogram_kernel!(histogram_output, input)
17-
gid = @index(Group, Linear)
18-
lid = @index(Local, Linear)
17+
# This a 1D histogram kernel where the histogramming happens on static shmem
18+
function histogram_kernel!(histogram_output, input, ::Val{gs}) where {gs}
19+
gid = KI.get_group_id().x
20+
lid = KI.get_local_id().x
1921

20-
@uniform gs = prod(@groupsize())
2122
tid = (gid - 1) * gs + lid
22-
@uniform N = length(histogram_output)
23+
N = length(histogram_output)
2324

24-
shared_histogram = @localmem eltype(input) (gs)
25+
shared_histogram = KI.localmemory(eltype(input), gs)
2526

2627
# This will go through all input elements and assign them to a location in
2728
# shmem. Note that if there is not enough shem, we create different shmem
@@ -32,7 +33,7 @@ end
3233

3334
# Setting shared_histogram to 0
3435
@inbounds shared_histogram[lid] = 0
35-
@synchronize()
36+
KI.barrier()
3637

3738
max_element = min_element + gs
3839
if max_element > N
@@ -46,21 +47,20 @@ end
4647
@atomic shared_histogram[bin] += 1
4748
end
4849

49-
@synchronize()
50+
KI.barrier()
5051

5152
if ((lid + min_element - 1) <= N)
5253
@atomic histogram_output[lid + min_element - 1] += shared_histogram[lid]
5354
end
5455

5556
end
56-
57+
return
5758
end
5859

5960
function histogram!(histogram_output, input, groupsize = 256)
6061
backend = get_backend(histogram_output)
6162
# Need static block size
62-
kernel! = histogram_kernel!(backend, (groupsize,))
63-
kernel!(histogram_output, input, ndrange = size(input))
63+
KI.@kernel backend workgroupsize = groupsize numworkgroups = cld(length(input), groupsize) histogram_kernel!(histogram_output, input, Val(groupsize))
6464
return
6565
end
6666

examples/performant_matmul.jl

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,79 @@
11
using KernelAbstractions
2+
import KernelAbstractions.KernelIntrinsics as KI
3+
24
using StaticArrays
35
using Test
46
using Random
7+
58
include(joinpath(dirname(pathof(KernelAbstractions)), "../examples/utils.jl")) # Load backend
69

710
# We use a TILE_DIM of 16 as a safe value since while
811
# most backends support up to 1024 threads per group,
912
# Metal sometimes supports fewer.
1013
const TILE_DIM = 16
1114

12-
@kernel unsafe_indices = true function coalesced_matmul_kernel!(
13-
output, @Const(input1), @Const(input2), N, R, M,
14-
::Val{BANK} = Val(1),
15-
) where {BANK}
16-
gi, gj = @index(Group, NTuple)
17-
i, j = @index(Local, NTuple)
18-
19-
TILE_DIM = @uniform @groupsize()[1]
15+
function coalesced_matmul_kernel!(
16+
output, input1, input2, N, R, M,
17+
::Val{TDIM}, ::Val{BANK} = Val(1)
18+
) where {TDIM, BANK}
19+
gi, gj, _ = KI.get_group_id()
20+
i, j, _ = KI.get_local_id()
2021

2122
# +1 to avoid bank conflicts on shared memory
22-
tile1 = @localmem eltype(output) (TILE_DIM + BANK, TILE_DIM)
23-
tile2 = @localmem eltype(output) (TILE_DIM + BANK, TILE_DIM)
23+
tile1 = KI.localmemory(eltype(output), (TDIM + BANK, TDIM))
24+
tile2 = KI.localmemory(eltype(output), (TDIM + BANK, TDIM))
2425

25-
# private variable for tile output
26-
outval = @private eltype(output) 1
27-
@inbounds outval[1] = -zero(eltype(output))
26+
# variable for tile output
27+
outval = -zero(eltype(output))
2828

29-
@uniform N = size(output, 1)
29+
N = size(output, 1)
3030
# number of tiles depends on inner dimension
31-
@uniform NUM_TILES = div(R + TILE_DIM - 1, TILE_DIM)
31+
NUM_TILES = div(R + TDIM - 1, TDIM)
3232

3333
# loop over all tiles needed for this calculation
3434
for t in 0:(NUM_TILES - 1)
3535
# Can't use @index(Global), because we use a smaller ndrange
36-
I = (gi - 1) * TILE_DIM + i
37-
J = (gj - 1) * TILE_DIM + j
36+
I = (gi - 1) * TDIM + i
37+
J = (gj - 1) * TDIM + j
3838

3939
# load inputs into tiles, with bounds checking for non-square matrices
40-
if I <= N && t * TILE_DIM + j <= R
41-
@inbounds tile1[i, j] = input1[I, t * TILE_DIM + j]
40+
if I <= N && t * TDIM + j <= R
41+
@inbounds tile1[i, j] = input1[I, t * TDIM + j]
4242
else
4343
@inbounds tile1[i, j] = 0.0
4444
end
4545
if t * TILE_DIM + i <= R && J <= M
46-
@inbounds tile2[i, j] = input2[t * TILE_DIM + i, J]
46+
@inbounds tile2[i, j] = input2[t * TDIM + i, J]
4747
else
4848
@inbounds tile2[i, j] = 0.0
4949
end
5050

5151
# wait for all tiles to be loaded
52-
@synchronize
52+
KI.barrier()
5353

5454
# get global values again
55-
I = (gi - 1) * TILE_DIM + i
56-
J = (gj - 1) * TILE_DIM + j
55+
I = (gi - 1) * TDIM + i
56+
J = (gj - 1) * TDIM + j
5757

5858
# calculate value of spot in output, use temporary value to allow for vectorization
5959
out = zero(eltype(output))
60-
@simd for k in 1:TILE_DIM
60+
@simd for k in 1:TDIM
6161
@inbounds out += tile1[i, k] * tile2[k, j]
6262
end
63-
outval[1] += out
63+
outval += out
6464

65-
@synchronize
65+
KI.barrier()
6666
end
6767

6868
# get global indices again
69-
I = (gi - 1) * TILE_DIM + i
70-
J = (gj - 1) * TILE_DIM + j
69+
I = (gi - 1) * TDIM + i
70+
J = (gj - 1) * TDIM + j
7171

7272
# save if inbounds
7373
if I <= N && J <= M
74-
@inbounds output[I, J] = outval[1]
74+
@inbounds output[I, J] = outval
7575
end
76+
return
7677
end
7778

7879
N = 1024
@@ -82,9 +83,10 @@ A = rand!(allocate(backend, Float32, N, R))
8283
B = rand!(allocate(backend, Float32, R, M))
8384
C = KernelAbstractions.zeros(backend, Float32, N, M)
8485

85-
kern = coalesced_matmul_kernel!(backend, (TILE_DIM, TILE_DIM))
86+
workgroupsize = (TILE_DIM, TILE_DIM)
87+
numworkgroups = (cld(size(C, 1), TILE_DIM), cld(size(C, 2), TILE_DIM))
8688

87-
kern(C, A, B, N, R, M, ndrange = size(C))
89+
KI.@kernel backend workgroupsize numworkgroups coalesced_matmul_kernel!(C, A, B, N, R, M, Val(TILE_DIM))
8890
KernelAbstractions.synchronize(backend)
8991

9092
@test isapprox(A * B, C)

src/KernelAbstractions.jl

Lines changed: 51 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,15 @@ function unsafe_free! end
194194

195195
unsafe_free!(::AbstractArray) = return
196196

197+
"""
198+
Abstract type for all KernelAbstractions backends.
199+
"""
200+
abstract type Backend end
201+
202+
include("intrinsics.jl")
203+
import .KernelIntrinsics as KI
204+
export KernelIntrinsics
205+
197206
###
198207
# Kernel language
199208
# - @localmem
@@ -360,6 +369,25 @@ macro context()
360369
return esc(:(__ctx__))
361370
end
362371

372+
# Defined to keep cpu support for `__print`
373+
@generated function KI._print(items...)
374+
str = ""
375+
args = []
376+
377+
for i in 1:length(items)
378+
item = :(items[$i])
379+
T = items[i]
380+
if T <: Val
381+
item = QuoteNode(T.parameters[1])
382+
end
383+
push!(args, item)
384+
end
385+
386+
return quote
387+
print($(args...))
388+
end
389+
end
390+
363391
"""
364392
@print(items...)
365393
@@ -460,13 +488,27 @@ end
460488
# Internal kernel functions
461489
###
462490

463-
function __index_Local_Linear end
464-
function __index_Group_Linear end
465-
function __index_Global_Linear end
491+
@inline function __index_Local_Linear(ctx)
492+
return KI.get_local_id().x
493+
end
494+
495+
@inline function __index_Group_Linear(ctx)
496+
return KI.get_group_id().x
497+
end
466498

467-
function __index_Local_Cartesian end
468-
function __index_Group_Cartesian end
469-
function __index_Global_Cartesian end
499+
@inline function __index_Global_Linear(ctx)
500+
return KI.get_global_id().x
501+
end
502+
503+
@inline function __index_Local_Cartesian(ctx)
504+
return @inbounds workitems(__iterspace(ctx))[KI.get_local_id().x]
505+
end
506+
@inline function __index_Group_Cartesian(ctx)
507+
return @inbounds blocks(__iterspace(ctx))[KI.get_group_id().x]
508+
end
509+
@inline function __index_Global_Cartesian(ctx)
510+
return @inbounds expand(__iterspace(ctx), KI.get_group_id().x, KI.get_local_id().x)
511+
end
470512

471513
@inline __index_Local_NTuple(ctx, I...) = Tuple(__index_Local_Cartesian(ctx, I...))
472514
@inline __index_Group_NTuple(ctx, I...) = Tuple(__index_Group_Cartesian(ctx, I...))
@@ -482,11 +524,6 @@ constify(arg) = adapt(ConstAdaptor(), arg)
482524
# Backend hierarchy
483525
###
484526

485-
"""
486-
487-
Abstract type for all KernelAbstractions backends.
488-
"""
489-
abstract type Backend end
490527

491528
"""
492529
Abstract type for all GPU based KernelAbstractions backends.
@@ -796,29 +833,11 @@ include("macros.jl")
796833
###
797834

798835
function Scratchpad end
799-
function SharedMemory end
800-
801-
function __synchronize()
802-
error("@synchronize used outside kernel or not captured")
803-
end
804-
805-
@generated function __print(items...)
806-
str = ""
807-
args = []
836+
SharedMemory(t::Type{T}, dims::Val{Dims}, id::Val{Id}) where {T, Dims, Id} = KI.localmemory(t, dims)
808837

809-
for i in 1:length(items)
810-
item = :(items[$i])
811-
T = items[i]
812-
if T <: Val
813-
item = QuoteNode(T.parameters[1])
814-
end
815-
push!(args, item)
816-
end
838+
__synchronize() = KI.barrier()
817839

818-
return quote
819-
print($(args...))
820-
end
821-
end
840+
__print(args...) = KI._print(args...)
822841

823842
# Utils
824843
__size(args::Tuple) = Tuple{args...}

0 commit comments

Comments
 (0)