using UMAP
using Distances
using StringDistances

Advanced Usage

Algorithm

At a high level, the UMAP algorithm proceeds in the following steps:

knns_dists = knn_search(data, knn_params)
umap_graphs = fuzzy_simplicial_set(knns_dists, knn_params, src_view_params)
umap_graph = coalesce_views(umap_graphs, src_global_params)
embedding = initialize_embedding(umap_graph, tgt_params)
optimize_embedding!(embedding, umap_graph, tgt_params, opt_params)

In a typical workflow, the first step of the UMAP algorithm is to find a (approximate) k-nearest neighbor graph.

Approximate neighbors for vector data

A very simple example of this is to find 4 approximate nearest neighbors for vectors in \(R^n\) using the Euclidean metric:

xs = [rand(10) for _ in 1:10];
DescentNeighbors{M, K}(n_neighbors, metric::M, kwargs::K)

Parameters for finding approximate nearest neighbors using NearestNeighborDescent.

knn_params = UMAP.DescentNeighbors(4, Euclidean());
([9 4 … 1 9; 7 7 … 10 8; 2 1 … 7 5; 4 6 … 5 1], [0.8050651226597185 1.1379123414569403 … 0.8050651226597185 1.115300210819546; 1.1969410289966496 1.187404792132657 … 1.115300210819546 1.2506821039945475; 1.2050594450140457 1.2050594450140457 … 1.2471916120847273 1.3019564225690332; 1.2213931852253204 1.2773290390420435 … 1.271039754146728 1.4056660014440479])
UMAP.knn_search(xs, knn_params)

The return result in this case is a tuple of 4x10 (n_neighbors x n_points) matrices, one for the indices of the nearest neighbors and the second for the distances.

e.g. knn_search(xs, knn_params) -> indices, distances

The knn parameter struct DescentNeighbors uses NearestNeighborDescent.jl to find the approximate knns of the data. It also allows passing keyword arguments to nndescent:

knn_params_kw = UMAP.DescentNeighbors(4, Euclidean(), (max_iters=15,));
([9 4 … 1 9; 7 7 … 10 8; 2 1 … 7 5; 4 6 … 5 1], [0.8050651226597185 1.1379123414569403 … 0.8050651226597185 1.115300210819546; 1.1969410289966496 1.187404792132657 … 1.115300210819546 1.2506821039945475; 1.2050594450140457 1.2050594450140457 … 1.2471916120847273 1.3019564225690332; 1.2213931852253204 1.2773290390420435 … 1.271039754146728 1.4056660014440479])
UMAP.knn_search(xs, knn_params_kw)

Precomputed distances

Alternatively, a precomputed distance matrix can be passed in if the pairwise distances are already known. This is done by using the PrecomputedNeighbors knn parameter struct (note that n_neighbors is still required in order to later construct the fuzzy simplicial set, and for transforming new data):

distances = [0. 2 1;
             2 0 3;
             1 3 0];
PrecomputedNeighbors{M}(n_neighbors, dists_or_graphs::M)

Parameters for finding nearest neighbors from precomputed distances. dists_or_graphs can either be a matrix (pairwise distances) or a NearestNeighborDescent.ApproximateKNNGraph.

knn_params_pre = UMAP.PrecomputedNeighbors(2, distances);
([3 1 1; 2 3 2], [1.0 2.0 1.0; 2.0 3.0 3.0])
UMAP.knn_search(nothing, knn_params_pre)

Multiple views

One key feature of UMAP is combining multiple, heterogeneous views of the same dataset. For the knn search step, this is set up by passing a NamedTuple of data views and a corresponding named tuple of knn parameter structs. The knn_search function then broadcasts for each (data, knn_param) pair and returns a named tuple of (indices, distances) that similarly corresponds to the input.

For example, in addition to the vector data xs we might also have string data:

xs_str = [join(rand('A':'Z', 10), "") for _ in 1:10];
knn_params_str = UMAP.DescentNeighbors(3, RatcliffObershelp());
data_views = (view_1=xs, 
              view_2=xs_str);
knn_params_views = (view_1=knn_params, 
                    view_2=knn_params_str);
(view_1 = ([9 4 … 1 9; 7 7 … 10 8; 2 1 … 7 5; 4 6 … 5 1], [0.8050651226597185 1.1379123414569403 … 0.8050651226597185 1.115300210819546; 1.1969410289966496 1.187404792132657 … 1.115300210819546 1.2506821039945475; 1.2050594450140457 1.2050594450140457 … 1.2471916120847273 1.3019564225690332; 1.2213931852253204 1.2773290390420435 … 1.271039754146728 1.4056660014440479]), view_2 = ([6 7 … 6 4; 4 8 … 5 6; 7 3 … 7 7], [0.6 0.8 … 0.7 0.7; 0.7 0.8 … 0.7 0.8; 0.7 0.8 … 0.7 0.8]))
UMAP.knn_search(data_views, knn_params_views)

