Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/tasks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ Available Tasks
MIMIC-III ICD-9 Coding <tasks/pyhealth.tasks.MIMIC3ICD9Coding>
Cardiology Detection <tasks/pyhealth.tasks.cardiology_detect>
COVID-19 CXR Classification <tasks/pyhealth.tasks.COVID19CXRClassification>
DKA Prediction (MIMIC-IV) <tasks/pyhealth.tasks.dka>
Drug Recommendation <tasks/pyhealth.tasks.drug_recommendation>
EEG Abnormal <tasks/pyhealth.tasks.EEG_abnormal>
EEG Events <tasks/pyhealth.tasks.EEG_events>
Expand Down
8 changes: 8 additions & 0 deletions docs/api/tasks/pyhealth.tasks.dka.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
pyhealth.tasks.dka
==================

.. autoclass:: pyhealth.tasks.dka.DKAPredictionMIMIC4
:members:
:undoc-members:
:show-inheritance:

188 changes: 188 additions & 0 deletions examples/benchmark_perf/benchmark_workers_12.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""Benchmark script for MIMIC-IV mortality prediction with num_workers=4.

This benchmark measures:
1. Time to load base dataset
2. Time to process task with num_workers=4
Comment on lines +1 to +5
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment on line 5 says 'Time to process task with num_workers=4' but the actual implementation uses num_workers=12 (line 142). The documentation should be updated to reflect the correct value.

Suggested change
"""Benchmark script for MIMIC-IV mortality prediction with num_workers=4.
This benchmark measures:
1. Time to load base dataset
2. Time to process task with num_workers=4
"""Benchmark script for MIMIC-IV mortality prediction with num_workers=12.
This benchmark measures:
1. Time to load base dataset
2. Time to process task with num_workers=12

Copilot uses AI. Check for mistakes.
Comment on lines +1 to +5
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment on line 1 says 'num_workers=4' but the actual configuration and execution uses num_workers=12 (line 142). The comment should match the actual implementation to avoid confusion.

Suggested change
"""Benchmark script for MIMIC-IV mortality prediction with num_workers=4.
This benchmark measures:
1. Time to load base dataset
2. Time to process task with num_workers=4
"""Benchmark script for MIMIC-IV mortality prediction with num_workers=12.
This benchmark measures:
1. Time to load base dataset
2. Time to process task with num_workers=12

Copilot uses AI. Check for mistakes.
3. Total processing time
4. Cache sizes
5. Peak memory usage (with optional memory limit)
"""

import time
import os
import threading
from pathlib import Path
import psutil
from pyhealth.datasets import MIMIC4Dataset
from pyhealth.tasks import MortalityPredictionStageNetMIMIC4

try:
import resource

HAS_RESOURCE = True
except ImportError:
HAS_RESOURCE = False


PEAK_MEM_USAGE = 0
SELF_PROC = psutil.Process(os.getpid())


def track_mem():
"""Background thread to track peak memory usage."""
global PEAK_MEM_USAGE
while True:
m = SELF_PROC.memory_info().rss
if m > PEAK_MEM_USAGE:
PEAK_MEM_USAGE = m
time.sleep(0.1)


def set_memory_limit(max_memory_gb):
"""Set hard memory limit for the process.

Args:
max_memory_gb: Maximum memory in GB (e.g., 8 for 8GB)

