Source code for dstk.modules.count_models

"""
This module offers functionality to transform and reduce high-dimensional text data represented as matrices, enabling more effective downstream analysis and modeling.

Key features include:

* Scaling input matrices to zero mean and unit variance using standardization.
* Generating low-dimensional word embeddings from co-occurrence matrices using dimensionality reduction techniques:
* Truncated Singular Value Decomposition (SVD)
* Principal Component Analysis (PCA)

These techniques help distill semantic information from sparse and high-dimensional co-occurrence data, facilitating tasks such as clustering, visualization, and feature extraction in natural language processing pipelines.

All functions return results as Pandas DataFrames for seamless integration with data workflows.
"""

from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA, TruncatedSVD
import pandas as pd

from ..lib_types import ndarray, DataFrame

[docs] def scale_matrix(matrix: DataFrame, **kwargs) -> DataFrame: """ Scales the input matrix to have zero mean and unit variance for each feature. This method applies standardization using scikit-learn's StandardScaler, which transforms the data such that each colum (feature) has a mean of 0 and a standard deviation of 1. :param matrix: The input data to scale. :type matrix: DataFrame :param kwargs: Additional keyword arguments to pass to sklearn's StandardScaler. For more information check: https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html :returns: A scaled matrix. :rtype: DataFrame """ scaler: StandardScaler = StandardScaler(**kwargs) scaled_matrix: ndarray = scaler.fit_transform(matrix) return pd.DataFrame(scaled_matrix, index=matrix.index, columns=matrix.columns)
[docs] def svd_embeddings(matrix: DataFrame, n_components: int = 100, **kwargs) -> DataFrame: """ Generates word embeddings using truncated Single Value Descomposition (SVD). :param matrix: A Co-occurrence matrix from which embeddings will be generated. :type matrix: DataFrame :param n_components: The number of dimensions to reduce the word embeddings to. Defaults to 100. :type n_components: int :param kwargs: Additional keyword arguments to pass to sklearn's TruncatedSVD. Common options include: * **n_components:** Specifies the number of dimensions to reduce the Co-ocurrence matrix to. For more information check: https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.TruncatedSVD.html :returns: A DataFrame of word embeddings generated by SVD. :rtype: DataFrame """ svd: TruncatedSVD = TruncatedSVD(n_components=n_components, **kwargs) embeddings: ndarray = svd.fit_transform(matrix) shape: int = embeddings.shape[1] columns: list[str] = [f"dim_{num}" for num in range(shape)] return pd.DataFrame(embeddings, index=matrix.index, columns=columns)
[docs] def pca_embeddings(matrix: DataFrame, n_components: int | float = 100, **kwargs) -> DataFrame: """ Generates word embeddings using Principal Component Analysis (PCA). :param matrix: A Co-occurrence matrix from which embeddings will be generated. :type matrix: DataFrame :param n_components: If an integer, the number of dimensions to reduce the word embeddings to. If a float between 0 and 1, specifies the proportion of variance to preserve. Defaults to 100. :type n_components: int or float :param kwargs: Additional keyword arguments to pass to sklearn's PCA. Common options include: * **n_components:** If an integer, specifies the number of dimensions to reduce the Co-ocurrence matrix to. If a float, the amount of variance to preserve during PCA. For more information check: https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html :returns: A DataFrame of word embeddings generated by PCA. :rtype: DataFrame """ pca: PCA = PCA(n_components=n_components, **kwargs) embeddings: ndarray = pca.fit_transform(matrix) shape: int = embeddings.shape[1] columns: list[str] = [f"dim_{num}" for num in range(shape)] return pd.DataFrame(embeddings, index=matrix.index, columns=columns)