This is relatively flexible, with one major restriction that the size of the data in each view must be the same. You can even combine DescentNeighbors with PrecomputedNeighbors.

Fuzzy Simplicial Sets

Once we have one or more set of knns for our data (one for each view), we can construct a global fuzzy simplicial set. This is done via the function

fuzzy_simplicial_set(...) -> umap_graph::SparseMatrixCSC

A global fuzzy simplicial set is constructed for each view of the data with construction paramaterized by the SourceViewParams struct. If there is more than one view, their results are combined to return a single fuzzy simplicial set (represented as a weighted, undirected graph).

Fuzzy simplicial set - one view

To create a fuzzy simplicial set for our original dataset of vectors:

SourceViewParams{T}(set_operation_ratio, local_connectivity, bandwidth)

Struct for parameterizing the representation of the data in the source (original) manifold; i.e. constructing fuzzy simplicial sets of each view of the dataset.

src_view_params = UMAP.SourceViewParams(1, 1, 1);
knns_dists = UMAP.knn_search(xs, knn_params);
10×10 SparseArrays.SparseMatrixCSC{Float64, Int64} with 50 stored entries:
  ⋅        0.587095  0.169593  0.524543  …  0.343225   ⋅        1.0       0.191563
 0.587095   ⋅         ⋅        1.0          0.488591   ⋅         ⋅         ⋅ 
 0.169593   ⋅         ⋅         ⋅           0.521125  0.229551   ⋅         ⋅ 
 0.524543  1.0        ⋅         ⋅           1.0        ⋅         ⋅         ⋅ 
 0.294908   ⋅        1.0        ⋅            ⋅         ⋅        0.455068  0.345658
  ⋅        0.132978  1.0       0.504976  …  0.674144  0.151089   ⋅         ⋅ 
 0.343225  0.488591  0.521125  1.0           ⋅        1.0       0.296401   ⋅ 
  ⋅         ⋅        0.229551   ⋅           1.0        ⋅         ⋅        0.79551
 1.0        ⋅         ⋅         ⋅           0.296401   ⋅         ⋅        1.0
 0.191563   ⋅         ⋅         ⋅            ⋅        0.79551   1.0        ⋅ 
UMAP.fuzzy_simplicial_set(knns_dists, knn_params, src_view_params)

Fuzzy simplicial set - multiple views

As before, multiple views can be passed to fuzzy_simplicial_set - each parameterized by its own SourceViewParams - and combined into a single, global fuzzy simplicial set.

Using our combination of vector and string data:

knns_dists_views = UMAP.knn_search(data_views, knn_params_views);
src_view_params_views = (view_1=src_view_params, 
                         view_2=src_view_params);
(view_1 = sparse([2, 3, 4, 5, 7, 9, 10, 1, 4, 6  …  7, 10, 1, 5, 7, 10, 1, 5, 8, 9], [1, 1, 1, 1, 1, 1, 1, 2, 2, 2  …  8, 8, 9, 9, 9, 9, 10, 10, 10, 10], [0.5870951563012333, 0.1695925296557939, 0.5245434455326707, 0.29490818849910133, 0.34322483124854974, 1.0, 0.1915627005583961, 0.5870951563012333, 1.0, 0.13297836394748544  …  1.0, 0.7955099346169894, 1.0, 0.45506753304259995, 0.2964014904125186, 1.0, 0.1915627005583961, 0.34565833419311764, 0.7955099346169894, 1.0], 10, 10), view_2 = sparse([4, 6, 7, 10, 3, 8, 9, 2, 6, 7  …  5, 9, 2, 5, 6, 7, 8, 1, 4, 6], [1, 1, 1, 1, 2, 2, 2, 3, 3, 3  …  8, 8, 9, 9, 9, 9, 9, 10, 10, 10], [0.4994242052642983, 1.0, 0.9999999999999999, 0.29248618760076334, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  1.0, 0.29248618760076334, 1.0, 1.0, 1.0, 1.0, 0.29248618760076334, 0.29248618760076334, 1.0, 0.29248618760076334], 10, 10))
fsset_views = UMAP.fuzzy_simplicial_set(knns_dists_views, knn_params_views, src_view_params_views)

Combining views' fuzzy simplicial sets

We need a single umap graph (i.e. global fuzzy simplicial set) in order to perform optimization, so if there are multiple dataset views we must combine their sets.

The views' fuzzy sets are combined via fuzzy set intersection left-to-right (foldl), where mix_ratio is a weight for the intersection operation, analogous to a weighted mean.

SourceGlobalParams{T}(mix_ratio)

Parameters for merging the fuzzy simplicial sets for each dataset view into one fuzzy simplicial set, otherwise known as the UMAP graph.

