Skip to content

Conversation

@alonsoC1s
Copy link

Fixes #191
It turns out that adding the Metal extension was not enough, the Act functions that generated kernels assumed CuArray. The current implementation has several issues, for instance, it fails with weird errors when both CUDA and Metal are loaded. Also currently broken: some gradient tests

I'm opening this as draft because I think it would be a good idea to make to overhaul the GPU backend selection code and wanted to get your thoughts on it @mcabbott

Some gradient tests that fail with Metal:

@testset "from TensorTrace" begin
    # These can all be handled using TensorOperations

    triv1(x) = @tullio A[i,j] := 2 * x[i,j]
    @test gradtest(triv1, (2,3))
end
...
r392 = randn(3,9,2);
con6(x) = @tullio C[n,i,m,j] := x[i,j,k] * r392[k,m,n]
@test gradtest(con6, (9,2,3))
...
dm = ForwardDiff.gradient(m -> sum(f8(m,v2)), m4)
@test dm  _gradient(sumf8, m4, v2)[1]  # avx: OK with 0.8, broken with 0.9
dm .- _gradient(sumf8, m4, v2)[1]       # at exactly one element
dv = ForwardDiff.gradient(v -> sum(f8(m4,v)), v2)
@test dv  _gradient(sumf8, m4, v2)[2]
...
ΔA = Tracker.gradient((A,B) -> sum(mul(A, B)), A, B)[1]
@test ΔA  ones(3,500) * B'
@test mtl(ΔA)  Tracker.gradient((A,B) -> sum(mul(A, B)), mtl(A), mtl(B))[1]

I have no clue why the gradients fail

@mcabbott
Copy link
Owner

mcabbott commented Oct 4, 2025

Thanks, this looks like the right minimal extension to treat Metal like CUDA.

Are these failed tests wrong values, or errors from MtlArrays? I would have guessed it should all be same as CuArray, Tracker doesn't seem to load anything GPU-related: https://github.com/FluxML/Tracker.jl/blob/master/Project.toml

@mcabbott mcabbott added the GPU label Oct 4, 2025
@alonsoC1s
Copy link
Author

I think it would be a good idea to refine the "backend selection" bit in make_many_actors. If I understand correctly, ROC, OneAPI, etc... won't work with Tullio@main. Would it help to dispatch on GPUArray in general instead?

They are wrong values, the gradient computation does not error out, just gives erroneous values. I'm also surprised it's not working. Could it be that there is also some dispatch on the gradient computation code that assumes CuArray directly?

@mcabbott
Copy link
Owner

mcabbott commented Oct 4, 2025

Ok.

Yes everything about how Tullio interacts with other packages is a bit of a hack. Ideally it would dispatch to AbstractGPUArray as you say. But I'm not sure whether there's a generic replacement for MetalBackend() here -- there may well be, I just haven't looked at KernelAbstractions in ages:

function $act!(::Type{<:MtlArray}, $(args...), $KEEP=nothing, $FINAL=true) where {$TYP}
    $info2
    mtl_kern! = $kernel(MetalBackend())

Re tests this should probably be rebased on top of #193.

Can any CI can test Metal right now, maybe github does in fact support this? For CUDA the buildkite jobs run this on Julia's GPU servers.

@alonsoC1s
Copy link
Author

I agree, I think the extension would still be necessary. But obviously that's just intuition, the codebase is very new to me. What we could do there is use KernelAbstractions.get_backend (not sure if that's the actual function name) to get the backend that corresponds to the array. That works for CPU arrays as well, so that could simplify the condition checks

Not sure how #193 interacts with the gradient tests, but I can give it a go and see what happens. As for CI, I think I read somewhere there are some GitHub MacOS machines. Not sure if they have Metal GPUs though

@alonsoC1s
Copy link
Author

Another issue with the kernel launching I just came across. Since the backend selection checks isdefined(store.mod, :GPUArray), the package calling Tullio has to load CUDA or Metal. When loaded in conjuntion with instead of by the module directly, Tullio doesn't realize the GPU is available, tries to use the CPU version and the scalar issue is triggered again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Scalar indexing error when using Metal GPU backend

2 participants