Skip to content

Commit dee15a7

Browse files
committed
Vectorized code
1 parent d54194a commit dee15a7

File tree

2 files changed

+21
-29
lines changed

2 files changed

+21
-29
lines changed

src/StochasticBlockModelVariants.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module StochasticBlockModelVariants
22

33
using Graphs: AbstractGraph, neighbors
4-
using LinearAlgebra: Symmetric, dot
4+
using LinearAlgebra: dot, mul!
55
using ProgressMeter: @showprogress
66
using Random: AbstractRNG, default_rng
77
using SimpleWeightedGraphs: SimpleWeightedGraph

src/csbm_inference.jl

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -101,40 +101,35 @@ function update_amp!(
101101
ûᵗ⁺¹, v̂ᵗ⁺¹, χ₊eᵗ⁺¹ = next_storage.û, next_storage.v̂, next_storage.χ₊e
102102
(; û_no_feat, v̂_no_comm, h̃₊, h̃₋, χ₊) = temp_storage
103103

104+
ûₜ_sum = sum(ûᵗ)
105+
ûₜ_sum2 = sum(abs2, ûᵗ)
106+
104107
# AMP estimation of v
105-
σᵥ_no_comm =/ N) * sum(ûᵗ[i]^2 for i in 1:N)
106-
for α in 1:P
107-
v̂_no_comm[α] = (
108-
sqrt/ N) * sum(B[α, i] * ûᵗ[i] for i in 1:N) -
109-
/ N) * sum((1 - ûᵗ[i]^2) * v̂ᵗ[α] for i in 1:N)
110-
)
111-
v̂ᵗ⁺¹[α] = v̂_no_comm[α] / (1 + σᵥ_no_comm)
112-
end
108+
σᵥ_no_comm =/ N) * ûₜ_sum2
109+
mul!(v̂_no_comm, B, ûᵗ)
110+
v̂_no_comm .*= sqrt/ N)
111+
v̂_no_comm .-=/ N) .* v̂ᵗ .* (N - ûₜ_sum2)
112+
v̂ᵗ⁺¹ .= v̂_no_comm ./ (1 + σᵥ_no_comm)
113113
σᵥ = 1 / (1 + σᵥ_no_comm)
114114

115115
# BP estimation of u
116-
for i in 1:N
117-
û_no_feat[i] = (
118-
sqrt/ N) * sum(B[α, i] * v̂ᵗ⁺¹[α] for α in 1:P) -/ (N / P)) * σᵥ * ûᵗ[i]
119-
)
120-
end
116+
mul!(û_no_feat, B', v̂ᵗ⁺¹)
117+
û_no_feat .*= sqrt/ N)
118+
û_no_feat .-=/ (N / P)) .* σᵥ .* ûᵗ
121119

122120
# Estimation of the field h
123-
h₊ = (1 / N) * sum(cᵢ * (1 + ûᵗ[i]) / 2 + cₒ * (1 - ûᵗ[i]) / 2 for i in 1:N)
124-
h₋ = (1 / N) * sum(cₒ * (1 + ûᵗ[i]) / 2 + cᵢ * (1 - ûᵗ[i]) / 2 for i in 1:N)
125-
for i in 1:N
126-
h̃₊[i] = -h₊ + log(one(R) / 2) + û_no_feat[i]
127-
h̃₋[i] = -h₋ + log(one(R) / 2) - û_no_feat[i]
128-
end
121+
h₊ = (1 / 2N) * (cᵢ * (N + ûₜ_sum) + cₒ * (N - ûₜ_sum))
122+
h₋ = (1 / 2N) * (cₒ * (N + ûₜ_sum) + cᵢ * (N - ûₜ_sum))
123+
h̃₊ .= -h₊ .+ log(one(R) / 2) .+ û_no_feat
124+
h̃₋ .= -h₋ .+ log(one(R) / 2) .- û_no_feat
129125

130126
# BP update of the messages
131127
for i in 1:N, j in neighbors(g, i)
132128
s_ij = h̃₊[i] - h̃₋[i]
133129
for k in neighbors(g, i)
134130
if k != j
135-
num = (cₒ + 2λ * sqrt(d) * χ₊eᵗ[k, i])
136-
den = (cᵢ - 2λ * sqrt(d) * χ₊eᵗ[k, i])
137-
s_ij += log(num / den)
131+
common = 2λ * sqrt(d) * χ₊eᵗ[k, i]
132+
s_ij += log((cₒ + common) / (cᵢ - common))
138133
end
139134
end
140135
χ₊eᵗ⁺¹[i, j] = sigmoid(s_ij)
@@ -144,17 +139,14 @@ function update_amp!(
144139
for i in 1:N
145140
s_i = h̃₊[i] - h̃₋[i]
146141
for k in neighbors(g, i)
147-
num = (cₒ + 2λ * sqrt(d) * χ₊eᵗ[k, i])
148-
den = (cᵢ - 2λ * sqrt(d) * χ₊eᵗ[k, i])
149-
s_i += log(num / den)
142+
common = 2λ * sqrt(d) * χ₊eᵗ[k, i]
143+
s_i += log((cₒ + common) / (cᵢ - common))
150144
end
151145
χ₊[i] = sigmoid(s_i)
152146
end
153147

154148
# BP estimation of u
155-
for i in 1:N
156-
ûᵗ⁺¹[i] = 2χ₊[i] - 1
157-
end
149+
ûᵗ⁺¹ .= 2 .* χ₊ .- 1
158150

159151
return nothing
160152
end

0 commit comments

Comments
 (0)