src_gbl_params = UMAP.SourceGlobalParams(0.5);
10×10 SparseArrays.SparseMatrixCSC{Float64, Int64} with 68 stored entries:
  ⋅        0.784542  0.527008  0.962333  …  1.0        ⋅        0.814922  0.700262
 0.784542   ⋅        0.772443  1.0          0.674864  0.79684   0.699147   ⋅ 
 0.527008  0.772443   ⋅         ⋅           0.962359  0.652674   ⋅         ⋅ 
 0.962333  1.0        ⋅         ⋅           1.0        ⋅         ⋅        0.84522
 0.625076   ⋅        0.896094   ⋅            ⋅        0.817183  1.0       0.775504
 0.570467  0.35683   1.0       0.970254  …  0.601093  0.436695  0.595365  0.452716
 1.0       0.674864  0.962359  1.0           ⋅        1.0       0.914833   ⋅ 
  ⋅        0.79684   0.652674   ⋅           1.0        ⋅        0.404673  0.981639
 0.814922  0.699147   ⋅         ⋅           0.914833  0.404673   ⋅        1.0
 0.700262   ⋅         ⋅        0.84522       ⋅        0.981639  1.0        ⋅ 
graph = UMAP.coalesce_views(fsset_views, src_gbl_params)

Initialize and optimize target embedding

Again, the goal of UMAP is to embed points on one manifold onto a new manifold. By default, this is a Euclidean manifold of specified dimension d.

The target embedding is parameterized by a target metric (default SqEuclidean, but in principle could be anything) as well as a membership function that computes the fuzzy set membership on the new manifold (the analog to the probabilities contained in the UMAP graph).

MembershipFnParams{T}(min_dist, spread)

A smooth approximation for the membership strength of a 1-simplex between two points x, y, can be given by the following, with dissimilarity function dist, and constants a, b: ϕ(x, y, dist, a, b) = (1 + a*(dist(x, y))^b)^(-1)

The approximation parameters a, b are chosen by non-linear least squares fitting of the following function ψ:

ψ(x, y, dist, mindist, spread) = dist(x, y) ≤ mindist ? 1 : exp(-(dist(x, y) - min_dist)/spread)

TargetParams{M, D, I, P}(manifold::M, metric::D, init::I, memb_params::P)

Parameters for controlling the target embedding, e.g. the manifold, distance metric, initialization method.

Initialize target embedding

The default target space is d-dimensional Euclidean space, with the squared Euclidean distance metric. Two initialization methods are provided: random and spectral layout.

Uniform random vectors in R^2

tgt_params = UMAP.TargetParams(UMAP._EuclideanManifold{2}(), SqEuclidean(), UMAP.UniformInitialization(), UMAP.MembershipFnParams(1., 1.));
umap_graph = UMAP.fuzzy_simplicial_set(knns_dists, knn_params, src_view_params);
10-element Vector{Vector{Float64}}:
 [-3.0255214443566247, -4.087284846667034]
 [3.7651828407682544, -0.800758962910475]
 [3.522351502198502, 1.7502173216668062]
 [2.9884133532046224, 5.3380140403531335]
 [3.768566155722139, 1.8326983313719243]
 [-6.494281122048424, 1.8862327889807524]
 [-8.742119168484537, 9.538045349814727]
 [2.28903150693494, 9.61799032191642]
 [8.519614286932896, 9.036748113625535]
 [6.051612258653968, -0.512747889497863]
xs_embed = UMAP.initialize_embedding(umap_graph, tgt_params)

Optimize target embedding

The embedding is optimized by minimizing the fuzzy set cross entropy loss between the two fuzzy set representations of the data. This is optimization process is parameterized by OptimizationParams

OptimizationParams(n_epochs, learning_rate, repulsion_strength, neg_sample_rate)

Parameters for controlling the optimization process.

Optimize one epoch

opt_params = UMAP.OptimizationParams(1, 1., 1., 5);
xs_embed_opt = copy(xs_embed);
10-element Vector{Vector{Float64}}:
 [-1.947728854100013, -2.8075136096361253]
 [2.357546490587466, -1.0415375342231918]
 [2.6593730028542732, 1.5545382193695583]
 [1.5607909296986362, 4.522105763031607]
 [6.793232926875479, 1.4087963137827095]
 [-5.56327532065652, 2.3039503581161718]
 [-6.9692112091546585, 8.721785551019172]
 [1.858229737302015, 8.884024662525338]
 [7.398732791906423, 7.494070673669046]
 [3.5929586615799227, 6.3174681125898475]
UMAP.optimize_embedding!(xs_embed_opt, umap_graph, tgt_params, opt_params)
323.5911074748375
UMAP.cross_entropy(umap_graph, xs_embed, tgt_params)
323.5911074748375
UMAP.cross_entropy(umap_graph, xs_embed_opt, tgt_params)