Predict hand pose from EMG signals using ML

Meta has been doing a lot of work trying to figure out how to replace keyboards and controllers with hand gestures (when the user is wearing a wrist band).

Back when I was a young hopper (AKA a PhD student), I interviewed with their Reality Labs team. The technical interview at the time was to take a dataset of EMG signals from a participant, and to predict the hand pose of the participant (Also, nearly pure python.. I.e. no scikit-learn, no pytorch. Woof).

Well now they’ve open-sourced the EMG dataset from that project, so I am going to save some soul out there some time and show how them how to do it.

from collections.abc import KeysView
from dataclasses import dataclass
from pathlib import Path
from typing import Any, ClassVar

import h5py
import matplotlib.pyplot as plt
import mne
import numpy as np
import pooch
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

Create a helper function to read the data

@dataclass
class Emg2PoseSessionData:
    """A read-only interface to a single emg2pose session file stored in
    HDF5 format.

    ``self.timeseries`` is a `h5py.Dataset` instance with a compound data type
    as in a numpy structured array containing three fields - EMG data from the
    left and right wrists, and their corresponding timestamps.
    The sampling rate of EMG is 2kHz, each EMG device has 16 electrode
    channels, and the signal has been high-pass filtered. Therefore, the fields
    corresponding to left and right EMG are 2D arrays of shape ``(T, 16)`` each
    and ``timestamps`` is a 1D array of length ``T``.

    NOTE: Only the metadata and ground-truth are loaded into memory while the
    EMG data is accesssed directly from disk. When wrapping this interface
    within a PyTorch Dataset, use multiple dataloading workers to mask the
    disk seek and read latencies."""

    HDF5_GROUP: ClassVar[str] = "emg2pose"
    # timeseries keys
    TIMESERIES: ClassVar[str] = "timeseries"
    EMG: ClassVar[str] = "emg"
    JOINT_ANGLES: ClassVar[str] = "joint_angles"
    TIMESTAMPS: ClassVar[str] = "time"
    # metadata keys
    SESSION_NAME: ClassVar[str] = "session"
    SIDE: ClassVar[str] = "side"
    STAGE: ClassVar[str] = "stage"
    START_TIME: ClassVar[str] = "start"
    END_TIME: ClassVar[str] = "end"
    NUM_CHANNELS: ClassVar[str] = "num_channels"
    DATASET_NAME: ClassVar[str] = "dataset"
    USER: ClassVar[str] = "user"
    SAMPLE_RATE: ClassVar[str] = "sample_rate"

    hdf5_path: Path

    def __post_init__(self) -> None:
        self._file = h5py.File(self.hdf5_path, "r")
        emg2pose_group: h5py.Group = self._file[self.HDF5_GROUP]

        # ``timeseries`` is a HDF5 compound Dataset
        self.timeseries: h5py.Dataset = emg2pose_group[self.TIMESERIES]
        assert self.timeseries.dtype.fields is not None
        assert self.EMG in self.timeseries.dtype.fields
        assert self.JOINT_ANGLES in self.timeseries.dtype.fields
        assert self.TIMESTAMPS in self.timeseries.dtype.fields

        # Load the metadata entirely into memory as it's rather small
        self.metadata: dict[str, Any] = {}
        for key, val in emg2pose_group.attrs.items():
            self.metadata[key] = val

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback) -> None:
        self._file.close()

    def __len__(self) -> int:
        return len(self.timeseries)

    def __getitem__(self, key: slice) -> np.ndarray:
        return self.timeseries[key]

    def slice(self, start_t: float = -np.inf, end_t: float = np.inf) -> np.ndarray:
        """Load and return a contiguous slice of the timeseries windowed
        by the provided start and end timestamps.

        Args:
            start_t (float): The start time of the window to grab
                (in absolute unix time). Defaults to selecting from the
                beginning of the session. (default: ``-np.inf``).
            end_t (float): The end time of the window to grab
                (in absolute unix time). Defaults to selecting until the
                end of the session. (default: ``np.inf``)
        """
        start_idx, end_idx = self.timestamps.searchsorted([start_t, end_t])
        return self[start_idx:end_idx]

    @property
    def fields(self) -> KeysView[str]:
        """The names of the fields in ``timeseries``."""
        fields: KeysView[str] = self.timeseries.dtype.fields.keys()
        return fields

    @property
    def timestamps(self) -> np.ndarray:
        """EMG timestamps.

        NOTE: This reads the entire sequence of timesetamps from the underlying
        HDF5 file and therefore incurs disk latency. Avoid this in the critical
        path."""
        emg_timestamps = self.timeseries[self.TIMESTAMPS]
        assert (np.diff(emg_timestamps) >= 0).all(), "Not monotonic"
        return emg_timestamps

    @property
    def session_name(self) -> str:
        """Unique name of the session."""
        return self.metadata[self.SESSION_NAME]

    @property
    def user(self) -> str:
        """Unique ID of the user this session corresponds to."""
        return self.metadata[self.USER]

    def __str__(self) -> str:
        return f"{self.__class__.__name__} {self.session_name} ({len(self)} samples)"

