diff --git a/examples/conformal_eeg/tuev_eeg_quickstart.ipynb b/examples/conformal_eeg/tuev_eeg_quickstart.ipynb new file mode 100644 index 000000000..9a5370a28 --- /dev/null +++ b/examples/conformal_eeg/tuev_eeg_quickstart.ipynb @@ -0,0 +1,500 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9bbdf2d5", + "metadata": {}, + "source": [ + "## 1. Environment Setup\n", + "Seed the random generators, import core dependencies, and detect the training device." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "aa3af180", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on device: cpu\n" + ] + } + ], + "source": [ + "import random\n", + "\n", + "import numpy as np\n", + "import torch\n", + "\n", + "from pyhealth.datasets import TUEVDataset\n", + "from pyhealth.tasks import EEGEventsTUEV\n", + "from pyhealth.datasets.splitter import split_by_sample\n", + "from pyhealth.datasets.utils import get_dataloader\n", + "\n", + "SEED = 42\n", + "random.seed(SEED)\n", + "np.random.seed(SEED)\n", + "torch.manual_seed(SEED)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(SEED)\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Running on device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "id": "a530c574", + "metadata": {}, + "source": [ + "## 2. Load TUEV Dataset\n", + "Point to the TUEV dataset root and load the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "687a951e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No config path provided, using default config\n", + "Using both train and eval subsets\n", + "Initializing tuev dataset from F:\\coding_projects\\pyhealth\\downloads\\tuev\\v2.0.1\\edf (dev mode: False)\n", + "Scanning table: train from F:\\coding_projects\\pyhealth\\downloads\\tuev\\v2.0.1\\edf\\tuev-train-pyhealth.csv\n", + "Scanning table: eval from F:\\coding_projects\\pyhealth\\downloads\\tuev\\v2.0.1\\edf\\tuev-eval-pyhealth.csv\n", + "Collecting global event dataframe...\n", + "Collected dataframe with shape: (15, 10)\n", + "Dataset: tuev\n", + "Dev mode: False\n", + "Number of patients: 10\n", + "Number of events: 15\n" + ] + } + ], + "source": [ + "dataset = TUEVDataset(\n", + " root='F:\\\\coding_projects\\\\pyhealth\\\\downloads\\\\tuev\\\\v2.0.1\\\\edf', # Update this path\n", + ")\n", + "dataset.stats()" + ] + }, + { + "cell_type": "markdown", + "id": "4ef47e0e", + "metadata": {}, + "source": [ + "## 3. Prepare PyHealth Dataset\n", + "Set the task for the dataset and convert raw samples into PyHealth format for EEG event classification." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9955d9f6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Setting task EEG_events for tuev base dataset...\n", + "Generating samples with 1 worker(s)...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating samples for EEG_events with 1 worker: 100%|██████████| 10/10 [00:31<00:00, 3.17s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Label label vocab: {0: 0, 1: 1, 3: 2, 4: 3, 5: 4}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Processing samples: 100%|██████████| 3397/3397 [00:00<00:00, 12759.42it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generated 3397 samples for task EEG_events\n", + "Total task samples: 3397\n", + "Input schema: {'signal': 'tensor'}\n", + "Output schema: {'label': 'multiclass'}\n", + "\n", + "Sample keys: dict_keys(['patient_id', 'signal_file', 'signal', 'offending_channel', 'label'])\n", + "Signal shape: torch.Size([16, 1280])\n", + "Label: 0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "sample_dataset = dataset.set_task(EEGEventsTUEV())\n", + "\n", + "print(f\"Total task samples: {len(sample_dataset)}\")\n", + "print(f\"Input schema: {sample_dataset.input_schema}\")\n", + "print(f\"Output schema: {sample_dataset.output_schema}\")\n", + "\n", + "# Inspect a sample\n", + "sample = sample_dataset[0]\n", + "print(f\"\\nSample keys: {sample.keys()}\")\n", + "print(f\"Signal shape: {sample['signal'].shape}\")\n", + "print(f\"Label: {sample['label']}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c260d0d4", + "metadata": {}, + "source": [ + "## 4. Split Dataset\n", + "Divide the processed samples into training, validation, and test subsets before building dataloaders." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "11ac1ce3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train/Val/Test sizes: 2377, 340, 680\n" + ] + } + ], + "source": [ + "BATCH_SIZE = 32\n", + "\n", + "train_ds, val_ds, test_ds = split_by_sample(sample_dataset, [0.7, 0.1, 0.2], seed=SEED)\n", + "print(f\"Train/Val/Test sizes: {len(train_ds)}, {len(val_ds)}, {len(test_ds)}\")\n", + "\n", + "train_loader = get_dataloader(train_ds, batch_size=BATCH_SIZE, shuffle=True)\n", + "val_loader = get_dataloader(val_ds, batch_size=BATCH_SIZE) if len(val_ds) else None\n", + "test_loader = get_dataloader(test_ds, batch_size=BATCH_SIZE) if len(test_ds) else None\n", + "\n", + "if len(train_loader) == 0:\n", + " raise RuntimeError(\"The training loader is empty. Increase the dataset size or adjust the split ratios.\")" + ] + }, + { + "cell_type": "markdown", + "id": "0a2fce71", + "metadata": {}, + "source": [ + "## 5. Inspect Batch Structure\n", + "Peek at the first training batch to understand feature shapes and data structure." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "dfda5f4d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch structure:\n", + " patient_id: list(len=32)\n", + " signal_file: list(len=32)\n", + " signal: Tensor(shape=(32, 16, 1280))\n", + " offending_channel: list(len=32)\n", + " label: Tensor(shape=(32,))\n" + ] + } + ], + "source": [ + "first_batch = next(iter(train_loader))\n", + "\n", + "def describe(value):\n", + " if hasattr(value, \"shape\"):\n", + " return f\"{type(value).__name__}(shape={tuple(value.shape)})\"\n", + " if isinstance(value, (list, tuple)):\n", + " return f\"{type(value).__name__}(len={len(value)})\"\n", + " return type(value).__name__\n", + "\n", + "batch_summary = {key: describe(value) for key, value in first_batch.items()}\n", + "print(\"Batch structure:\")\n", + "for key, desc in batch_summary.items():\n", + " print(f\" {key}: {desc}\")" + ] + }, + { + "cell_type": "markdown", + "id": "fb833579", + "metadata": {}, + "source": [ + "## 6. Instantiate Model\n", + "Create a simple CNN model for EEG event classification." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0f8d457c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total parameters: 2,610,662\n", + "Trainable parameters: 2,610,662\n" + ] + } + ], + "source": [ + "import torch.nn as nn\n", + "\n", + "class SimpleEEGClassifier(nn.Module):\n", + " def __init__(self, num_classes=6):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv1d(16, 32, kernel_size=5, stride=1)\n", + " self.pool = nn.MaxPool1d(2)\n", + " self.conv2 = nn.Conv1d(32, 64, kernel_size=5)\n", + " # After conv1: 1280 -> 1276, pool: 638\n", + " # conv2: 638 -> 634, pool: 317\n", + " self.fc1 = nn.Linear(64 * 317, 128)\n", + " self.fc2 = nn.Linear(128, num_classes)\n", + " self.relu = nn.ReLU()\n", + "\n", + " def forward(self, signal):\n", + " x = self.relu(self.conv1(signal))\n", + " x = self.pool(x)\n", + " x = self.relu(self.conv2(x))\n", + " x = self.pool(x)\n", + " x = x.view(x.size(0), -1)\n", + " x = self.relu(self.fc1(x))\n", + " x = self.fc2(x)\n", + " return x\n", + "\n", + "model = SimpleEEGClassifier(num_classes=6).to(device)\n", + "\n", + "total_params = sum(p.numel() for p in model.parameters())\n", + "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "\n", + "print(f\"Total parameters: {total_params:,}\")\n", + "print(f\"Trainable parameters: {trainable_params:,}\")" + ] + }, + { + "cell_type": "markdown", + "id": "6d27db3d", + "metadata": {}, + "source": [ + "## 7. Test Forward Pass\n", + "Verify the model can process a batch and compute outputs." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "c85a05f4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model output shape: torch.Size([32, 6])\n", + "Sample output: tensor([0.0669, 0.0869, 0.0540, 0.6302, 0.4835, 0.2699])\n" + ] + } + ], + "source": [ + "# Move batch to device\n", + "test_batch = {key: value.to(device) if hasattr(value, 'to') else value \n", + " for key, value in first_batch.items()}\n", + "\n", + "# Forward pass\n", + "with torch.no_grad():\n", + " output = model(test_batch['signal'])\n", + "\n", + "print(\"Model output shape:\", output.shape)\n", + "print(\"Sample output:\", output[0])" + ] + }, + { + "cell_type": "markdown", + "id": "3f60d5bf", + "metadata": {}, + "source": [ + "## 8. Configure Loss and Optimizer\n", + "Define the loss function and optimizer for training." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0ff06b16", + "metadata": {}, + "outputs": [], + "source": [ + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "markdown", + "id": "9580fe69", + "metadata": {}, + "source": [ + "## 9. Train the Model\n", + "Launch the training loop to learn from the EEG data." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2985dfcc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/5, Loss: 1.6282\n", + "Validation Loss: 0.2099, Accuracy: 92.06%\n", + "Epoch 2/5, Loss: 0.1427\n", + "Validation Loss: 0.1437, Accuracy: 96.47%\n", + "Epoch 3/5, Loss: 0.0808\n", + "Validation Loss: 0.1306, Accuracy: 95.88%\n", + "Epoch 4/5, Loss: 0.0559\n", + "Validation Loss: 0.1199, Accuracy: 96.18%\n", + "Epoch 5/5, Loss: 0.0464\n", + "Validation Loss: 0.1195, Accuracy: 95.88%\n" + ] + } + ], + "source": [ + "num_epochs = 5\n", + "\n", + "for epoch in range(num_epochs):\n", + " model.train()\n", + " running_loss = 0.0\n", + " for batch in train_loader:\n", + " signals = batch['signal'].to(device)\n", + " labels = batch['label'].to(device)\n", + " \n", + " optimizer.zero_grad()\n", + " outputs = model(signals)\n", + " loss = criterion(outputs, labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " running_loss += loss.item()\n", + " \n", + " print(f\"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}\")\n", + " \n", + " # Validation\n", + " if val_loader:\n", + " model.eval()\n", + " val_loss = 0.0\n", + " correct = 0\n", + " total = 0\n", + " with torch.no_grad():\n", + " for batch in val_loader:\n", + " signals = batch['signal'].to(device)\n", + " labels = batch['label'].to(device)\n", + " outputs = model(signals)\n", + " loss = criterion(outputs, labels)\n", + " val_loss += loss.item()\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total += labels.size(0)\n", + " correct += (predicted == labels).sum().item()\n", + " print(f\"Validation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {100 * correct / total:.2f}%\")" + ] + }, + { + "cell_type": "markdown", + "id": "0706a7f0", + "metadata": {}, + "source": [ + "## 10. Evaluate on Test Set\n", + "Evaluate the trained model on the test set." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "07631718", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test Loss: 0.0780, Accuracy: 97.94%\n" + ] + } + ], + "source": [ + "model.eval()\n", + "test_loss = 0.0\n", + "correct = 0\n", + "total = 0\n", + "with torch.no_grad():\n", + " for batch in test_loader:\n", + " signals = batch['signal'].to(device)\n", + " labels = batch['label'].to(device)\n", + " outputs = model(signals)\n", + " loss = criterion(outputs, labels)\n", + " test_loss += loss.item()\n", + " _, predicted = torch.max(outputs.data, 1)\n", + " total += labels.size(0)\n", + " correct += (predicted == labels).sum().item()\n", + "\n", + "print(f\"Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {100 * correct / total:.2f}%\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv (3.10.11)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyhealth/datasets/configs/tuev.yaml b/pyhealth/datasets/configs/tuev.yaml new file mode 100644 index 000000000..b6d662fae --- /dev/null +++ b/pyhealth/datasets/configs/tuev.yaml @@ -0,0 +1,17 @@ +version: "2.0.1" +tables: + train: + file_path: "tuev-train-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "record_id" + - "signal_file" + eval: + file_path: "tuev-eval-pyhealth.csv" + patient_id: "patient_id" + timestamp: null + attributes: + - "segment_id" + - "label" + - "signal_file" \ No newline at end of file diff --git a/pyhealth/datasets/tuev.py b/pyhealth/datasets/tuev.py index 77649acae..f7fdd6670 100644 --- a/pyhealth/datasets/tuev.py +++ b/pyhealth/datasets/tuev.py @@ -1,11 +1,15 @@ import os +import logging +import pandas as pd +from pathlib import Path -import numpy as np +from typing import Optional +from .base_dataset import BaseDataset +from pyhealth.tasks import EEGEventsTUEV -from pyhealth.datasets import BaseSignalDataset +logger = logging.getLogger(__name__) - -class TUEVDataset(BaseSignalDataset): +class TUEVDataset(BaseDataset): """Base EEG dataset for the TUH EEG Events Corpus Dataset is available at https://isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml @@ -15,18 +19,17 @@ class TUEVDataset(BaseSignalDataset): Files are named in the form of bckg_032_a_.edf in the eval partition: bckg: this file contains background annotations. 032: a reference to the eval index - a_.edf: EEG files are split into a series of files starting with a_.edf, a_1.ef, ... These represent pruned EEGs, so the original EEG is split into these segments, and uninteresting parts of the original recording were deleted. + a_.edf: EEG files are split into a series of files starting with a_.edf, a_1.ef, ... These represent pruned EEGs, so the original EEG is split into these segments, and uninteresting parts of the original recording were deleted. or in the form of 00002275_00000001.edf in the train partition: 00002275: a reference to the train index. - 0000001: indicating that this is the first file inssociated with this patient. + 0000001: indicating that this is the first file associated with this patient. Args: - dataset_name: name of the dataset. root: root directory of the raw data. *You can choose to use the path to Cassette portion or the Telemetry portion.* + dataset_name: name of the dataset. + config_path: Optional configuration file name, defaults to "tuev.yaml". dev: whether to enable dev mode (only use a small subset of the data). Default is False. - refresh_cache: whether to refresh the cache; if true, the dataset will - be processed from scratch and the cache will be updated. Default is False. Attributes: task: Optional[str], name of the task (e.g., "EEG_events"). @@ -41,72 +44,145 @@ class TUEVDataset(BaseSignalDataset): Examples: >>> from pyhealth.datasets import TUEVDataset + >>> from pyhealth.tasks import EEGEventsTUEV >>> dataset = TUEVDataset( ... root="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf/", ... ) - >>> dataset.stat() - >>> dataset.info() - """ + >>> dataset.stats() + >>> sample_dataset = dataset.set_task(EEGEventsTUEV()) + >>> sample = sample_dataset[0] + >>> print(sample['signal'].shape) # (16, 1280) - def process_EEG_data(self): - # get all file names - all_files = {} - - train_files = os.listdir(os.path.join(self.root, "train/")) - for id in train_files: - if id != ".DS_Store": - all_files["0_{}".format(id)] = [name for name in os.listdir(os.path.join(self.root, "train/", id)) if name.endswith(".edf")] - - eval_files = os.listdir(os.path.join(self.root, "eval/")) - for id in eval_files: - if id != ".DS_Store": - all_files["1_{}".format(id)] = [name for name in os.listdir(os.path.join(self.root, "eval/", id)) if name.endswith(".edf")] - - # get all patient ids - patient_ids = list(set(list(all_files.keys()))) - - if self.dev: - patient_ids = patient_ids[:20] - # print(patient_ids) - - # get patient to record maps - # - key: pid: - # - value: [{"load_from_path": None, "patient_id": None, "signal_file": None, "label_file": None, "save_to_path": None}, ...] - patients = { - pid: [] - for pid in patient_ids - } - - for pid in patient_ids: - split = "train" if pid.split("_")[0] == "0" else "eval" - id = pid.split("_")[1] - - patient_visits = all_files[pid] + For a complete example, see `examples/conformal_eeg/tuev_eeg_quickstart.ipynb`. + """ + + def __init__( + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + subset: Optional[str] = 'both', + **kwargs + ) -> None: + if config_path is None: + logger.info("No config path provided, using default config") + from pathlib import Path + config_path = Path(__file__).parent / "configs" / "tuev.yaml" + + self.root = root + + if subset in ['train', 'eval']: + logger.info(f"Using subset: {subset}") + tables = [subset] + elif subset == 'both': + logger.info("Using both train and eval subsets") + tables = ["train", "eval"] + else: + raise ValueError("subset must be one of 'train', 'eval', or 'both'") + + self.prepare_metadata() + + super().__init__( + root=root, + tables=tables, + dataset_name=dataset_name or "tuev", + config_path=config_path, + **kwargs + ) + + def prepare_metadata(self) -> None: + """Build and save processed metadata CSVs for TUEV train/eval separately. + + This writes: + - /tuev-train-pyhealth.csv + - /tuev-eval-pyhealth.csv + + Train filenames look like: 00002275_00000001.edf + - subject_id = 00002275 + - record_id = 00000001 + + Eval filenames look like: bckg_032_a_.edf + - label_kind = bckg (or spsw/gped/pled/eyem/artf depending on file) + - eval_index = 032 + - segment_id = a_ / a_1 / ... + """ + root = Path(self.root) + + train_rows: list[dict] = [] + eval_rows: list[dict] = [] + + for split in ("train", "eval"): + if os.path.exists(root / f"tuev-{split}-pyhealth.csv"): + continue - for visit in patient_visits: - if split == "train": - visit_id = visit.strip(".edf").split("_")[1] - else: - visit_id = visit.strip(".edf") - - patients[pid].append({ - "load_from_path": os.path.join(self.root, split, id), - "patient_id": pid, - "visit_id": visit_id, - "signal_file": visit, - "label_file": visit, - "save_to_path": self.filepath, - }) + split_dir = root / split + if not split_dir.is_dir(): + logger.warning("Split directory not found: %s", split_dir) + continue + + for subject_dir in split_dir.iterdir(): + if not subject_dir.is_dir() or subject_dir.name.startswith("."): + continue + + for edf_path in subject_dir.glob("*.edf"): + stem = edf_path.stem + + if split == "train": + parts = stem.split("_") + record_id = parts[-1] + + train_rows.append( + { + "patient_id": subject_dir.name, + "record_id": record_id, + "signal_file": str(edf_path), + } + ) + + else: + parts = stem.split("_") + label = parts[0] + segment_id = "_".join(parts[2:]) + + eval_rows.append( + { + "patient_id": subject_dir.name, + "label": label, + "segment_id": segment_id, + "signal_file": str(edf_path), + } + ) + + root.mkdir(parents=True, exist_ok=True) + + # Write train metadata + if train_rows: + train_df = pd.DataFrame(train_rows) + train_df.sort_values( + ["patient_id", "record_id"], inplace=True, na_position="last" + ) + train_df.reset_index(drop=True, inplace=True) + train_csv = root / "tuev-train-pyhealth.csv" + train_df.to_csv(train_csv, index=False) + + + # Write eval metadata + if eval_rows: + eval_df = pd.DataFrame(eval_rows) + eval_df.sort_values( + ["patient_id", "segment_id", "label"], + inplace=True, + na_position="last", + ) + eval_df.reset_index(drop=True, inplace=True) + eval_csv = root / "tuev-eval-pyhealth.csv" + eval_df.to_csv(eval_csv, index=False) + + @property + def default_task(self) -> EEGEventsTUEV: + """Returns the default task for the BMD-HS dataset: BMDHSDiseaseClassification. - return patients - - -if __name__ == "__main__": - dataset = TUEVDataset( - root="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf", - dev=True, - refresh_cache=True, - ) - dataset.stat() - dataset.info() - print(list(dataset.patients.items())[0]) + Returns: + BMDHSDiseaseClassification: The default task instance. + """ + return EEGEventsTUEV() \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index bcfb19f7a..dcd0a5314 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -61,7 +61,7 @@ sleep_staging_sleepedf_fn, ) from .sleep_staging_v2 import SleepStagingSleepEDF -from .temple_university_EEG_tasks import EEG_events_fn, EEG_isAbnormal_fn +from .temple_university_EEG_tasks import EEGEventsTUEV from .variant_classification import ( MutationPathogenicityPrediction, VariantClassificationClinVar, diff --git a/pyhealth/tasks/temple_university_EEG_tasks.py b/pyhealth/tasks/temple_university_EEG_tasks.py index dd529bb6e..6f349c3f3 100644 --- a/pyhealth/tasks/temple_university_EEG_tasks.py +++ b/pyhealth/tasks/temple_university_EEG_tasks.py @@ -1,364 +1,166 @@ +from __future__ import annotations + import os import pickle import mne import numpy as np +import os +from typing import Any, Dict, List, Tuple + +import mne +import numpy as np -def EEG_isAbnormal_fn(record): - """Processes a single patient for the abnormal EEG detection task on TUAB. +from pyhealth.tasks import BaseTask - Abnormal EEG detection aims at determining whether a EEG is abnormal. - Args: - record: a singleton list of one subject from the TUABDataset. - The (single) record is a dictionary with the following keys: - load_from_path, patient_id, visit_id, signal_file, label_file, save_to_path +class EEGEventsTUEV(BaseTask): + """Multi-class classification task for EEG event detection on TUEV. - Returns: - samples: a list of samples, each sample is a dict with patient_id, visit_id, record_id, - and epoch_path (the path to the saved epoch {"signal": signal, "label": label} as key. + For each EDF recording, this task: + 1) reads the EDF + 2) applies bandpass (0.1-75 Hz), notch (50 Hz), resamples to 256 Hz + 3) loads the paired .rec file (same path, .edf -> .rec) + 4) constructs 5-second event-centered windows (16 bipolar channels) + 5) returns one sample per event - Note that we define the task as a binary classification task. + Each returned sample contains: + - "signal": np.ndarray, shape (16, 256*5) + - "offending_channel": int + - "label": int Examples: - >>> from pyhealth.datasets import TUABDataset - >>> isabnormal = TUABDataset( - ... root="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf/", download=True, + >>> from pyhealth.datasets import TUEVDataset + >>> from pyhealth.tasks import EEGEventsTUEV + >>> dataset = TUEVDataset( + ... root="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf/", ... ) - >>> from pyhealth.tasks import EEG_isabnormal_fn - >>> EEG_abnormal_ds = isabnormal.set_task(EEG_isAbnormal_fn) - >>> EEG_abnormal_ds.samples[0] - { - 'patient_id': 'aaaaamye', - 'visit_id': 's001', - 'record_id': '1', - 'epoch_path': '/home/zhenlin4/.cache/pyhealth/datasets/832afe6e6e8a5c9ea5505b47e7af8125/10-1/1/0.pkl', - 'label': 1 - } - """ - - samples = [] - for visit in record: - root, pid, visit_id, signal, label, save_path = ( - visit["load_from_path"], - visit["patient_id"], - visit["visit_id"], - visit["signal_file"], - visit["label_file"], - visit["save_to_path"], - ) + >>> sample_dataset = dataset.set_task(EEGEventsTUEV()) + >>> sample = sample_dataset[0] + >>> print(sample['label']) - raw = mne.io.read_raw_edf(os.path.join(root, signal), preload=True) - raw.resample(200) - ch_name = raw.ch_names - raw_data = raw.get_data() - channeled_data = raw_data.copy()[:16] - try: - channeled_data[0] = ( - raw_data[ch_name.index("EEG FP1-REF")] - - raw_data[ch_name.index("EEG F7-REF")] - ) - channeled_data[1] = ( - raw_data[ch_name.index("EEG F7-REF")] - - raw_data[ch_name.index("EEG T3-REF")] - ) - channeled_data[2] = ( - raw_data[ch_name.index("EEG T3-REF")] - - raw_data[ch_name.index("EEG T5-REF")] - ) - channeled_data[3] = ( - raw_data[ch_name.index("EEG T5-REF")] - - raw_data[ch_name.index("EEG O1-REF")] - ) - channeled_data[4] = ( - raw_data[ch_name.index("EEG FP2-REF")] - - raw_data[ch_name.index("EEG F8-REF")] - ) - channeled_data[5] = ( - raw_data[ch_name.index("EEG F8-REF")] - - raw_data[ch_name.index("EEG T4-REF")] - ) - channeled_data[6] = ( - raw_data[ch_name.index("EEG T4-REF")] - - raw_data[ch_name.index("EEG T6-REF")] - ) - channeled_data[7] = ( - raw_data[ch_name.index("EEG T6-REF")] - - raw_data[ch_name.index("EEG O2-REF")] - ) - channeled_data[8] = ( - raw_data[ch_name.index("EEG FP1-REF")] - - raw_data[ch_name.index("EEG F3-REF")] - ) - channeled_data[9] = ( - raw_data[ch_name.index("EEG F3-REF")] - - raw_data[ch_name.index("EEG C3-REF")] - ) - channeled_data[10] = ( - raw_data[ch_name.index("EEG C3-REF")] - - raw_data[ch_name.index("EEG P3-REF")] - ) - channeled_data[11] = ( - raw_data[ch_name.index("EEG P3-REF")] - - raw_data[ch_name.index("EEG O1-REF")] - ) - channeled_data[12] = ( - raw_data[ch_name.index("EEG FP2-REF")] - - raw_data[ch_name.index("EEG F4-REF")] - ) - channeled_data[13] = ( - raw_data[ch_name.index("EEG F4-REF")] - - raw_data[ch_name.index("EEG C4-REF")] - ) - channeled_data[14] = ( - raw_data[ch_name.index("EEG C4-REF")] - - raw_data[ch_name.index("EEG P4-REF")] - ) - channeled_data[15] = ( - raw_data[ch_name.index("EEG P4-REF")] - - raw_data[ch_name.index("EEG O2-REF")] - ) - except: - with open("tuab-process-error-files.txt", "a") as f: - f.write(os.path.join(root, signal) + "\n") - continue - - # get the label - data_field = pid.split("_")[0] - if data_field == "0" or data_field == "2": - label = 1 - else: - label = 0 - - # load data - for i in range(channeled_data.shape[1] // 2000): - dump_path = os.path.join( - save_path, pid + "_" + visit_id + "_" + str(i) + ".pkl" - ) - pickle.dump( - {"signal": channeled_data[:, i * 2000 : (i + 1) * 2000], "label": label}, - open(dump_path, "wb"), - ) + For a complete example, see `examples/conformal_eeg/tuev_eeg_quickstart.ipynb`. + """ - samples.append( - { - "patient_id": pid, - "visit_id": visit_id, - "record_id": i, - "epoch_path": dump_path, - "label": label, - } - ) + task_name: str = "EEG_events" + input_schema: Dict[str, str] = {"signal": "tensor"} + output_schema: Dict[str, str] = {"label": "multiclass"} - return samples + def __init__(self) -> None: + super().__init__() -def EEG_events_fn(record): - """Processes a single patient for the EEG events task on TUEV. + @staticmethod + def BuildEvents( + signals: np.ndarray, times: np.ndarray, EventData: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + # Ensure 2D in case a .rec has only one row + EventData = np.atleast_2d(EventData) - This task aims at annotating of EEG segments as one of six classes: (1) spike and sharp wave (SPSW), (2) generalized periodic epileptiform discharges (GPED), (3) periodic lateralized epileptiform discharges (PLED), (4) eye movement (EYEM), (5) artifact (ARTF) and (6) background (BCKG). + numEvents, _ = EventData.shape + fs = 256.0 + numChan, _ = signals.shape - Args: - record: a singleton list of one subject from the TUEVDataset. - The (single) record is a dictionary with the following keys: - load_from_path, patient_id, visit_id, signal_file, label_file, save_to_path + features = np.zeros([numEvents, numChan, int(fs) * 5]) + offending_channel = np.zeros([numEvents, 1]) + labels = np.zeros([numEvents, 1]) - Returns: - samples: a list of samples, each sample is a dict with patient_id, visit_id, record_id, label, offending_channel, - and epoch_path (the path to the saved epoch {"signal": signal, "label": label} as key. + offset = signals.shape[1] + signals = np.concatenate([signals, signals, signals], axis=1) - Note that we define the task as a multiclass classification task. + for i in range(numEvents): + chan = int(EventData[i, 0]) + start = np.where((times) >= EventData[i, 1])[0][0] + end = np.where((times) >= EventData[i, 2])[0][0] - Examples: - >>> from pyhealth.datasets import TUEVDataset - >>> EEGevents = TUEVDataset( - ... root="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf/", download=True, - ... ) - >>> from pyhealth.tasks import EEG_events_fn - >>> EEG_events_ds = EEGevents.set_task(EEG_events_fn) - >>> EEG_events_ds.samples[0] - { - 'patient_id': '0_00002265', - 'visit_id': '00000001', - 'record_id': 0, - 'epoch_path': '/Users/liyanjing/.cache/pyhealth/datasets/d8f3cb92cc444d481444d3414fb5240c/0_00002265_00000001_0.pkl', - 'label': 6, - 'offending_channel': array([4.]) - } - """ - - samples = [] - for visit in record: - root, pid, visit_id, signal, label, save_path = ( - visit["load_from_path"], - visit["patient_id"], - visit["visit_id"], - visit["signal_file"], - visit["label_file"], - visit["save_to_path"], - ) + features[i, :] = signals[ + :, offset + start - 2 * int(fs) : offset + end + 2 * int(fs) + ] + offending_channel[i, :] = int(chan) + labels[i, :] = int(EventData[i, 3]) - - # load data - try: - [signals, times, event, Rawdata] = readEDF( - os.path.join(root, signal) - ) # event is the .rec file in the form of an array - signals = convert_signals(signals, Rawdata) - except (ValueError, KeyError): - print("something funky happened in " + os.path.join(root, signal)) - continue - signals, offending_channels, labels = BuildEvents(signals, times, event) - - for idx, (signal, offending_channel, label) in enumerate( - zip(signals, offending_channels, labels) - ): - dump_path = os.path.join( - save_path, pid + "_" + visit_id + "_" + str(idx) + ".pkl" - ) + return features, offending_channel, labels - pickle.dump( - {"signal": signal, "label": int(label[0])}, - open(dump_path, "wb"), - ) - - samples.append( - { - "patient_id": pid, - "visit_id": visit_id, - "record_id": idx, - "epoch_path": dump_path, - "label": int(label[0]), - "offending_channel": offending_channel, - } + @staticmethod + def convert_signals(signals: np.ndarray, Rawdata: mne.io.BaseRaw) -> np.ndarray: + signal_names = { + k: v + for (k, v) in zip( + Rawdata.info["ch_names"], list(range(len(Rawdata.info["ch_names"]))) ) + } - return samples - - -def BuildEvents(signals, times, EventData): - [numEvents, z] = EventData.shape # numEvents is equal to # of rows of the .rec file - fs = 250.0 - [numChan, numPoints] = signals.shape - - features = np.zeros([numEvents, numChan, int(fs) * 5]) - offending_channel = np.zeros([numEvents, 1]) # channel that had the detected thing - labels = np.zeros([numEvents, 1]) - offset = signals.shape[1] - signals = np.concatenate([signals, signals, signals], axis=1) - for i in range(numEvents): # for each event - chan = int(EventData[i, 0]) # chan is channel - start = np.where((times) >= EventData[i, 1])[0][0] - end = np.where((times) >= EventData[i, 2])[0][0] - features[i, :] = signals[ - :, offset + start - 2 * int(fs) : offset + end + 2 * int(fs) - ] - offending_channel[i, :] = int(chan) - labels[i, :] = int(EventData[i, 3]) - return [features, offending_channel, labels] - - -def convert_signals(signals, Rawdata): - signal_names = { - k: v - for (k, v) in zip( - Rawdata.info["ch_names"], list(range(len(Rawdata.info["ch_names"]))) - ) - } - new_signals = np.vstack( - ( - signals[signal_names["EEG FP1-REF"]] - - signals[signal_names["EEG F7-REF"]], # 0 - ( - signals[signal_names["EEG F7-REF"]] - - signals[signal_names["EEG T3-REF"]] - ), # 1 - ( - signals[signal_names["EEG T3-REF"]] - - signals[signal_names["EEG T5-REF"]] - ), # 2 - ( - signals[signal_names["EEG T5-REF"]] - - signals[signal_names["EEG O1-REF"]] - ), # 3 - ( - signals[signal_names["EEG FP2-REF"]] - - signals[signal_names["EEG F8-REF"]] - ), # 4 - ( - signals[signal_names["EEG F8-REF"]] - - signals[signal_names["EEG T4-REF"]] - ), # 5 - ( - signals[signal_names["EEG T4-REF"]] - - signals[signal_names["EEG T6-REF"]] - ), # 6 - ( - signals[signal_names["EEG T6-REF"]] - - signals[signal_names["EEG O2-REF"]] - ), # 7 - ( - signals[signal_names["EEG FP1-REF"]] - - signals[signal_names["EEG F3-REF"]] - ), # 14 - ( - signals[signal_names["EEG F3-REF"]] - - signals[signal_names["EEG C3-REF"]] - ), # 15 - ( - signals[signal_names["EEG C3-REF"]] - - signals[signal_names["EEG P3-REF"]] - ), # 16 - ( - signals[signal_names["EEG P3-REF"]] - - signals[signal_names["EEG O1-REF"]] - ), # 17 - ( - signals[signal_names["EEG FP2-REF"]] - - signals[signal_names["EEG F4-REF"]] - ), # 18 + new_signals = np.vstack( ( - signals[signal_names["EEG F4-REF"]] - - signals[signal_names["EEG C4-REF"]] - ), # 19 - ( - signals[signal_names["EEG C4-REF"]] - - signals[signal_names["EEG P4-REF"]] - ), # 20 - (signals[signal_names["EEG P4-REF"]] - signals[signal_names["EEG O2-REF"]]), + signals[signal_names["EEG FP1-REF"]] - signals[signal_names["EEG F7-REF"]], + signals[signal_names["EEG F7-REF"]] - signals[signal_names["EEG T3-REF"]], + signals[signal_names["EEG T3-REF"]] - signals[signal_names["EEG T5-REF"]], + signals[signal_names["EEG T5-REF"]] - signals[signal_names["EEG O1-REF"]], + signals[signal_names["EEG FP2-REF"]] - signals[signal_names["EEG F8-REF"]], + signals[signal_names["EEG F8-REF"]] - signals[signal_names["EEG T4-REF"]], + signals[signal_names["EEG T4-REF"]] - signals[signal_names["EEG T6-REF"]], + signals[signal_names["EEG T6-REF"]] - signals[signal_names["EEG O2-REF"]], + signals[signal_names["EEG FP1-REF"]] - signals[signal_names["EEG F3-REF"]], + signals[signal_names["EEG F3-REF"]] - signals[signal_names["EEG C3-REF"]], + signals[signal_names["EEG C3-REF"]] - signals[signal_names["EEG P3-REF"]], + signals[signal_names["EEG P3-REF"]] - signals[signal_names["EEG O1-REF"]], + signals[signal_names["EEG FP2-REF"]] - signals[signal_names["EEG F4-REF"]], + signals[signal_names["EEG F4-REF"]] - signals[signal_names["EEG C4-REF"]], + signals[signal_names["EEG C4-REF"]] - signals[signal_names["EEG P4-REF"]], + signals[signal_names["EEG P4-REF"]] - signals[signal_names["EEG O2-REF"]], + ) ) - ) # 21 - return new_signals + return new_signals + @staticmethod + def readEDF(fileName: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray, mne.io.BaseRaw]: + Rawdata = mne.io.read_raw_edf(fileName, preload=True, verbose="error") -def readEDF(fileName): - Rawdata = mne.io.read_raw_edf(fileName) - signals, times = Rawdata[:] - RecFile = fileName[0:-3] + "rec" - eventData = np.genfromtxt(RecFile, delimiter=",") - Rawdata.close() - return [signals, times, eventData, Rawdata] + Rawdata.filter(l_freq=0.1, h_freq=75.0, verbose="error") + Rawdata.notch_filter(50.0, verbose="error") + Rawdata.resample(256, n_jobs=5, verbose="error") -if __name__ == "__main__": - from pyhealth.datasets import TUABDataset, TUEVDataset - - # dataset = TUABDataset( - # root="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf/", - # dev=True, - # refresh_cache=True, - # ) - # EEG_abnormal_ds = dataset.set_task(EEG_isAbnormal_fn) - # print(EEG_abnormal_ds.samples[0]) - # print(EEG_abnormal_ds.input_info) - - dataset = TUEVDataset( - root="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf/", - dev=True, - refresh_cache=True, - ) - EEG_events_ds = dataset.set_task(EEG_events_fn) - print(EEG_events_ds.samples[0]) - print(EEG_events_ds.input_info) + _, times = Rawdata[:] + signals = Rawdata.get_data(units="uV") + RecFile = fileName[0:-3] + "rec" + + eventData = np.genfromtxt(RecFile, delimiter=",") + + Rawdata.close() + return signals, times, eventData, Rawdata + + + def __call__(self, patient: Any) -> List[Dict[str, Any]]: + """Processes one patient. Creates one sample per event in the .rec file. + + Expected patient events to include a `signal_file` attribute that points to an .edf file. + """ + pid = patient.patient_id + events = patient.get_events() + + samples: List[Dict[str, Any]] = [] + + for event in events: + edf_path = event.signal_file + + signals, times, rec, raw = self.readEDF(edf_path) + signals = self.convert_signals(signals, raw) + feats, offending_channels, labels = self.BuildEvents(signals, times, rec) + + for idx, (signal, offending_channel, label) in enumerate( + zip(feats, offending_channels, labels) + ): + samples.append( + { + "patient_id": pid, + "signal_file": edf_path, + "signal": signal, + "offending_channel": int(offending_channel.squeeze()), + "label": int(label.squeeze())-1, + } + ) + + return samples diff --git a/tests/core/test_tuev.py b/tests/core/test_tuev.py new file mode 100644 index 000000000..20dd11ce3 --- /dev/null +++ b/tests/core/test_tuev.py @@ -0,0 +1,184 @@ +import os +import tempfile +import unittest +from dataclasses import dataclass +from pathlib import Path +from typing import List +import pandas as pd +import numpy as np +from unittest.mock import patch + + +from pyhealth.datasets.tuev import TUEVDataset +from pyhealth.tasks.temple_university_EEG_tasks import EEGEventsTUEV + + +@dataclass +class _DummyEvent: + signal_file: str + + +class _DummyPatient: + def __init__(self, patient_id: str, events: List[_DummyEvent]): + self.patient_id = patient_id + self._events = events + + def get_events(self) -> List[_DummyEvent]: + return self._events + + +class TestTUEVDataset(unittest.TestCase): + def _touch(self, path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(b"") + + def test_prepare_metadata_creates_expected_csvs(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + + # Minimal filesystem layout the metadata builder expects + train_subject = root / "train" / "subject-train" + eval_subject = root / "eval" / "subject-eval" + + train_edf = train_subject / "00002275_00000001.edf" + eval_edf = eval_subject / "bckg_032_a_.edf" + self._touch(train_edf) + self._touch(eval_edf) + + # Call prepare_metadata without invoking BaseDataset init + ds = TUEVDataset.__new__(TUEVDataset) + ds.root = str(root) + ds.prepare_metadata() + + train_csv = root / "tuev-train-pyhealth.csv" + eval_csv = root / "tuev-eval-pyhealth.csv" + + self.assertTrue(train_csv.exists()) + self.assertTrue(eval_csv.exists()) + + train_df = pd.read_csv(train_csv) + eval_df = pd.read_csv(eval_csv) + + self.assertEqual(len(train_df), 1) + self.assertEqual(len(eval_df), 1) + + self.assertIn("patient_id", train_df.columns) + self.assertIn("record_id", train_df.columns) + self.assertIn("signal_file", train_df.columns) + + self.assertEqual(train_df.loc[0, "patient_id"], "subject-train") + self.assertEqual(train_df.loc[0, "record_id"], 1) + self.assertTrue(str(train_df.loc[0, "signal_file"]).endswith("00002275_00000001.edf")) + + self.assertIn("patient_id", eval_df.columns) + self.assertIn("label", eval_df.columns) + self.assertIn("segment_id", eval_df.columns) + self.assertIn("signal_file", eval_df.columns) + + self.assertEqual(eval_df.loc[0, "patient_id"], "subject-eval") + self.assertEqual(eval_df.loc[0, "label"], "bckg") + self.assertEqual(eval_df.loc[0, "segment_id"], "a_") + self.assertTrue(str(eval_df.loc[0, "signal_file"]).endswith("bckg_032_a_.edf")) + + # Idempotency: should not crash when CSVs already exist + ds.prepare_metadata() + + def test_invalid_subset_raises(self): + with tempfile.TemporaryDirectory() as tmp: + with self.assertRaises(ValueError): + TUEVDataset(root=tmp, subset="nope") + + def test_default_task_returns_task_instance(self): + ds = TUEVDataset.__new__(TUEVDataset) + task = ds.default_task + self.assertIsInstance(task, EEGEventsTUEV) + + +class TestEEGEventsTUEV(unittest.TestCase): + def test_convert_signals_output_shape_and_values(self): + class _Raw: + def __init__(self, ch_names): + self.info = {"ch_names": ch_names} + + ch_names = [ + "EEG FP1-REF", + "EEG F7-REF", + "EEG T3-REF", + "EEG T5-REF", + "EEG O1-REF", + "EEG FP2-REF", + "EEG F8-REF", + "EEG T4-REF", + "EEG T6-REF", + "EEG O2-REF", + "EEG F3-REF", + "EEG C3-REF", + "EEG P3-REF", + "EEG F4-REF", + "EEG C4-REF", + "EEG P4-REF", + ] + raw = _Raw(ch_names) + + n = 10 + signals = np.arange(len(ch_names) * n, dtype=float).reshape(len(ch_names), n) + + out = EEGEventsTUEV.convert_signals(signals, raw) + self.assertEqual(out.shape, (16, n)) + + fp1 = ch_names.index("EEG FP1-REF") + f7 = ch_names.index("EEG F7-REF") + expected0 = signals[fp1] - signals[f7] + np.testing.assert_allclose(out[0], expected0) + + def test_BuildEvents_single_row_eventdata_and_window_length(self): + fs = 256 + num_chan = 16 + num_points = 2000 + signals = np.random.randn(num_chan, num_points) + times = np.arange(num_points) / fs + + # Single-row .rec style data: [chan, start_time, end_time, label] + # Choose start/end so that (end-start)=1s => 5s window (2s pre + 1s event + 2s post) + event = np.array([4, 5.0, 6.0, 2]) + + feats, offending, labels = EEGEventsTUEV.BuildEvents(signals, times, event) + self.assertEqual(feats.shape, (1, num_chan, fs * 5)) + self.assertEqual(offending.shape, (1, 1)) + self.assertEqual(labels.shape, (1, 1)) + self.assertEqual(int(offending.squeeze()), 4) + self.assertEqual(int(labels.squeeze()), 2) + + def test_call_returns_one_sample_per_event_and_adjusts_label(self): + task = EEGEventsTUEV() + + dummy_patient = _DummyPatient( + patient_id="patient-0", + events=[_DummyEvent(signal_file=os.path.join("C:\\", "dummy.edf"))], + ) + + feats = np.zeros((2, 16, 256 * 5), dtype=float) + offending = np.array([[3], [7]]) + labels = np.array([[1], [6]]) # will become 0 and 5 in output + + with patch.object(EEGEventsTUEV, "readEDF", return_value=(None, None, None, None)): + with patch.object(EEGEventsTUEV, "convert_signals", return_value=None): + with patch.object( + EEGEventsTUEV, + "BuildEvents", + return_value=(feats, offending, labels), + ): + samples = task(dummy_patient) + + self.assertEqual(len(samples), 2) + self.assertEqual(samples[0]["patient_id"], "patient-0") + self.assertIn("signal", samples[0]) + self.assertEqual(samples[0]["signal"].shape, (16, 256 * 5)) + self.assertEqual(samples[0]["offending_channel"], 3) + self.assertEqual(samples[0]["label"], 0) + self.assertEqual(samples[1]["offending_channel"], 7) + self.assertEqual(samples[1]["label"], 5) + + +if __name__ == "__main__": + unittest.main()