using CairoMakie
true
begin
    using MLDatasets
    ENV["DATADEPS_ALWAYS_ACCEPT"] = true
end
using UMAP
using Distances

UMAP.jl on MNIST Digits

mnist_x = reshape(MNIST.traintensor(Float64), :, 60000);
mnist_y = MNIST.trainlabels();
result = UMAP.fit(mnist_x);
begin
    f = Figure()
    axis = f[1, 1] = Axis(f, title="UMAP.jl - MNIST")
    for d in 0:9
        idx = mnist_y[1:5000] .== d
        scatter!(axis, getindex.(result.embedding[1:5000][idx], 1), getindex.(result.embedding[1:5000][idx], 2), label=string(d), markersize=5)
    end
    f[1, 2] = Legend(f, axis, "Digit", framevisible=false)
    f
end