Skip to content

Commit 5447475

Browse files
committed
Matrix multiplication thread per output cell
1 parent 680437f commit 5447475

File tree

1 file changed

+146
-0
lines changed

1 file changed

+146
-0
lines changed
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
### Dumb matrix multiplication
2+
### Use one one GPU thread for each column of the output matrix
3+
4+
from gpu.host import DeviceContext, HostBuffer
5+
from gpu import thread_idx, block_idx, block_dim
6+
import random
7+
from layout import Layout, LayoutTensor
8+
from memory import UnsafePointer, memcpy
9+
from python import Python, PythonObject
10+
from testing import assert_true
11+
12+
alias ROWS_A = 64
13+
alias COLS_A = 16
14+
alias ROWS_B = 16
15+
alias COLS_B = 8
16+
alias ROWS_C = ROWS_A
17+
alias COLS_C = COLS_B
18+
19+
alias MATRIX_MIN_ELEM = -5.0
20+
alias MATRIX_MAX_ELEM = 5.0
21+
22+
alias dtype = DType.float32
23+
# Num threads per block
24+
alias THREADS = (5, 5)
25+
# Total numbers blocks in the grid
26+
alias BLOCKS = (
27+
(COLS_C + THREADS[0] - 1) // THREADS[0],
28+
(ROWS_C + THREADS[1] - 1) // THREADS[1],
29+
)
30+
31+
alias layout_a = Layout.row_major(ROWS_A, COLS_A)
32+
alias layout_b = Layout.row_major(ROWS_B, COLS_B)
33+
alias layout_c = Layout.row_major(ROWS_C, COLS_C)
34+
35+
36+
alias MatrixA = LayoutTensor[dtype, layout_a, MutableAnyOrigin]
37+
alias MatrixB = LayoutTensor[dtype, layout_b, MutableAnyOrigin]
38+
alias MatrixC = LayoutTensor[dtype, layout_c, MutableAnyOrigin]
39+
40+
41+
fn matmul_thread_per_output_cell[
42+
a: Layout, b: Layout, c: Layout
43+
](A: MatrixA, B: MatrixB, C: MatrixC,):
44+
var i = block_idx.y * block_dim.y + thread_idx.y # Rows
45+
var j = block_idx.x * block_dim.x + thread_idx.x # Colums
46+
47+
if i < ROWS_C and j < COLS_C:
48+
for k in range(ROWS_B):
49+
C[i, j] += A[i, k] * B[k, j]
50+
51+
52+
# Initialize the matrix buffer with values in the range 0 to 100
53+
fn fill_buffer(buffer: HostBuffer[dtype]):
54+
# Randomize
55+
random.seed()
56+
for i in range(len(buffer)):
57+
buffer[i] = random.random_float64(
58+
MATRIX_MIN_ELEM, MATRIX_MAX_ELEM
59+
).cast[dtype]()[0]
60+
61+
62+
fn main():
63+
try:
64+
ctx = DeviceContext()
65+
66+
buffer_a = ctx.enqueue_create_buffer[dtype](
67+
ROWS_A * COLS_A
68+
).enqueue_fill(0.0)
69+
buffer_b = ctx.enqueue_create_buffer[dtype](
70+
ROWS_B * COLS_B
71+
).enqueue_fill(0.0)
72+
buffer_c = ctx.enqueue_create_buffer[dtype](
73+
ROWS_C * COLS_C
74+
).enqueue_fill(0.0)
75+
76+
with buffer_a.map_to_host() as h_buffer_a:
77+
fill_buffer(h_buffer_a)
78+
79+
with buffer_b.map_to_host() as h_buffer_b:
80+
fill_buffer(h_buffer_b)
81+
82+
matrix_a = MatrixA(buffer_a)
83+
matrix_b = MatrixB(buffer_b)
84+
matrix_c = MatrixC(buffer_c)
85+
86+
ctx.enqueue_function[
87+
matmul_thread_per_output_cell[layout_a, layout_b, layout_c]
88+
](
89+
matrix_a,
90+
matrix_b,
91+
matrix_c,
92+
grid_dim=BLOCKS,
93+
block_dim=THREADS,
94+
)
95+
96+
ctx.synchronize()
97+
98+
with buffer_a.map_to_host() as h_buffer_a:
99+
with buffer_b.map_to_host() as h_buffer_b:
100+
with buffer_c.map_to_host() as h_buffer_c:
101+
assert_allclose(
102+
(ROWS_A, COLS_A, h_buffer_a),
103+
(ROWS_B, COLS_B, h_buffer_b),
104+
(ROWS_C, COLS_C, h_buffer_c),
105+
)
106+
107+
except e:
108+
print("Prininting here: ", e)
109+
110+
111+
fn assert_allclose(
112+
buff_a_with_dims: (Int, Int, HostBuffer[dtype]),
113+
buff_b_with_dims: (Int, Int, HostBuffer[dtype]),
114+
buff_c_with_dims: (Int, Int, HostBuffer[dtype]),
115+
) raises:
116+
a_rows, a_cols, a_buff = buff_a_with_dims
117+
matrix_a = reshape(to_ndarray(a_buff), a_rows, a_cols)
118+
119+
b_rows, b_cols, b_buff = buff_b_with_dims
120+
matrix_b = reshape(to_ndarray(b_buff), b_rows, b_cols)
121+
122+
c_rows, c_cols, c_buff = buff_c_with_dims
123+
matrix_c = reshape(to_ndarray(c_buff), c_rows, c_cols)
124+
np = Python.import_module("numpy")
125+
assert_true(np.allclose(np.matmul(matrix_a, matrix_b), matrix_c))
126+
print("Assertion was successful")
127+
128+
129+
fn to_ndarray(buffer: HostBuffer[dtype]) raises -> PythonObject:
130+
np = Python.import_module("numpy")
131+
ndarray = np.zeros(len(buffer), dtype=np.float32)
132+
ndarray_ptr = ndarray_ptr[dtype](ndarray)
133+
buffer_ptr = buffer.unsafe_ptr()
134+
memcpy(ndarray_ptr, buffer_ptr, len(buffer))
135+
return ndarray
136+
137+
138+
fn reshape(ndarray: PythonObject, rows: Int, cols: Int) raises -> PythonObject:
139+
return ndarray.reshape(rows, cols)
140+
141+
142+
fn ndarray_ptr[
143+
dtype: DType
144+
](ndarray: PythonObject) raises -> UnsafePointer[Scalar[dtype]]:
145+
return ndarray.__array_interface__["data"][0].unsafe_get_as_pointer[dtype]()
146+

0 commit comments

Comments
 (0)