@@ -101,40 +101,35 @@ function update_amp!(
101101 ûᵗ⁺¹, v̂ᵗ⁺¹, χ₊eᵗ⁺¹ = next_storage. û, next_storage. v̂, next_storage. χ₊e
102102 (; û_no_feat, v̂_no_comm, h̃₊, h̃₋, χ₊) = temp_storage
103103
104+ ûₜ_sum = sum (ûᵗ)
105+ ûₜ_sum2 = sum (abs2, ûᵗ)
106+
104107 # AMP estimation of v
105- σᵥ_no_comm = (μ / N) * sum (ûᵗ[i]^ 2 for i in 1 : N)
106- for α in 1 : P
107- v̂_no_comm[α] = (
108- sqrt (μ / N) * sum (B[α, i] * ûᵗ[i] for i in 1 : N) -
109- (μ / N) * sum ((1 - ûᵗ[i]^ 2 ) * v̂ᵗ[α] for i in 1 : N)
110- )
111- v̂ᵗ⁺¹[α] = v̂_no_comm[α] / (1 + σᵥ_no_comm)
112- end
108+ σᵥ_no_comm = (μ / N) * ûₜ_sum2
109+ mul! (v̂_no_comm, B, ûᵗ)
110+ v̂_no_comm .*= sqrt (μ / N)
111+ v̂_no_comm .- = (μ / N) .* v̂ᵗ .* (N - ûₜ_sum2)
112+ v̂ᵗ⁺¹ .= v̂_no_comm ./ (1 + σᵥ_no_comm)
113113 σᵥ = 1 / (1 + σᵥ_no_comm)
114114
115115 # BP estimation of u
116- for i in 1 : N
117- û_no_feat[i] = (
118- sqrt (μ / N) * sum (B[α, i] * v̂ᵗ⁺¹[α] for α in 1 : P) - (μ / (N / P)) * σᵥ * ûᵗ[i]
119- )
120- end
116+ mul! (û_no_feat, B' , v̂ᵗ⁺¹)
117+ û_no_feat .*= sqrt (μ / N)
118+ û_no_feat .- = (μ / (N / P)) .* σᵥ .* ûᵗ
121119
122120 # Estimation of the field h
123- h₊ = (1 / N) * sum (cᵢ * (1 + ûᵗ[i]) / 2 + cₒ * (1 - ûᵗ[i]) / 2 for i in 1 : N)
124- h₋ = (1 / N) * sum (cₒ * (1 + ûᵗ[i]) / 2 + cᵢ * (1 - ûᵗ[i]) / 2 for i in 1 : N)
125- for i in 1 : N
126- h̃₊[i] = - h₊ + log (one (R) / 2 ) + û_no_feat[i]
127- h̃₋[i] = - h₋ + log (one (R) / 2 ) - û_no_feat[i]
128- end
121+ h₊ = (1 / 2 N) * (cᵢ * (N + ûₜ_sum) + cₒ * (N - ûₜ_sum))
122+ h₋ = (1 / 2 N) * (cₒ * (N + ûₜ_sum) + cᵢ * (N - ûₜ_sum))
123+ h̃₊ .= - h₊ .+ log (one (R) / 2 ) .+ û_no_feat
124+ h̃₋ .= - h₋ .+ log (one (R) / 2 ) .- û_no_feat
129125
130126 # BP update of the messages
131127 for i in 1 : N, j in neighbors (g, i)
132128 s_ij = h̃₊[i] - h̃₋[i]
133129 for k in neighbors (g, i)
134130 if k != j
135- num = (cₒ + 2 λ * sqrt (d) * χ₊eᵗ[k, i])
136- den = (cᵢ - 2 λ * sqrt (d) * χ₊eᵗ[k, i])
137- s_ij += log (num / den)
131+ common = 2 λ * sqrt (d) * χ₊eᵗ[k, i]
132+ s_ij += log ((cₒ + common) / (cᵢ - common))
138133 end
139134 end
140135 χ₊eᵗ⁺¹[i, j] = sigmoid (s_ij)
@@ -144,17 +139,14 @@ function update_amp!(
144139 for i in 1 : N
145140 s_i = h̃₊[i] - h̃₋[i]
146141 for k in neighbors (g, i)
147- num = (cₒ + 2 λ * sqrt (d) * χ₊eᵗ[k, i])
148- den = (cᵢ - 2 λ * sqrt (d) * χ₊eᵗ[k, i])
149- s_i += log (num / den)
142+ common = 2 λ * sqrt (d) * χ₊eᵗ[k, i]
143+ s_i += log ((cₒ + common) / (cᵢ - common))
150144 end
151145 χ₊[i] = sigmoid (s_i)
152146 end
153147
154148 # BP estimation of u
155- for i in 1 : N
156- ûᵗ⁺¹[i] = 2 χ₊[i] - 1
157- end
149+ ûᵗ⁺¹ .= 2 .* χ₊ .- 1
158150
159151 return nothing
160152end
0 commit comments