Note:
If limit is exceeded, the process will raise MemoryError.
Only works on Unix-like systems (Linux, macOS).
"""
if not HAS_RESOURCE:
print(
"Warning: resource module not available (Windows?). "
"Memory limit not enforced."
)
return

max_memory_bytes = int(max_memory_gb * 1024**3)
try:
resource.setrlimit(resource.RLIMIT_AS, (max_memory_bytes, max_memory_bytes))
print(f"✓ Memory limit set to {max_memory_gb} GB")
except Exception as e:
print(f"Warning: Failed to set memory limit: {e}")


def get_directory_size(path):
"""Calculate total size of a directory in bytes."""
total = 0
try:
for entry in Path(path).rglob("*"):
if entry.is_file():
total += entry.stat().st_size
except Exception as e:
print(f"Error calculating size for {path}: {e}")
return total


def format_size(size_bytes):
"""Format bytes to human-readable size."""
for unit in ["B", "KB", "MB", "GB", "TB"]:
if size_bytes < 1024.0:
return f"{size_bytes:.2f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.2f} PB"


def main():
"""Main benchmark function."""
# Configuration
dev = False # Set to True for development/testing
enable_memory_limit = False # Set to True to enforce memory limit
max_memory_gb = 32 # Memory limit in GB (if enable_memory_limit=True)
Comment on lines +91 to +92
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable max_memory_gb is not used.

Suggested change
enable_memory_limit = False # Set to True to enforce memory limit
max_memory_gb = 32 # Memory limit in GB (if enable_memory_limit=True)
# Enable memory limit via environment variable: set ENABLE_MEMORY_LIMIT=1 to enable
enable_memory_limit = os.getenv("ENABLE_MEMORY_LIMIT", "0") == "1"
# Configure memory limit in GB via MAX_MEMORY_GB (defaults to 32 if not set)
max_memory_gb = int(os.getenv("MAX_MEMORY_GB", "32"))

Copilot uses AI. Check for mistakes.

# Apply memory limit if enabled
if enable_memory_limit:
set_memory_limit(max_memory_gb)
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement is unreachable.

Copilot uses AI. Check for mistakes.

# Start memory tracking thread
mem_thread = threading.Thread(target=track_mem, daemon=True)
mem_thread.start()

print("=" * 80)
print(f"BENCHMARK: num_workers=4, dev={dev}")
if enable_memory_limit:
print(f"Memory Limit: {max_memory_gb} GB (ENFORCED)")
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement is unreachable.

Copilot uses AI. Check for mistakes.
else:
print("Memory Limit: None (unrestricted)")
print("=" * 80)

# Define cache directories based on dev mode
cache_root = "/shared/rsaas/pyhealth/"
if dev:
cache_root += "_dev"
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement is unreachable.

Copilot uses AI. Check for mistakes.

# Track total time
total_start = time.time() # STEP 1: Load MIMIC-IV base dataset
print("\n[1/2] Loading MIMIC-IV base dataset...")
dataset_start = time.time()

base_dataset = MIMIC4Dataset(
ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/",
ehr_tables=[
"patients",
"admissions",
"diagnoses_icd",
"procedures_icd",
"labevents",
],
dev=dev,
cache_dir=f"{cache_root}/base_dataset",
)

dataset_time = time.time() - dataset_start
print(f"✓ Dataset loaded in {dataset_time:.2f} seconds")

# STEP 2: Apply StageNet mortality prediction task with num_workers=12
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says 'Apply StageNet mortality prediction task with num_workers=12' but the print statement just before (line 137) says 'num_workers=12' while being inconsistent with the comment header. More importantly, there's a mismatch in the summary output (line 171) which says 'num_workers: 4'. All references should consistently use the same value.

Copilot uses AI. Check for mistakes.
print("\n[2/2] Applying mortality prediction task (num_workers=12)...")
task_start = time.time()

sample_dataset = base_dataset.set_task(
MortalityPredictionStageNetMIMIC4(),
num_workers=12,
cache_dir=f"{cache_root}/task_samples",
)

task_time = time.time() - task_start
print(f"✓ Task processing completed in {task_time:.2f} seconds")

# Measure cache sizes
print("\n[3/3] Measuring cache sizes...")
base_cache_dir = f"{cache_root}/base_dataset"
task_cache_dir = f"{cache_root}/task_samples"

base_cache_size = get_directory_size(base_cache_dir)
task_cache_size = get_directory_size(task_cache_dir)
total_cache_size = base_cache_size + task_cache_size

print(f"✓ Base dataset cache: {format_size(base_cache_size)}")
print(f"✓ Task samples cache: {format_size(task_cache_size)}")
print(f"✓ Total cache size: {format_size(total_cache_size)}")

# Total time and peak memory
total_time = time.time() - total_start
peak_mem = PEAK_MEM_USAGE

# Print summary
print("\n" + "=" * 80)
print("BENCHMARK RESULTS")
print("=" * 80)
print("Configuration:")
print(" - num_workers: 4")
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The printed configuration shows 'num_workers: 4' but the actual value used is 12 (line 142). This inconsistency between the printed summary and actual execution could mislead users analyzing benchmark results.

Copilot uses AI. Check for mistakes.
print(f" - dev mode: {dev}")
print(f" - Total samples: {len(sample_dataset)}")
print("\nTiming:")
print(f" - Dataset loading: {dataset_time:.2f}s")
print(f" - Task processing: {task_time:.2f}s")
print(f" - Total time: {total_time:.2f}s")
print("\nCache Sizes:")
print(f" - Base dataset cache: {format_size(base_cache_size)}")
print(f" - Task samples cache: {format_size(task_cache_size)}")
print(f" - Total cache: {format_size(total_cache_size)}")
print("\nMemory:")
print(f" - Peak memory usage: {format_size(peak_mem)}")
print("=" * 80)


if __name__ == "__main__":
main()
186 changes: 186 additions & 0 deletions examples/clinical_tasks/dka_mimic4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""
Example of using StageNet for DKA (Diabetic Ketoacidosis) prediction on MIMIC-IV.

This example demonstrates:
1. Loading MIMIC-IV data with relevant tables for DKA prediction
2. Applying the DKAPredictionMIMIC4 task (general population)
3. Creating a SampleDataset with StageNet processors
4. Training a StageNet model for DKA prediction

Target Population:
- ALL patients in MIMIC-IV (no diabetes filtering)
- Much larger patient pool with more negative samples
- Label: 1 if patient has ANY DKA diagnosis, 0 otherwise

Note: For T1DM-specific DKA prediction, see t1dka_mimic4.py
"""