Download the dataset

data_dir = Path.home() / "emg_data"
emg_dir = data_dir / "emg2pose_dataset_mini"
want_fpath = "emg2pose_dataset_mini/2022-12-06-1670313600-e3096-cv-emg-pose-train@2-recording-1_left.hdf5"

unpack = pooch.Untar(extract_dir=data_dir, # Relative to the path where the zip file is downloaded
                     members=[want_fpath]
                     )
emg_fpaths = pooch.retrieve(
    url="https://fb-ctrl-oss.s3.amazonaws.com/emg2pose/emg2pose_dataset_mini.tar",
    known_hash="sha256:d7400e98508ccbb2139c2d78e552867b23501f637456546fd6680f3fe7fec50d",
    progressbar=True,
    path=data_dir,
    processor=unpack,
)
emg_fname = Path(emg_fpaths[0])
emg_dir = emg_fname.parent
# Delete the large tar file
list(data_dir.glob("*.tar"))[0].unlink()
  0%|                                               | 0.00/647M [00:00<?, ?B/s]
  1%|▍                                     | 6.61M/647M [00:00<00:14, 43.3MB/s]
  2%|▋                                     | 10.9M/647M [00:00<00:16, 37.7MB/s]
  3%|▉                                     | 16.8M/647M [00:00<00:17, 36.1MB/s]
  4%|█▎                                    | 23.4M/647M [00:00<00:15, 41.4MB/s]
  4%|█▌                                    | 27.6M/647M [00:00<00:16, 38.6MB/s]
  5%|█▉                                    | 32.2M/647M [00:00<00:15, 40.7MB/s]
  6%|██▏                                   | 36.4M/647M [00:00<00:15, 40.2MB/s]
  6%|██▍                                   | 41.9M/647M [00:01<00:18, 32.5MB/s]
  8%|██▊                                   | 48.5M/647M [00:01<00:18, 32.0MB/s]
  8%|███                                   | 51.9M/647M [00:01<00:20, 29.7MB/s]
  9%|███▎                                  | 56.9M/647M [00:01<00:25, 23.5MB/s]
  9%|███▍                                  | 59.6M/647M [00:01<00:26, 21.8MB/s]
 10%|███▊                                  | 65.3M/647M [00:02<00:29, 19.9MB/s]
 10%|███▉                                  | 67.5M/647M [00:02<00:29, 19.6MB/s]
 11%|████▎                                 | 74.1M/647M [00:02<00:21, 27.0MB/s]
 12%|████▌                                 | 77.2M/647M [00:02<00:20, 27.2MB/s]
 12%|████▋                                 | 80.2M/647M [00:02<00:21, 25.8MB/s]
 13%|████▉                                 | 84.2M/647M [00:02<00:19, 28.8MB/s]
 14%|█████▎                                | 90.5M/647M [00:03<00:18, 29.3MB/s]
 14%|█████▍                                | 93.6M/647M [00:03<00:21, 25.6MB/s]
 15%|█████▊                                | 98.9M/647M [00:03<00:19, 28.7MB/s]
 16%|██████▏                                | 102M/647M [00:03<00:24, 22.2MB/s]
 17%|██████▍                                | 107M/647M [00:03<00:24, 22.4MB/s]
 17%|██████▌                                | 110M/647M [00:04<00:24, 21.5MB/s]
 18%|██████▉                                | 116M/647M [00:04<00:27, 19.2MB/s]
 18%|███████                                | 118M/647M [00:04<00:28, 18.4MB/s]
 19%|███████▍                               | 124M/647M [00:04<00:24, 21.2MB/s]
 19%|███████▌                               | 126M/647M [00:04<00:27, 18.7MB/s]
 20%|███████▉                               | 132M/647M [00:05<00:25, 20.2MB/s]
 21%|████████                               | 134M/647M [00:05<00:28, 18.2MB/s]
 22%|████████▍                              | 141M/647M [00:05<00:19, 25.4MB/s]
 22%|████████▋                              | 144M/647M [00:05<00:22, 22.0MB/s]
 23%|█████████                              | 151M/647M [00:05<00:18, 26.2MB/s]
 25%|█████████▌                             | 159M/647M [00:06<00:14, 34.4MB/s]
 26%|██████████                             | 168M/647M [00:06<00:12, 37.0MB/s]
 27%|██████████▌                            | 174M/647M [00:06<00:13, 35.4MB/s]
 28%|██████████▋                            | 178M/647M [00:06<00:13, 34.8MB/s]
 29%|███████████                            | 185M/647M [00:06<00:13, 35.2MB/s]
 30%|███████████▋                           | 193M/647M [00:06<00:12, 36.8MB/s]
 31%|████████████                           | 200M/647M [00:07<00:11, 39.1MB/s]
 31%|████████████▎                          | 204M/647M [00:07<00:12, 36.3MB/s]
 32%|████████████▋                          | 210M/647M [00:07<00:12, 35.6MB/s]
 33%|█████████████                          | 216M/647M [00:07<00:11, 39.0MB/s]
 34%|█████████████▎                         | 220M/647M [00:07<00:13, 30.6MB/s]
 35%|█████████████▌                         | 225M/647M [00:07<00:15, 27.9MB/s]
 35%|█████████████▋                         | 228M/647M [00:08<00:16, 24.8MB/s]
 36%|██████████████                         | 233M/647M [00:08<00:15, 27.6MB/s]
 36%|██████████████▏                        | 236M/647M [00:08<00:16, 25.6MB/s]
 37%|██████████████▌                        | 242M/647M [00:08<00:12, 31.3MB/s]
 38%|██████████████▊                        | 245M/647M [00:08<00:13, 30.2MB/s]
 39%|███████████████                        | 250M/647M [00:08<00:16, 24.8MB/s]
 39%|███████████████▏                       | 253M/647M [00:09<00:16, 23.5MB/s]
 40%|███████████████▌                       | 259M/647M [00:09<00:16, 23.4MB/s]
 40%|███████████████▋                       | 261M/647M [00:09<00:16, 24.0MB/s]
 41%|████████████████                       | 267M/647M [00:09<00:15, 25.1MB/s]
 42%|████████████████▏                      | 269M/647M [00:09<00:17, 21.9MB/s]
 42%|████████████████▌                      | 275M/647M [00:10<00:15, 24.2MB/s]
 43%|████████████████▋                      | 277M/647M [00:10<00:17, 21.6MB/s]
 44%|█████████████████                      | 283M/647M [00:10<00:14, 24.6MB/s]
 44%|█████████████████▏                     | 286M/647M [00:10<00:14, 24.5MB/s]
 45%|█████████████████▋                     | 294M/647M [00:10<00:13, 26.8MB/s]
 46%|██████████████████                     | 300M/647M [00:11<00:13, 25.8MB/s]
 47%|██████████████████▏                    | 303M/647M [00:11<00:13, 24.8MB/s]
 48%|██████████████████▌                    | 309M/647M [00:11<00:11, 30.4MB/s]
 48%|██████████████████▊                    | 312M/647M [00:11<00:12, 27.3MB/s]
 49%|███████████████████▏                   | 319M/647M [00:11<00:11, 27.8MB/s]
 50%|███████████████████▌                   | 325M/647M [00:11<00:10, 30.5MB/s]
 51%|███████████████████▊                   | 328M/647M [00:12<00:12, 25.5MB/s]
 52%|████████████████████                   | 334M/647M [00:12<00:10, 28.9MB/s]
 52%|████████████████████▎                  | 337M/647M [00:12<00:11, 27.2MB/s]
 53%|████████████████████▋                  | 343M/647M [00:12<00:09, 33.5MB/s]
 53%|████████████████████▊                  | 346M/647M [00:12<00:10, 27.9MB/s]
 54%|█████████████████████▏                 | 351M/647M [00:12<00:09, 31.1MB/s]
 55%|█████████████████████▎                 | 354M/647M [00:12<00:10, 28.3MB/s]
 55%|█████████████████████▌                 | 358M/647M [00:12<00:09, 29.4MB/s]
 56%|█████████████████████▋                 | 361M/647M [00:13<00:09, 29.4MB/s]
 57%|██████████████████████▏                | 368M/647M [00:13<00:08, 31.3MB/s]
 57%|██████████████████████▎                | 371M/647M [00:13<00:10, 25.5MB/s]
 58%|██████████████████████▋                | 376M/647M [00:13<00:11, 23.0MB/s]
 58%|██████████████████████▊                | 378M/647M [00:13<00:12, 21.8MB/s]
 59%|███████████████████████▏               | 384M/647M [00:14<00:10, 24.7MB/s]
 60%|███████████████████████▎               | 387M/647M [00:14<00:10, 23.9MB/s]
 61%|███████████████████████▋               | 392M/647M [00:14<00:10, 23.7MB/s]
 61%|███████████████████████▊               | 395M/647M [00:14<00:12, 20.1MB/s]
 62%|████████████████████████▏              | 401M/647M [00:14<00:10, 24.6MB/s]
 62%|████████████████████████▎              | 403M/647M [00:14<00:10, 23.9MB/s]
 63%|████████████████████████▋              | 409M/647M [00:15<00:12, 19.8MB/s]
 64%|████████████████████████▊              | 411M/647M [00:15<00:12, 18.7MB/s]
 65%|█████████████████████████▏             | 418M/647M [00:15<00:09, 25.1MB/s]
 65%|█████████████████████████▎             | 420M/647M [00:15<00:09, 24.0MB/s]
 66%|█████████████████████████▋             | 426M/647M [00:15<00:08, 25.4MB/s]
 66%|█████████████████████████▊             | 429M/647M [00:16<00:09, 23.7MB/s]
 67%|██████████████████████████▏            | 434M/647M [00:16<00:07, 27.6MB/s]
 68%|██████████████████████████▎            | 437M/647M [00:16<00:08, 24.6MB/s]
 68%|██████████████████████████▋            | 443M/647M [00:16<00:07, 28.3MB/s]
 69%|██████████████████████████▉            | 448M/647M [00:16<00:06, 33.0MB/s]
 70%|███████████████████████████▎           | 453M/647M [00:16<00:07, 26.9MB/s]
 71%|███████████████████████████▋           | 460M/647M [00:17<00:08, 21.0MB/s]
 71%|███████████████████████████▊           | 462M/647M [00:17<00:10, 18.5MB/s]
 72%|████████████████████████████▏          | 468M/647M [00:17<00:09, 19.1MB/s]
 73%|████████████████████████████▎          | 471M/647M [00:17<00:09, 18.9MB/s]
 74%|████████████████████████████▋          | 476M/647M [00:18<00:07, 22.5MB/s]
 74%|████████████████████████████▊          | 479M/647M [00:18<00:07, 21.2MB/s]
 75%|█████████████████████████████▎         | 486M/647M [00:18<00:05, 29.8MB/s]
 76%|█████████████████████████████▍         | 489M/647M [00:18<00:05, 28.1MB/s]
 76%|█████████████████████████████▊         | 495M/647M [00:18<00:05, 27.4MB/s]
 78%|██████████████████████████████▎        | 503M/647M [00:18<00:04, 32.8MB/s]
 79%|██████████████████████████████▊        | 512M/647M [00:19<00:04, 33.2MB/s]
 80%|███████████████████████████████▏       | 518M/647M [00:19<00:03, 36.2MB/s]
 81%|███████████████████████████████▍       | 522M/647M [00:19<00:03, 32.4MB/s]
 81%|███████████████████████████████▋       | 526M/647M [00:19<00:03, 33.4MB/s]
 82%|███████████████████████████████▉       | 529M/647M [00:19<00:04, 28.9MB/s]
 83%|████████████████████████████████▎      | 535M/647M [00:19<00:03, 35.8MB/s]
 83%|████████████████████████████████▌      | 539M/647M [00:19<00:03, 33.8MB/s]
 84%|████████████████████████████████▋      | 543M/647M [00:20<00:04, 23.4MB/s]
 84%|████████████████████████████████▉      | 546M/647M [00:20<00:04, 21.6MB/s]
 85%|█████████████████████████████████▎     | 552M/647M [00:20<00:03, 24.3MB/s]
 86%|█████████████████████████████████▍     | 555M/647M [00:20<00:04, 21.9MB/s]
 87%|█████████████████████████████████▊     | 560M/647M [00:21<00:04, 21.5MB/s]
 87%|█████████████████████████████████▉     | 563M/647M [00:21<00:04, 20.8MB/s]
 88%|██████████████████████████████████▎    | 569M/647M [00:21<00:02, 26.7MB/s]
 88%|██████████████████████████████████▍    | 572M/647M [00:21<00:03, 23.6MB/s]
 89%|██████████████████████████████████▊    | 577M/647M [00:21<00:02, 27.9MB/s]
 90%|███████████████████████████████████    | 581M/647M [00:21<00:02, 31.1MB/s]
 91%|███████████████████████████████████▍   | 587M/647M [00:21<00:01, 30.6MB/s]
 92%|███████████████████████████████████▊   | 594M/647M [00:22<00:01, 30.7MB/s]
 92%|███████████████████████████████████▉   | 597M/647M [00:22<00:01, 30.7MB/s]
 93%|████████████████████████████████████▎  | 602M/647M [00:22<00:01, 23.9MB/s]
 93%|████████████████████████████████████▍  | 605M/647M [00:22<00:01, 24.1MB/s]
 95%|████████████████████████████████████▉  | 612M/647M [00:22<00:01, 34.0MB/s]
 95%|█████████████████████████████████████▏ | 616M/647M [00:22<00:00, 32.9MB/s]
 96%|█████████████████████████████████████▎ | 620M/647M [00:23<00:00, 31.5MB/s]
 96%|█████████████████████████████████████▌ | 624M/647M [00:23<00:00, 24.9MB/s]
 97%|█████████████████████████████████████▊ | 628M/647M [00:23<00:00, 26.4MB/s]
 97%|██████████████████████████████████████ | 631M/647M [00:23<00:00, 21.0MB/s]
 98%|██████████████████████████████████████▎| 636M/647M [00:23<00:00, 27.2MB/s]
 99%|██████████████████████████████████████▌| 639M/647M [00:23<00:00, 25.1MB/s]
