Я играю с flux.jl
, и у меня возникают проблемы с обновлением параметров пользовательской функции.
Функция определена ниже как objective
:
using Distributions
using Flux.Tracker: gradient, param, Params
using Flux.Optimise: Descent, ADAM, update!
D = 2
num_samples = 100
function log_density(params)
mu, log_sigma = params
d1 = Normal(0, 1.35)
d2 = Normal(0, exp(log_sigma))
d1_density = logpdf(d1, log_sigma)
d2_density = logpdf(d2, mu)
return d1_density + d2_density
end
function J(log_std)
H = 0.5 * D * (1.0 + log(2 * pi)) + sum(log_std)
return H
end
function objective(mu, log_std; D=2)
samples = rand(Normal(), num_samples, D) .* sqrt.(log_std) .+ mu
log_px = mapslices(log_density, samples; dims=2)
elbo = J(log_std) + mean(log_px)
return -elbo
end
И я пытаюсь сделать одно обновление следующим образом:
mu = param(reshape([-1, -1], 1, :))
sigma = param(reshape([5, 5], 1, :))
grads = gradient(() -> objective(mu, sigma), Params([mu, sigma]))
opt = Descent(0.001)
for p in (mu, sigma)
update!(opt, p, grads[p])
end
Выдает ошибку:
ERROR: Can't differentiate `setindex!`
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] setindex!(::TrackedArray{…,Array{Float64,2}}, ::Flux.Tracker.TrackedReal{Float64}, ::CartesianIndex{2}) at /Users/vasya/.julia/packages/Flux/T3PhK/src/tracker/lib/array.jl:63
[3] macro expansion at ./broadcast.jl:838 [inlined]
[4] macro expansion at ./simdloop.jl:73 [inlined]
[5] copyto! at ./broadcast.jl:837 [inlined]
[6] copyto! at ./broadcast.jl:792 [inlined]
[7] materialize! at ./broadcast.jl:751 [inlined]
[8] update!(::Descent, ::TrackedArray{…,Array{Float64,2}}, ::TrackedArray{…,Array{Float64,2}}) at /Users/vasya/.julia/packages/Flux/T3PhK/src/optimise/optimisers.jl:22
[9] top-level scope at ./REPL[23]:2 [inlined]
[10] top-level scope at ./none:0
Я также попытался заменить grads[p]
на grads[p].data
. Это не приводит к ошибке, но не обновляет параметры!
Сведения об окружающей среде:
- Юлия Версия 1.0.2
- Флюс v0.7.0
- Дистрибутивы v0.16.4
Обсуждение в Slack выявило правильное использование функций update!
. Код ниже делает ссылки на модули явными и создает обновленные параметры (для Flux v0.7.0):
using Distributions
using Flux
D = 2
num_samples = 100
function log_density(params)
mu, log_sigma = params
d1 = Normal(0, 1.35)
d2 = Normal(0, exp(log_sigma))
d1_density = logpdf(d1, log_sigma)
d2_density = logpdf(d2, mu)
return d1_density + d2_density
end
function J(log_std)
H = 0.5 * D * (1.0 + log(2 * pi)) + sum(log_std)
return H
end
function objective(mu, log_std; D=2)
samples = rand(Normal(), num_samples, D) .* sqrt.(log_std) .+ mu
log_px = mapslices(log_density, samples; dims=2)
elbo = J(log_std) + mean(log_px)
return -elbo
end
mu = Flux.Tracker.param(reshape([-1, -1], 1, :))
sigma = Flux.Tracker.param(reshape([5, 5], 1, :))
grads = Flux.Tracker.gradient(() -> objective(mu, sigma), Flux.Tracker.Params([mu, sigma]))
println(mu, sigma)
opt = Flux.Optimise.Descent(0.01)
for p in (mu, sigma)
Flux.Tracker.update!(p, Flux.Optimise.update!(opt, p, Flux.data(grads[p])))
end
println(mu, sigma)
Это печатает:
[-1.0 -1.0] (tracked)[5.0 5.0] (tracked)
[-198.742 -459.423] (tracked)[31.0583 225.657] (tracked)