From c9afeb1938893a12d02a2a7739b7c853a7ae9d6b Mon Sep 17 00:00:00 2001 From: charanw Date: Tue, 9 Dec 2025 20:14:05 -0600 Subject: [PATCH] Add BHCToAVS model for patient-friendly summaries Introduces the BHCToAVS model, which converts clinical Brief Hospital Course (BHC) notes into After-Visit Summaries (AVS) using a fine-tuned Mistral 7B model with LoRA adapters. Adds model implementation, documentation, an example usage script, and unit tests. --- docs/api/models.rst | 1 + docs/api/models/pyhealth.models.BHCToAVS.rst | 11 +++ examples/bhc_to_avs_example.py | 21 +++++ pyhealth/models/__init__.py | 3 +- pyhealth/models/bhc_to_avs.py | 98 ++++++++++++++++++++ tests/core/test_bhc_to_avs.py | 36 +++++++ 6 files changed, 169 insertions(+), 1 deletion(-) create mode 100644 docs/api/models/pyhealth.models.BHCToAVS.rst create mode 100644 examples/bhc_to_avs_example.py create mode 100644 pyhealth/models/bhc_to_avs.py create mode 100644 tests/core/test_bhc_to_avs.py diff --git a/docs/api/models.rst b/docs/api/models.rst index a0df0d943..4ab9d6ed8 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -8,6 +8,7 @@ We implement the following models for supporting multiple healthcare predictive :maxdepth: 3 models/pyhealth.models.BaseModel + models/pyhealth.models.BHCToAVS models/pyhealth.models.LogisticRegression models/pyhealth.models.MLP models/pyhealth.models.CNN diff --git a/docs/api/models/pyhealth.models.BHCToAVS.rst b/docs/api/models/pyhealth.models.BHCToAVS.rst new file mode 100644 index 000000000..594493186 --- /dev/null +++ b/docs/api/models/pyhealth.models.BHCToAVS.rst @@ -0,0 +1,11 @@ +pyhealth.models.bhc_to_avs +========================== + +BHCToAVS +------------------------------ + +.. autoclass:: pyhealth.models.bhc_to_avs.BHCToAVS + :members: + :inherited-members: + :show-inheritance: + :undoc-members: \ No newline at end of file diff --git a/examples/bhc_to_avs_example.py b/examples/bhc_to_avs_example.py new file mode 100644 index 000000000..ef438622f --- /dev/null +++ b/examples/bhc_to_avs_example.py @@ -0,0 +1,21 @@ +from pyhealth.models.bhc_to_avs import BHCToAVS + +# Initialize the model +model = BHCToAVS() + +# Example Brief Hospital Course (BHC) text with common clinical abbreviations generated synthetically via ChatGPT 5.1 +bhc = ( + "Pt admitted with acute onset severe epigastric pain and hypotension. " + "Labs notable for elevated lactate, WBC 18K, mild AST/ALT elevation, and Cr 1.4 (baseline 0.9). " + "CT A/P w/ contrast demonstrated peripancreatic fat stranding c/w acute pancreatitis; " + "no necrosis or peripancreatic fluid collection. " + "Pt received aggressive IVFs, electrolyte repletion, IV analgesia, and NPO status initially. " + "Serial abd exams remained benign with no rebound or guarding. " + "BP stabilized, lactate downtrended, and pt tolerated ADAT to low-fat diet without recurrence of sx. " + "Discharged in stable condition w/ instructions for GI f/u and outpatient CMP in 1 week." +) + +# Generate a patient-friendly After-Visit Summary +print(model.predict(bhc)) + +# Expected output: A simplified, patient-friendly summary explaining the hospital stay without medical jargon. \ No newline at end of file diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5c3683bc1..659bb7f88 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -1,6 +1,7 @@ from .adacare import AdaCare, AdaCareLayer from .agent import Agent, AgentLayer from .base_model import BaseModel +from .bhc_to_avs import BHCToAVS from .cnn import CNN, CNNLayer from .concare import ConCare, ConCareLayer from .contrawr import ContraWR, ResBlock2D @@ -26,4 +27,4 @@ from .transformer import Transformer, TransformerLayer from .transformers_model import TransformersModel from .vae import VAE -from .sdoh import SdohClassifier \ No newline at end of file +from .sdoh import SdohClassifier diff --git a/pyhealth/models/bhc_to_avs.py b/pyhealth/models/bhc_to_avs.py new file mode 100644 index 000000000..b57e6fbfb --- /dev/null +++ b/pyhealth/models/bhc_to_avs.py @@ -0,0 +1,98 @@ +# Author: Charan Williams +# NetID: charanw2 +# Description: Converts clinical brief hospital course (BHC) data to after visit summaries using a fine-tuned Mistral 7B model. + +from typing import Dict, Any +from dataclasses import dataclass, field +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +from peft import PeftModelForCausalLM +from pyhealth.models.base_model import BaseModel + +_PROMPT = """Summarize for the patient what happened during the hospital stay: + +### Brief Hospital Course: +{bhc} + +### Patient Summary: +""" + +# System prompt used during inference +_SYSTEM_PROMPT = ( + "You are a clinical summarization model. Produce accurate, patient-friendly summaries " + "using only information from the doctor's note. Do not add new details.\n\n" +) + +# Prompt used during fine-tuning +_PROMPT = ( + "Summarize for the patient what happened during the hospital stay based on this doctor's note:\n" + "{bhc}\n\n" + "Summary for the patient:\n" +) + +@dataclass +class BHCToAVS(BaseModel): + base_model_id: str = field(default="mistralai/Mistral-7B-Instruct") + """HuggingFace repo containing the base Mistral 7B model.""" + + adapter_model_id: str = field(default="williach31/mistral-7b-bhc-to-avs-lora") + """HuggingFace repo containing only LoRA adapter weights.""" + + def _get_pipeline(self): + """Create and cache the text-generation pipeline.""" + if not hasattr(self, "_pipeline"): + # Load base model + base = AutoModelForCausalLM.from_pretrained( + self.base_model_id, + torch_dtype=torch.bfloat16, + device_map="auto" + ) + + # Load LoRA adapter + model = PeftModelForCausalLM.from_pretrained( + base, + self.adapter_model_id, + torch_dtype=torch.bfloat16 + ) + + tokenizer = AutoTokenizer.from_pretrained(self.base_model_id) + + # Create HF pipeline + self._pipeline = pipeline( + "text-generation", + model=model, + tokenizer=tokenizer, + device_map="auto", + model_kwargs={"torch_dtype": torch.bfloat16} + ) + + return self._pipeline + + def predict(self, bhc_text: str) -> str: + """ + Generate an After-Visit Summary (AVS) from a Brief Hospital Course (BHC) note. + + Parameters + ---------- + bhc_text : str + Raw BHC text. + + Returns + ------- + str + Patient-friendly summary. + """ + + prompt = _SYSTEM_PROMPT + _PROMPT.format(bhc=bhc_text) + + pipe = self._get_pipeline() + outputs = pipe( + prompt, + max_new_tokens=512, + temperature=0.0, + eos_token_id=[pipe.tokenizer.eos_token_id], + pad_token_id=pipe.tokenizer.eos_token_id, + ) + + # Output is a single text string + return outputs[0]["generated_text"].strip() \ No newline at end of file diff --git a/tests/core/test_bhc_to_avs.py b/tests/core/test_bhc_to_avs.py new file mode 100644 index 000000000..04c45fbf5 --- /dev/null +++ b/tests/core/test_bhc_to_avs.py @@ -0,0 +1,36 @@ +from tests.base import BaseTestCase +from pyhealth.models.bhc_to_avs import BHCToAVS + + +class TestBHCToAVS(BaseTestCase): + """Unit tests for the BHCToAVS model.""" + + def setUp(self): + self.set_random_seed() + + def test_predict(self): + """Test the predict method of BHCToAVS.""" + bhc_text = ( + "Patient admitted with abdominal pain. Imaging showed no acute findings. " + "Pain improved with supportive care and the patient was discharged in stable condition." + ) + model = BHCToAVS() + try: + + summary = model.predict(bhc_text) + + # Output must be type str + self.assertIsInstance(summary, str) + + # Output should not be empty + self.assertGreater(len(summary.strip()), 0) + + # Output should be different from input + self.assertNotIn(bhc_text[:40], summary) + + except OSError as e: + # Allow test to pass if model download fails on e.g. on GitHub workflows + if "gated repo" in str(e).lower() or "404" in str(e): + pass + else: + raise e