import os
import torch

from pyhealth.datasets import (
MIMIC4Dataset,
get_dataloader,
split_by_patient,
)
from pyhealth.datasets.utils import save_processors, load_processors
from pyhealth.models import StageNet
from pyhealth.tasks import DKAPredictionMIMIC4
from pyhealth.trainer import Trainer


def main():
"""Main function to run DKA prediction pipeline on general population."""

# Configuration
MIMIC4_ROOT = "/srv/local/data/physionet.org/files/mimiciv/2.2/"
DATASET_CACHE_DIR = "/shared/rsaas/pyhealth/cache/mimic4_dataset"
TASK_CACHE_DIR = "/shared/rsaas/pyhealth/cache/mimic4_dka_general_stagenet"
PROCESSOR_DIR = "/shared/rsaas/pyhealth/processors/stagenet_dka_general_mimic4"
DEVICE = "cuda:5" if torch.cuda.is_available() else "cpu"

print("=" * 60)
print("DKA PREDICTION (GENERAL POPULATION) WITH STAGENET ON MIMIC-IV")
print("=" * 60)

# STEP 1: Load MIMIC-IV base dataset
print("\n=== Step 1: Loading MIMIC-IV Dataset ===")
base_dataset = MIMIC4Dataset(
ehr_root=MIMIC4_ROOT,
ehr_tables=[
"admissions",
"diagnoses_icd",
"procedures_icd",
"labevents",
],
cache_dir=DATASET_CACHE_DIR,
# dev=True, # Uncomment for faster development iteration
)

print("Dataset initialized, proceeding to task processing...")

# STEP 2: Apply DKA prediction task (general population)
print("\n=== Step 2: Applying DKA Prediction Task (General Population) ===")

# Create task with padding for unseen sequences
# No T1DM filtering - includes ALL patients
dka_task = DKAPredictionMIMIC4(padding=10)

