Use AMICA in a Scikit-Learn Pipeline

We’ll use AMICA as a preprocessing step in a scikit-learn pipeline to perform digit classification on the MNIST dataset.

import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report

from amica import AMICA

Load & split dataset

Download MNIST (70k samples, 28×28 flattened)

X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)

# Just take digits 0-3 to speed up computation
mask = np.isin(y, ["0", "1", "2", "3"])
X = X[mask].copy()
y = y[mask].copy().astype(int)

# Train/test split: 60k / 10k
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=1/7.0, shuffle=True, random_state=0
)

Build scikit-learn pipeline with AMICA

pipe = Pipeline([
    ("center", StandardScaler(with_std=False)),  # remove global brightness bias
    ("amica", AMICA(n_components=60, max_iter=200, tol=.0001, random_state=0)),
    ("scale_components", StandardScaler()),      # optional but helps LR
    ("logreg", LogisticRegression(
        max_iter=2000,
        n_jobs=-1
    )),
])

Fit

/home/circleci/project/amica-python/src/amica/linalg.py:333: RuntimeWarning: invalid value encountered in sqrt
  Winv = (eigvecs * np.sqrt(eigvals)) @ eigvecs.T  # Inverse of the whitening matrix

/home/circleci/project/amica-python/.venv/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:1184: FutureWarning: 'n_jobs' has no effect since 1.8 and will be removed in 1.10. You provided 'n_jobs=-1', please leave it unspecified.
  warnings.warn(msg, category=FutureWarning)
Pipeline(steps=[('center', StandardScaler(with_std=False)),
                ('amica',
                 AMICA(max_iter=200, n_components=60, random_state=0,
                       tol=0.0001)),
                ('scale_components', StandardScaler()),
                ('logreg', LogisticRegression(max_iter=2000, n_jobs=-1))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


Evaluate

y_pred = pipe.predict(X_test)

print(classification_report(
    y_test, y_pred, target_names=[str(i) for i in range(4)]
))

print(f"Accuracy: {pipe.score(X_test, y_test):.4f}")
              precision    recall  f1-score   support

           0       0.98      0.99      0.98       951
           1       0.98      0.99      0.98      1135
           2       0.96      0.94      0.95       988
           3       0.97      0.96      0.97      1057

    accuracy                           0.97      4131
   macro avg       0.97      0.97      0.97      4131
weighted avg       0.97      0.97      0.97      4131

Accuracy: 0.9717

Important features for the 0 digit

We can select the most important ICA features for the 0 class (with negative and positive weights) and display their associate ICA sources.

Helper

def imshow_row(images, titles=None, figsize=(20, 4), suptitle=None, cmap="gray"):
    fig, axes = plt.subplots(1, len(images), figsize=figsize, constrained_layout=True)
    if suptitle:
        fig.suptitle(suptitle, fontsize=18, fontweight="bold")
    for i, ax in enumerate(axes):
        ax.imshow(images[i].reshape(28, 28), cmap=cmap)
        ax.axis("off")
        if titles is not None:
            ax.set_title(titles[i])
    return fig

Show sample digits of class 0

zeros = X[y == 0][:10]

imshow_row(
    zeros,
    suptitle="10 samples of digit '0'"
)
plt.show()
10 samples of digit '0'

Top positive / negative logistic weights

logreg = pipe.named_steps["logreg"]
amica = pipe.named_steps["amica"]

coef = logreg.coef_[0]
sorted_idx = np.argsort(coef)

top_pos = sorted_idx[-5:][::-1]
top_neg = sorted_idx[:5]

imshow_row(
    amica.components_[top_pos],
    titles=[f"Comp {i}" for i in top_pos],
    suptitle="Top 5 positive AMICA components for class 0"
)
plt.show()
Top 5 positive AMICA components for class 0, Comp 19, Comp 30, Comp 9, Comp 3, Comp 39
imshow_row(
    amica.components_[top_neg],
    titles=[f"Comp {i}" for i in top_neg],
    suptitle="Top 5 negative AMICA components for class 0"
)
plt.show()
Top 5 negative AMICA components for class 0, Comp 0, Comp 59, Comp 16, Comp 22, Comp 57

Total running time of the script: (1 minutes 49.319 seconds)

Gallery generated by Sphinx-Gallery