100%|██████████████████████████████████████▊| 644M/647M [00:24<00:00, 26.7MB/s]
100%|██████████████████████████████████████▉| 647M/647M [00:24<00:00, 26.8MB/s]
  0%|                                               | 0.00/647M [00:00<?, ?B/s]
100%|███████████████████████████████████████| 647M/647M [00:00<00:00, 2.71TB/s]

Load the data

data = Emg2PoseSessionData(hdf5_path=emg_fname)

Visualize the data

We’ll let MNE-Python do the heavy lifting for us here.

ch_names = [f"EMG{ii:02}" for ii, _ in enumerate(data["emg"].T, 1)]
ch_types = ["emg"] * len(ch_names)
sfreq = data.metadata[Emg2PoseSessionData.SAMPLE_RATE]
info = mne.create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq)
# MNE expects data in the shape (n_channels, n_times). So we need to transpose the data
raw = mne.io.RawArray(data["emg"].T, info)
# MNE expects the EMG data to be in Volts, so we need to scale it from mV to V
raw.apply_function(lambda x: x * 1e-6, picks="emg")
raw.plot(start=20, duration=20)
plot emg pose
Creating RawArray with float64 data, n_channels=16, n_times=142674
    Range : 0 ... 142673 =      0.000 ...    71.337 secs
Ready.
Using matplotlib as 2D backend.

<MNEBrowseFigure size 800x800 with 4 Axes>

Use PCA and KMeans to cluster the data

We’ll use PCA to reduce the data dimenstionality to 3D and then use KMeans to cluster the data.

n_components = 3
pca = PCA(n_components=n_components)
data_pca = pca.fit_transform(data["emg"])
clusters = KMeans(n_clusters=5).fit_predict(data_pca)

Visualize the clusters

sns.set_theme(style="darkgrid")
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
ax.scatter(data_pca[:, 0], data_pca[:, 1], data_pca[:, 2], c=clusters)
ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.set_zlabel("PC3")
ax.set_title("PCA of EMG data with KMeans clustering")
plt.show()
PCA of EMG data with KMeans clustering

Total running time of the script: (0 minutes 31.031 seconds)

Gallery generated by Sphinx-Gallery