Skip to content

Commit 3f9b8a4

Browse files
committed
fix tullio GPU loading
1 parent 281641b commit 3f9b8a4

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

src/proj_equirect.jl

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,4 @@
11

2-
# when CUDA is loaded, we need to reload this file so the @tullio
3-
# macro calls generate a GPU version
4-
@init @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
5-
using KernelAbstractions, CUDAKernels, CUDA
6-
include(@__FILE__)
7-
end
8-
92
struct ProjEquiRect{T} <: CartesianProj
103

114
Ny :: Int
@@ -227,19 +220,19 @@ end
227220

228221
(*)(M::BlockDiagEquiRect{B}, f::EquiRectField) where {B<:Basis} = M * B(f)
229222

230-
function (*)(M::BlockDiagEquiRect{B}, f::F) where {B<:AzBasis, F<:EquiRectField{B}}
223+
@uses_tullio function (*)(M::BlockDiagEquiRect{B}, f::F) where {B<:AzBasis, F<:EquiRectField{B}}
231224
promote_metadata_strict(M.proj, f.proj) # ensure same projection
232225
F(@tullio(Bf[p,iₘ] := M.blocks[p,q,iₘ] * f.arr[q,iₘ]), f.proj)
233226
end
234227

235228
(*)(M::Adjoint{T,<:BlockDiagEquiRect{B}}, f::EquiRectField) where {T, B<:Basis} = M * B(f)
236229

237-
function (*)(M::Adjoint{T,<:BlockDiagEquiRect{B}}, f::F) where {T, B<:AzBasis, F<:EquiRectField{B}}
230+
@uses_tullio function (*)(M::Adjoint{T,<:BlockDiagEquiRect{B}}, f::F) where {T, B<:AzBasis, F<:EquiRectField{B}}
238231
promote_metadata_strict(M.parent.proj, f.proj) # ensure same projection
239232
F(@tullio(Bf[p,iₘ] := conj(M.parent.blocks[q,p,iₘ]) * f.arr[q,iₘ]), f.proj)
240233
end
241234

242-
function rrule(::typeof(*), M::BlockDiagEquiRect{B}, f::EquiRectField{B′}) where {B<:Basis, B′<:Basis}
235+
@uses_tullio function rrule(::typeof(*), M::BlockDiagEquiRect{B}, f::EquiRectField{B′}) where {B<:Basis, B′<:Basis}
243236
function times_pullback(Δ)
244237
BΔ, Bf = B(Δ), B(f)
245238
Zygote.ChainRules.NoTangent(), @thunk(BlockDiagEquiRect{B}(@tullio(M̄[p,q,iₘ] := Bf.arr[p,iₘ] * conj(BΔ.arr[q,iₘ])), M.proj)'), B′(M' * BΔ)
@@ -251,19 +244,19 @@ end
251244
# ## Linear Algebra: tullio accelerated (operator, operator)
252245

253246
# M₁ * M₂
254-
function (*)(M₁::BlockDiagEquiRect{B}, M₂::BlockDiagEquiRect{B}) where {B<:AzBasis}
247+
@uses_tullio function (*)(M₁::BlockDiagEquiRect{B}, M₂::BlockDiagEquiRect{B}) where {B<:AzBasis}
255248
promote_metadata_strict(M₁.proj, M₂.proj) # ensure same projection
256249
BlockDiagEquiRect{B}(@tullio(M₃[p,q,iₘ] := M₁.blocks[p,j,iₘ] * M₂.blocks[j,q,iₘ]), M₁.proj)
257250
end
258251

259252
# M₁' * M₂
260-
function (*)(M₁::Adjoint{T,<:BlockDiagEquiRect{B}}, M₂::BlockDiagEquiRect{B}) where {T, B<:AzBasis}
253+
@uses_tullio function (*)(M₁::Adjoint{T,<:BlockDiagEquiRect{B}}, M₂::BlockDiagEquiRect{B}) where {T, B<:AzBasis}
261254
promote_metadata_strict(M₁.parent.proj, M₂.proj) # ensure same projection
262255
BlockDiagEquiRect{B}(@tullio(M₃[p,q,iₘ] := conj(M₁.parent.blocks[j,p,iₘ]) * M₂.blocks[j,q,iₘ]), M₁.parent.proj)
263256
end
264257

265258
# M₁ * M₂'
266-
function (*)(M₁::BlockDiagEquiRect{B}, M₂::Adjoint{T,<:BlockDiagEquiRect{B}}) where {T, B<:AzBasis}
259+
@uses_tullio function (*)(M₁::BlockDiagEquiRect{B}, M₂::Adjoint{T,<:BlockDiagEquiRect{B}}) where {T, B<:AzBasis}
267260
promote_metadata_strict(M₁.proj, M₂.parent.proj) # ensure same projection
268261
BlockDiagEquiRect{B}(@tullio(M₃[p,q,iₘ] := M₁.blocks[p,j,iₘ] * conj(M₂.parent.blocks[q,j,iₘ])), M₁.proj)
269262
end
@@ -355,7 +348,7 @@ end
355348
LinearAlgebra.dot(a::EquiRectField, b::EquiRectField) = dot(Ł(a).arr, Ł(b).arr)
356349

357350
# needed by AD
358-
function LinearAlgebra.dot(M₁::Adjoint{T,<:BlockDiagEquiRect{B}}, M₂::BlockDiagEquiRect{B}) where {T, B<:AzBasis}
351+
@uses_tullio function LinearAlgebra.dot(M₁::Adjoint{T,<:BlockDiagEquiRect{B}}, M₂::BlockDiagEquiRect{B}) where {T, B<:AzBasis}
359352
(@tullio a[] := conj(M₁.parent.blocks[q,p,iₘ]) * M₂.blocks[p,q,iₘ])[]
360353
end
361354

@@ -500,7 +493,7 @@ end
500493

501494
end
502495

503-
function Cℓ_to_Beam(::Val{:I}, proj::ProjEquiRect{T}, CI::Cℓs; units=1, ℓmax=10_000, progress=true) where {T}
496+
@uses_tullio function Cℓ_to_Beam(::Val{:I}, proj::ProjEquiRect{T}, CI::Cℓs; units=1, ℓmax=10_000, progress=true) where {T}
504497

505498
@unpack Ω = proj
506499
Ω′ = T.(Ω)
@@ -511,7 +504,7 @@ function Cℓ_to_Beam(::Val{:I}, proj::ProjEquiRect{T}, CI::Cℓs; units=1, ℓm
511504
return Cov
512505
end
513506

514-
function Cℓ_to_Beam(::Val{:P}, proj::ProjEquiRect{T}, CI::Cℓs; units=1, ℓmax=10_000, progress=true) where {T}
507+
@uses_tullio function Cℓ_to_Beam(::Val{:P}, proj::ProjEquiRect{T}, CI::Cℓs; units=1, ℓmax=10_000, progress=true) where {T}
515508

516509
@unpack θ, Ω = proj
517510
Ω′ = T.(Ω)

src/util.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,3 +507,14 @@ end
507507

508508
real_type(T) = promote_type(real(T), Float32)
509509
@init @require Unitful="1986cc42-f94f-5a68-af5c-568840ba703d" real_type(::Type{<:Unitful.Quantity{T}}) where {T} = real_type(T)
510+
511+
512+
macro uses_tullio(funcdef)
513+
quote
514+
$(esc(funcdef))
515+
@init @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
516+
using KernelAbstractions, CUDAKernels, CUDA
517+
$(esc(funcdef))
518+
end
519+
end
520+
end

0 commit comments

Comments
 (0)