diff --git a/examples/covidred_example.py b/examples/covidred_example.py new file mode 100644 index 00000000..ccc56ef8 --- /dev/null +++ b/examples/covidred_example.py @@ -0,0 +1,385 @@ +""" +COVID-RED Dataset Example for PyHealth + +This script demonstrates how to: +1. Load the COVID-RED wearable device dataset +2. Define a COVID-19 detection task +3. Train a simple LSTM classifier for early COVID-19 detection + +Dataset: Remote Early Detection of SARS-CoV-2 infections (COVID-RED) +Source: https://dataverse.nl/dataset.xhtml?persistentId=doi:10.34894/FW9PO7 +""" + +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +import numpy as np +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score + +# Import PyHealth components (adjust imports based on actual PyHealth structure) +try: + from pyhealth.datasets import COVIDREDDataset + from pyhealth.tasks import covidred_detection_fn, covidred_prediction_fn +except ImportError: + # For standalone testing, import from local files + import sys + sys.path.insert(0, '.') + from covidred_dataset import COVIDREDDataset + from covidred_tasks import covidred_detection_fn, covidred_prediction_fn + + +class LSTMClassifier(nn.Module): + """ + Simple LSTM classifier for COVID-19 detection from time series data. + + This model processes multivariate time series of wearable device measurements + (heart rate, steps, sleep) to predict COVID-19 infection. + """ + + def __init__(self, input_size, hidden_size=64, num_layers=2, num_classes=2, dropout=0.3): + """ + Parameters + ---------- + input_size : int + Number of features per time step (e.g., 8 for COVID-RED) + hidden_size : int + Number of LSTM hidden units + num_layers : int + Number of LSTM layers + num_classes : int + Number of output classes (2 for binary classification) + dropout : float + Dropout probability + """ + super(LSTMClassifier, self).__init__() + + self.hidden_size = hidden_size + self.num_layers = num_layers + + # LSTM layer + self.lstm = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + dropout=dropout if num_layers > 1 else 0, + bidirectional=True + ) + + # Fully connected layers + self.fc1 = nn.Linear(hidden_size * 2, hidden_size) # *2 for bidirectional + self.relu = nn.ReLU() + self.dropout = nn.Dropout(dropout) + self.fc2 = nn.Linear(hidden_size, num_classes) + + def forward(self, x): + """ + Forward pass. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, n_features, seq_len) + + Returns + ------- + torch.Tensor + Output logits of shape (batch_size, num_classes) + """ + # Transpose to (batch_size, seq_len, n_features) for LSTM + x = x.transpose(1, 2) + + # LSTM forward pass + lstm_out, _ = self.lstm(x) + + # Take the output from the last time step + last_output = lstm_out[:, -1, :] + + # Fully connected layers + out = self.fc1(last_output) + out = self.relu(out) + out = self.dropout(out) + out = self.fc2(out) + + return out + + +def collate_fn(batch): + """ + Custom collate function to batch samples from COVIDREDDataset. + + Parameters + ---------- + batch : list + List of sample dictionaries from the dataset. + + Returns + ------- + dict + Batched data with stacked tensors. + """ + signals = torch.stack([item["signal"] for item in batch]) + labels = torch.tensor([item["label"] for item in batch], dtype=torch.long) + patient_ids = [item["patient_id"] for item in batch] + visit_ids = [item["visit_id"] for item in batch] + + return { + "signal": signals, + "label": labels, + "patient_id": patient_ids, + "visit_id": visit_ids, + } + + +def train_epoch(model, dataloader, criterion, optimizer, device): + """Train the model for one epoch.""" + model.train() + total_loss = 0 + all_labels = [] + all_predictions = [] + + for batch in dataloader: + signals = batch["signal"].to(device) + labels = batch["label"].to(device) + + # Forward pass + optimizer.zero_grad() + outputs = model(signals) + loss = criterion(outputs, labels) + + # Backward pass + loss.backward() + optimizer.step() + + total_loss += loss.item() + + # Get predictions + _, predicted = torch.max(outputs, 1) + all_labels.extend(labels.cpu().numpy()) + all_predictions.extend(predicted.cpu().numpy()) + + avg_loss = total_loss / len(dataloader) + accuracy = accuracy_score(all_labels, all_predictions) + + return avg_loss, accuracy + + +def evaluate(model, dataloader, criterion, device): + """Evaluate the model.""" + model.eval() + total_loss = 0 + all_labels = [] + all_predictions = [] + all_probabilities = [] + + with torch.no_grad(): + for batch in dataloader: + signals = batch["signal"].to(device) + labels = batch["label"].to(device) + + # Forward pass + outputs = model(signals) + loss = criterion(outputs, labels) + + total_loss += loss.item() + + # Get predictions and probabilities + probabilities = torch.softmax(outputs, dim=1) + _, predicted = torch.max(outputs, 1) + + all_labels.extend(labels.cpu().numpy()) + all_predictions.extend(predicted.cpu().numpy()) + all_probabilities.extend(probabilities[:, 1].cpu().numpy()) # Probability of positive class + + avg_loss = total_loss / len(dataloader) + + # Calculate metrics + accuracy = accuracy_score(all_labels, all_predictions) + precision = precision_score(all_labels, all_predictions, zero_division=0) + recall = recall_score(all_labels, all_predictions, zero_division=0) + f1 = f1_score(all_labels, all_predictions, zero_division=0) + + # Calculate AUC if there are both positive and negative samples + if len(set(all_labels)) > 1: + auc = roc_auc_score(all_labels, all_probabilities) + else: + auc = 0.0 + + return avg_loss, accuracy, precision, recall, f1, auc + + +def main(): + """Main function to demonstrate COVID-RED dataset usage.""" + + # Set random seeds for reproducibility + torch.manual_seed(42) + np.random.seed(42) + + # Configuration + DATA_ROOT = "/path/to/covidred" # Update this path + WINDOW_DAYS = 7 + TASK_TYPE = "prediction" # "detection" or "prediction" + BATCH_SIZE = 32 + NUM_EPOCHS = 50 + LEARNING_RATE = 0.001 + DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + print("=" * 80) + print("COVID-RED Dataset Example for PyHealth") + print("=" * 80) + + # Load datasets + print("\n[1/5] Loading COVID-RED dataset...") + print(f" - Root directory: {DATA_ROOT}") + print(f" - Window size: {WINDOW_DAYS} days") + print(f" - Task type: {TASK_TYPE}") + + try: + train_dataset = COVIDREDDataset( + root=DATA_ROOT, + split="train", + window_days=WINDOW_DAYS, + task=TASK_TYPE, + ) + + test_dataset = COVIDREDDataset( + root=DATA_ROOT, + split="test", + window_days=WINDOW_DAYS, + task=TASK_TYPE, + ) + + print(f"\n Dataset loaded successfully!") + print(f" - Training samples: {len(train_dataset)}") + print(f" - Test samples: {len(test_dataset)}") + + # Show label distribution + train_dist = train_dataset.get_label_distribution() + test_dist = test_dataset.get_label_distribution() + + print(f"\n Training set distribution:") + print(f" - Positive: {train_dist['positive_samples']} ({train_dist['positive_ratio']:.2%})") + print(f" - Negative: {train_dist['negative_samples']}") + + print(f"\n Test set distribution:") + print(f" - Positive: {test_dist['positive_samples']} ({test_dist['positive_ratio']:.2%})") + print(f" - Negative: {test_dist['negative_samples']}") + + except FileNotFoundError as e: + print(f"\n ERROR: {e}") + print("\n Please download the COVID-RED dataset from:") + print(" https://dataverse.nl/dataset.xhtml?persistentId=doi:10.34894/FW9PO7") + return + + # Create data loaders + print("\n[2/5] Creating data loaders...") + + # Apply task function to samples + task_fn = covidred_prediction_fn if TASK_TYPE == "prediction" else covidred_detection_fn + + # Wrap samples with task function + class TaskDataset(torch.utils.data.Dataset): + def __init__(self, base_dataset, task_fn): + self.base_dataset = base_dataset + self.task_fn = task_fn + + def __len__(self): + return len(self.base_dataset) + + def __getitem__(self, idx): + sample = self.base_dataset[idx] + return self.task_fn(sample) + + train_task_dataset = TaskDataset(train_dataset, task_fn) + test_task_dataset = TaskDataset(test_dataset, task_fn) + + train_loader = DataLoader( + train_task_dataset, + batch_size=BATCH_SIZE, + shuffle=True, + collate_fn=collate_fn, + ) + + test_loader = DataLoader( + test_task_dataset, + batch_size=BATCH_SIZE, + shuffle=False, + collate_fn=collate_fn, + ) + + print(f" - Batch size: {BATCH_SIZE}") + print(f" - Training batches: {len(train_loader)}") + print(f" - Test batches: {len(test_loader)}") + + # Initialize model + print("\n[3/5] Initializing LSTM model...") + + # Get feature dimension from first sample + sample = train_dataset[0] + input_size = len(train_dataset.get_feature_names()) + + model = LSTMClassifier( + input_size=input_size, + hidden_size=64, + num_layers=2, + num_classes=2, + dropout=0.3, + ).to(DEVICE) + + print(f" - Input features: {input_size}") + print(f" - Model parameters: {sum(p.numel() for p in model.parameters()):,}") + print(f" - Device: {DEVICE}") + + # Loss and optimizer + # Use weighted loss for imbalanced datasets + pos_weight = train_dist['negative_samples'] / max(train_dist['positive_samples'], 1) + criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.0, pos_weight]).to(DEVICE)) + optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) + + # Training + print("\n[4/5] Training model...") + print(f" - Epochs: {NUM_EPOCHS}") + print(f" - Learning rate: {LEARNING_RATE}") + print(f" - Class weight (positive): {pos_weight:.2f}") + + best_f1 = 0.0 + + for epoch in range(NUM_EPOCHS): + train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, DEVICE) + test_loss, test_acc, test_prec, test_rec, test_f1, test_auc = evaluate( + model, test_loader, criterion, DEVICE + ) + + # Save best model + if test_f1 > best_f1: + best_f1 = test_f1 + torch.save(model.state_dict(), "best_covidred_model.pt") + + if (epoch + 1) % 10 == 0: + print(f"\n Epoch [{epoch+1}/{NUM_EPOCHS}]") + print(f" Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}") + print(f" Test - Loss: {test_loss:.4f}, Acc: {test_acc:.4f}, " + f"F1: {test_f1:.4f}, AUC: {test_auc:.4f}") + + # Final evaluation + print("\n[5/5] Final evaluation on test set...") + model.load_state_dict(torch.load("best_covidred_model.pt")) + test_loss, test_acc, test_prec, test_rec, test_f1, test_auc = evaluate( + model, test_loader, criterion, DEVICE + ) + + print(f"\n Final Test Metrics:") + print(f" - Accuracy: {test_acc:.4f}") + print(f" - Precision: {test_prec:.4f}") + print(f" - Recall: {test_rec:.4f}") + print(f" - F1-Score: {test_f1:.4f}") + print(f" - AUC: {test_auc:.4f}") + + print("\n" + "=" * 80) + print("Training complete! Best model saved to 'best_covidred_model.pt'") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 7d6a65f1..1b2090f4 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -67,6 +67,7 @@ def __init__(self, *args, **kwargs): from .bmd_hs import BMDHSDataset from .support2 import Support2Dataset from .tcga_prad import TCGAPRADDataset +from .covidred import COVIDREDDataset from .splitter import ( split_by_patient, split_by_patient_conformal, diff --git a/pyhealth/datasets/covidred_dataset.py b/pyhealth/datasets/covidred_dataset.py new file mode 100644 index 00000000..ad0e5f9e --- /dev/null +++ b/pyhealth/datasets/covidred_dataset.py @@ -0,0 +1,371 @@ +""" +COVID-RED Dataset Loader for PyHealth + +This module implements a dataset loader for the COVID-RED (Remote Early Detection +of SARS-CoV-2 infections) dataset from Utrecht University. + +Dataset: https://dataverse.nl/dataset.xhtml?persistentId=doi:10.34894/FW9PO7 +""" + +import os +import pandas as pd +import torch +from torch.utils.data import Dataset +from typing import Optional, Callable, Literal + + +class COVIDREDDataset(Dataset): + """ + COVID-RED Dataset for early detection of COVID-19 from wearable device data. + + The COVID-RED dataset contains wearable device measurements (heart rate, steps, sleep) + from participants during the COVID-19 pandemic, collected to enable early detection + of SARS-CoV-2 infections before symptom onset. + + Parameters + ---------- + root : str + Root directory containing the COVID-RED dataset files. + Expected files (from DataverseNL download): + - bc_20230515.csv (baseline characteristics) + - ct_20230515.csv (COVID-19 test results) + - cv_20230515.csv (COVID-19 vaccination) + - dm_20230515.csv (daily measurements - heart rate, steps, etc.) + - field_options.csv (field value mappings) + - ho_20230515.csv (hospitalization) + - hu_20230515.csv (healthcare utilization) + - ie_20230515.csv (illness episodes) + - mh_20230515.csv (medical history) + - ov_20230515.csv (overview/participant info) + - pcr_20230515.csv (PCR test results) + - sc_20230515.csv (symptom checklist) + - ser_20230515.csv (serology results) + - si_20230515.csv (symptom information) + - variable_descriptions.csv (data dictionary) + - wd_20230515.csv (wearable device data) + + split : Literal["train", "test", "all"], default="train" + Which split of the data to use. + + window_days : int, default=7 + Number of days to include in each sample window. + + task : Literal["detection", "prediction"], default="detection" + Task type: + - "detection": Classify COVID-19 positive vs negative + - "prediction": Predict COVID-19 onset before symptom onset + + transform : Optional[Callable], default=None + Optional transform to be applied on a sample. + + random_seed : int, default=42 + Random seed for train/test split reproducibility. + + Examples + -------- + >>> from pyhealth.datasets import COVIDREDDataset + >>> dataset = COVIDREDDataset( + ... root="/path/to/covidred", + ... split="train", + ... window_days=7, + ... task="prediction" + ... ) + >>> print(f"Dataset size: {len(dataset)}") + >>> sample = dataset[0] + >>> print(f"Features shape: {sample['features'].shape}") + + Notes + ----- + Download from: https://dataverse.nl/dataset.xhtml?persistentId=doi:10.34894/FW9PO7 + """ + + def __init__( + self, + root: str, + split: Literal["train", "test", "all"] = "train", + window_days: int = 7, + task: Literal["detection", "prediction"] = "detection", + transform: Optional[Callable] = None, + random_seed: int = 42, + ): + self.root = root + self.split = split + self.window_days = window_days + self.task = task + self.transform = transform + self.random_seed = random_seed + + # Feature names + self.feature_names = [ + "resting_hr_mean", + "resting_hr_std", + "resting_hr_min", + "resting_hr_max", + "steps_total", + "steps_mean_hourly", + "sleep_duration_hours", + "sleep_efficiency", + ] + + # Load and process the dataset + self._load_data() + self._create_samples() + + def _load_data(self): + """Load CSV files from the COVID-RED dataset directory.""" + # Check if required files exist + required_files = { + 'daily_measurements': 'dm_20230515.csv', + 'wearable_data': 'wd_20230515.csv', + 'covid_tests': 'ct_20230515.csv', + 'symptom_info': 'si_20230515.csv', + 'illness_episodes': 'ie_20230515.csv', + 'overview': 'ov_20230515.csv', + } + + missing_files = [] + for name, filename in required_files.items(): + file_path = os.path.join(self.root, filename) + if not os.path.exists(file_path): + missing_files.append(filename) + + if missing_files: + raise FileNotFoundError( + f"Required files not found in {self.root}:\n" + f"{', '.join(missing_files)}\n\n" + f"Please download the COVID-RED dataset from:\n" + f"https://dataverse.nl/dataset.xhtml?persistentId=doi:10.34894/FW9PO7\n\n" + f"Expected files:\n" + + "\n".join(f" - {f}" for f in required_files.values()) + ) + + # Load main data files + print(f"Loading COVID-RED dataset from {self.root}...") + + self.daily_measurements = pd.read_csv(os.path.join(self.root, 'dm_20230515.csv')) + self.wearable_data = pd.read_csv(os.path.join(self.root, 'wd_20230515.csv')) + self.covid_tests = pd.read_csv(os.path.join(self.root, 'ct_20230515.csv')) + self.symptom_info = pd.read_csv(os.path.join(self.root, 'si_20230515.csv')) + self.illness_episodes = pd.read_csv(os.path.join(self.root, 'ie_20230515.csv')) + self.overview = pd.read_csv(os.path.join(self.root, 'ov_20230515.csv')) + + print(f"✓ Loaded {len(self.overview)} participants") + print(f"✓ Daily measurements: {len(self.daily_measurements)} records") + print(f"✓ Wearable data: {len(self.wearable_data)} records") + + # Convert date columns + self._convert_dates() + + def _convert_dates(self): + """Convert date columns to datetime format.""" + date_columns_map = { + 'daily_measurements': ['date', 'measurement_date'], + 'wearable_data': ['date', 'wear_date'], + 'covid_tests': ['test_date', 'result_date'], + 'symptom_info': ['symptom_date', 'onset_date'], + 'illness_episodes': ['start_date', 'end_date'], + } + + for df_name, possible_cols in date_columns_map.items(): + df = getattr(self, df_name) + for col in possible_cols: + if col in df.columns: + try: + df[col] = pd.to_datetime(df[col], errors='coerce') + except: + pass + + def _create_samples(self): + """Create samples with sliding windows.""" + self.samples = [] + + # Get unique participants + id_col = self._find_id_column(self.overview) + participants = self.overview[id_col].unique() + + # Split participants + import numpy as np + np.random.seed(self.random_seed) + n_train = int(len(participants) * 0.7) + shuffled = np.random.permutation(participants) + + if self.split == "train": + selected = shuffled[:n_train] + elif self.split == "test": + selected = shuffled[n_train:] + else: + selected = participants + + print(f"\nCreating samples for {len(selected)} participants...") + + for participant_id in selected: + self._create_participant_samples(participant_id) + + print(f"✓ Created {len(self.samples)} samples") + + def _find_id_column(self, df): + """Find the participant ID column in a dataframe.""" + for col in ['participant_id', 'subject_id', 'id', 'user_id']: + if col in df.columns: + return col + return df.columns[0] + + def _create_participant_samples(self, participant_id): + """Create samples for a single participant.""" + id_col = self._find_id_column(self.daily_measurements) + + # Get participant data + data = self.daily_measurements[ + self.daily_measurements[id_col] == participant_id + ].copy() + + if len(data) == 0: + return + + # Find date column + date_col = None + for col in ['date', 'measurement_date', 'day', 'record_date']: + if col in data.columns: + date_col = col + break + + if not date_col: + return + + data = data.sort_values(date_col) + + # Get COVID label + covid_positive, symptom_date = self._get_covid_label(participant_id) + + # Create windows + for i in range(len(data) - self.window_days + 1): + window = data.iloc[i:i + self.window_days] + + window_start = window[date_col].iloc[0] + window_end = window[date_col].iloc[-1] + + # Determine label + if self.task == "detection": + label = covid_positive + else: # prediction + label = 0 + if covid_positive == 1 and symptom_date is not None: + if pd.notna(symptom_date) and pd.notna(window_end): + days_to_onset = (symptom_date - window_end).days + label = int(0 < days_to_onset <= 14) + + # Extract features + features = self._extract_features(window) + + if features is not None: + self.samples.append({ + "participant_id": participant_id, + "window_start_date": window_start, + "window_end_date": window_end, + "features": features, + "label": label, + }) + + def _get_covid_label(self, participant_id): + """Get COVID-19 label for a participant.""" + id_col = self._find_id_column(self.covid_tests) + + tests = self.covid_tests[self.covid_tests[id_col] == participant_id] + + # Check for positive result + covid_positive = 0 + for col in ['test_result', 'result', 'pcr_result', 'outcome', 'positive']: + if col in tests.columns and len(tests) > 0: + results = tests[col].astype(str).str.lower() + if any(r in ['positive', '1', 'true', 'pos'] for r in results): + covid_positive = 1 + break + + # Get symptom onset + symptom_date = None + id_col_symptom = self._find_id_column(self.symptom_info) + symptoms = self.symptom_info[self.symptom_info[id_col_symptom] == participant_id] + + if len(symptoms) > 0: + for col in ['onset_date', 'symptom_date', 'start_date']: + if col in symptoms.columns: + dates = symptoms[col].dropna() + if len(dates) > 0: + symptom_date = pd.to_datetime(dates.iloc[0]) + break + + return covid_positive, symptom_date + + def _extract_features(self, window_data): + """Extract features from a window.""" + feature_mapping = { + 'resting_hr_mean': ['hr_mean', 'heart_rate_mean', 'resting_hr', 'hr_avg'], + 'resting_hr_std': ['hr_std', 'heart_rate_std', 'hr_sd'], + 'resting_hr_min': ['hr_min', 'heart_rate_min'], + 'resting_hr_max': ['hr_max', 'heart_rate_max'], + 'steps_total': ['steps', 'step_count', 'daily_steps', 'total_steps'], + 'steps_mean_hourly': ['steps_per_hour', 'hourly_steps'], + 'sleep_duration_hours': ['sleep_hours', 'sleep_duration', 'total_sleep'], + 'sleep_efficiency': ['sleep_eff', 'sleep_quality'], + } + + features = [] + + for _, row in window_data.iterrows(): + day_features = [] + + for feature_name in self.feature_names: + value = 0.0 + possible_cols = feature_mapping.get(feature_name, [feature_name]) + + for col in possible_cols: + if col in row.index and pd.notna(row[col]): + value = float(row[col]) + break + + # Calculate derived features + if feature_name == 'steps_mean_hourly' and value == 0.0: + for col in feature_mapping['steps_total']: + if col in row.index and pd.notna(row[col]): + value = float(row[col]) / 24.0 + break + + day_features.append(value) + + features.append(day_features) + + try: + tensor = torch.tensor(features, dtype=torch.float32) + if tensor.shape == (self.window_days, len(self.feature_names)): + return tensor + except: + pass + + return None + + def __len__(self): + """Return the number of samples.""" + return len(self.samples) + + def __getitem__(self, idx): + """Get a sample.""" + sample = self.samples[idx].copy() + + if self.transform: + sample = self.transform(sample) + + return sample + + def get_feature_names(self): + """Return feature names.""" + return self.feature_names + + def get_label_distribution(self): + """Return label distribution.""" + labels = [s["label"] for s in self.samples] + return { + "total_samples": len(labels), + "positive_samples": sum(labels), + "negative_samples": len(labels) - sum(labels), + "positive_ratio": sum(labels) / len(labels) if labels else 0.0, + } diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index fb3c6966..01b0ebb3 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -62,3 +62,8 @@ MutationPathogenicityPrediction, VariantClassificationClinVar, ) +from .covidred import ( + covidred_detection_fn, + covidred_prediction_fn, + covidred_multiclass_fn, +) diff --git a/pyhealth/tasks/covidred_tasks.py b/pyhealth/tasks/covidred_tasks.py new file mode 100644 index 00000000..09a6adb9 --- /dev/null +++ b/pyhealth/tasks/covidred_tasks.py @@ -0,0 +1,201 @@ +""" +COVID-RED Classification Task Function for PyHealth + +This module provides task functions for COVID-19 detection and prediction +using the COVID-RED wearable device dataset. +""" + +from typing import Dict, List, Any +import torch + + +def covidred_detection_fn(sample: Dict[str, Any]) -> Dict[str, Any]: + """ + Task function for COVID-19 detection from wearable device data. + + This function processes a sample from the COVID-RED dataset and formats it + for PyHealth's standard pipeline. The task is to classify whether a participant + has COVID-19 based on their wearable device measurements (heart rate, steps, sleep). + + Parameters + ---------- + sample : Dict[str, Any] + A sample dictionary from COVIDREDDataset containing: + - participant_id: Participant identifier + - window_start_date: Start date of the measurement window + - window_end_date: End date of the measurement window + - features: Feature tensor of shape (window_days, n_features) + - label: Binary label (1=COVID-19 positive, 0=negative) + + Returns + ------- + Dict[str, Any] + Processed sample dictionary in PyHealth format: + - patient_id: Participant identifier (str) + - visit_id: Unique identifier for this window (str) + - signal: Time series tensor of shape (n_features, window_days) + - label: Binary classification label (int, 0 or 1) + - metadata: Additional information (dict) + + Examples + -------- + >>> from pyhealth.datasets import COVIDREDDataset + >>> from pyhealth.tasks import covidred_detection_fn + >>> + >>> dataset = COVIDREDDataset(root="/path/to/covidred", split="train") + >>> sample = dataset[0] + >>> processed_sample = covidred_detection_fn(sample) + >>> print(processed_sample.keys()) + dict_keys(['patient_id', 'visit_id', 'signal', 'label', 'metadata']) + + Notes + ----- + The signal tensor is transposed to shape (n_features, window_days) to match + PyHealth's expected format for time series data, where the first dimension + represents different feature channels and the second represents time steps. + """ + # Extract patient and visit identifiers + patient_id = str(sample["participant_id"]) + visit_id = f"{patient_id}_{sample['window_start_date'].strftime('%Y%m%d')}" + + # Transpose features from (window_days, n_features) to (n_features, window_days) + # This matches PyHealth's expected signal format + signal = sample["features"].transpose(0, 1) + + # Extract label + label = int(sample["label"]) + + # Create metadata + metadata = { + "window_start_date": sample["window_start_date"], + "window_end_date": sample["window_end_date"], + "window_days": signal.shape[1], + "n_features": signal.shape[0], + } + + return { + "patient_id": patient_id, + "visit_id": visit_id, + "signal": signal, + "label": label, + "metadata": metadata, + } + + +def covidred_prediction_fn(sample: Dict[str, Any]) -> Dict[str, Any]: + """ + Task function for early COVID-19 prediction from wearable device data. + + This function processes a sample from the COVID-RED dataset for the early + detection task - predicting COVID-19 onset before symptom appearance. + + Parameters + ---------- + sample : Dict[str, Any] + A sample dictionary from COVIDREDDataset containing: + - participant_id: Participant identifier + - window_start_date: Start date of the measurement window + - window_end_date: End date of the measurement window + - features: Feature tensor of shape (window_days, n_features) + - label: Binary label (1=pre-symptomatic period, 0=normal) + + Returns + ------- + Dict[str, Any] + Processed sample dictionary in PyHealth format: + - patient_id: Participant identifier (str) + - visit_id: Unique identifier for this window (str) + - signal: Time series tensor of shape (n_features, window_days) + - label: Binary prediction label (int, 0 or 1) + - metadata: Additional information (dict) + + Examples + -------- + >>> from pyhealth.datasets import COVIDREDDataset + >>> from pyhealth.tasks import covidred_prediction_fn + >>> + >>> dataset = COVIDREDDataset( + ... root="/path/to/covidred", + ... split="train", + ... task="prediction" + ... ) + >>> sample = dataset[0] + >>> processed_sample = covidred_prediction_fn(sample) + >>> print(f"Signal shape: {processed_sample['signal'].shape}") + >>> print(f"Label: {processed_sample['label']}") + + Notes + ----- + The prediction task focuses on identifying pre-symptomatic patterns in the + 1-14 days before symptom onset, which is critical for early intervention + and reducing transmission. + """ + # Use the same processing as detection task + # The distinction is in how the dataset creates labels + return covidred_detection_fn(sample) + + +def covidred_multiclass_fn(sample: Dict[str, Any]) -> Dict[str, Any]: + """ + Task function for multiclass COVID-19 severity classification. + + This function extends the basic detection to classify COVID-19 cases + into multiple severity categories: negative, mild, moderate, severe. + + Parameters + ---------- + sample : Dict[str, Any] + A sample dictionary from COVIDREDDataset with additional severity info. + + Returns + ------- + Dict[str, Any] + Processed sample dictionary with multiclass label: + - 0: COVID-19 negative + - 1: Mild (recovered at home, no assistance) + - 2: Moderate (recovered at home with assistance) + - 3: Severe (hospitalized) + + Examples + -------- + >>> from pyhealth.datasets import COVIDREDDataset + >>> from pyhealth.tasks import covidred_multiclass_fn + >>> + >>> # Assuming dataset includes severity information + >>> dataset = COVIDREDDataset(root="/path/to/covidred", split="train") + >>> sample = dataset[0] + >>> processed_sample = covidred_multiclass_fn(sample) + >>> print(f"Severity class: {processed_sample['label']}") + + Notes + ----- + This task requires the dataset to include severity information. + Check dataset documentation for availability. + """ + # Extract patient and visit identifiers + patient_id = str(sample["participant_id"]) + visit_id = f"{patient_id}_{sample['window_start_date'].strftime('%Y%m%d')}" + + # Transpose features + signal = sample["features"].transpose(0, 1) + + # Extract severity label if available + # Default to binary if severity not provided + label = sample.get("severity_label", sample["label"]) + + # Create metadata + metadata = { + "window_start_date": sample["window_start_date"], + "window_end_date": sample["window_end_date"], + "window_days": signal.shape[1], + "n_features": signal.shape[0], + "task_type": "multiclass_severity", + } + + return { + "patient_id": patient_id, + "visit_id": visit_id, + "signal": signal, + "label": int(label), + "metadata": metadata, + }