Skip to content

Commit 7ccc628

Browse files
committed
Merge base surgical copilot code
1 parent 9df1f4e commit 7ccc628

31 files changed

+14882
-0
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
models/*
2+
!models/.gitignore

THIRD_PARTY_NOTICES.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
------------------------------------------------------------
2+
Bootswatch v5.3.1 (https://bootswatch.com)
3+
Theme: darkly_green
4+
Copyright 2012-2023 Thomas Park
5+
Licensed under MIT
6+
Based on Bootstrap
7+
------------------------------------------------------------
8+
Bootstrap v5.3.1 (https://getbootstrap.com/)
9+
Copyright 2011-2023 The Bootstrap Authors
10+
Licensed under MIT (https://github.com/twbs/bootstrap/blob/main/LICENSE)
11+
------------------------------------------------------------
12+
Bootstrap Bundle v5.3.1 (https://getbootstrap.com/)
13+
Copyright 2011-2023 The Bootstrap Authors
14+
Licensed under MIT (https://github.com/twbs/bootstrap/blob/main/LICENSE)
15+
------------------------------------------------------------
16+
jQuery v3.6.3
17+
Copyright (c) OpenJS Foundation and other contributors
18+
Licensed under the jQuery license (jquery.org/license)
19+
------------------------------------------------------------

agents/annotation_agent.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""
2+
# Copyright (c) MONAI Consortium
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
"""
13+
14+
import threading
15+
import time
16+
import logging
17+
import os
18+
import json
19+
import queue
20+
from typing import List
21+
from pydantic import BaseModel
22+
from .base_agent import Agent
23+
24+
class SurgeryAnnotation(BaseModel):
25+
timestamp: str
26+
elapsed_time_seconds: float
27+
tools: List[str]
28+
anatomy: List[str]
29+
surgical_phase: str
30+
description: str
31+
32+
class AnnotationAgent(Agent):
33+
def __init__(self, settings_path, response_handler, frame_queue, agent_key=None, procedure_start_str=None):
34+
super().__init__(settings_path, response_handler, agent_key=agent_key)
35+
self._logger = logging.getLogger(__name__)
36+
self.frame_queue = frame_queue
37+
self.time_step = self.agent_settings.get("time_step_seconds", 10)
38+
39+
if procedure_start_str is None:
40+
procedure_start_str = time.strftime("%Y_%m_%d__%H_%M_%S", time.localtime())
41+
self.procedure_start_str = procedure_start_str
42+
self.procedure_start = time.time()
43+
44+
45+
base_output_dir = self.agent_settings.get("annotation_output_dir", "procedure_outputs")
46+
subfolder = os.path.join(base_output_dir, f"procedure_{self.procedure_start_str}")
47+
os.makedirs(subfolder, exist_ok=True)
48+
49+
self.annotation_filepath = os.path.join(subfolder, "annotation.json")
50+
self._logger.info(f"AnnotationAgent writing annotations to: {self.annotation_filepath}")
51+
52+
self.annotations = []
53+
self.stop_event = threading.Event()
54+
55+
# Start the background loop in a separate thread.
56+
self.thread = threading.Thread(target=self._background_loop, daemon=True)
57+
self.thread.start()
58+
self._logger.info(f"AnnotationAgent background thread started (interval={self.time_step}s).")
59+
60+
def _background_loop(self):
61+
while not self.stop_event.is_set():
62+
try:
63+
# Attempt to get image data from the frame queue.
64+
try:
65+
frame_data = self.frame_queue.get_nowait()
66+
except queue.Empty:
67+
self._logger.debug("No image data available; skipping annotation generation.")
68+
time.sleep(self.time_step)
69+
continue
70+
71+
annotation = self._generate_annotation(frame_data)
72+
if annotation:
73+
self.annotations.append(annotation)
74+
self.append_json_to_file(annotation, self.annotation_filepath)
75+
self._logger.debug(f"New annotation appended: {annotation}")
76+
except Exception as e:
77+
self._logger.error(f"Error generating annotation: {e}", exc_info=True)
78+
time.sleep(self.time_step)
79+
80+
def _generate_annotation(self, frame_data):
81+
messages = []
82+
if self.agent_prompt:
83+
messages.append({"role": "system", "content": self.agent_prompt})
84+
user_content = "Please produce an annotation of the surgical scene based on the provided image, following the required schema."
85+
messages.append({"role": "user", "content": user_content})
86+
try:
87+
guided_params = {"guided_json": json.loads(self.grammar)}
88+
raw_json_str = self.stream_image_response(
89+
prompt=self.generate_prompt(user_content, []),
90+
image_b64=frame_data,
91+
temperature=0.3,
92+
extra_body=guided_params
93+
)
94+
self._logger.debug(f"Raw annotation response: {raw_json_str}")
95+
96+
try:
97+
parsed = SurgeryAnnotation.model_validate_json(raw_json_str)
98+
except Exception as e:
99+
self._logger.warning(f"Annotation parse error: {e}")
100+
return None
101+
102+
annotation_dict = parsed.dict()
103+
timestamp_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
104+
annotation_dict["timestamp"] = timestamp_str
105+
annotation_dict["elapsed_time_seconds"] = time.time() - self.procedure_start
106+
107+
return annotation_dict
108+
109+
except Exception as e:
110+
self._logger.warning(f"Annotation generation error: {e}")
111+
return None
112+
113+
def process_request(self, input_data, chat_history):
114+
return {
115+
"name": "AnnotationAgent",
116+
"response": "AnnotationAgent runs in the background and generates annotations only when image data is available."
117+
}
118+
119+
def stop(self):
120+
self.stop_event.set()
121+
self._logger.info("Stopping AnnotationAgent background thread.")
122+
self.thread.join()

agents/base_agent.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
"""
2+
# Copyright (c) MONAI Consortium
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
"""
13+
14+
from abc import ABC, abstractmethod
15+
import json
16+
import logging
17+
import yaml
18+
import time
19+
import tiktoken
20+
from threading import Lock
21+
import base64
22+
import tempfile
23+
import os
24+
import requests
25+
from openai import OpenAI
26+
27+
class Agent(ABC):
28+
_llm_lock = Lock()
29+
30+
def __init__(self, settings_path, response_handler, agent_key=None):
31+
self._logger = logging.getLogger(f"{__name__}.{type(self).__name__}")
32+
self.load_settings(settings_path, agent_key=agent_key)
33+
self.response_handler = response_handler
34+
self.tokenizer = tiktoken.get_encoding("cl100k_base")
35+
self.client = OpenAI(api_key="EMPTY", base_url=self.llm_url)
36+
self._wait_for_server()
37+
38+
def load_settings(self, settings_path, agent_key=None):
39+
with open(settings_path, 'r') as f:
40+
full_config = yaml.safe_load(f)
41+
if agent_key and agent_key in full_config:
42+
self.agent_settings = full_config[agent_key]
43+
else:
44+
self.agent_settings = full_config
45+
self.description = self.agent_settings.get('description', '')
46+
self.max_prompt_tokens = self.agent_settings.get('max_prompt_tokens', 3000)
47+
self.ctx_length = self.agent_settings.get('ctx_length', 2048)
48+
self.agent_prompt = self.agent_settings.get('agent_prompt', '').strip()
49+
self.user_prefix = self.agent_settings.get('user_prefix', '')
50+
self.bot_prefix = self.agent_settings.get('bot_prefix', '')
51+
self.bot_rule_prefix = self.agent_settings.get('bot_rule_prefix', '')
52+
self.end_token = self.agent_settings.get('end_token', '')
53+
self.grammar = self.agent_settings.get('grammar', None)
54+
self.model_name = self.agent_settings.get('model_name', 'llama3.2')
55+
self.publish_settings = self.agent_settings.get('publish', {})
56+
self.llm_url = self.agent_settings.get('llm_url', "http://localhost:8000/v1")
57+
self.tools = self.agent_settings.get('tools', {})
58+
self._logger.debug(f"Agent config loaded. llm_url={self.llm_url}, model_name={self.model_name}")
59+
60+
def _wait_for_server(self, timeout=30):
61+
attempts = 0
62+
check_url = f"{self.llm_url}/models"
63+
while attempts < timeout:
64+
try:
65+
r = requests.get(check_url)
66+
if r.status_code == 200:
67+
self._logger.debug(f"Connected to vLLM server at {self.llm_url}")
68+
return
69+
except Exception as e:
70+
self._logger.debug(f"Waiting for vLLM server (attempt {attempts+1}): {e}")
71+
time.sleep(1)
72+
attempts += 1
73+
raise ConnectionError(f"Unable to connect to vLLM server at {self.llm_url} after {timeout} seconds")
74+
75+
def stream_response(self, prompt, grammar=None, temperature=0.0, display_output=True):
76+
with Agent._llm_lock:
77+
user_message = prompt.split("<|im_start|>user\n")[-1].split("<|im_end|>")[0].strip()
78+
request_messages = []
79+
if self.agent_prompt:
80+
request_messages.append({"role": "system", "content": self.agent_prompt})
81+
request_messages.append({"role": "user", "content": user_message})
82+
self._logger.debug(
83+
f"Sending chat request to vLLM/OpenAI client. Model={self.model_name}, temperature={temperature}\nUser message:\n{user_message[:500]}"
84+
)
85+
try:
86+
completion = self.client.chat.completions.create(
87+
model=self.model_name,
88+
messages=request_messages,
89+
temperature=temperature,
90+
max_tokens=self.ctx_length
91+
)
92+
response_text = completion.choices[0].message.content if completion.choices else ""
93+
if display_output and self.response_handler:
94+
self.response_handler.add_response(response_text)
95+
self.response_handler.end_response()
96+
return response_text
97+
except Exception as e:
98+
self._logger.error(f"vLLM chat request failed: {e}", exc_info=True)
99+
return ""
100+
101+
def stream_image_response(self, prompt, image_b64, grammar=None, temperature=0.0, display_output=True, extra_body=None):
102+
self._logger.debug(f"stream_image_response with model={self.model_name}")
103+
if not image_b64:
104+
raise ValueError("No image data provided for image response")
105+
user_message = prompt.split("<|im_start|>user\n")[-1].split("<|im_end|>")[0].strip()
106+
try:
107+
raw_b64 = self._extract_raw_base64(image_b64)
108+
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file:
109+
file_path = tmp_file.name
110+
tmp_file.write(base64.b64decode(raw_b64))
111+
self._logger.debug(f"Temp image file created: {file_path}")
112+
messages = []
113+
if self.agent_prompt:
114+
messages.append({"role": "system", "content": self.agent_prompt})
115+
messages.append({
116+
"role": "user",
117+
"content": user_message,
118+
"images": [file_path]
119+
})
120+
request_kwargs = {
121+
"model": self.model_name,
122+
"messages": messages,
123+
"temperature": temperature,
124+
"max_tokens": self.ctx_length
125+
}
126+
if extra_body is not None:
127+
request_kwargs["extra_body"] = extra_body
128+
result = self.client.chat.completions.create(**request_kwargs)
129+
raw_text = result.choices[0].message.content
130+
if display_output and self.response_handler:
131+
self.response_handler.add_response(raw_text)
132+
self.response_handler.end_response()
133+
os.remove(file_path)
134+
return raw_text
135+
except Exception as e:
136+
self._logger.error(f"vLLM vision request failed: {e}", exc_info=True)
137+
raise
138+
139+
def _extract_raw_base64(self, image_b64: str) -> str:
140+
prefix = "data:image/"
141+
if image_b64.startswith(prefix):
142+
parts = image_b64.split(',', 1)
143+
if len(parts) == 2:
144+
return parts[1]
145+
else:
146+
return image_b64
147+
else:
148+
return image_b64
149+
150+
def generate_prompt(self, text, chat_history):
151+
system_prompt = f"{self.bot_rule_prefix}\n{self.agent_prompt}\n{self.end_token}"
152+
user_prompt = f"\n{self.user_prefix}\n{text}\n{self.end_token}"
153+
token_usage = self.calculate_token_usage(system_prompt + user_prompt)
154+
chat_prompt = self.create_conversation_str(chat_history, token_usage)
155+
prompt = system_prompt + chat_prompt + user_prompt
156+
prompt += f"\n{self.bot_prefix}\n"
157+
return prompt
158+
159+
def create_conversation_str(self, chat_history, token_usage, conversation_length=2):
160+
total_tokens = token_usage
161+
msg_hist = []
162+
for user_msg, bot_msg in chat_history[:-1][-conversation_length:][::-1]:
163+
if bot_msg:
164+
bot_msg_str = f"\n{self.bot_prefix}\n{bot_msg}\n{self.end_token}"
165+
bot_tokens = self.calculate_token_usage(bot_msg_str)
166+
if total_tokens + bot_tokens > self.max_prompt_tokens:
167+
break
168+
total_tokens += bot_tokens
169+
msg_hist.append(bot_msg_str)
170+
if user_msg:
171+
user_msg_str = f"\n{self.user_prefix}\n{user_msg}\n{self.end_token}"
172+
user_tokens = self.calculate_token_usage(user_msg_str)
173+
if total_tokens + user_tokens > self.max_prompt_tokens:
174+
break
175+
total_tokens += user_tokens
176+
msg_hist.append(user_msg_str)
177+
return "".join(msg_hist[::-1])
178+
179+
def calculate_token_usage(self, text):
180+
return len(self.tokenizer.encode(text))
181+
182+
@abstractmethod
183+
def process_request(self, input_data, chat_history):
184+
pass
185+
186+
def append_json_to_file(self, json_object, file_path):
187+
try:
188+
if not os.path.isfile(file_path):
189+
with open(file_path, 'w') as f:
190+
json.dump([json_object], f, indent=2)
191+
else:
192+
with open(file_path, 'r') as f:
193+
try:
194+
data = json.load(f)
195+
except json.JSONDecodeError:
196+
data = []
197+
if not isinstance(data, list):
198+
data = []
199+
data.append(json_object)
200+
with open(file_path, 'w') as f:
201+
json.dump(data, f, indent=2)
202+
except Exception as e:
203+
self._logger.error(f"append_json_to_file error: {e}", exc_info=True)

0 commit comments

Comments
 (0)