Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing NaN sampling for von Mises Fisher distribution #1918

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

williamjsdavis
Copy link

Summary of contribution

I have made a change to the sampling algorithm of von Mises Fisher distributions to avoid NaN outputs.

The problem

Sampling from certain von Mises Fisher distributions will result in NaN. For example:

using Distributions
using LinearAlgebra: norm

κ = 1.0
μ = [1.0,1e-8]
μ = μ ./ norm(μ)
d = VonMisesFisher(μ, κ)
rand(d)

...returns:

2-element Vector{Float64}:
 NaN
 NaN

The error happens because of operations performed in the function _vmf_householder_vec:

function _vmf_householder_vec::Vector{Float64})
# assuming μ is a unit-vector (which it should be)
# can compute v in a single pass over μ
p = length(μ)
v = similar(μ)
v[1] = μ[1] - 1.0
s = sqrt(-2*v[1])
v[1] /= s
@inbounds for i in 2:p
v[i] = μ[i] / s
end
return v
end

The problem arises when μ is a vector exactly in the direction of the first dimension, or very close to the first dimension. (For example, μ = [1.0, 0.0], μ = [1.0, 0.0, 0.0], or μ = [1.0, 0.0, 1e-8], etc...) In these cases, variable s will become zero, leading vector v to become filled with NaNs.

This issue was previously raised in #1423

Proposed solution

I propose a solution, where I have added a small epsilon value to the denominators in _vmf_householder_vec. This approach is common in machine learning algorithms like ADAM, where dividing by zero is avoided during accumulation:

  1. Adam: A Method for Stochastic Optimization
  2. What do I mean by ε?
using Distributions
using LinearAlgebra: norm

κ = 1.0
μ = [1.0,1e-8]
μ = μ ./ norm(μ)
d = VonMisesFisher(μ, κ)
rand(d)

...returns:

2-element Vector{Float64}:
 0.8976253988301388
 0.4407591670913201

... which are not NaNs!

Discussion

I have the following discussion topics:

  1. I'm open to feedback regarding whether this PR is going in the right direction. Is a change here desired? Is my proposed solution the right way forward?
  2. If so, what additional contributions would be needed? For example, I'd assume that we'd also like to add some tests to verify my change is not causing unintended changes, and to check for future NaN outputs.

As always, feedback is appreciated

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant