"""
Visualization utilities for word embeddings using UMAP dimensionality reduction.
This module provides a tool to project high-dimensional word embeddings into 2D or 3D space for visualization purposes. By employing the UMAP algorithm, it reduces complex data dimensions while preserving the underlying semantic structure, allowing users to visually explore how words relate to one another in a vector space.
Core functionalities include:
* Projecting high-dimensional embeddings into both 2D and 3D coordinate systems.
* Generating interactive scatter plots using Plotly for intuitive exploration.
* Automatically identifying and coloring clusters based on input data.
* Customizing UMAP parameters (e.g., number of neighbors, distance metrics) to refine the projection.
* Exporting visualizations as standalone HTML files for easy sharing or offline viewing.
This module is designed to help linguists and digital humanities researchers interpret word embeddings by identifying semantic clusters and patterns through visual inspection.
"""
import plotly.express as px
from umap import UMAP
import pandas as pd
import numpy as np
from typing import Literal, cast
from ..lib_types import ndarray, DataFrame, Figure
[docs]
def plot_embeddings(
embeddings: DataFrame,
n_dimensions: Literal[2, 3] = 2,
labels: bool = False,
show: bool = True,
path: str | None = None,
n_neighbors: int = 15,
projection_metric: str = "cosine",
min_dist: float = 0.1,
approximate: int | None = None,
) -> Figure:
"""
Generates a 2D or 3D visualization of word embeddings using UMAP.
:param embeddings: DataFrame containing word embeddings.
:type embeddings: DataFrame
:param n_dimensions: Output dimensionality (2 or 3).
:type n_dimensions: int
:param labels: If True, display word labels on points.
:type labels: bool
:param show: If True, display the plot.
:type show: bool
:param path: If provided, save the figure as an HTML file.
:type path: str | None
:param n_neighbors: Number of neighbors used by UMAP.
:type n_neighbors: int
:param projection_metric: Distance metric used by UMAP (e.g. "cosine").
:type projection_metric: str
:param min_dist: Minimum distance between embedded points.
:type min_dist: float
:param approximate: If set, subsamples the embeddings for faster computation.
:type approximate: int | None
:return: Plotly figure with the projected embeddings.
:rtype: Figure
"""
if n_dimensions not in (2, 3):
raise ValueError("Only 2D or 3D plots are supported (n_dimensions=2 or 3)")
if approximate:
embeddings = embeddings.sample(approximate, random_state=42)
feature_columns: DataFrame = embeddings.select_dtypes(include=np.number)
embeddings_array: ndarray = feature_columns.to_numpy()
shape: int = embeddings_array.shape[1]
if shape < n_dimensions:
raise ValueError(
f"You cannot plot in {n_dimensions} dimensions because your embeddings only have {shape} dimensions."
)
visual_reducer: UMAP = UMAP(
n_components=n_dimensions,
n_neighbors=n_neighbors,
min_dist=min_dist,
metric=projection_metric,
low_memory=True if approximate else False,
random_state=42,
)
umap_visual_embeddings: ndarray = cast(
ndarray, visual_reducer.fit_transform(embeddings_array)
)
columns: list[str] = [f"Semantic Axis {i+1}" for i in range(n_dimensions)]
umap_df: DataFrame = pd.DataFrame(
umap_visual_embeddings, index=embeddings.index, columns=columns
)
if "cluster" in embeddings.columns:
umap_df["Cluster"] = embeddings["cluster"]
else:
umap_df["Cluster"] = "Unclustered"
umap_df["Word"] = umap_df.index
color_scale: list[str] = px.colors.qualitative.Plotly
scatter: Figure
if n_dimensions == 2:
scatter = px.scatter(
umap_df,
x=columns[0],
y=columns[1],
color="Cluster",
text="Word" if labels else None,
hover_data=["Word"] + columns + ["Cluster"],
title="2D Projection of word embeddings",
color_discrete_sequence=color_scale,
)
scatter.update_traces(textfont_size=10, textposition="top center")
else:
scatter = px.scatter_3d(
umap_df,
x=columns[0],
y=columns[1],
z=columns[2],
color="Cluster",
text="Word" if labels else None,
hover_data=["Word"] + columns + ["Cluster"],
title="3D Projection of word embeddings",
color_discrete_sequence=color_scale,
)
scatter.update_traces(textfont_size=14, textposition="top center")
if path:
scatter.write_html(path)
if show:
scatter.show()
return scatter