Note
Go to the end to download the full example code.
Predict hand pose from EMG signals using ML¶
Meta has been doing a lot of work trying to figure out how to replace keyboards and controllers with hand gestures (when the user is wearing a wrist band).
Back when I was a young hopper (AKA a PhD student), I interviewed with their Reality Labs team. The technical interview at the time was to take a dataset of EMG signals from a participant, and to predict the hand pose of the participant (Also, nearly pure python.. I.e. no scikit-learn, no pytorch. Woof).
Well now they’ve open-sourced the EMG dataset from that project, so I am going to save some soul out there some time and show how them how to do it.
from collections.abc import KeysView
from dataclasses import dataclass
from pathlib import Path
from typing import Any, ClassVar
import h5py
import matplotlib.pyplot as plt
import mne
import numpy as np
import pooch
import seaborn as sns
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
Create a helper function to read the data¶
@dataclass
class Emg2PoseSessionData:
"""A read-only interface to a single emg2pose session file stored in
HDF5 format.
``self.timeseries`` is a `h5py.Dataset` instance with a compound data type
as in a numpy structured array containing three fields - EMG data from the
left and right wrists, and their corresponding timestamps.
The sampling rate of EMG is 2kHz, each EMG device has 16 electrode
channels, and the signal has been high-pass filtered. Therefore, the fields
corresponding to left and right EMG are 2D arrays of shape ``(T, 16)`` each
and ``timestamps`` is a 1D array of length ``T``.
NOTE: Only the metadata and ground-truth are loaded into memory while the
EMG data is accesssed directly from disk. When wrapping this interface
within a PyTorch Dataset, use multiple dataloading workers to mask the
disk seek and read latencies."""
HDF5_GROUP: ClassVar[str] = "emg2pose"
# timeseries keys
TIMESERIES: ClassVar[str] = "timeseries"
EMG: ClassVar[str] = "emg"
JOINT_ANGLES: ClassVar[str] = "joint_angles"
TIMESTAMPS: ClassVar[str] = "time"
# metadata keys
SESSION_NAME: ClassVar[str] = "session"
SIDE: ClassVar[str] = "side"
STAGE: ClassVar[str] = "stage"
START_TIME: ClassVar[str] = "start"
END_TIME: ClassVar[str] = "end"
NUM_CHANNELS: ClassVar[str] = "num_channels"
DATASET_NAME: ClassVar[str] = "dataset"
USER: ClassVar[str] = "user"
SAMPLE_RATE: ClassVar[str] = "sample_rate"
hdf5_path: Path
def __post_init__(self) -> None:
self._file = h5py.File(self.hdf5_path, "r")
emg2pose_group: h5py.Group = self._file[self.HDF5_GROUP]
# ``timeseries`` is a HDF5 compound Dataset
self.timeseries: h5py.Dataset = emg2pose_group[self.TIMESERIES]
assert self.timeseries.dtype.fields is not None
assert self.EMG in self.timeseries.dtype.fields
assert self.JOINT_ANGLES in self.timeseries.dtype.fields
assert self.TIMESTAMPS in self.timeseries.dtype.fields
# Load the metadata entirely into memory as it's rather small
self.metadata: dict[str, Any] = {}
for key, val in emg2pose_group.attrs.items():
self.metadata[key] = val
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
self._file.close()
def __len__(self) -> int:
return len(self.timeseries)
def __getitem__(self, key: slice) -> np.ndarray:
return self.timeseries[key]
def slice(self, start_t: float = -np.inf, end_t: float = np.inf) -> np.ndarray:
"""Load and return a contiguous slice of the timeseries windowed
by the provided start and end timestamps.
Args:
start_t (float): The start time of the window to grab
(in absolute unix time). Defaults to selecting from the
beginning of the session. (default: ``-np.inf``).
end_t (float): The end time of the window to grab
(in absolute unix time). Defaults to selecting until the
end of the session. (default: ``np.inf``)
"""
start_idx, end_idx = self.timestamps.searchsorted([start_t, end_t])
return self[start_idx:end_idx]
@property
def fields(self) -> KeysView[str]:
"""The names of the fields in ``timeseries``."""
fields: KeysView[str] = self.timeseries.dtype.fields.keys()
return fields
@property
def timestamps(self) -> np.ndarray:
"""EMG timestamps.
NOTE: This reads the entire sequence of timesetamps from the underlying
HDF5 file and therefore incurs disk latency. Avoid this in the critical
path."""
emg_timestamps = self.timeseries[self.TIMESTAMPS]
assert (np.diff(emg_timestamps) >= 0).all(), "Not monotonic"
return emg_timestamps
@property
def session_name(self) -> str:
"""Unique name of the session."""
return self.metadata[self.SESSION_NAME]
@property
def user(self) -> str:
"""Unique ID of the user this session corresponds to."""
return self.metadata[self.USER]
def __str__(self) -> str:
return f"{self.__class__.__name__} {self.session_name} ({len(self)} samples)"
Download the dataset¶
data_dir = Path.home() / "emg_data"
emg_dir = data_dir / "emg2pose_dataset_mini"
want_fpath = "emg2pose_dataset_mini/2022-12-06-1670313600-e3096-cv-emg-pose-train@2-recording-1_left.hdf5"
unpack = pooch.Untar(extract_dir=data_dir, # Relative to the path where the zip file is downloaded
members=[want_fpath]
)
emg_fpaths = pooch.retrieve(
url="https://fb-ctrl-oss.s3.amazonaws.com/emg2pose/emg2pose_dataset_mini.tar",
known_hash="sha256:d7400e98508ccbb2139c2d78e552867b23501f637456546fd6680f3fe7fec50d",
progressbar=True,
path=data_dir,
processor=unpack,
)
emg_fname = Path(emg_fpaths[0])
emg_dir = emg_fname.parent
# Delete the large tar file
list(data_dir.glob("*.tar"))[0].unlink()
0%| | 0.00/647M [00:00<?, ?B/s]
1%|▍ | 6.61M/647M [00:00<00:14, 43.3MB/s]
2%|▋ | 10.9M/647M [00:00<00:16, 37.7MB/s]
3%|▉ | 16.8M/647M [00:00<00:17, 36.1MB/s]
4%|█▎ | 23.4M/647M [00:00<00:15, 41.4MB/s]
4%|█▌ | 27.6M/647M [00:00<00:16, 38.6MB/s]
5%|█▉ | 32.2M/647M [00:00<00:15, 40.7MB/s]
6%|██▏ | 36.4M/647M [00:00<00:15, 40.2MB/s]
6%|██▍ | 41.9M/647M [00:01<00:18, 32.5MB/s]
8%|██▊ | 48.5M/647M [00:01<00:18, 32.0MB/s]
8%|███ | 51.9M/647M [00:01<00:20, 29.7MB/s]
9%|███▎ | 56.9M/647M [00:01<00:25, 23.5MB/s]
9%|███▍ | 59.6M/647M [00:01<00:26, 21.8MB/s]
10%|███▊ | 65.3M/647M [00:02<00:29, 19.9MB/s]
10%|███▉ | 67.5M/647M [00:02<00:29, 19.6MB/s]
11%|████▎ | 74.1M/647M [00:02<00:21, 27.0MB/s]
12%|████▌ | 77.2M/647M [00:02<00:20, 27.2MB/s]
12%|████▋ | 80.2M/647M [00:02<00:21, 25.8MB/s]
13%|████▉ | 84.2M/647M [00:02<00:19, 28.8MB/s]
14%|█████▎ | 90.5M/647M [00:03<00:18, 29.3MB/s]
14%|█████▍ | 93.6M/647M [00:03<00:21, 25.6MB/s]
15%|█████▊ | 98.9M/647M [00:03<00:19, 28.7MB/s]
16%|██████▏ | 102M/647M [00:03<00:24, 22.2MB/s]
17%|██████▍ | 107M/647M [00:03<00:24, 22.4MB/s]
17%|██████▌ | 110M/647M [00:04<00:24, 21.5MB/s]
18%|██████▉ | 116M/647M [00:04<00:27, 19.2MB/s]
18%|███████ | 118M/647M [00:04<00:28, 18.4MB/s]
19%|███████▍ | 124M/647M [00:04<00:24, 21.2MB/s]
19%|███████▌ | 126M/647M [00:04<00:27, 18.7MB/s]
20%|███████▉ | 132M/647M [00:05<00:25, 20.2MB/s]
21%|████████ | 134M/647M [00:05<00:28, 18.2MB/s]
22%|████████▍ | 141M/647M [00:05<00:19, 25.4MB/s]
22%|████████▋ | 144M/647M [00:05<00:22, 22.0MB/s]
23%|█████████ | 151M/647M [00:05<00:18, 26.2MB/s]
25%|█████████▌ | 159M/647M [00:06<00:14, 34.4MB/s]
26%|██████████ | 168M/647M [00:06<00:12, 37.0MB/s]
27%|██████████▌ | 174M/647M [00:06<00:13, 35.4MB/s]
28%|██████████▋ | 178M/647M [00:06<00:13, 34.8MB/s]
29%|███████████ | 185M/647M [00:06<00:13, 35.2MB/s]
30%|███████████▋ | 193M/647M [00:06<00:12, 36.8MB/s]
31%|████████████ | 200M/647M [00:07<00:11, 39.1MB/s]
31%|████████████▎ | 204M/647M [00:07<00:12, 36.3MB/s]
32%|████████████▋ | 210M/647M [00:07<00:12, 35.6MB/s]
33%|█████████████ | 216M/647M [00:07<00:11, 39.0MB/s]
34%|█████████████▎ | 220M/647M [00:07<00:13, 30.6MB/s]
35%|█████████████▌ | 225M/647M [00:07<00:15, 27.9MB/s]
35%|█████████████▋ | 228M/647M [00:08<00:16, 24.8MB/s]
36%|██████████████ | 233M/647M [00:08<00:15, 27.6MB/s]
36%|██████████████▏ | 236M/647M [00:08<00:16, 25.6MB/s]
37%|██████████████▌ | 242M/647M [00:08<00:12, 31.3MB/s]
38%|██████████████▊ | 245M/647M [00:08<00:13, 30.2MB/s]
39%|███████████████ | 250M/647M [00:08<00:16, 24.8MB/s]
39%|███████████████▏ | 253M/647M [00:09<00:16, 23.5MB/s]
40%|███████████████▌ | 259M/647M [00:09<00:16, 23.4MB/s]
40%|███████████████▋ | 261M/647M [00:09<00:16, 24.0MB/s]
41%|████████████████ | 267M/647M [00:09<00:15, 25.1MB/s]
42%|████████████████▏ | 269M/647M [00:09<00:17, 21.9MB/s]
42%|████████████████▌ | 275M/647M [00:10<00:15, 24.2MB/s]
43%|████████████████▋ | 277M/647M [00:10<00:17, 21.6MB/s]
44%|█████████████████ | 283M/647M [00:10<00:14, 24.6MB/s]
44%|█████████████████▏ | 286M/647M [00:10<00:14, 24.5MB/s]
45%|█████████████████▋ | 294M/647M [00:10<00:13, 26.8MB/s]
46%|██████████████████ | 300M/647M [00:11<00:13, 25.8MB/s]
47%|██████████████████▏ | 303M/647M [00:11<00:13, 24.8MB/s]
48%|██████████████████▌ | 309M/647M [00:11<00:11, 30.4MB/s]
48%|██████████████████▊ | 312M/647M [00:11<00:12, 27.3MB/s]
49%|███████████████████▏ | 319M/647M [00:11<00:11, 27.8MB/s]
50%|███████████████████▌ | 325M/647M [00:11<00:10, 30.5MB/s]
51%|███████████████████▊ | 328M/647M [00:12<00:12, 25.5MB/s]
52%|████████████████████ | 334M/647M [00:12<00:10, 28.9MB/s]
52%|████████████████████▎ | 337M/647M [00:12<00:11, 27.2MB/s]
53%|████████████████████▋ | 343M/647M [00:12<00:09, 33.5MB/s]
53%|████████████████████▊ | 346M/647M [00:12<00:10, 27.9MB/s]
54%|█████████████████████▏ | 351M/647M [00:12<00:09, 31.1MB/s]
55%|█████████████████████▎ | 354M/647M [00:12<00:10, 28.3MB/s]
55%|█████████████████████▌ | 358M/647M [00:12<00:09, 29.4MB/s]
56%|█████████████████████▋ | 361M/647M [00:13<00:09, 29.4MB/s]
57%|██████████████████████▏ | 368M/647M [00:13<00:08, 31.3MB/s]
57%|██████████████████████▎ | 371M/647M [00:13<00:10, 25.5MB/s]
58%|██████████████████████▋ | 376M/647M [00:13<00:11, 23.0MB/s]
58%|██████████████████████▊ | 378M/647M [00:13<00:12, 21.8MB/s]
59%|███████████████████████▏ | 384M/647M [00:14<00:10, 24.7MB/s]
60%|███████████████████████▎ | 387M/647M [00:14<00:10, 23.9MB/s]
61%|███████████████████████▋ | 392M/647M [00:14<00:10, 23.7MB/s]
61%|███████████████████████▊ | 395M/647M [00:14<00:12, 20.1MB/s]
62%|████████████████████████▏ | 401M/647M [00:14<00:10, 24.6MB/s]
62%|████████████████████████▎ | 403M/647M [00:14<00:10, 23.9MB/s]
63%|████████████████████████▋ | 409M/647M [00:15<00:12, 19.8MB/s]
64%|████████████████████████▊ | 411M/647M [00:15<00:12, 18.7MB/s]
65%|█████████████████████████▏ | 418M/647M [00:15<00:09, 25.1MB/s]
65%|█████████████████████████▎ | 420M/647M [00:15<00:09, 24.0MB/s]
66%|█████████████████████████▋ | 426M/647M [00:15<00:08, 25.4MB/s]
66%|█████████████████████████▊ | 429M/647M [00:16<00:09, 23.7MB/s]
67%|██████████████████████████▏ | 434M/647M [00:16<00:07, 27.6MB/s]
68%|██████████████████████████▎ | 437M/647M [00:16<00:08, 24.6MB/s]
68%|██████████████████████████▋ | 443M/647M [00:16<00:07, 28.3MB/s]
69%|██████████████████████████▉ | 448M/647M [00:16<00:06, 33.0MB/s]
70%|███████████████████████████▎ | 453M/647M [00:16<00:07, 26.9MB/s]
71%|███████████████████████████▋ | 460M/647M [00:17<00:08, 21.0MB/s]
71%|███████████████████████████▊ | 462M/647M [00:17<00:10, 18.5MB/s]
72%|████████████████████████████▏ | 468M/647M [00:17<00:09, 19.1MB/s]
73%|████████████████████████████▎ | 471M/647M [00:17<00:09, 18.9MB/s]
74%|████████████████████████████▋ | 476M/647M [00:18<00:07, 22.5MB/s]
74%|████████████████████████████▊ | 479M/647M [00:18<00:07, 21.2MB/s]
75%|█████████████████████████████▎ | 486M/647M [00:18<00:05, 29.8MB/s]
76%|█████████████████████████████▍ | 489M/647M [00:18<00:05, 28.1MB/s]
76%|█████████████████████████████▊ | 495M/647M [00:18<00:05, 27.4MB/s]
78%|██████████████████████████████▎ | 503M/647M [00:18<00:04, 32.8MB/s]
79%|██████████████████████████████▊ | 512M/647M [00:19<00:04, 33.2MB/s]
80%|███████████████████████████████▏ | 518M/647M [00:19<00:03, 36.2MB/s]
81%|███████████████████████████████▍ | 522M/647M [00:19<00:03, 32.4MB/s]
81%|███████████████████████████████▋ | 526M/647M [00:19<00:03, 33.4MB/s]
82%|███████████████████████████████▉ | 529M/647M [00:19<00:04, 28.9MB/s]
83%|████████████████████████████████▎ | 535M/647M [00:19<00:03, 35.8MB/s]
83%|████████████████████████████████▌ | 539M/647M [00:19<00:03, 33.8MB/s]
84%|████████████████████████████████▋ | 543M/647M [00:20<00:04, 23.4MB/s]
84%|████████████████████████████████▉ | 546M/647M [00:20<00:04, 21.6MB/s]
85%|█████████████████████████████████▎ | 552M/647M [00:20<00:03, 24.3MB/s]
86%|█████████████████████████████████▍ | 555M/647M [00:20<00:04, 21.9MB/s]
87%|█████████████████████████████████▊ | 560M/647M [00:21<00:04, 21.5MB/s]
87%|█████████████████████████████████▉ | 563M/647M [00:21<00:04, 20.8MB/s]
88%|██████████████████████████████████▎ | 569M/647M [00:21<00:02, 26.7MB/s]
88%|██████████████████████████████████▍ | 572M/647M [00:21<00:03, 23.6MB/s]
89%|██████████████████████████████████▊ | 577M/647M [00:21<00:02, 27.9MB/s]
90%|███████████████████████████████████ | 581M/647M [00:21<00:02, 31.1MB/s]
91%|███████████████████████████████████▍ | 587M/647M [00:21<00:01, 30.6MB/s]
92%|███████████████████████████████████▊ | 594M/647M [00:22<00:01, 30.7MB/s]
92%|███████████████████████████████████▉ | 597M/647M [00:22<00:01, 30.7MB/s]
93%|████████████████████████████████████▎ | 602M/647M [00:22<00:01, 23.9MB/s]
93%|████████████████████████████████████▍ | 605M/647M [00:22<00:01, 24.1MB/s]
95%|████████████████████████████████████▉ | 612M/647M [00:22<00:01, 34.0MB/s]
95%|█████████████████████████████████████▏ | 616M/647M [00:22<00:00, 32.9MB/s]
96%|█████████████████████████████████████▎ | 620M/647M [00:23<00:00, 31.5MB/s]
96%|█████████████████████████████████████▌ | 624M/647M [00:23<00:00, 24.9MB/s]
97%|█████████████████████████████████████▊ | 628M/647M [00:23<00:00, 26.4MB/s]
97%|██████████████████████████████████████ | 631M/647M [00:23<00:00, 21.0MB/s]
98%|██████████████████████████████████████▎| 636M/647M [00:23<00:00, 27.2MB/s]
99%|██████████████████████████████████████▌| 639M/647M [00:23<00:00, 25.1MB/s]
100%|██████████████████████████████████████▊| 644M/647M [00:24<00:00, 26.7MB/s]
100%|██████████████████████████████████████▉| 647M/647M [00:24<00:00, 26.8MB/s]
0%| | 0.00/647M [00:00<?, ?B/s]
100%|███████████████████████████████████████| 647M/647M [00:00<00:00, 2.71TB/s]
Load the data¶
data = Emg2PoseSessionData(hdf5_path=emg_fname)
Visualize the data¶
We’ll let MNE-Python do the heavy lifting for us here.
ch_names = [f"EMG{ii:02}" for ii, _ in enumerate(data["emg"].T, 1)]
ch_types = ["emg"] * len(ch_names)
sfreq = data.metadata[Emg2PoseSessionData.SAMPLE_RATE]
info = mne.create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq)
# MNE expects data in the shape (n_channels, n_times). So we need to transpose the data
raw = mne.io.RawArray(data["emg"].T, info)
# MNE expects the EMG data to be in Volts, so we need to scale it from mV to V
raw.apply_function(lambda x: x * 1e-6, picks="emg")
raw.plot(start=20, duration=20)

Creating RawArray with float64 data, n_channels=16, n_times=142674
Range : 0 ... 142673 = 0.000 ... 71.337 secs
Ready.
Using matplotlib as 2D backend.
<MNEBrowseFigure size 800x800 with 4 Axes>
Use PCA and KMeans to cluster the data¶
We’ll use PCA to reduce the data dimenstionality to 3D and then use KMeans to cluster the data.
n_components = 3
pca = PCA(n_components=n_components)
data_pca = pca.fit_transform(data["emg"])
clusters = KMeans(n_clusters=5).fit_predict(data_pca)
Visualize the clusters¶
sns.set_theme(style="darkgrid")
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
ax.scatter(data_pca[:, 0], data_pca[:, 1], data_pca[:, 2], c=clusters)
ax.set_xlabel("PC1")
ax.set_ylabel("PC2")
ax.set_zlabel("PC3")
ax.set_title("PCA of EMG data with KMeans clustering")
plt.show()

Total running time of the script: (0 minutes 31.031 seconds)