print(f"Task: {dka_task.task_name}")
print(f"Input schema: {list(dka_task.input_schema.keys())}")
print(f"Output schema: {list(dka_task.output_schema.keys())}")
print("Note: This includes ALL patients (not just diabetics)")

# Check for pre-fitted processors
if os.path.exists(os.path.join(PROCESSOR_DIR, "input_processors.pkl")):
print("\nLoading pre-fitted processors...")
input_processors, output_processors = load_processors(PROCESSOR_DIR)

sample_dataset = base_dataset.set_task(
dka_task,
num_workers=4,
cache_dir=TASK_CACHE_DIR,
input_processors=input_processors,
output_processors=output_processors,
)
else:
print("\nFitting new processors...")
sample_dataset = base_dataset.set_task(
dka_task,
num_workers=4,
cache_dir=TASK_CACHE_DIR,
)

# Save processors for future runs
print("Saving processors...")
os.makedirs(PROCESSOR_DIR, exist_ok=True)
save_processors(sample_dataset, PROCESSOR_DIR)

print(f"\nTotal samples: {len(sample_dataset)}")

# Count label distribution
label_counts = {0: 0, 1: 0}
for sample in sample_dataset:
label_counts[int(sample["label"].item())] += 1

print(f"Label distribution:")
print(f" No DKA (0): {label_counts[0]} ({100*label_counts[0]/len(sample_dataset):.1f}%)")
print(f" Has DKA (1): {label_counts[1]} ({100*label_counts[1]/len(sample_dataset):.1f}%)")

# Inspect a sample
sample = sample_dataset[0]
print("\nSample structure:")
print(f" Patient ID: {sample['patient_id']}")
print(f" ICD codes (diagnoses + procedures): {sample['icd_codes'][1].shape} (visits x codes)")
print(f" Labs: {sample['labs'][0].shape} (timesteps x features)")
print(f" Label: {sample['label']}")

# STEP 3: Split dataset
print("\n=== Step 3: Splitting Dataset ===")
train_dataset, val_dataset, test_dataset = split_by_patient(
sample_dataset, [0.8, 0.1, 0.1]
)

print(f"Train: {len(train_dataset)} samples")
print(f"Validation: {len(val_dataset)} samples")
print(f"Test: {len(test_dataset)} samples")

# Create dataloaders
train_loader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
val_loader = get_dataloader(val_dataset, batch_size=256, shuffle=False)
test_loader = get_dataloader(test_dataset, batch_size=256, shuffle=False)

# STEP 4: Initialize StageNet model
print("\n=== Step 4: Initializing StageNet Model ===")
model = StageNet(
dataset=sample_dataset,
embedding_dim=128,
chunk_size=128,
levels=3,
dropout=0.3,
)

num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

# STEP 5: Train the model
print("\n=== Step 5: Training Model ===")
trainer = Trainer(
model=model,
device=DEVICE,
metrics=["pr_auc", "roc_auc", "accuracy", "f1"],
)

trainer.train(
train_dataloader=train_loader,
val_dataloader=val_loader,
epochs=50,
monitor="roc_auc",
optimizer_params={"lr": 1e-5},
)

# STEP 6: Evaluate on test set
print("\n=== Step 6: Evaluation ===")
results = trainer.evaluate(test_loader)
print("\nTest Results:")
for metric, value in results.items():
print(f" {metric}: {value:.4f}")

# STEP 7: Inspect model predictions
print("\n=== Step 7: Sample Predictions ===")
sample_batch = next(iter(test_loader))
with torch.no_grad():
output = model(**sample_batch)

print(f"Predicted probabilities: {output['y_prob'][:5]}")
print(f"True labels: {output['y_true'][:5]}")

print("\n" + "=" * 60)
print("DKA PREDICTION (GENERAL POPULATION) TRAINING COMPLETED!")
print("=" * 60)

return results


if __name__ == "__main__":
main()
Loading