diff --git a/examples/medlink_mimic3.ipynb b/examples/medlink_mimic3.ipynb new file mode 100644 index 000000000..9246436cf --- /dev/null +++ b/examples/medlink_mimic3.ipynb @@ -0,0 +1,681 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "1ee5347e", + "metadata": { + "execution": { + "iopub.execute_input": "2025-12-22T06:22:52.819501Z", + "iopub.status.busy": "2025-12-22T06:22:52.819215Z", + "iopub.status.idle": "2025-12-22T06:22:59.559734Z", + "shell.execute_reply": "2025-12-22T06:22:59.559502Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PROJECT_ROOT: /Users/saurabhatri/Downloads/PyHealth\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PyTorch is installed\n", + "pyhealth is importable, version: 1.1.4\n" + ] + } + ], + "source": [ + "import os\n", + "import sys\n", + "\n", + "# Ensure project root is on sys.path when running from examples/\n", + "PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n", + "if PROJECT_ROOT not in sys.path:\n", + " sys.path.insert(0, PROJECT_ROOT)\n", + "\n", + "print(\"PROJECT_ROOT:\", PROJECT_ROOT)\n", + "\n", + "# Basic sanity check for torch and pyhealth\n", + "try:\n", + " import torch\n", + " print(\"PyTorch is installed\")\n", + "except ImportError as e:\n", + " raise RuntimeError(\n", + " \"PyTorch is not installed. Install it into your environment first \" \n", + " ) from e\n", + "\n", + "try:\n", + " import pyhealth\n", + " print(\"pyhealth is importable, version:\", getattr(pyhealth, \"__version__\", \"unknown\"))\n", + "except ImportError as e:\n", + " raise RuntimeError(\n", + " \"pyhealth is not importable.\"\n", + " ) from e\n", + "\n", + "# Core dataset + MedLink imports\n", + "from pyhealth.datasets import MIMIC3Dataset\n", + "from pyhealth.tasks import BaseTask\n", + "from pyhealth.models.medlink import (\n", + " BM25Okapi,\n", + " convert_to_ir_format,\n", + " filter_by_candidates,\n", + " generate_candidates,\n", + " get_bm25_hard_negatives,\n", + " get_eval_dataloader,\n", + " get_train_dataloader,\n", + " tvt_split,\n", + ")\n", + "from pyhealth.models.medlink.model import MedLink\n", + "from pyhealth.metrics import ranking_metrics_fn\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "240e358e", + "metadata": { + "execution": { + "iopub.execute_input": "2025-12-22T06:22:59.561383Z", + "iopub.status.busy": "2025-12-22T06:22:59.561144Z", + "iopub.status.idle": "2025-12-22T06:22:59.563040Z", + "shell.execute_reply": "2025-12-22T06:22:59.562832Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MIMIC-III demo root: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4\n" + ] + } + ], + "source": [ + "# Downloaded from: https://physionet.org/content/mimiciii-demo/1.4/\n", + "MIMIC3_DEMO_ROOT = \"/path/to/mimic-iii-clinical-database-demo-1.4\" # <-- adjust for real\n", + "#MIMIC3_DEMO_ROOT = \"/Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4\"\n", + "print(\"MIMIC-III demo root:\", MIMIC3_DEMO_ROOT)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "f0851481", + "metadata": { + "execution": { + "iopub.execute_input": "2025-12-22T06:22:59.564228Z", + "iopub.status.busy": "2025-12-22T06:22:59.564145Z", + "iopub.status.idle": "2025-12-22T06:22:59.579422Z", + "shell.execute_reply": "2025-12-22T06:22:59.579179Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No config path provided, using default config\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Initializing mimic3 dataset from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4 (dev mode: False)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scanning table: patients from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/PATIENTS.csv.gz\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/PATIENTS.csv\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scanning table: admissions from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv.gz\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scanning table: icustays from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ICUSTAYS.csv.gz\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ICUSTAYS.csv\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scanning table: diagnoses_icd from /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/DIAGNOSES_ICD.csv.gz\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/DIAGNOSES_ICD.csv\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Joining with table: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv.gz\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Original path does not exist. Using alternative: /Users/saurabhatri/Downloads/mimic-iii-clinical-database-demo-1.4/ADMISSIONS.csv\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting global event dataframe...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collected dataframe with shape: (2126, 31)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset: mimic3\n", + "Dev mode: False\n", + "Number of patients: 100\n", + "Number of events: 2126\n" + ] + } + ], + "source": [ + "# Load base MIMIC-III dataset from the demo\n", + "\n", + "base_dataset = MIMIC3Dataset(\n", + " root=MIMIC3_DEMO_ROOT,\n", + " tables=[\"diagnoses_icd\"], # matches in configs/mimic3.yaml\n", + " dev=False, # True => small subset of patients\n", + ")\n", + "\n", + "base_dataset.stats()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "5d18d87c", + "metadata": { + "execution": { + "iopub.execute_input": "2025-12-22T06:22:59.580593Z", + "iopub.status.busy": "2025-12-22T06:22:59.580517Z", + "iopub.status.idle": "2025-12-22T06:22:59.582093Z", + "shell.execute_reply": "2025-12-22T06:22:59.581921Z" + } + }, + "outputs": [], + "source": [ + "from pyhealth.tasks.patient_linkage_mimic3 import PatientLinkageMIMIC3Task\n", + "from datetime import datetime\n", + "from collections import defaultdict\n", + "import math\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "bce967de", + "metadata": { + "execution": { + "iopub.execute_input": "2025-12-22T06:22:59.583209Z", + "iopub.status.busy": "2025-12-22T06:22:59.583137Z", + "iopub.status.idle": "2025-12-22T06:22:59.666946Z", + "shell.execute_reply": "2025-12-22T06:22:59.666677Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Setting task patient_linkage_mimic3 for mimic3 base dataset...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generating samples with 1 worker(s)...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r\n", + "Generating samples for patient_linkage_mimic3 with 1 worker: 0%| | 0/100 [00:00 9\n", + "id_p 9\n", + "s_q 9\n", + "s_p 9\n" + ] + } + ], + "source": [ + "USE_BM25_HARDNEGS = False\n", + "\n", + "# optionally refine training qrels with BM25-based hard negatives\n", + "if USE_BM25_HARDNEGS:\n", + " bm25_model = BM25Okapi(corpus)\n", + " tr_qrels = get_bm25_hard_negatives(\n", + " bm25_model, corpus, tr_queries, tr_qrels\n", + " )\n", + "\n", + "#Dataloaders for training / validation / test\n", + "train_dataloader = get_train_dataloader(\n", + " corpus=corpus,\n", + " queries=tr_queries,\n", + " qrels=tr_qrels,\n", + " batch_size=32,\n", + " shuffle=True,\n", + ")\n", + "\n", + "val_dataloader = get_train_dataloader(\n", + " corpus=corpus,\n", + " queries=va_queries,\n", + " qrels=va_qrels,\n", + " batch_size=32,\n", + " shuffle=False,\n", + ")\n", + "\n", + "test_corpus_dataloader, test_queries_dataloader = get_eval_dataloader(\n", + " corpus=corpus,\n", + " queries=te_queries,\n", + " batch_size=32,\n", + ")\n", + "\n", + "batch = next(iter(train_dataloader))\n", + "for k, v in batch.items():\n", + " print(k, type(v), (len(v) if hasattr(v, \"__len__\") else None))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "eae98819", + "metadata": { + "execution": { + "iopub.execute_input": "2025-12-22T06:22:59.676421Z", + "iopub.status.busy": "2025-12-22T06:22:59.676318Z", + "iopub.status.idle": "2025-12-22T06:22:59.678451Z", + "shell.execute_reply": "2025-12-22T06:22:59.678264Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded 9 training pairs.\n", + "dict_keys(['query_id', 'id_p', 's_q', 's_p'])\n" + ] + } + ], + "source": [ + "# Build train_loader for MedLink\n", + "\n", + "from pyhealth.models.medlink import get_train_dataloader, tvt_split\n", + "\n", + "tr_queries, va_queries, te_queries, tr_qrels, va_qrels, te_qrels = tvt_split(\n", + " queries, qrels\n", + ")\n", + "\n", + "train_loader = get_train_dataloader(\n", + " corpus=corpus,\n", + " queries=tr_queries,\n", + " qrels=tr_qrels,\n", + " batch_size=32,\n", + " shuffle=True,\n", + ")\n", + "\n", + "# quick sanity check\n", + "batch = next(iter(train_loader))\n", + "print(batch.keys())\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "c877b5ba", + "metadata": { + "execution": { + "iopub.execute_input": "2025-12-22T06:22:59.679609Z", + "iopub.status.busy": "2025-12-22T06:22:59.679519Z", + "iopub.status.idle": "2025-12-22T06:22:59.706174Z", + "shell.execute_reply": "2025-12-22T06:22:59.705947Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Raw batch keys: dict_keys(['query_id', 'id_p', 's_q', 's_p'])\n", + "MedLink outputs keys: dict_keys(['loss'])\n", + "Loss: 51.74361801147461\n", + "Backward pass completed.\n" + ] + } + ], + "source": [ + "import torch\n", + "from pyhealth.models import BaseModel\n", + "from pyhealth.datasets import SampleDataset\n", + "from pyhealth.models.medlink.model import MedLink\n", + "\n", + "# normalize sequences so tokenizer sees lists, not tensors\n", + "def _normalize_seqs(obj):\n", + " \"\"\"\n", + " Convert batch field (tensor or list of tensors/lists) into\n", + " List[List[str]] as expected by Tokenizer.batch_encode_2d.\n", + " \"\"\"\n", + " if torch.is_tensor(obj):\n", + " obj = obj.tolist() # -> list[list[int]]\n", + "\n", + " seqs_out = []\n", + " for seq in obj:\n", + " if torch.is_tensor(seq):\n", + " seq = seq.tolist()\n", + " # at this point seq is list[int] or list[str]\n", + " seqs_out.append([str(tok) for tok in seq])\n", + " return seqs_out\n", + "\n", + "#init medlink and run a single forward/backward pass\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "# sample_dataset and train_loader must already be defined in earlier cells\n", + "model = MedLink(\n", + " dataset=sample_dataset,\n", + " feature_keys=[\"conditions\"],\n", + " embedding_dim=128,\n", + ").to(device)\n", + "\n", + "# Take one batch from the MedLink train dataloader\n", + "batch = next(iter(train_loader))\n", + "print(\"Raw batch keys:\", batch.keys())\n", + "\n", + "# Normalize the sequence fields so AdmissionPrediction/Tokenizer work\n", + "if \"s_q\" in batch:\n", + " batch[\"s_q\"] = _normalize_seqs(batch[\"s_q\"])\n", + "if \"s_p\" in batch:\n", + " batch[\"s_p\"] = _normalize_seqs(batch[\"s_p\"])\n", + "if \"s_n\" in batch and batch[\"s_n\"] is not None:\n", + " batch[\"s_n\"] = _normalize_seqs(batch[\"s_n\"])\n", + "\n", + "model.train()\n", + "outputs = model(**batch)\n", + "print(\"MedLink outputs keys:\", outputs.keys())\n", + "print(\"Loss:\", float(outputs[\"loss\"]))\n", + "\n", + "outputs[\"loss\"].backward()\n", + "print(\"Backward pass completed.\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "03113472", + "metadata": { + "execution": { + "iopub.execute_input": "2025-12-22T06:22:59.707276Z", + "iopub.status.busy": "2025-12-22T06:22:59.707206Z", + "iopub.status.idle": "2025-12-22T06:22:59.748012Z", + "shell.execute_reply": "2025-12-22T06:22:59.747776Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 0: avg loss = 44.0731\n", + "epoch 1: avg loss = 67.9914\n", + "epoch 2: avg loss = 39.2616\n" + ] + } + ], + "source": [ + "#Sanity\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n", + "\n", + "for epoch in range(3):\n", + " total = 0.0\n", + " n = 0\n", + " for batch in train_loader:\n", + " # normalize s_q / s_p as before\n", + " batch[\"s_q\"] = _normalize_seqs(batch[\"s_q\"])\n", + " batch[\"s_p\"] = _normalize_seqs(batch[\"s_p\"])\n", + " if \"s_n\" in batch and batch[\"s_n\"] is not None:\n", + " batch[\"s_n\"] = _normalize_seqs(batch[\"s_n\"])\n", + "\n", + " optimizer.zero_grad()\n", + " out = model(**batch)\n", + " loss = out[\"loss\"]\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " total += float(loss)\n", + " n += 1\n", + " print(f\"epoch {epoch}: avg loss = {total / max(n,1):.4f}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ed96b498", + "metadata": { + "execution": { + "iopub.execute_input": "2025-12-22T06:22:59.749296Z", + "iopub.status.busy": "2025-12-22T06:22:59.749219Z", + "iopub.status.idle": "2025-12-22T06:22:59.750807Z", + "shell.execute_reply": "2025-12-22T06:22:59.750582Z" + } + }, + "outputs": [], + "source": [ + "#Unit test script - pytest tests/core/test_medlink.py\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b77a6ee2", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/test_eICU_addition.py b/examples/test_eICU_addition.py index 82c05f38e..e588e3221 100644 --- a/examples/test_eICU_addition.py +++ b/examples/test_eICU_addition.py @@ -1,15 +1,17 @@ from pyhealth.datasets import eICUDataset -from pyhealth.tasks import mortality_prediction_eicu_fn, mortality_prediction_eicu_fn2 +from pyhealth.tasks import MortalityPredictionEICU, MortalityPredictionEICU2 -base_dataset = eICUDataset( - root="/srv/local/data/physionet.org/files/eicu-crd/2.0", - tables=["diagnosis", "admissionDx", "treatment"], - dev=False, - refresh_cache=False, -) -sample_dataset = base_dataset.set_task(task_fn=mortality_prediction_eicu_fn2) -sample_dataset.stat() -print(sample_dataset.available_keys) +if __name__ == "__main__": + base_dataset = eICUDataset( + root="/srv/local/data/physionet.org/files/eicu-crd/2.0", + tables=["diagnosis", "admissionDx", "treatment"], + dev=False, + refresh_cache=False, + ) + task = MortalityPredictionEICU2() + sample_dataset = base_dataset.set_task(task=task) + sample_dataset.stat() + print(sample_dataset.available_keys) # base_dataset = eICUDataset( # root="/srv/local/data/physionet.org/files/eicu-crd/2.0", @@ -17,6 +19,7 @@ # dev=True, # refresh_cache=False, # ) -# sample_dataset = base_dataset.set_task(task_fn=mortality_prediction_eicu_fn2) +# task = MortalityPredictionEICU2() +# sample_dataset = base_dataset.set_task(task=task) # sample_dataset.stat() # print(sample_dataset.available_keys) diff --git a/pyhealth/__init__.py b/pyhealth/__init__.py index efd7e39b7..722483dcb 100755 --- a/pyhealth/__init__.py +++ b/pyhealth/__init__.py @@ -18,3 +18,4 @@ formatter = logging.Formatter("%(message)s") handler.setFormatter(formatter) logger.addHandler(handler) + diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5c3683bc1..ee606158f 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -26,4 +26,5 @@ from .transformer import Transformer, TransformerLayer from .transformers_model import TransformersModel from .vae import VAE -from .sdoh import SdohClassifier \ No newline at end of file +from .sdoh import SdohClassifier +from .medlink import MedLink diff --git a/pyhealth/models/embedding.py b/pyhealth/models/embedding.py index a3aad8244..34c1955bc 100644 --- a/pyhealth/models/embedding.py +++ b/pyhealth/models/embedding.py @@ -1,4 +1,7 @@ -from typing import Dict +from __future__ import annotations + +from typing import Dict, Any, Optional, Union +import os import torch import torch.nn as nn @@ -18,6 +21,94 @@ ) from .base_model import BaseModel + +def _iter_text_vectors( + path: str, + embedding_dim: int, + wanted_tokens: set[str], + encoding: str = "utf-8", +) -> Dict[str, torch.Tensor]: + """Loads word vectors from a text file (e.g., GloVe) for a subset of tokens. + + Expected format: one token per line followed by embedding_dim floats. + + This function reads the file line-by-line and only retains vectors for + tokens present in `wanted_tokens`. + """ + + if not os.path.exists(path): + raise FileNotFoundError(f"pretrained embedding file not found: {path}") + + vectors: Dict[str, torch.Tensor] = {} + with open(path, "r", encoding=encoding) as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split() + # token + embedding_dim values + if len(parts) < embedding_dim + 1: + continue + token = parts[0] + if token not in wanted_tokens: + continue + try: + vec = torch.tensor( + [float(x) for x in parts[1 : embedding_dim + 1]], + dtype=torch.float, + ) + except ValueError: + continue + vectors[token] = vec + return vectors + + +def init_embedding_with_pretrained( + embedding: nn.Embedding, + code_vocab: Dict[Any, int], + pretrained_path: str, + embedding_dim: int, + pad_token: str = "", + unk_token: str = "", + normalize: bool = False, + freeze: bool = False, +) -> int: + """Initializes an nn.Embedding from a pretrained text-vector file. + + Tokens not found in the pretrained file are left as the module's existing + random initialization. + + Returns: + int: number of tokens successfully initialized from the file. + """ + + # Build wanted token set (stringified) + vocab_tokens = {str(t) for t in code_vocab.keys()} + vectors = _iter_text_vectors(pretrained_path, embedding_dim, vocab_tokens) + + loaded = 0 + with torch.no_grad(): + for tok, idx in code_vocab.items(): + tok_s = str(tok) + if tok_s in vectors: + vec = vectors[tok_s] + if normalize: + vec = vec / (vec.norm(p=2) + 1e-12) + embedding.weight[idx].copy_(vec) + loaded += 1 + + # Ensure pad row is zero + if pad_token in code_vocab: + embedding.weight[code_vocab[pad_token]].zero_() + # If embedding has a padding_idx, keep it consistent + if embedding.padding_idx is not None: + embedding.weight[embedding.padding_idx].zero_() + + if freeze: + embedding.weight.requires_grad_(False) + + return loaded + class EmbeddingModel(BaseModel): """ EmbeddingModel is responsible for creating embedding layers for different types of input data. @@ -46,7 +137,14 @@ class EmbeddingModel(BaseModel): - MultiHotProcessor: nn.Linear over multi-hot vector """ - def __init__(self, dataset: SampleDataset, embedding_dim: int = 128): + def __init__( + self, + dataset: SampleDataset, + embedding_dim: int = 128, + pretrained_emb_path: Optional[Union[str, Dict[str, str]]] = None, + freeze_pretrained: bool = False, + normalize_pretrained: bool = False, + ): super().__init__(dataset) self.embedding_dim = embedding_dim self.embedding_layers = nn.ModuleDict() @@ -81,6 +179,22 @@ def __init__(self, dataset: SampleDataset, embedding_dim: int = 128): padding_idx=0, ) + # Optional pretrained initialization (e.g., GloVe). + if pretrained_emb_path is not None: + if isinstance(pretrained_emb_path, str): + path = pretrained_emb_path + else: + path = pretrained_emb_path.get(field_name) + if path: + init_embedding_with_pretrained( + self.embedding_layers[field_name], + processor.code_vocab, + path, + embedding_dim=embedding_dim, + normalize=normalize_pretrained, + freeze=freeze_pretrained, + ) + # Numeric features (including deep nested floats) -> nn.Linear over last dim elif isinstance( processor, diff --git a/pyhealth/models/medlink.py b/pyhealth/models/medlink.py new file mode 100644 index 000000000..9fd30aa92 --- /dev/null +++ b/pyhealth/models/medlink.py @@ -0,0 +1,447 @@ +from typing import Dict, List + +import torch +import torch.nn as nn +import torch.nn.functional as F +import tqdm + +from pyhealth.datasets import SampleDataset +from pyhealth.models import BaseModel +from pyhealth.models.transformer import TransformerLayer +from pyhealth.tokenizer import Tokenizer + + +def batch_to_multi_hot(label_batch: torch.Tensor, num_classes: int) -> torch.Tensor: + """ + Convert a 2D batch of label indices into a multi-hot representation. + + Parameters + ---------- + label_batch: + Long tensor of shape (batch_size, seq_len) with token indices. + num_classes: + Size of vocabulary. + + Returns + ------- + multi_hot: + Float tensor of shape (batch_size, num_classes), entries in {0,1}. + """ + # label_batch: (B, T) + batch_size, seq_len = label_batch.shape + flat = label_batch.view(-1) # (B*T,) + # Build index for scatter + row_idx = torch.arange(batch_size, device=label_batch.device).repeat_interleave(seq_len) + multi_hot = torch.zeros(batch_size, num_classes, device=label_batch.device, dtype=torch.float32) + multi_hot.index_put_((row_idx, flat), torch.ones_like(flat, dtype=torch.float32), accumulate=True) + multi_hot.clamp_max_(1.0) + return multi_hot + + +class AdmissionEncoder(nn.Module): + """ + Encodes a sequence of discrete tokens (code sequence) for MedLink. + + It uses: + - a learnable embedding over the vocabulary + - a TransformerLayer backbone + - a BCE-with-logits loss over multi-hot targets + """ + + def __init__( + self, + tokenizer: Tokenizer, + embedding_dim: int, + heads: int = 2, + dropout: float = 0.5, + num_layers: int = 1, + ) -> None: + super().__init__() + self.tokenizer = tokenizer + self.vocab_size = tokenizer.get_vocabulary_size() + + self.embedding = nn.Embedding( + num_embeddings=self.vocab_size, + embedding_dim=embedding_dim, + padding_idx=tokenizer.get_padding_index(), + ) + + self.encoder = TransformerLayer( + feature_size=embedding_dim, + heads=heads, + dropout=dropout, + num_layers=num_layers, + ) + + self.criterion = nn.BCEWithLogitsLoss() + + def _encode_tokens(self, seqs: List[List[str]], device: torch.device): + """ + Turn a batch of token sequences into contextual embeddings and a padding mask. + + seqs: list of list of token strings, e.g. [["250.0","401.9"], ["414.0"], ...] + """ + token_ids = self.tokenizer.batch_encode_2d(seqs, padding=True) + token_ids = torch.tensor(token_ids, dtype=torch.long, device=device) # (B, T) + pad_idx = self.tokenizer.get_padding_index() + mask = token_ids != pad_idx # (B, T) + + emb = self.embedding(token_ids) # (B, T, D) + encoded, _ = self.encoder(emb) # (B, T, D) + return encoded, mask, token_ids + + def _multi_hot_targets(self, token_ids: torch.Tensor) -> torch.Tensor: + """ + Build a multi-hot target vector for each sequence in the batch. + """ + multi_hot = batch_to_multi_hot(token_ids, self.vocab_size) # (B, V) + # Clear special tokens + pad_id = self.tokenizer.vocabulary("") + cls_id = self.tokenizer.vocabulary("") + if pad_id is not None: + multi_hot[:, pad_id] = 0.0 + if cls_id is not None: + multi_hot[:, cls_id] = 0.0 + return multi_hot + + def logits_and_targets( + self, + seqs: List[List[str]], + vocab_embeddings: torch.Tensor, + device: torch.device, + ): + """ + Compute: + - per-token logits against vocab embeddings + - multi-hot label vectors for the sequence. + + Returns + ------- + logits: (B, V) tensor + targets: (B, V) tensor multi-hot + """ + encoded, mask, token_ids = self._encode_tokens(seqs, device=device) # (B,T,D), (B,T), (B,T) + targets = self._multi_hot_targets(token_ids) # (B,V) + + # encoded: (B,T,D), vocab_embeddings: (V,D) + # per-token logits: (B,T,V) + logits = torch.matmul(encoded, vocab_embeddings.T) # (B,T,V) + # mask padded positions with large negative value + mask_expanded = mask.unsqueeze(-1) # (B,T,1) + logits = logits.masked_fill(~mask_expanded, -1e9) + # max-pool over time + logits = logits.max(dim=1).values # (B,V) + + return logits, targets + + def classification_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """ + BCE loss on multi-hot labels; handles potential size mismatches defensively. + """ + # In case of tiny mismatches, truncate to the smaller dimension. + batch = min(logits.size(0), targets.size(0)) + logits = logits[:batch] + targets = targets[:batch] + return self.criterion(logits, targets) + + +class MedLink(BaseModel): + """ + MedLink: de-identified patient record linkage model (KDD 2023). + + This model links de-identified patient records using admission sequences + and a transformer-based architecture. It is designed to operate on PyHealth's `SampleDataset`. + + Inputs: + - dataset (SampleDataset): The dataset containing patient admission sequences. + - feature_keys (List[str]): List with the key for patient admission codes (only the first is used). + - embedding_dim (int, default=128): Embedding dimension for learned token embeddings. + - alpha, beta, gamma (float): Loss weights for model's multi-loss objective. + - heads (int): Number of transformer heads. + - dropout (float): Dropout rate for transformer encoders. + - num_layers (int): Number of layers in transformer encoders. + + Outputs: + - The model primarily outputs a dictionary {"loss": loss_tensor} during training (see forward method). + - For retrieval/evaluation, the model provides embeddings and search utilities to score record similarity. + + Example: + >>> from pyhealth.datasets import SampleDataset + >>> from pyhealth.models import MedLink + >>> samples = [{"patient_id": "1", "admissions": ["ICD9_430", "ICD9_401"]}, ...] + >>> input_schema = {"admissions": "code"} + >>> output_schema = {"label": "binary"} + >>> dataset = SampleDataset(path="/some/path", samples=samples, input_schema=input_schema, output_schema=output_schema) + >>> model = MedLink(dataset=dataset, feature_keys=["admissions"]) + >>> batch = {"query_id": [...], "id_p": [...], "s_q": [["ICD9_430", "ICD9_401"]], "s_p": [[...]], "s_n": None} + >>> out = model(**batch) + >>> print(out["loss"])is + + Notes: + - Only works with a single feature_key (list of length 1). + - Specialized for code sequence/text-based features (e.g., admissions). + - Retrieval is performed via TF-IDF-style similarity on learned multi-hot embeddings. + """ + + def __init__( + self, + dataset: SampleDataset, + feature_keys: List[str], + embedding_dim: int = 128, + alpha: float = 0.5, + beta: float = 0.5, + gamma: float = 1.0, + heads: int = 2, + dropout: float = 0.5, + num_layers: int = 1, + **kwargs, + ) -> None: + # MedLink is defined over a single textual / code sequence feature + assert len(feature_keys) == 1, "MedLink supports exactly one feature key" + # BaseModel only accepts dataset parameter, not feature_keys, label_key, or mode + super().__init__(dataset=dataset) + # Set feature_keys manually since BaseModel extracts it from dataset.input_schema + # but MedLink needs to use the provided feature_keys + self.feature_keys = feature_keys + self.feature_key = feature_keys[0] + self.embedding_dim = embedding_dim + self.alpha = alpha + self.beta = beta + self.gamma = gamma + + # Build vocabulary for both queries and corpus sides + q_tokens = dataset.get_all_tokens(key=self.feature_key) + d_tokens = dataset.get_all_tokens(key="d_" + self.feature_key) + + tokenizer = Tokenizer( + tokens=q_tokens + d_tokens, + special_tokens=["", "", ""], + ) + self.tokenizer = tokenizer + self.vocab_size = tokenizer.get_vocabulary_size() + + # Two direction-specific encoders (forward / backward) + self.forward_encoder = AdmissionEncoder( + tokenizer=tokenizer, + embedding_dim=embedding_dim, + heads=heads, + dropout=dropout, + num_layers=num_layers, + ) + self.backward_encoder = AdmissionEncoder( + tokenizer=tokenizer, + embedding_dim=embedding_dim, + heads=heads, + dropout=dropout, + num_layers=num_layers, + ) + + # Retrieval / ranking loss + self.rank_loss = nn.CrossEntropyLoss() + + # ------------------------ + # Encoding utilities + # ------------------------ + def _all_vocab_ids(self) -> torch.Tensor: + return torch.arange(self.vocab_size, device=self.device, dtype=torch.long) + + def encode_queries(self, queries: List[List[str]]) -> torch.Tensor: + """ + Encode query records into embeddings for retrieval. + + queries: list of token sequences, e.g. [["250.0","401.9"], ...] + Returns: (num_queries, vocab_size) embedding matrix. + """ + all_vocab = self._all_vocab_ids() # (V,) + bwd_vocab_emb = self.backward_encoder.embedding(all_vocab) # (V,D) + + logits, multi_hot = self.backward_encoder.logits_and_targets( + seqs=queries, + vocab_embeddings=bwd_vocab_emb, + device=self.device, + ) + logits = torch.log1p(F.relu(logits)) # smooth nonlinearity + return logits + multi_hot # (Q,V) + + def encode_corpus(self, corpus: List[List[str]]) -> torch.Tensor: + """ + Encode corpus records into embeddings for retrieval. + + corpus: list of token sequences. + Returns: (num_docs, vocab_size) embedding matrix. + """ + all_vocab = self._all_vocab_ids() + fwd_vocab_emb = self.forward_encoder.embedding(all_vocab) # (V,D) + + logits, multi_hot = self.forward_encoder.logits_and_targets( + seqs=corpus, + vocab_embeddings=fwd_vocab_emb, + device=self.device, + ) + logits = torch.log1p(F.relu(logits)) + return logits + multi_hot # (D,V) + + # ------------------------ + # Retrieval scoring + # ------------------------ + @staticmethod + def compute_scores(queries_emb: torch.Tensor, corpus_emb: torch.Tensor) -> torch.Tensor: + """ + Compute TF-IDF-like matching scores between queries and corpus. + + queries_emb: (Q,V) + corpus_emb: (D,V) + + Returns: + scores: (Q,D) + """ + # Inverse document frequency per term + n_docs = torch.tensor(corpus_emb.shape[0], device=corpus_emb.device, dtype=torch.float32) + df = (corpus_emb > 0).sum(dim=0) # (V,) + idf = torch.log1p(n_docs) - torch.log1p(df) + + # Term-frequency contribution per (query, doc, term) + tf = torch.einsum("qv,dv->qdv", queries_emb, corpus_emb) # (Q,D,V) + tf_idf = tf * idf # broadcast idf over last dim + + scores = tf_idf.sum(dim=-1) # (Q,D) + return scores + + def get_loss(self, scores: torch.Tensor) -> torch.Tensor: + """ + Retrieval loss: each query is matched to its corresponding positive + document at the same index. + """ + num_queries = scores.size(0) + target = torch.arange(num_queries, device=scores.device, dtype=torch.long) + return self.rank_loss(scores, target) + + # ------------------------ + # Training forward + # ------------------------ + def forward( + self, + query_id, + id_p, + s_q, + s_p, + s_n=None, + ) -> Dict[str, torch.Tensor]: + """ + Forward pass used for training. + + Parameters in the batch (dict passed as **batch): + - query_id: list of query identifiers (unused by the loss) + - id_p: list of positive record ids (unused here, used for evaluation) + - s_q: list of query sequences (list[list[str]]) + - s_p: list of positive corpus sequences (list[list[str]]) + - s_n: optional list of negative corpus sequences (list[list[str]]) + + Returns + ------- + dict with key "loss": scalar tensor. + """ + # Build full corpus: positives plus negatives if provided + if s_n is None: + corpus = s_p + else: + corpus = s_p + s_n + queries = s_q + + # Precompute vocab embeddings for both encoders + all_vocab = self._all_vocab_ids() + fwd_vocab_emb = self.forward_encoder.embedding(all_vocab) # (V,D) + bwd_vocab_emb = self.backward_encoder.embedding(all_vocab) # (V,D) + + # Forward and backward admission prediction losses + # Corpus -> query distributions + pred_queries, corpus_targets = self.forward_encoder.logits_and_targets( + seqs=corpus, + vocab_embeddings=fwd_vocab_emb, + device=self.device, + ) + # Query -> corpus distributions + pred_corpus, query_targets = self.backward_encoder.logits_and_targets( + seqs=queries, + vocab_embeddings=bwd_vocab_emb, + device=self.device, + ) + + fwd_cls_loss = self.forward_encoder.classification_loss(pred_queries, query_targets) + bwd_cls_loss = self.backward_encoder.classification_loss(pred_corpus, corpus_targets) + + # Turn predictions into dense embeddings + pred_queries_act = torch.log1p(F.relu(pred_queries)) + pred_corpus_act = torch.log1p(F.relu(pred_corpus)) + + corpus_emb = corpus_targets + pred_queries_act + queries_emb = query_targets + pred_corpus_act + + scores = self.compute_scores(queries_emb, corpus_emb) + retrieval_loss = self.get_loss(scores) + + total_loss = ( + self.alpha * fwd_cls_loss + + self.beta * bwd_cls_loss + + self.gamma * retrieval_loss + ) + return {"loss": total_loss} + + # ------------------------ + # Retrieval helpers + # ------------------------ + def search( + self, + queries_ids: List[str], + queries_embeddings: torch.Tensor, + corpus_ids: List[str], + corpus_embeddings: torch.Tensor, + ) -> Dict[str, Dict[str, float]]: + """ + Compute scores for all (query, corpus) pairs and return as nested dict: + {query_id: {corpus_id: score, ...}, ...} + """ + scores = self.compute_scores(queries_embeddings, corpus_embeddings) # (Q,D) + results: Dict[str, Dict[str, float]] = {} + for q_idx, q_id in enumerate(queries_ids): + row_scores = scores[q_idx] + results[q_id] = {c_id: row_scores[c_idx].item() for c_idx, c_id in enumerate(corpus_ids)} + return results + + def evaluate(self, corpus_dataloader, queries_dataloader) -> Dict[str, Dict[str, float]]: + """ + Run MedLink in retrieval mode on dataloaders for corpus and queries. + + corpus_dataloader yields batches with keys: "corpus_id", "s". + queries_dataloader yields batches with keys: "query_id", "s". + """ + self.eval() + all_corpus_ids: List[str] = [] + all_queries_ids: List[str] = [] + all_corpus_embeddings: List[torch.Tensor] = [] + all_queries_embeddings: List[torch.Tensor] = [] + + with torch.no_grad(): + for batch in tqdm.tqdm(corpus_dataloader): + corpus_ids = batch["corpus_id"] + corpus_seqs = batch["s"] + corpus_emb = self.encode_corpus(corpus_seqs) + all_corpus_ids.extend(corpus_ids) + all_corpus_embeddings.append(corpus_emb) + + for batch in tqdm.tqdm(queries_dataloader): + query_ids = batch["query_id"] + query_seqs = batch["s"] + query_emb = self.encode_queries(query_seqs) + all_queries_ids.extend(query_ids) + all_queries_embeddings.append(query_emb) + + corpus_mat = torch.cat(all_corpus_embeddings, dim=0) + queries_mat = torch.cat(all_queries_embeddings, dim=0) + + return self.search( + queries_ids=all_queries_ids, + queries_embeddings=queries_mat, + corpus_ids=all_corpus_ids, + corpus_embeddings=corpus_mat, + ) diff --git a/pyhealth/models/medlink/model.py b/pyhealth/models/medlink/model.py index f4e2b4ddb..ad007becb 100644 --- a/pyhealth/models/medlink/model.py +++ b/pyhealth/models/medlink/model.py @@ -1,185 +1,359 @@ -from typing import Dict, List +from __future__ import annotations + +from typing import Dict, List, Any, Optional, Sequence, Tuple import torch import torch.nn as nn import torch.nn.functional as F import tqdm +from torch.nn.utils.rnn import pad_sequence + +from ...datasets import SampleDataset +from ..base_model import BaseModel +from ..transformer import TransformerLayer +from ...processors import SequenceProcessor + +from ..embedding import init_embedding_with_pretrained + + +def _build_shared_vocab( + q_processor: SequenceProcessor, + d_processor: SequenceProcessor, + pad_token: str = "", + unk_token: str = "", +) -> Dict[str, int]: + """Build a shared token->index mapping from two fitted SequenceProcessors. + + The returned vocabulary is deterministic (sorted token order) and always + includes `pad_token` and `unk_token`. + """ -from pyhealth.datasets import SampleEHRDataset -from pyhealth.models import BaseModel -from pyhealth.models.transformer import TransformerLayer -from pyhealth.tokenizer import Tokenizer + vocab: Dict[str, int] = {pad_token: 0, unk_token: 1} + tokens = set(str(t) for t in q_processor.code_vocab.keys()) | set( + str(t) for t in d_processor.code_vocab.keys() + ) + tokens.discard(pad_token) + tokens.discard(unk_token) + + for t in sorted(tokens): + if t not in vocab: + vocab[t] = len(vocab) + return vocab + + +def _build_index_remap( + processor: SequenceProcessor, + shared_vocab: Dict[str, int], + unk_idx: int, +) -> torch.Tensor: + """Build a dense remap tensor old_idx -> shared_idx.""" + + size = len(processor.code_vocab) + remap = torch.full((size,), unk_idx, dtype=torch.long) + for tok, old_idx in processor.code_vocab.items(): + tok_s = str(tok) + remap[old_idx] = shared_vocab.get(tok_s, unk_idx) + return remap + + +def _to_index_tensor( + seq: Any, + processor: SequenceProcessor, +) -> torch.Tensor: + """Converts a single sequence to an index tensor using the fitted processor.""" + if isinstance(seq, torch.Tensor): + return seq.long() + if isinstance(seq, (list, tuple)): + return processor.process(seq) + # single token + return processor.process([seq]) + + +def _pad_and_remap( + sequences: Sequence[Any], + processor: SequenceProcessor, + remap: torch.Tensor, + pad_value: int = 0, + device: Optional[torch.device] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Pads a batch of sequences and remaps indices into the shared vocab. + + Returns: + ids_shared: LongTensor [B, L] + mask: BoolTensor [B, L] where True indicates valid token positions. + """ -def batch_to_one_hot(label_batch, num_class): - """ convert to one hot label """ - label_batch_onehot = [] - for label in label_batch: - label_batch_onehot.append(F.one_hot(label, num_class).sum(dim=0)) - label_batch_onehot = torch.stack(label_batch_onehot, dim=0) - label_batch_onehot[label_batch_onehot > 1] = 1 - return label_batch_onehot + ids = [_to_index_tensor(s, processor) for s in sequences] + ids_padded = pad_sequence(ids, batch_first=True, padding_value=pad_value) + if device is not None: + ids_padded = ids_padded.to(device) + remap = remap.to(device) + ids_shared = remap[ids_padded] + mask = ids_shared != 0 + return ids_shared, mask class AdmissionPrediction(nn.Module): - def __init__(self, tokenizer, embedding_dim, heads=2, dropout=0.5, num_layers=1): - super(AdmissionPrediction, self).__init__() - self.tokenizer = tokenizer - self.vocabs_size = tokenizer.get_vocabulary_size() + """Admission prediction module used by MedLink. + + This is a lightly-adapted version of the original MedLink implementation, + refactored to work with PyHealth 2.0 processors (i.e., indexed tensors). + """ + + def __init__( + self, + code_vocab: Dict[str, int], + embedding_dim: int, + heads: int = 2, + dropout: float = 0.5, + num_layers: int = 1, + pretrained_emb_path: Optional[str] = None, + freeze_pretrained: bool = False, + ): + super().__init__() + self.code_vocab = code_vocab + self.vocab_size = len(code_vocab) + self.pad_idx = code_vocab.get("", 0) + self.cls_idx = code_vocab.get("") + self.embedding = nn.Embedding( - self.vocabs_size, - embedding_dim, - padding_idx=tokenizer.get_padding_index() + num_embeddings=self.vocab_size, + embedding_dim=embedding_dim, + padding_idx=self.pad_idx, ) + if pretrained_emb_path: + init_embedding_with_pretrained( + self.embedding, + code_vocab, + pretrained_emb_path, + embedding_dim=embedding_dim, + freeze=freeze_pretrained, + ) + self.encoder = TransformerLayer( feature_size=embedding_dim, heads=heads, dropout=dropout, - num_layers=num_layers + num_layers=num_layers, ) self.criterion = nn.BCEWithLogitsLoss() - def encode_one_hot(self, input: List[str], device): - input_batch = self.tokenizer.batch_encode_2d(input, padding=True) - input_batch = torch.tensor(input_batch, dtype=torch.long, device=device) - input_onehot = batch_to_one_hot(input_batch, self.vocabs_size) - input_onehot = input_onehot.float().to(device) - input_onehot[:, self.tokenizer.vocabulary("")] = 0 - input_onehot[:, self.tokenizer.vocabulary("")] = 0 - return input_onehot - - def encode_dense(self, input: List[str], device): - input_batch = self.tokenizer.batch_encode_2d(input, padding=True) - input_batch = torch.tensor(input_batch, dtype=torch.long, device=device) - mask = input_batch != 0 - input_embeddings = self.embedding(input_batch) - input_embeddings, _ = self.encoder(input_embeddings) - return input_embeddings, mask - - def get_loss(self, logits, target_onehot): - true_batch_size = min(logits.shape[0], target_onehot.shape[0]) - loss = self.criterion(logits[:true_batch_size], target_onehot[:true_batch_size]) - return loss - - def forward(self, input, vocab_emb, device): - input_dense, mask = self.encode_dense(input, device) - input_one_hot = self.encode_one_hot(input, device) - logits = torch.matmul(input_dense, vocab_emb.T) - logits[~mask] = -1e9 - logits = logits.max(dim=1)[0] - return logits, input_one_hot + def _multi_hot(self, input_ids: torch.Tensor) -> torch.Tensor: + """Builds a multi-hot label vector per sample.""" + + # input_ids: [B, L] + bsz = input_ids.size(0) + out = torch.zeros(bsz, self.vocab_size, device=input_ids.device, dtype=torch.float) + src = torch.ones_like(input_ids, dtype=torch.float) + out.scatter_add_(1, input_ids, src) + out = (out > 0).float() + # Remove special tokens from labels. + if self.pad_idx is not None: + out[:, self.pad_idx] = 0 + if self.cls_idx is not None: + out[:, self.cls_idx] = 0 + return out + + def get_loss(self, logits: torch.Tensor, target_multi_hot: torch.Tensor) -> torch.Tensor: + true_batch_size = min(logits.shape[0], target_multi_hot.shape[0]) + return self.criterion(logits[:true_batch_size], target_multi_hot[:true_batch_size]) + + def forward(self, input_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute vocabulary logits and target multi-hot labels. + + Args: + input_ids: LongTensor [B, L] in shared vocabulary indices. + + Returns: + logits: FloatTensor [B, V] + target: FloatTensor [B, V] multi-hot labels. + """ + + mask = input_ids != self.pad_idx + x = self.embedding(input_ids) + x, _ = self.encoder(x, mask=mask) + + # Use embedding table as vocabulary embedding. + vocab_emb = self.embedding.weight # [V, D] + logits = torch.matmul(x, vocab_emb.T) # [B, L, V] + logits = logits.masked_fill(~mask.unsqueeze(-1), -1e9) + logits = logits.max(dim=1).values # [B, V] + + target = self._multi_hot(input_ids) + return logits, target class MedLink(BaseModel): - """MedLink model. + """MedLink model (KDD 2023). - Paper: Zhenbang Wu et al. MedLink: De-Identified Patient Health - Record Linkage. KDD 2023. + Paper: Zhenbang Wu et al. MedLink: De-Identified Patient Health Record + Linkage. KDD 2023. - IMPORTANT: This implementation differs from the original paper in order to - make it work with the PyHealth framework. Specifically, we do not use the - pre-trained GloVe embeddings. And we only monitor the loss on the validation - set instead of the ranking metrics. As a result, the performance of this model - is different from the original paper. To reproduce the results in the paper, - please use the official GitHub repo: https://github.com/zzachw/MedLink. + IMPORTANT: This implementation differs from the original paper to fit the + PyHealth 2.0 framework. By default, it uses randomly-initialized embeddings. + Optionally, you may initialize the embedding tables using a GloVe-style + text vector file. Args: - dataset: SampleEHRDataset. - feature_keys: List of feature keys. MedLink only supports one feature key. - embedding_dim: Dimension of embedding. - alpha: Weight of the forward prediction loss. - beta: Weight of the backward prediction loss. - gamma: Weight of the retrieval loss. + dataset: SampleDataset. + feature_keys: List of feature keys. MedLink only supports one feature. + embedding_dim: embedding dimension. + alpha: weight for forward prediction loss. + beta: weight for backward prediction loss. + gamma: weight for retrieval loss. + pretrained_emb_path: optional path to a GloVe-style embedding file. + freeze_pretrained: if True, freezes embedding weights after init. """ def __init__( self, - dataset: SampleEHRDataset, + dataset: SampleDataset, feature_keys: List[str], embedding_dim: int = 128, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.0, + pretrained_emb_path: Optional[str] = None, + freeze_pretrained: bool = False, **kwargs, ): assert len(feature_keys) == 1, "MedLink only supports one feature key" - super(MedLink, self).__init__( - dataset=dataset, - feature_keys=feature_keys, - label_key=None, - mode=None, - ) + super().__init__(dataset=dataset) + self.feature_key = feature_keys[0] self.embedding_dim = embedding_dim self.alpha = alpha self.beta = beta self.gamma = gamma - q_tokens = self.dataset.get_all_tokens(key=self.feature_key) - d_tokens = self.dataset.get_all_tokens(key="d_" + self.feature_key) - tokenizer = Tokenizer( - tokens=q_tokens + d_tokens, - special_tokens=["", "", ""], - ) - self.fwd_adm_pred = AdmissionPrediction(tokenizer, embedding_dim, **kwargs) - self.bwd_adm_pred = AdmissionPrediction(tokenizer, embedding_dim, **kwargs) - self.criterion = nn.CrossEntropyLoss() - self.vocabs_size = tokenizer.get_vocabulary_size() - return - - def encode_queries(self, queries: List[str]): - all_vocab = torch.tensor(list(range(self.vocabs_size)), device=self.device) - bwd_vocab_emb = self.bwd_adm_pred.embedding(all_vocab) - pred_corpus, queries_one_hot = self.bwd_adm_pred( - queries, bwd_vocab_emb, device=self.device - ) - pred_corpus = torch.log(1 + torch.relu(pred_corpus)) - queries_emb = pred_corpus + queries_one_hot - return queries_emb - def encode_corpus(self, corpus: List[str]): - all_vocab = torch.tensor(list(range(self.vocabs_size)), device=self.device) - fwd_vocab_emb = self.fwd_adm_pred.embedding(all_vocab) - pred_queries, corpus_one_hot = self.fwd_adm_pred( - corpus, fwd_vocab_emb, device=self.device + q_field = self.feature_key + d_field = "d_" + self.feature_key + if q_field not in self.dataset.input_processors or d_field not in self.dataset.input_processors: + raise KeyError( + f"MedLink expects both '{q_field}' and '{d_field}' in dataset.input_schema" + ) + + q_processor = self.dataset.input_processors[q_field] + d_processor = self.dataset.input_processors[d_field] + if not isinstance(q_processor, SequenceProcessor) or not isinstance(d_processor, SequenceProcessor): + raise TypeError( + "MedLink currently supports SequenceProcessor for both query and corpus fields" + ) + + self.q_processor = q_processor + self.d_processor = d_processor + + # Shared vocabulary across query/corpus streams. + self.code_vocab = _build_shared_vocab(q_processor, d_processor) + self.vocab_size = len(self.code_vocab) + self.unk_idx = self.code_vocab.get("", 1) + + # Remap tensors from per-field vocab -> shared vocab. + self.q_remap = _build_index_remap(q_processor, self.code_vocab, self.unk_idx) + self.d_remap = _build_index_remap(d_processor, self.code_vocab, self.unk_idx) + + self.fwd_adm_pred = AdmissionPrediction( + code_vocab=self.code_vocab, + embedding_dim=embedding_dim, + pretrained_emb_path=pretrained_emb_path, + freeze_pretrained=freeze_pretrained, + **kwargs, ) - pred_queries = torch.log(1 + torch.relu(pred_queries)) - corpus_emb = corpus_one_hot + pred_queries - return corpus_emb + self.forward_encoder = self.fwd_adm_pred.encoder + + self.bwd_adm_pred = AdmissionPrediction( + code_vocab=self.code_vocab, + embedding_dim=embedding_dim, + pretrained_emb_path=pretrained_emb_path, + freeze_pretrained=freeze_pretrained, + **kwargs, + ) + self.backward_encoder = self.bwd_adm_pred.encoder - def compute_scores(self, queries_emb, corpus_emb): - n = torch.tensor(corpus_emb.shape[0]).to(queries_emb.device) - df = (corpus_emb > 0).sum(dim=0) - idf = torch.log(1 + n) - torch.log(1 + df) + self.criterion = nn.CrossEntropyLoss() - tf = torch.einsum('ac,bc->abc', queries_emb, corpus_emb) + # ------------------------------------------------------------------ + # Encoding helpers + # ------------------------------------------------------------------ + + def _prepare_queries(self, queries: Sequence[Any]) -> Tuple[torch.Tensor, torch.Tensor]: + return _pad_and_remap( + queries, + processor=self.q_processor, + remap=self.q_remap, + pad_value=0, + device=self.device, + ) - tf_idf = tf * idf - final_scores = tf_idf.sum(dim=-1) - return final_scores + def _prepare_corpus(self, corpus: Sequence[Any]) -> Tuple[torch.Tensor, torch.Tensor]: + return _pad_and_remap( + corpus, + processor=self.d_processor, + remap=self.d_remap, + pad_value=0, + device=self.device, + ) - def get_loss(self, scores): - label = torch.tensor(list(range(scores.shape[0])), device=scores.device) - loss = self.criterion(scores, label) - return loss + def encode_queries(self, queries: Sequence[Any]) -> torch.Tensor: + q_ids, _ = self._prepare_queries(queries) + pred_corpus, queries_one_hot = self.bwd_adm_pred(q_ids) + pred_corpus = torch.log1p(torch.relu(pred_corpus)) + emb = pred_corpus + queries_one_hot + # Keep special tokens out of retrieval scoring. + emb[:, self.code_vocab.get("", 0)] = 0 + if "" in self.code_vocab: + emb[:, self.code_vocab[""]] = 0 + return emb + + def encode_corpus(self, corpus: Sequence[Any]) -> torch.Tensor: + c_ids, _ = self._prepare_corpus(corpus) + pred_queries, corpus_one_hot = self.fwd_adm_pred(c_ids) + pred_queries = torch.log1p(torch.relu(pred_queries)) + emb = corpus_one_hot + pred_queries + emb[:, self.code_vocab.get("", 0)] = 0 + if "" in self.code_vocab: + emb[:, self.code_vocab[""]] = 0 + return emb + + # ------------------------------------------------------------------ + # Scoring / losses + # ------------------------------------------------------------------ + + def compute_scores(self, queries_emb: torch.Tensor, corpus_emb: torch.Tensor) -> torch.Tensor: + """TF-IDF-like score used by MedLink. + + queries_emb: [Q, V] + corpus_emb: [C, V] + returns: [Q, C] + """ + + n = torch.tensor(float(corpus_emb.shape[0]), device=queries_emb.device) + df = (corpus_emb > 0).sum(dim=0).float() + idf = torch.log1p(n) - torch.log1p(df) + # Equivalent to sum_c q[c] * d[c] * idf[c] + return torch.matmul(queries_emb * idf, corpus_emb.T) + + def get_loss(self, scores: torch.Tensor) -> torch.Tensor: + label = torch.arange(scores.shape[0], device=scores.device) + return self.criterion(scores, label) def forward(self, query_id, id_p, s_q, s_p, s_n=None) -> Dict[str, torch.Tensor]: - corpus = s_p if s_n is None else s_p + s_n + # corpus is positives optionally concatenated with negatives. + corpus = s_p if s_n is None else (s_p + s_n) queries = s_q - all_vocab = torch.tensor(list(range(self.vocabs_size)), device=self.device) - fwd_vocab_emb = self.fwd_adm_pred.embedding(all_vocab) - bwd_vocab_emb = self.bwd_adm_pred.embedding(all_vocab) - pred_queries, corpus_one_hot = self.fwd_adm_pred( - corpus, fwd_vocab_emb, self.device - ) - pred_corpus, queries_one_hot = self.bwd_adm_pred( - queries, bwd_vocab_emb, self.device - ) + + q_ids, _ = self._prepare_queries(queries) + c_ids, _ = self._prepare_corpus(corpus) + + pred_queries, corpus_one_hot = self.fwd_adm_pred(c_ids) + pred_corpus, queries_one_hot = self.bwd_adm_pred(q_ids) fwd_cls_loss = self.fwd_adm_pred.get_loss(pred_queries, queries_one_hot) bwd_cls_loss = self.bwd_adm_pred.get_loss(pred_corpus, corpus_one_hot) - pred_queries = torch.log(1 + torch.relu(pred_queries)) - pred_corpus = torch.log(1 + torch.relu(pred_corpus)) + pred_queries = torch.log1p(torch.relu(pred_queries)) + pred_corpus = torch.log1p(torch.relu(pred_corpus)) corpus_emb = corpus_one_hot + pred_queries queries_emb = pred_corpus + queries_one_hot @@ -187,11 +361,13 @@ def forward(self, query_id, id_p, s_q, s_p, s_n=None) -> Dict[str, torch.Tensor] scores = self.compute_scores(queries_emb, corpus_emb) ret_loss = self.get_loss(scores) - loss = self.alpha * fwd_cls_loss + \ - self.beta * bwd_cls_loss + \ - self.gamma * ret_loss + loss = self.alpha * fwd_cls_loss + self.beta * bwd_cls_loss + self.gamma * ret_loss return {"loss": loss} + # ------------------------------------------------------------------ + # Retrieval API + # ------------------------------------------------------------------ + def search(self, queries_ids, queries_embeddings, corpus_ids, corpus_embeddings): scores = self.compute_scores(queries_embeddings, corpus_embeddings) results = {} @@ -206,30 +382,29 @@ def evaluate(self, corpus_dataloader, queries_dataloader): all_corpus_ids, all_corpus_embeddings = [], [] all_queries_ids, all_queries_embeddings = [], [] with torch.no_grad(): - for i, batch in enumerate(tqdm.tqdm(corpus_dataloader)): + for batch in tqdm.tqdm(corpus_dataloader): corpus_ids, corpus = batch["corpus_id"], batch["s"] corpus_embeddings = self.encode_corpus(corpus) all_corpus_ids.extend(corpus_ids) all_corpus_embeddings.append(corpus_embeddings) - for i, batch in enumerate(tqdm.tqdm(queries_dataloader)): + for batch in tqdm.tqdm(queries_dataloader): queries_ids, queries = batch["query_id"], batch["s"] queries_embeddings = self.encode_queries(queries) all_queries_ids.extend(queries_ids) all_queries_embeddings.append(queries_embeddings) - all_corpus_embeddings = torch.cat(all_corpus_embeddings) - all_queries_embeddings = torch.cat(all_queries_embeddings) - results = self.search( + all_corpus_embeddings = torch.cat(all_corpus_embeddings, dim=0) + all_queries_embeddings = torch.cat(all_queries_embeddings, dim=0) + return self.search( all_queries_ids, all_queries_embeddings, all_corpus_ids, - all_corpus_embeddings + all_corpus_embeddings, ) - return results if __name__ == "__main__": + # Minimal smoke-test matching the public example script. from pyhealth.datasets import MIMIC3Dataset - from pyhealth.models import MedLink from pyhealth.models.medlink import ( convert_to_ir_format, get_train_dataloader, @@ -246,20 +421,10 @@ def evaluate(self, corpus_dataloader, queries_dataloader): ) sample_dataset = base_dataset.set_task(patient_linkage_mimic3_fn) - corpus, queries, qrels = convert_to_ir_format(sample_dataset.samples) - tr_queries, va_queries, te_queries, tr_qrels, va_qrels, te_qrels = tvt_split( - queries, qrels - ) - train_dataloader = get_train_dataloader( - corpus, tr_queries, tr_qrels, batch_size=32, shuffle=True - ) + corpus, queries, qrels, *_ = convert_to_ir_format(sample_dataset.samples) + tr_queries, _, _, tr_qrels, _, _ = tvt_split(queries, qrels) + train_dataloader = get_train_dataloader(corpus, tr_queries, tr_qrels, batch_size=4) batch = next(iter(train_dataloader)) - model = MedLink( - dataset=sample_dataset, - feature_keys=["conditions"], - embedding_dim=128, - ) - with torch.autograd.detect_anomaly(): - o = model(**batch) - print("loss:", o["loss"]) - o["loss"].backward() + model = MedLink(dataset=sample_dataset, feature_keys=["conditions"], embedding_dim=32) + out = model(**batch) + print("loss:", out["loss"].item()) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index bcfb19f7a..520201c25 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -66,3 +66,4 @@ MutationPathogenicityPrediction, VariantClassificationClinVar, ) +from .patient_linkage_mimic3 import PatientLinkageMIMIC3Task diff --git a/pyhealth/tasks/patient_linkage_mimic3.py b/pyhealth/tasks/patient_linkage_mimic3.py new file mode 100644 index 000000000..3772a0811 --- /dev/null +++ b/pyhealth/tasks/patient_linkage_mimic3.py @@ -0,0 +1,106 @@ +from datetime import datetime +from collections import defaultdict +import math +from pyhealth.tasks import BaseTask + +class PatientLinkageMIMIC3Task(BaseTask): + """ + Patient linkage task for MIMIC-III using the Patient/Visit/Event API. + + Produces the same sample keys as the original patient_linkage_mimic3 task + so pyhealth.models.medlink.convert_to_ir_format works as usual + + Output sample schema: + - patient_id: ground-truth entity id (equivalent to "master patient record id" in MIMIC) + - visit_id: query admission id (hadm_id) + - conditions, age, identifiers: query side + - d_visit_id: doc admission id (hadm_id) + - d_conditions, d_age, d_identifiers: doc side + """ + + task_name = "patient_linkage_mimic3" + input_schema = {"conditions": "sequence", "d_conditions": "sequence"} + output_schema = {} + + def __call__(self, patient): + admissions = patient.get_events(event_type="admissions") + if len(admissions) < 2: + return [] + + admissions = sorted(admissions, key=lambda e: e.timestamp) + q_visit = admissions[-1] + d_visit = admissions[-2] + + patients_events = patient.get_events(event_type="patients") + if not patients_events: + return [] + demo = patients_events[0] + + gender = str(demo.attr_dict.get("gender") or "") + + dob_raw = demo.attr_dict.get("dob") + birth_dt = None + if isinstance(dob_raw, datetime): + birth_dt = dob_raw + elif dob_raw is not None: + try: + birth_dt = datetime.fromisoformat(str(dob_raw)) + except Exception: + birth_dt = None + + def compute_age(ts): + if birth_dt is None or ts is None: + return None + return int((ts - birth_dt).days // 365.25) + + q_age = compute_age(q_visit.timestamp) + d_age = compute_age(d_visit.timestamp) + if q_age is None or d_age is None or q_age < 18 or d_age < 18: + return [] + + diag_events = patient.get_events(event_type="diagnoses_icd") + hadm_to_codes = defaultdict(list) + for ev in diag_events: + hadm = ev.attr_dict.get("hadm_id") + code = ev.attr_dict.get("icd9_code") + if hadm is None or code is None: + continue + hadm_to_codes[str(hadm)].append(str(code)) + + q_hadm = str(q_visit.attr_dict.get("hadm_id")) + d_hadm = str(d_visit.attr_dict.get("hadm_id")) + + q_conditions = hadm_to_codes.get(q_hadm, []) + d_conditions = hadm_to_codes.get(d_hadm, []) + if len(q_conditions) == 0 or len(d_conditions) == 0: + return [] + + def clean(x): + if x is None: + return "" + if isinstance(x, float) and math.isnan(x): + return "" + return str(x) + + def build_identifiers(adm_event): + insurance = clean(adm_event.attr_dict.get("insurance")) + language = clean(adm_event.attr_dict.get("language")) + religion = clean(adm_event.attr_dict.get("religion")) + marital_status = clean(adm_event.attr_dict.get("marital_status")) + ethnicity = clean(adm_event.attr_dict.get("ethnicity")) + return "+".join([gender, insurance, language, religion, marital_status, ethnicity]) + + sample = { + "patient_id": patient.patient_id, + + "visit_id": q_hadm, + "conditions": [""] + q_conditions, + "age": q_age, + "identifiers": build_identifiers(q_visit), + + "d_visit_id": d_hadm, + "d_conditions": [""] + d_conditions, + "d_age": d_age, + "d_identifiers": build_identifiers(d_visit), + } + return [sample] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e7991eef9 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,2 @@ +# Tests package + diff --git a/tests/core/test_medlink.py b/tests/core/test_medlink.py new file mode 100644 index 000000000..80c9a01bc --- /dev/null +++ b/tests/core/test_medlink.py @@ -0,0 +1,124 @@ +import unittest +import torch + +from pyhealth.datasets import create_sample_dataset +from pyhealth.models import MedLink + + +class TestMedLink(unittest.TestCase): + """Basic tests for the MedLink model on pseudo data.""" + + def setUp(self): + # Each "sample" here is a simple patient-record placeholder. + # The dataset is used to fit SequenceProcessors (vocabularies), which + # MedLink reuses for processor-native indexing. + self.samples = [ + { + "patient_id": "p0", + "visit_id": "v0", + # query-side codes + "conditions": ["A", "B", "C"], + # corpus-side codes ("d_" + feature_key) + "d_conditions": ["A", "D"], + }, + { + "patient_id": "p1", + "visit_id": "v1", + "conditions": ["B", "E"], + "d_conditions": ["C", "E", "F"], + }, + ] + + # Two sequence-type inputs: conditions and d_conditions + self.input_schema = { + "conditions": "sequence", + "d_conditions": "sequence", + } + # No labels are needed; MedLink is self-supervised + self.output_schema = {} + + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="medlink_test", + in_memory=True, + ) + + self.model = MedLink( + dataset=self.dataset, + feature_keys=["conditions"], + embedding_dim=32, + alpha=0.5, + beta=0.5, + gamma=1.0, + ) + + def _make_batch(self): + # Construct a tiny batch in the format expected by MedLink.forward + # s_q: list of query sequences + s_q = [ + ["A", "B", "C"], + ["B", "E"], + ] + # s_p: list of positive corpus sequences + s_p = [ + ["A", "D"], + ["C", "E", "F"], + ] + # Optionally you could also define negatives s_n = [...] + batch = { + "query_id": ["q0", "q1"], + "id_p": ["p0", "p1"], + "s_q": s_q, + "s_p": s_p, + # no s_n -> defaults to None + } + return batch + + def test_model_initialization(self): + """Model constructs with correct vocabulary size and encoders.""" + self.assertIsInstance(self.model, MedLink) + self.assertEqual(self.model.feature_key, "conditions") + self.assertGreater(self.model.vocab_size, 0) + self.assertIsNotNone(self.model.forward_encoder) + self.assertIsNotNone(self.model.backward_encoder) + + def test_forward_and_backward(self): + """Forward pass returns a scalar loss and backward computes gradients.""" + batch = self._make_batch() + + # Forward + ret = self.model(**batch) + self.assertIn("loss", ret) + loss = ret["loss"] + self.assertTrue(torch.is_tensor(loss)) + self.assertEqual(loss.dim(), 0) # scalar + + # Backward + loss.backward() + has_grad = any( + p.requires_grad and p.grad is not None + for p in self.model.parameters() + ) + self.assertTrue(has_grad, "No gradients after backward pass") + + def test_encoding_helpers(self): + """encode_queries / encode_corpus produce consistent shapes.""" + queries = [["A", "B"], ["C"]] + corpus = [["A"], ["B", "C"]] + + q_emb = self.model.encode_queries(queries) + c_emb = self.model.encode_corpus(corpus) + + self.assertEqual(q_emb.shape[1], self.model.vocab_size) + self.assertEqual(c_emb.shape[1], self.model.vocab_size) + self.assertEqual(q_emb.shape[0], len(queries)) + self.assertEqual(c_emb.shape[0], len(corpus)) + + scores = self.model.compute_scores(q_emb, c_emb) + self.assertEqual(scores.shape, (len(queries), len(corpus))) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_sdoh.py b/tests/core/test_sdoh.py index 0e8096268..f2eaf98a7 100644 --- a/tests/core/test_sdoh.py +++ b/tests/core/test_sdoh.py @@ -1,5 +1,5 @@ from typing import Set -from base import BaseTestCase +from tests.base import BaseTestCase from pyhealth.models.sdoh import SdohClassifier diff --git a/tests/nlp/test_metrics.py b/tests/nlp/test_metrics.py index 0536153e8..a0fd76a68 100644 --- a/tests/nlp/test_metrics.py +++ b/tests/nlp/test_metrics.py @@ -1,6 +1,6 @@ from typing import List import logging -from base import BaseTestCase +from tests.base import BaseTestCase from pathlib import Path import pandas as pd from pyhealth.nlp.metrics import ( diff --git a/tests/todo/test_datasets/test_eicu.py b/tests/todo/test_datasets/test_eicu.py index fdb466273..1bd0ce470 100644 --- a/tests/todo/test_datasets/test_eicu.py +++ b/tests/todo/test_datasets/test_eicu.py @@ -5,7 +5,7 @@ import pandas from pyhealth.datasets import eICUDataset -from pyhealth.unittests.test_datasets.utils import EHRDatasetStatAssertion +from tests.todo.test_datasets.utils import EHRDatasetStatAssertion class TesteICUDataset(unittest.TestCase): diff --git a/tests/todo/test_datasets/test_mimic3.py b/tests/todo/test_datasets/test_mimic3.py index 2957add0d..fe0fee6ae 100644 --- a/tests/todo/test_datasets/test_mimic3.py +++ b/tests/todo/test_datasets/test_mimic3.py @@ -2,7 +2,7 @@ import unittest from pyhealth.datasets import MIMIC3Dataset -from pyhealth.unittests.test_datasets.utils import EHRDatasetStatAssertion +from tests.todo.test_datasets.utils import EHRDatasetStatAssertion import os, sys current = os.path.dirname(os.path.realpath(__file__)) @@ -30,8 +30,6 @@ class TestsMimic3Dataset(unittest.TestCase): dataset_name=DATASET_NAME, root=ROOT, tables=TABLES, - code_mapping=CODE_MAPPING, - refresh_cache=REFRESH_CACHE, ) def setUp(self): diff --git a/tests/todo/test_datasets/test_mimic4.py b/tests/todo/test_datasets/test_mimic4.py index 0133cbb93..bf21b64ef 100644 --- a/tests/todo/test_datasets/test_mimic4.py +++ b/tests/todo/test_datasets/test_mimic4.py @@ -2,7 +2,7 @@ import unittest from pyhealth.datasets import MIMIC4Dataset -from pyhealth.unittests.test_datasets.utils import EHRDatasetStatAssertion +from tests.todo.test_datasets.utils import EHRDatasetStatAssertion import os, sys @@ -25,17 +25,14 @@ class TestMimic4Dataset(unittest.TestCase): DEV = True # not needed when using demo set since its 100 patients large REFRESH_CACHE = True - dataset = MIMIC4Dataset( - dataset_name=DATASET_NAME, - root=ROOT, - tables=TABLES, - code_mapping=CODE_MAPPING, - dev=DEV, - refresh_cache=REFRESH_CACHE, - ) - def setUp(self): - pass + # Initialize dataset in setUp to avoid loading during test collection + self.dataset = MIMIC4Dataset( + dataset_name=self.DATASET_NAME, + ehr_root=self.ROOT, + ehr_tables=self.TABLES, + dev=self.DEV, + ) # test the dataset integrity based on a single sample. def test_patient(self): diff --git a/tests/todo/test_datasets/test_omop.py b/tests/todo/test_datasets/test_omop.py index f57420659..765c88920 100644 --- a/tests/todo/test_datasets/test_omop.py +++ b/tests/todo/test_datasets/test_omop.py @@ -5,7 +5,7 @@ import collections from pyhealth.datasets import OMOPDataset -from pyhealth.unittests.test_datasets.utils import EHRDatasetStatAssertion +from tests.todo.test_datasets.utils import EHRDatasetStatAssertion class TestOMOPDataset(unittest.TestCase): @@ -25,9 +25,7 @@ class TestOMOPDataset(unittest.TestCase): dataset_name=DATASET_NAME, root=ROOT, tables=TABLES, - code_mapping=CODE_MAPPING, dev=DEV, - refresh_cache=REFRESH_CACHE, ) def setUp(self): diff --git a/tests/todo/test_mortality_prediction.py b/tests/todo/test_mortality_prediction.py index 729640abb..749a944cb 100644 --- a/tests/todo/test_mortality_prediction.py +++ b/tests/todo/test_mortality_prediction.py @@ -53,7 +53,7 @@ def test_mortality_prediction_mimic4(): # Enable dev mode to limit memory usage dataset = MIMIC4Dataset( ehr_root=mimic_iv_root, - notes_root=mimic_note_root, + note_root=mimic_note_root, ehr_tables=[ "patients", # Demographics "admissions", # Admission/discharge info @@ -152,7 +152,7 @@ def test_multimodal_mortality_prediction_mimic4(): # Initialize dataset with comprehensive tables dataset = MIMIC4Dataset( ehr_root=mimic_iv_root, - notes_root=mimic_note_root, + note_root=mimic_note_root, cxr_root=mimic_cxr_root, ehr_tables=[ "patients", # Demographics @@ -274,7 +274,7 @@ def test_multimodal_mortality_prediction_with_images(): # Initialize the dataset with all required tables dataset = MIMIC4Dataset( ehr_root=mimic_iv_root, - notes_root=mimic_note_root, + note_root=mimic_note_root, cxr_root=mimic_cxr_root, ehr_tables=[ "patients",