diff --git a/README.rst b/README.rst index ad8d1b257..683302ea1 100644 --- a/README.rst +++ b/README.rst @@ -13,7 +13,7 @@ Welcome to PyHealth! .. image:: https://readthedocs.org/projects/pyhealth/badge/?version=latest :target: https://pyhealth.readthedocs.io/en/latest/ :alt: Documentation status - + .. image:: https://img.shields.io/github/stars/sunlabuiuc/pyhealth.svg :target: https://github.com/sunlabuiuc/pyhealth/stargazers @@ -121,7 +121,7 @@ PyHealth is a comprehensive deep learning toolkit for supporting clinical predic You can use the following functions independently: - **Dataset**: ``MIMIC-III``, ``MIMIC-IV``, ``eICU``, ``OMOP-CDM``, ``customized EHR datasets``, etc. -- **Tasks**: ``diagnosis-based drug recommendation``, ``patient hospitalization and mortality prediction``, ``length stay forecasting``, etc. +- **Tasks**: ``diagnosis-based drug recommendation``, ``patient hospitalization and mortality prediction``, ``length stay forecasting``, etc. - **ML models**: ``CNN``, ``LSTM``, ``GRU``, ``LSTM``, ``RETAIN``, ``SafeDrug``, ``Deepr``, etc. *Building a healthcare AI pipeline can be as short as 10 lines of code in PyHealth*. @@ -130,7 +130,7 @@ You can use the following functions independently: 3. Build ML Pipelines :trophy: --------------------------------- -All healthcare tasks in our package follow a **five-stage pipeline**: +All healthcare tasks in our package follow a **five-stage pipeline**: .. image:: figure/five-stage-pipeline.png :width: 640 @@ -150,7 +150,7 @@ Module 1: mimic3base = MIMIC3Dataset( # root directory of the dataset - root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/", + root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/", # raw CSV table name tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], # map all NDC codes to CCS codes in these tables @@ -169,9 +169,9 @@ Module 2: .. code-block:: python - from pyhealth.tasks import readmission_prediction_mimic3_fn + from pyhealth.tasks import ReadmissionPredictionMIMIC3 - mimic3sample = mimic3base.set_task(task_fn=readmission_prediction_mimic3_fn) # use default task + mimic3sample = mimic3base.set_task(ReadmissionPredictionMIMIC3()) mimic3sample.samples[0] # show the information of the first sample """ { @@ -180,7 +180,7 @@ Module 2: 'conditions': ['5990', '4280', '2851', '4240', '2749', '9982', 'E8499', '42831', '34600'], 'procedures': ['0040', '3931', '7769'], 'drugs': ['N06DA02', 'V06DC01', 'B01AB01', 'A06AA02', 'R03AC02', 'H03AA01', 'J01FA09'], - 'label': 0 + 'readmission': 0 } """ @@ -213,7 +213,7 @@ Module 4: ``pyhealth.trainer`` can specify training arguments, such as epochs, optimizer, learning rate, etc. The trainer will automatically save the best model and output the path in the end. .. code-block:: python - + from pyhealth.trainer import Trainer trainer = Trainer(model=model) @@ -233,19 +233,19 @@ Module 5: # method 1 trainer.evaluate(test_loader) - + # method 2 from pyhealth.metrics.binary import binary_metrics_fn y_true, y_prob, loss = trainer.inference(test_loader) binary_metrics_fn(y_true, y_prob, metrics=["pr_auc", "roc_auc"]) -4. Medical Code Map :hospital: +4. Medical Code Map :hospital: --------------------------------- ``pyhealth.codemap`` provides two core functionalities. **This module can be used independently.** -* For code ontology lookup within one medical coding system (e.g., name, category, sub-concept); +* For code ontology lookup within one medical coding system (e.g., name, category, sub-concept); .. code-block:: python @@ -256,7 +256,7 @@ Module 5: # `Congestive heart failure, unspecified` icd9cm.get_ancestors("428.0") # ['428', '420-429.99', '390-459.99', '001-999.99'] - + atc = InnerMap.load("ATC") atc.lookup("M01AE51") # `ibuprofen, combinations` @@ -267,7 +267,7 @@ Module 5: atc.lookup("M01AE51", "indication") # Ibuprofen is the most commonly used and prescribed NSAID. It is very common over the ... -* For code mapping between two coding systems (e.g., ICD9CM to CCSCM). +* For code mapping between two coding systems (e.g., ICD9CM to CCSCM). .. code-block:: python @@ -300,12 +300,12 @@ Module 5: 'A12B', 'A12C', 'A13A', 'A14A', 'A14B', 'A16A'] tokenizer = Tokenizer(tokens=token_space, special_tokens=["", ""]) - # 2d encode + # 2d encode tokens = [['A03C', 'A03D', 'A03E', 'A03F'], ['A04A', 'B035', 'C129']] - indices = tokenizer.batch_encode_2d(tokens) + indices = tokenizer.batch_encode_2d(tokens) # [[8, 9, 10, 11], [12, 1, 1, 0]] - # 2d decode + # 2d decode indices = [[8, 9, 10, 11], [12, 1, 1, 0]] tokens = tokenizer.batch_decode_2d(indices) # [['A03C', 'A03D', 'A03E', 'A03F'], ['A04A', '', '']] @@ -331,69 +331,69 @@ Module 5: .. - We provide the following tutorials to help users get started with our pyhealth. + We provide the following tutorials to help users get started with our pyhealth. -`Tutorial 0: Introduction to pyhealth.data `_ `[Video] `__ +`Tutorial 0: Introduction to pyhealth.data `_ `[Video] `__ -`Tutorial 1: Introduction to pyhealth.datasets `_ `[Video] `__ +`Tutorial 1: Introduction to pyhealth.datasets `_ `[Video] `__ -`Tutorial 2: Introduction to pyhealth.tasks `_ `[Video] `__ +`Tutorial 2: Introduction to pyhealth.tasks `_ `[Video] `__ -`Tutorial 3: Introduction to pyhealth.models `_ `[Video] `__ +`Tutorial 3: Introduction to pyhealth.models `_ `[Video] `__ -`Tutorial 4: Introduction to pyhealth.trainer `_ `[Video] `__ +`Tutorial 4: Introduction to pyhealth.trainer `_ `[Video] `__ -`Tutorial 5: Introduction to pyhealth.metrics `_ `[Video] `__ +`Tutorial 5: Introduction to pyhealth.metrics `_ `[Video] `__ -`Tutorial 6: Introduction to pyhealth.tokenizer `_ `[Video] `__ +`Tutorial 6: Introduction to pyhealth.tokenizer `_ `[Video] `__ -`Tutorial 7: Introduction to pyhealth.medcode `_ `[Video] `__ +`Tutorial 7: Introduction to pyhealth.medcode `_ `[Video] `__ The following tutorials will help users build their own task pipelines. `Pipeline 1: Drug Recommendation `_ `[Video] `__ +www.youtube.com/watch?v=GGP3Dhfyisc&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=12>`__ `Pipeline 2: Length of Stay Prediction `_ `[Video] `__ +www.youtube.com/watch?v=GGP3Dhfyisc&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=12>`__ `Pipeline 3: Readmission Prediction `_ `[Video] `__ +www.youtube.com/watch?v=GGP3Dhfyisc&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=12>`__ `Pipeline 4: Mortality Prediction `_ `[Video] `__ +www.youtube.com/watch?v=GGP3Dhfyisc&list=PLR3CNIF8DDHJUl8RLhyOVpX_kT4bxulEV&index=12>`__ -`Pipeline 5: Sleep Staging `_ `[Video] `__ +`Pipeline 5: Sleep Staging `_ `[Video] `__ - We provided the advanced tutorials for supporting various needs. + We provided the advanced tutorials for supporting various needs. -`Advanced Tutorial 1: Fit your dataset into our pipeline `_ `[Video] `__ +`Advanced Tutorial 1: Fit your dataset into our pipeline `_ `[Video] `__ -`Advanced Tutorial 2: Define your own healthcare task `_ +`Advanced Tutorial 2: Define your own healthcare task `_ -`Advanced Tutorial 3: Adopt customized model into pyhealth `_ `[Video] `__ +`Advanced Tutorial 3: Adopt customized model into pyhealth `_ `[Video] `__ -`Advanced Tutorial 4: Load your own processed data into pyhealth and try out our ML models `_ `[Video] `__ +`Advanced Tutorial 4: Load your own processed data into pyhealth and try out our ML models `_ `[Video] `__ 7. Datasets :mountain_snow: ----------------------------- We provide the processing files for the following open EHR datasets: -=================== ======================================= ======================================== ======================================================================================================== -Dataset Module Year Information =================== ======================================= ======================================== ======================================================================================================== -MIMIC-III ``pyhealth.datasets.MIMIC3Dataset`` 2016 `MIMIC-III Clinical Database `_ -MIMIC-IV ``pyhealth.datasets.MIMIC4Dataset`` 2020 `MIMIC-IV Clinical Database `_ -eICU ``pyhealth.datasets.eICUDataset`` 2018 `eICU Collaborative Research Database `_ -OMOP ``pyhealth.datasets.OMOPDataset`` `OMOP-CDM schema based dataset `_ +Dataset Module Year Information +=================== ======================================= ======================================== ======================================================================================================== +MIMIC-III ``pyhealth.datasets.MIMIC3Dataset`` 2016 `MIMIC-III Clinical Database `_ +MIMIC-IV ``pyhealth.datasets.MIMIC4Dataset`` 2020 `MIMIC-IV Clinical Database `_ +eICU ``pyhealth.datasets.eICUDataset`` 2018 `eICU Collaborative Research Database `_ +OMOP ``pyhealth.datasets.OMOPDataset`` `OMOP-CDM schema based dataset `_ SleepEDF ``pyhealth.datasets.SleepEDFDataset`` 2018 `Sleep-EDF dataset `_ -SHHS ``pyhealth.datasets.SHHSDataset`` 2016 `Sleep Heart Health Study dataset `_ -ISRUC ``pyhealth.datasets.ISRUCDataset`` 2016 `ISRUC-SLEEP dataset `_ +SHHS ``pyhealth.datasets.SHHSDataset`` 2016 `Sleep Heart Health Study dataset `_ +ISRUC ``pyhealth.datasets.ISRUCDataset`` 2016 `ISRUC-SLEEP dataset `_ =================== ======================================= ======================================== ======================================================================================================== diff --git a/docs/api/tasks/pyhealth.tasks.readmission_prediction.rst b/docs/api/tasks/pyhealth.tasks.readmission_prediction.rst index f9c56f3aa..cfa5ea6e6 100644 --- a/docs/api/tasks/pyhealth.tasks.readmission_prediction.rst +++ b/docs/api/tasks/pyhealth.tasks.readmission_prediction.rst @@ -1,7 +1,7 @@ pyhealth.tasks.readmission_prediction ======================================= -.. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_mimic3_fn +.. autofunction:: pyhealth.tasks.readmission_prediction.ReadmissionPredictionMIMIC3 .. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_mimic4_fn .. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_eicu_fn .. autofunction:: pyhealth.tasks.readmission_prediction.readmission_prediction_eicu_fn2 diff --git a/examples/readmission_mimic3_fairness.py b/examples/readmission_mimic3_fairness.py index d53b67c9e..63cdcce26 100644 --- a/examples/readmission_mimic3_fairness.py +++ b/examples/readmission_mimic3_fairness.py @@ -1,5 +1,5 @@ from pyhealth.datasets import MIMIC3Dataset -from pyhealth.tasks import readmission_prediction_mimic3_fn +from pyhealth.tasks import ReadmissionPredictionMIMIC3 from pyhealth.datasets import split_by_patient, get_dataloader from pyhealth.metrics import fairness_metrics_fn from pyhealth.models import Transformer @@ -11,11 +11,10 @@ root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/", tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], ) -base_dataset.stat() +base_dataset.stats() # STEP 2: set task -sample_dataset = base_dataset.set_task(readmission_prediction_mimic3_fn) -sample_dataset.stat() +sample_dataset = base_dataset.set_task(ReadmissionPredictionMIMIC3(exclude_minors=False)) # Must include minors to get any readmission samples on the synthetic dataset train_dataset, val_dataset, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1]) train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) diff --git a/examples/readmission_mimic3_rnn.py b/examples/readmission_mimic3_rnn.py index 9870fcd64..18eb0df1c 100644 --- a/examples/readmission_mimic3_rnn.py +++ b/examples/readmission_mimic3_rnn.py @@ -1,22 +1,18 @@ from pyhealth.datasets import MIMIC3Dataset from pyhealth.datasets import split_by_patient, get_dataloader from pyhealth.models import RNN -from pyhealth.tasks import readmission_prediction_mimic3_fn +from pyhealth.tasks import ReadmissionPredictionMIMIC3 from pyhealth.trainer import Trainer # STEP 1: load data base_dataset = MIMIC3Dataset( - root="/srv/local/data/physionet.org/files/mimiciii/1.4", + root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III", tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], - code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"}, - dev=False, - refresh_cache=True, ) -base_dataset.stat() +base_dataset.stats() # STEP 2: set task -sample_dataset = base_dataset.set_task(readmission_prediction_mimic3_fn) -sample_dataset.stat() +sample_dataset = base_dataset.set_task(ReadmissionPredictionMIMIC3(exclude_minors=False)) # Must include minors to get any readmission samples on the synthetic dataset train_dataset, val_dataset, test_dataset = split_by_patient( sample_dataset, [0.8, 0.1, 0.1] @@ -28,9 +24,6 @@ # STEP 3: define model model = RNN( dataset=sample_dataset, - feature_keys=["conditions", "procedures", "drugs"], - label_key="label", - mode="binary", ) # STEP 4: define trainer @@ -38,7 +31,7 @@ trainer.train( train_dataloader=train_dataloader, val_dataloader=val_dataloader, - epochs=50, + epochs=1, monitor="roc_auc", ) diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index bcfb19f7a..76632467b 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -49,9 +49,9 @@ from .patient_linkage import patient_linkage_mimic3_fn from .readmission_30days_mimic4 import Readmission30DaysMIMIC4 from .readmission_prediction import ( + ReadmissionPredictionMIMIC3, readmission_prediction_eicu_fn, readmission_prediction_eicu_fn2, - readmission_prediction_mimic3_fn, readmission_prediction_mimic4_fn, readmission_prediction_omop_fn, ) diff --git a/pyhealth/tasks/readmission_prediction.py b/pyhealth/tasks/readmission_prediction.py index 9d04b5eb7..cd8b94b38 100644 --- a/pyhealth/tasks/readmission_prediction.py +++ b/pyhealth/tasks/readmission_prediction.py @@ -1,67 +1,111 @@ -from pyhealth.data import Patient, Visit +from datetime import datetime, timedelta +from typing import Dict, List +from pyhealth.data import Event, Patient +from pyhealth.tasks import BaseTask -# TODO: time_window cannot be passed in to base_dataset -def readmission_prediction_mimic3_fn(patient: Patient, time_window=15): - """Processes a single patient for the readmission prediction task. - - Readmission prediction aims at predicting whether the patient will be readmitted - into hospital within time_window days based on the clinical information from - current visit (e.g., conditions and procedures). - - Args: - patient: a Patient object - time_window: the time window threshold (gap < time_window means label=1 for - the task) - - Returns: - samples: a list of samples, each sample is a dict with patient_id, visit_id, - and other task-specific attributes as key - - Note that we define the task as a binary classification task. - - Examples: - >>> from pyhealth.datasets import MIMIC3Dataset - >>> mimic3_base = MIMIC3Dataset( - ... root="/srv/local/data/physionet.org/files/mimiciii/1.4", - ... tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], - ... code_mapping={"ICD9CM": "CCSCM"}, - ... ) - >>> from pyhealth.tasks import readmission_prediction_mimic3_fn - >>> mimic3_sample = mimic3_base.set_task(readmission_prediction_mimic3_fn) - >>> mimic3_sample.samples[0] - [{'visit_id': '130744', 'patient_id': '103', 'conditions': [['42', '109', '19', '122', '98', '663', '58', '51']], 'procedures': [['1']], 'label': 1}] +class ReadmissionPredictionMIMIC3(BaseTask): """ - samples = [] + Readmission prediction on the MIMIC3 dataset. - # we will drop the last visit - for i in range(len(patient) - 1): - visit: Visit = patient[i] - next_visit: Visit = patient[i + 1] + This task aims at predicting whether the patient will be readmitted into hospital within + a specified number of days based on clinical information from the current visit. - # get time difference between current visit and next visit - time_diff = (next_visit.encounter_time - visit.encounter_time).days - readmission_label = 1 if time_diff < time_window else 0 + Attributes: + task_name (str): The name of the task. + input_schema (Dict[str, str]): The schema for the task input. + output_schema (Dict[str, str]): The schema for the task output. + """ + task_name: str = "ReadmissionPredictionMIMIC3" + input_schema: Dict[str, str] = {"conditions": "sequence", "procedures": "sequence", "drugs": "sequence"} + output_schema: Dict[str, str] = {"readmission": "binary"} + + def __init__(self, window: timedelta=timedelta(days=15), exclude_minors: bool=True) -> None: + """ + Initializes the task object. + + Args: + window (timedelta): If two admissions are closer than this window, it is considered a readmission. Defaults to 15 days. + exclude_minors (bool): Whether to exclude visits where the patient was under 18 years old. Defaults to True. + """ + self.window = window + self.exclude_minors = exclude_minors + + def __call__(self, patient: Patient) -> List[Dict]: + """ + Generates binary classification data samples for a single patient. + + Visits with no conditions OR no procedures OR no drugs are excluded from the output but are still used to calculate readmission for prior visits. + + Args: + patient (Patient): A patient object. + + Returns: + List[Dict]: A list containing a dictionary for each patient visit with: + - 'visit_id': MIMIC3 hadm_id. + - 'patient_id': MIMIC3 subject_id. + - 'conditions': MIMIC3 diagnoses_icd table ICD-9 codes. + - 'procedures': MIMIC3 procedures_icd table ICD-9 codes. + - 'drugs': MIMIC3 prescriptions table drug column entries. + - 'readmission': binary label. + """ + patients: List[Event] = patient.get_events(event_type="patients") + assert len(patients) == 1 + + if self.exclude_minors: + try: + dob = datetime.strptime(patients[0].dob, "%Y-%m-%d %H:%M:%S") + except ValueError: + dob = datetime.strptime(patients[0].dob, "%Y-%m-%d") + + admissions: List[Event] = patient.get_events(event_type="admissions") + if len(admissions) < 2: + return [] + + samples = [] + for i in range(len(admissions) - 1): # Skip the last admission since we need a "next" admission + if self.exclude_minors: + age = admissions[i].timestamp.year - dob.year + age = age-1 if ((admissions[i].timestamp.month, admissions[i].timestamp.day) < (dob.month, dob.day)) else age + if age < 18: + continue + + filter = ("hadm_id", "==", admissions[i].hadm_id) + + diagnoses = patient.get_events(event_type="diagnoses_icd", filters=[filter]) + diagnoses = [event.icd9_code for event in diagnoses] + if len(diagnoses) == 0: + continue + + procedures = patient.get_events(event_type="procedures_icd", filters=[filter]) + procedures = [event.icd9_code for event in procedures] + if len(procedures) == 0: + continue + + prescriptions = patient.get_events(event_type="prescriptions", filters=[filter]) + prescriptions = [event.drug for event in prescriptions] + if len(prescriptions) == 0: + continue + + try: + discharge_time = datetime.strptime(admissions[i].dischtime, "%Y-%m-%d %H:%M:%S") + except ValueError: + discharge_time = datetime.strptime(admissions[i].dischtime, "%Y-%m-%d") + + readmission = int((admissions[i + 1].timestamp - discharge_time) < self.window) + + samples.append( + { + "visit_id": admissions[i].hadm_id, + "patient_id": patient.patient_id, + "conditions": diagnoses, + "procedures": procedures, + "drugs": prescriptions, + "readmission": readmission, + } + ) - conditions = visit.get_code_list(table="DIAGNOSES_ICD") - procedures = visit.get_code_list(table="PROCEDURES_ICD") - drugs = visit.get_code_list(table="PRESCRIPTIONS") - # exclude: visits without condition, procedure, or drug code - if len(conditions) * len(procedures) * len(drugs) == 0: - continue - # TODO: should also exclude visit with age < 18 - samples.append( - { - "visit_id": visit.visit_id, - "patient_id": patient.patient_id, - "conditions": [conditions], - "procedures": [procedures], - "drugs": [drugs], - "label": readmission_label, - } - ) - # no cohort selection - return samples + return samples def readmission_prediction_mimic4_fn(patient: Patient, time_window=15): @@ -328,19 +372,6 @@ def readmission_prediction_omop_fn(patient: Patient, time_window=15): if __name__ == "__main__": - from pyhealth.datasets import MIMIC3Dataset - - base_dataset = MIMIC3Dataset( - root="/srv/local/data/physionet.org/files/mimiciii/1.4", - tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], - dev=True, - code_mapping={"ICD9CM": "CCSCM", "NDC": "ATC"}, - refresh_cache=False, - ) - sample_dataset = base_dataset.set_task(task_fn=readmission_prediction_mimic3_fn) - sample_dataset.stat() - print(sample_dataset.available_keys) - from pyhealth.datasets import MIMIC4Dataset base_dataset = MIMIC4Dataset( diff --git a/tests/core/test_mimic3_readmission_prediction.py b/tests/core/test_mimic3_readmission_prediction.py new file mode 100644 index 000000000..103728db4 --- /dev/null +++ b/tests/core/test_mimic3_readmission_prediction.py @@ -0,0 +1,252 @@ +from datetime import timedelta +import os +import unittest + +from pyhealth.datasets import MIMIC3Dataset +from pyhealth.tasks import ReadmissionPredictionMIMIC3 + + +class TestReadmissionPredictionMIMIC3(unittest.TestCase): + def setUp(self): + """Seed dataset with neg and pos 15 day readmission examples (min required for sample generation)""" + self.mock = MockMICIC3Dataset() + patient = self.mock.add_patient() + self.admission1 = self.mock.add_admission(patient, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + self.admission2 = self.mock.add_admission(patient, "2020-01-16 12:00:00", "2020-01-16 12:00:01") # Exactly 15 days later + self.admission3 = self.mock.add_admission(patient, "2020-01-31 12:00:00", "2020-01-31 12:00:01") # 15 days later less 1 second + + def tearDown(self): + self.mock.destroy() + + def test_patient_with_pos_and_neg_samples(self): + dataset = self.mock.create() + + task = ReadmissionPredictionMIMIC3() + samples = dataset.set_task(task) + + self.assertIn("task_name", vars(ReadmissionPredictionMIMIC3)) + self.assertIn("input_schema", vars(ReadmissionPredictionMIMIC3)) + self.assertIn("output_schema", vars(ReadmissionPredictionMIMIC3)) + + self.assertEqual(task.window, timedelta(days=15)) + + for sample in samples: + self.assertIn("visit_id", sample) + self.assertIn("patient_id", sample) + self.assertIn("conditions", sample) + self.assertIn("procedures", sample) + self.assertIn("drugs", sample) + self.assertIn("readmission", sample) + + self.assertEqual(len(samples), 2) + + neg_samples = [s for s in samples if s["readmission"] == 0] + pos_samples = [s for s in samples if s["readmission"] == 1] + + self.assertEqual(len(neg_samples), 1) + self.assertEqual(len(pos_samples), 1) + + self.assertEqual(neg_samples[0]["visit_id"], str(self.admission1)) + self.assertEqual(pos_samples[0]["visit_id"], str(self.admission2)) + + self.assertTrue(all(s["visit_id"] != str(self.admission3) for s in samples)) # Patient's last admission not included + + def test_explicit_time_window(self): + patient = self.mock.add_patient() + admission1 = self.mock.add_admission(patient, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + admission2 = self.mock.add_admission(patient, "2020-01-06 12:00:00", "2020-01-06 12:00:01") # Exactly 5 days later + admission3 = self.mock.add_admission(patient, "2020-01-11 12:00:00", "2020-01-11 12:00:01") # 5 days later less 1 second + dataset = self.mock.create() + + samples = dataset.set_task(ReadmissionPredictionMIMIC3(timedelta(days=5))) + + visit1 = [s for s in samples if s["visit_id"] == str(admission1)] + visit2 = [s for s in samples if s["visit_id"] == str(admission2)] + + self.assertEqual(len(visit1), 1) + self.assertEqual(len(visit2), 1) + + self.assertEqual(visit1[0]["readmission"], 0) + self.assertEqual(visit2[0]["readmission"], 1) + + self.assertTrue(all(s["visit_id"] != str(admission3) for s in samples)) # Patient's last admission not included + + def test_patient_with_only_one_visit_is_excluded(self): + patient = self.mock.add_patient() + admission = self.mock.add_admission(patient, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + dataset = self.mock.create() + + samples = dataset.set_task(ReadmissionPredictionMIMIC3()) + + self.assertTrue(all(s["patient_id"] != str(patient) for s in samples)) + self.assertTrue(all(s["visit_id"] != str(admission) for s in samples)) + + def test_admissions_without_diagnoses_are_excluded(self): + patient1 = self.mock.add_patient() + patient2 = self.mock.add_patient() + admission1 = self.mock.add_admission(patient1, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + admission2 = self.mock.add_admission(patient1, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + admission3 = self.mock.add_admission(patient2, "2020-01-01 00:00:00", "2020-01-01 12:00:00", add_diagnosis=False) + admission4 = self.mock.add_admission(patient2, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + dataset = self.mock.create() + + samples = dataset.set_task(ReadmissionPredictionMIMIC3()) + + visit_ids = [int(s["visit_id"]) for s in samples] + + self.assertIn (admission1, visit_ids) + self.assertNotIn(admission2, visit_ids) # Patient's last admission should not be included + self.assertNotIn(admission3, visit_ids) + self.assertNotIn(admission4, visit_ids) # Patient's last admission should not be included + + def test_admissions_without_prescriptions_are_excluded(self): + patient1 = self.mock.add_patient() + patient2 = self.mock.add_patient() + admission1 = self.mock.add_admission(patient1, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + admission2 = self.mock.add_admission(patient1, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + admission3 = self.mock.add_admission(patient2, "2020-01-01 00:00:00", "2020-01-01 12:00:00", add_prescription=False) + admission4 = self.mock.add_admission(patient2, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + dataset = self.mock.create() + + samples = dataset.set_task(ReadmissionPredictionMIMIC3()) + + visit_ids = [int(s["visit_id"]) for s in samples] + + self.assertIn (admission1, visit_ids) + self.assertNotIn(admission2, visit_ids) # Patient's last admission should not be included + self.assertNotIn(admission3, visit_ids) + self.assertNotIn(admission4, visit_ids) # Patient's last admission should not be included + + def test_admissions_without_procedures_are_excluded(self): + patient1 = self.mock.add_patient() + patient2 = self.mock.add_patient() + admission1 = self.mock.add_admission(patient1, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + admission2 = self.mock.add_admission(patient1, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + admission3 = self.mock.add_admission(patient2, "2020-01-01 00:00:00", "2020-01-01 12:00:00", add_procedure=False) + admission4 = self.mock.add_admission(patient2, "2020-01-01 00:00:00", "2020-01-01 12:00:00") + dataset = self.mock.create() + + samples = dataset.set_task(ReadmissionPredictionMIMIC3()) + + visit_ids = [int(s["visit_id"]) for s in samples] + + self.assertIn (admission1, visit_ids) + self.assertNotIn(admission2, visit_ids) # Patient's last admission should not be included + self.assertNotIn(admission3, visit_ids) + self.assertNotIn(admission4, visit_ids) # Patient's last admission should not be included + + def test_admissions_of_minors_are_excluded(self): + patient = self.mock.add_patient(dob="2000-01-01 00:00:00") + admission1 = self.mock.add_admission(patient, admittime="2017-12-31 23:59:59", dischtime="2018-01-01 00:00:00") # Admitted 1 second before turning 18 + admission2 = self.mock.add_admission(patient, admittime="2018-01-01 00:00:00", dischtime="2018-01-01 12:00:00") # Admitted at exactly 18 + admission3 = self.mock.add_admission(patient, admittime="2020-01-01 00:00:00", dischtime="2020-01-01 12:00:00") + dataset = self.mock.create() + + task = ReadmissionPredictionMIMIC3() + samples = dataset.set_task(task) + + self.assertTrue(task.exclude_minors) + + visit_ids = [int(s["visit_id"]) for s in samples] + + self.assertNotIn(admission1, visit_ids) + self.assertIn (admission2, visit_ids) + self.assertNotIn(admission3, visit_ids) # Patient's last admission should not be included + + def test_exclude_minors_flag(self): + patient = self.mock.add_patient(dob="2000-01-01 00:00:00") + admission1 = self.mock.add_admission(patient, admittime="2017-12-31 23:59:59", dischtime="2018-01-01 00:00:00") # Admitted 1 second before turning 18 + admission2 = self.mock.add_admission(patient, admittime="2018-01-01 00:00:00", dischtime="2018-01-01 12:00:00") # Admitted at exactly 18 + admission3 = self.mock.add_admission(patient, admittime="2020-01-01 00:00:00", dischtime="2020-01-01 12:00:00") + dataset = self.mock.create() + + samples = dataset.set_task(ReadmissionPredictionMIMIC3(exclude_minors=False)) + + visit_ids = [int(s["visit_id"]) for s in samples] + + self.assertIn (admission1, visit_ids) + self.assertIn (admission2, visit_ids) + self.assertNotIn(admission3, visit_ids) # Patient's last admission should not be included + + +class MockMICIC3Dataset: + def __init__(self): + self.patients = ["row_id,subject_id,gender,dob,dod,dod_hosp,dod_ssn,expire_flag"] + self.admissions = ["row_id,subject_id,hadm_id,admittime,dischtime,deathtime,admission_type,admission_location,discharge_location,insurance,language,religion,marital_status,ethnicity,edregtime,edouttime,diagnosis,hospital_expire_flag,has_chartevents_data"] + self.icu_stays = ["subject_id,intime,icustay_id,first_careunit,dbsource,last_careunit,outtime"] + self.diagnoses = ["row_id,subject_id,hadm_id,seq_num,icd9_code"] + self.prescriptions = ["row_id,subject_id,hadm_id,icustay_id,startdate,enddate,drug_type,drug,drug_name_poe,drug_name_generic,formulary_drug_cd,gsn,ndc,prod_strength,dose_val_rx,dose_unit_rx,form_val_disp,form_unit_disp,route"] + self.procedures = ["row_id,subject_id,hadm_id,seq_num,icd9_code"] + + self.next_subject_id = 1 + self.next_hadm_id = 1 + self.next_diagnosis_id = 1 + self.next_prescription_id = 1 + self.next_procedure_id = 1 + + def add_patient(self,dob: str="2000-01-01 00:00:00") -> int: + subject_id = self.next_subject_id + self.next_subject_id += 1 + self.patients.append(f"{subject_id},{subject_id},,{dob},,,,") + return subject_id + + def add_admission(self, + subject_id: int, + admittime: str, + dischtime: str, + add_diagnosis: bool=True, + add_prescription: bool=True, + add_procedure: bool=True + ) -> int: + hadm_id = self.next_hadm_id + self.next_hadm_id += 1 + self.admissions.append(f"{hadm_id},{subject_id},{hadm_id},{admittime},{dischtime},,,,,,,,,,,,,,") + if add_diagnosis: self.add_diagnosis(subject_id, hadm_id) + if add_prescription: self.add_prescription(subject_id, hadm_id) + if add_procedure: self.add_procedure(subject_id, hadm_id) + return hadm_id + + def add_diagnosis(self, subject_id: int, hadm_id: int, seq_num: int=1, icd9_code: str="") -> int: + row_id = self.next_diagnosis_id + self.next_diagnosis_id += 1 + self.diagnoses.append(f"{row_id},{subject_id},{hadm_id},{seq_num},{icd9_code}") + return row_id + + def add_prescription(self, subject_id: int, hadm_id: int) -> int: + row_id = self.next_prescription_id + self.next_prescription_id += 1 + self.prescriptions.append(f"{row_id},{subject_id},{hadm_id},,,,,,,,,,,,,,,,") + return row_id + + def add_procedure(self, subject_id: int, hadm_id: int, seq_num: int=1, icd9_code: str="") -> int: + row_id = self.next_procedure_id + self.next_procedure_id += 1 + self.procedures.append(f"{row_id},{subject_id},{hadm_id},{seq_num},{icd9_code}") + return row_id + + def create(self, tables: list=["diagnoses_icd", "prescriptions", "procedures_icd"]): + files = { + "PATIENTS.csv": "\n".join(self.patients), + "ADMISSIONS.csv": "\n".join(self.admissions), + "ICUSTAYS.csv": "\n".join(self.icu_stays), + "DIAGNOSES_ICD.csv": "\n".join(self.diagnoses), + "PRESCRIPTIONS.csv": "\n".join(self.prescriptions), + "PROCEDURES_ICD.csv": "\n".join(self.procedures), + } + + for k, v in files.items(): + with open(k, 'w') as f: f.write(v) + + return MIMIC3Dataset(root=".", tables=tables) + + def destroy(self): + if os.path.exists("PATIENTS.csv"): os.remove("PATIENTS.csv") + if os.path.exists("ADMISSIONS.csv"): os.remove("ADMISSIONS.csv") + if os.path.exists("ICUSTAYS.csv"): os.remove("ICUSTAYS.csv") + if os.path.exists("DIAGNOSES_ICD.csv"): os.remove("DIAGNOSES_ICD.csv") + if os.path.exists("PRESCRIPTIONS.csv"): os.remove("PRESCRIPTIONS.csv") + if os.path.exists("PROCEDURES_ICD.csv"): os.remove("PROCEDURES_ICD.csv") + + +if __name__ == "__main__": + unittest.main()