Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions examples/wav2sleep_README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Wav2Sleep PyHealth Contribution

**Author:** Meredith McClain (mmcclan2)
**Paper:** wav2sleep: A Unified Multi-Modal Approach to Sleep Stage Classification from Physiological Signals
**Link:** https://arxiv.org/abs/2411.04644

## Overview

This contribution implements the wav2sleep model for PyHealth - a unified multi-modal approach to sleep stage classification that can operate on variable sets of physiological signals.

### Key Features

- **Multi-modal Architecture**: Processes ECG, PPG, and respiratory signals (ABD, THX)
- **Variable Input Modalities**: Supports any subset of signals at inference time
- **Joint Training**: Can train on heterogeneous datasets with different signal availability
- **State-of-the-art Performance**: Outperforms single-modality and transfer learning approaches

### Model Architecture

```
Input Signals (ECG, PPG, ABD, THX)
Signal Encoders (CNN per modality)
Epoch Mixer (Transformer for cross-modal fusion)
Sequence Mixer (Dilated CNN for temporal modeling)
Sleep Stage Predictions (Wake, N1, N2, N3, REM)
```

## Installation

```bash
pip install torch numpy
```

## Quick Start

```python
from wav2sleep_pyhealth import Wav2Sleep
import torch

# Define modalities and sampling rates
modalities = {
"ecg": 1024, # 34 Hz * 30 seconds
"ppg": 1024,
"thx": 256 # 8 Hz * 30 seconds
}

# Create model
model = Wav2Sleep(
modalities=modalities,
num_classes=5,
feature_dim=128
)

# Example: 10 hours of data (1200 epochs of 30 seconds)
batch_size = 8
T = 1200

# Training with multiple modalities
inputs = {
"ecg": torch.randn(batch_size, 1, T * 1024),
"ppg": torch.randn(batch_size, 1, T * 1024),
"thx": torch.randn(batch_size, 1, T * 256)
}
labels = torch.randint(0, 5, (batch_size, T))

output = model(inputs, labels)
print(f"Loss: {output['loss'].item():.4f}")

# Inference with subset of modalities
inputs_ecg_only = {"ecg": torch.randn(batch_size, 1, T * 1024)}
probs = model.predict_proba(inputs_ecg_only)
```

## Model Components

### 1. Signal Encoders

Separate CNN encoders for each modality:
- Residual blocks with instance normalization
- Progressive downsampling via max pooling
- Outputs fixed-dimensional features per epoch

### 2. Epoch Mixer

Transformer encoder for cross-modal fusion:
- Uses CLS token to aggregate multi-modal information
- Handles variable number of input modalities
- Produces unified representation per epoch

### 3. Sequence Mixer

Dilated CNN for temporal modeling:
- Exponentially increasing dilation rates (1, 2, 4, 8, 16, 32)
- Large receptive field for long-range dependencies
- Outputs sleep stage classifications

## Usage Examples

### Training on Multiple Datasets

```python
# Joint training with heterogeneous data
for batch in dataloader:
# Some samples may have different available signals
inputs = batch['signals'] # Dict with available modalities
labels = batch['labels']

output = model(inputs, labels)
loss = output['loss']

loss.backward()
optimizer.step()
```

### Inference with Different Modalities

```python
# Use all available signals
inputs_full = {"ecg": ecg_data, "ppg": ppg_data, "thx": thx_data}
predictions_full = model(inputs_full)['predictions']

# Use only ECG (e.g., if PPG sensor fails)
inputs_ecg = {"ecg": ecg_data}
predictions_ecg = model(inputs_ecg)['predictions']
```

## Model Parameters

| Parameter | Default | Description |
|-----------|---------|-------------|
| `modalities` | Required | Dict mapping signal names to sampling rates |
| `num_classes` | 5 | Number of sleep stages (Wake, N1, N2, N3, REM) |
| `feature_dim` | 128 | Feature dimension throughout model |
| `dropout` | 0.1 | Dropout probability |

## Expected Performance

Based on the original paper, wav2sleep achieves:

| Dataset | Test Modality | Cohen's κ | Accuracy |
|---------|--------------|-----------|----------|
| SHHS | ECG only | 0.739 | 82.3% |
| SHHS | ECG + THX | 0.779 | 85.0% |
| MESA | PPG only | 0.742 | - |
| Census | ECG only | 0.783 | 84.8% |

## Validation

This implementation was validated using the Sleep-EDF database from PhysioNet, a publicly-available polysomnography dataset with real overnight sleep recordings. While Sleep-EDF contains EEG/EOG/EMG signals rather than cardiac/respiratory signals, it confirmed the model's multi-modal processing capabilities and architectural correctness.

For reproduction with the original NSRR datasets (SHHS, MESA, etc.), data is available via the National Sleep Research Resource at https://sleepdata.org/.

## Testing

Run the included test cases with synthetic data:

```bash
python wav2sleep_pyhealth.py
```

Expected output:
```
Wav2Sleep Model Example
==================================================

Model created with XXX,XXX parameters

--- Example 1: Training with all modalities ---
Logits shape: torch.Size([4, 1200, 5])
Loss: X.XXXX
Predictions shape: torch.Size([4, 1200])

--- Example 2: Inference with ECG only ---
Probabilities shape: torch.Size([4, 1200, 5])
Example probabilities for first epoch:
tensor([0.2XXX, 0.1XXX, 0.2XXX, 0.2XXX, 0.2XXX])

==================================================
Example completed successfully!
```

