UMAP Loss Function
This page describes the fuzzy set cross entropy loss function used in UMAP and how it is optimized during the embedding process.
Fuzzy Set Cross Entropy
Given two fuzzy simplicial sets, we can consider the 1-skeleta as a fuzzy graph (i.e., a set of edges, where each edge has a probability of existing in the graph). The two sets (of edges) can be compared by computing the set cross entropy.
For a set $A$ and membership functions $\mu: A \to [0, 1]$ and $\nu: A \to [0, 1]$, the set cross entropy is:
\[C(A, \mu, \nu) = \sum_{a \in A} \left[ \mu(a) \log\frac{\mu(a)}{\nu(a)} + (1 - \mu(a)) \log\frac{1 - \mu(a)}{1 - \nu(a)} \right]\]
In code:
function cross_entropy(A::Set, μ, ν)
loss = 0
for a ∈ A
loss += μ(a) * log(μ(a) / ν(a)) + (1 - μ(a)) * log((1 - μ(a)) / (1 - ν(a)))
end
return loss
endGeneralization to $\ell$-Skeleta
The loss function can be generalized to $\ell$-skeleta by the weighted sum of the set cross entropies of the fuzzy sets of $i$-simplices. That is,
\[C_\ell(X, Y) = \sum_{i=1}^{\ell} \lambda_i \cdot C(X_i, Y_i)\]
where $X_i$ denotes the $i$-simplices of $X$.
In code:
function cross_entropy(𝐀::Vector{Set}, μ, ν)
loss = 0
for A in 𝐀
loss += cross_entropy(A, μ, ν)
end
endSimplified Loss for Optimization
During optimization, we can simplify the loss function to only consider terms that aren't fixed values and minimize that:
\[C(A, \mu, \nu) = -\sum_{a \in A} \left[ \mu(a) \log\nu(a) + (1 - \mu(a)) \log(1 - \nu(a)) \right]\]
In code:
function cross_entropy(A::Set, μ, ν)
loss = 0
for a ∈ A
loss += μ(a) * log(ν(a)) + (1 - μ(a)) * log(1 - ν(a))
end
return -loss
endStochastic Sampling
Instead of calculating the loss over the entire set (if our set is comprised of the 1-simplices, then calculating this loss would have time complexity $\mathcal{O}(n^2)$), we can sample elements with probability $\mu(a)$ and update according to the value $\nu(a)$. This takes care of the $\mu(a) \log\nu(a)$ term.
For the negative samples, elements are sampled uniformly and assumed to have $\mu(a) = 0$. This results in a sampling distribution of
\[P(x_i) = \frac{\sum_{a \in A \mid d_0(a) = x_i}(1 - \mu(a))}{\sum_{b \in A \mid d_0(b) \neq x_i}(1 - \mu(b))}\]
which is approximately uniform for sufficiently large datasets.
Differentiable Membership Function
To optimize this loss with gradient descent, $\nu(v)$ must be differentiable. A smooth approximation for the membership strength of a 1-simplex between two points $x, y$ can be given by the following, with dissimilarity function $\sigma$ and constants $a$, $b$:
\[\phi(x, y) = \left(1 + a \cdot \sigma(x, y)^{2b}\right)^{-1}\]
In code:
ϕ(x, y, σ, a, b) = (1 + a*(σ(x, y))^(2b))^(-1)The approximation parameters $a$, $b$ are chosen by non-linear least squares fitting of the following function $\psi$:
\[\psi(x, y) = \begin{cases} 1 & \text{if } \sigma(x, y) \leq \text{min\_dist} \\ e^{-(\sigma(x, y) - \text{min\_dist})} & \text{otherwise} \end{cases}\]
In code:
ψ(x, y, σ, min_dist) = σ(x, y) ≤ min_dist ? 1 : exp(-(σ(x, y) - min_dist))Optimization Algorithm
Optimizing the embedding is accomplished by stochastic gradient descent, where:
fs_setis the set of $\ell$-simplices (typically 1-simplices)Y_embis the target embedding of the points that make up the vertices offs_set\[\sigma\]
is a differentiable distance measure between points inY_emb\[\phi\]
is the differentiable approximation to the fuzzy set membership function for the simplices in the target embedding
The algorithm proceeds as follows:
function optimize_embedding(fs_set, Y_emb, σ, ϕ, n_epochs, n_neg_samples)
η = 1 # learning rate
∇logϕ(x, y) = gradient((_x, _y) -> log(ϕ(_x, _y, σ)), x, y)
∇log1_ϕ(x, y) = gradient((_x, _y) -> log(1 - ϕ(_x, _y, σ)), x, y)
for e in 1:n_epochs
for (a, b, p) in fs_set₁ # iterate over 1-simplices
if rand() ≤ p # sample with probability p = μ(a)
# Attractive force (positive sample)
∂a, ∂b = η * ∇logϕ(Y_emb[a], Y_emb[b])
Y_emb[a] -= ∂a
# Repulsive forces (negative samples)
for _ in 1:n_neg_samples
c = sample(Y_emb)
∂a, ∂c = η * ∇log1_ϕ(Y_emb[a], Y_emb[c])
Y_emb[a] -= ∂a
end
end
end
η = 1 - e/n_epochs # linear learning rate decay
end
return Y_emb
endThe algorithm iterates over edges in the fuzzy simplicial set, sampling each edge with probability equal to its membership strength $\mu(a)$. For each sampled edge:
- Attractive force: Apply gradient descent on $\log\phi(x, y)$ to pull connected points together
- Repulsive forces: Sample
n_neg_samplesrandom points and apply gradient descent on $\log(1 - \phi(x, y))$ to push disconnected points apart
The learning rate $\eta$ decays linearly from 1 to 0 over the course of training.