## Data Format

### Input Signals
- Shape: `(batch_size, 1, seq_len)` where `seq_len = T * sampling_rate`
- T = number of 30-second epochs
- Sampling rates: ECG/PPG typically 1024 (34 Hz), Respiratory typically 256 (8 Hz)

### Labels
- Shape: `(batch_size, T)`
- Values: 0 (Wake), 1 (N1), 2 (N2), 3 (N3), 4 (REM)

## Citation

If you use this implementation, please cite the original wav2sleep paper:

```bibtex
@article{carter2024wav2sleep,
title={wav2sleep: A Unified Multi-Modal Approach to Sleep Stage Classification from Physiological Signals},
author={Carter, Jonathan F. and Tarassenko, Lionel},
journal={arXiv preprint arXiv:2411.04644},
year={2024}
}
```

## License

This implementation follows the same license as the original wav2sleep repository.

## Contact

For questions or issues with this PyHealth integration:
- **Author:** Meredith McClain
- **Email:** mmcclan2@illinois.edu
- **Original Paper:** https://arxiv.org/abs/2411.04644
- **Original Code:** https://github.com/joncarter1/wav2sleep
149 changes: 149 additions & 0 deletions examples/wav2sleep_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""
Example usage of wav2sleep model for sleep stage classification.

This script demonstrates how to use the wav2sleep model with different
modality combinations and synthetic data for testing.

Author: Meredith McClain (mmcclan2)
"""

import torch
from wav2sleep_pyhealth import Wav2Sleep

def example_basic_usage():
"""Basic example with all modalities."""
print("\n" + "="*50)
print("Example 1: Training with all modalities")
print("="*50)

# Define modalities (signal name -> samples per epoch)
modalities = {
"ecg": 1024, # 34 Hz * 30 seconds
"ppg": 1024,
"abd": 256, # 8 Hz * 30 seconds
"thx": 256
}

# Create model
model = Wav2Sleep(
modalities=modalities,
num_classes=5,
feature_dim=128,
dropout=0.1
)

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
print(f"Model created with {num_params:,} parameters")

# Generate synthetic data for testing
# Simulate 10 hours of sleep (1200 epochs of 30 seconds each)
batch_size = 4
T = 1200 # number of epochs

inputs = {
"ecg": torch.randn(batch_size, 1, T * 1024),
"ppg": torch.randn(batch_size, 1, T * 1024),
"abd": torch.randn(batch_size, 1, T * 256),
"thx": torch.randn(batch_size, 1, T * 256)
}

# Generate random labels (0=Wake, 1=N1, 2=N2, 3=N3, 4=REM)
labels = torch.randint(0, 5, (batch_size, T))

# Forward pass with all modalities
output = model(inputs, labels)

print(f"\nLogits shape: {output['logits'].shape}")
print(f"Loss: {output['loss'].item():.4f}")
print(f"Predictions shape: {output['predictions'].shape}")

return model


def example_subset_modalities():
"""Example with subset of modalities (ECG only)."""
print("\n" + "="*50)
print("Example 2: Inference with ECG only")
print("="*50)

# Model with potential for multiple modalities
modalities = {
"ecg": 1024,
"ppg": 1024,
"thx": 256
}

model = Wav2Sleep(modalities=modalities, num_classes=5)

# Inference with only ECG (e.g., if PPG sensor fails)
batch_size = 4
T = 1200

inputs_ecg_only = {
"ecg": torch.randn(batch_size, 1, T * 1024)
}

# Get predictions without labels (inference mode)
probs = model.predict_proba(inputs_ecg_only)

print(f"Probabilities shape: {probs.shape}")
print(f"Example probabilities for first epoch:")
print(probs[0, 0])
print(f"Sum of probabilities: {probs[0, 0].sum().item():.4f} (should be ~1.0)")


def example_variable_combinations():
"""Example testing different modality combinations."""
print("\n" + "="*50)
print("Example 3: Testing variable modality combinations")
print("="*50)

modalities = {
"ecg": 1024,
"ppg": 1024,
"abd": 256,
"thx": 256
}

model = Wav2Sleep(modalities=modalities, num_classes=5)

batch_size = 2
T = 100 # Shorter sequence for quick testing

# Test different combinations
test_cases = [
{"ecg": torch.randn(batch_size, 1, T * 1024)},
{"ecg": torch.randn(batch_size, 1, T * 1024),
"thx": torch.randn(batch_size, 1, T * 256)},
{"ppg": torch.randn(batch_size, 1, T * 1024),
"abd": torch.randn(batch_size, 1, T * 256)},
{"ecg": torch.randn(batch_size, 1, T * 1024),
"ppg": torch.randn(batch_size, 1, T * 1024),
"abd": torch.randn(batch_size, 1, T * 256),
"thx": torch.randn(batch_size, 1, T * 256)}
]

for i, inputs in enumerate(test_cases, 1):
probs = model.predict_proba(inputs)
modality_names = ", ".join(inputs.keys())
print(f"Test {i} ({modality_names}): Output shape = {probs.shape} ✓")


def main():
"""Run all examples."""
print("\nWav2Sleep Model Example")
print("="*50)

# Run examples
model = example_basic_usage()
example_subset_modalities()
example_variable_combinations()

print("\n" + "="*50)
print("Example completed successfully!")
print("="*50 + "\n")


if __name__ == "__main__":
main()
Loading