Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for vila-m3 #65

Merged
merged 4 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 146 additions & 9 deletions plugins/RadCoPilot_Slicer/RadCoPilot.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,23 @@ def __init__(self, parent):
groupLayout.addRow(_("Server address:"), serverUrl)
parent.registerProperty("RadCoPilot/serverUrl", serverUrl, "text", str(qt.SIGNAL("textChanged(QString)")))

scanUrl = qt.QLineEdit()
groupLayout.addRow(_("Scan address:"), scanUrl)
parent.registerProperty("RadCoPilot/scanUrl", scanUrl, "text", str(qt.SIGNAL("textChanged(QString)")))


serverUrlHistory = qt.QLineEdit()
groupLayout.addRow(_("Server address history:"), serverUrlHistory)
parent.registerProperty(
"RadCoPilot/serverUrlHistory", serverUrlHistory, "text", str(qt.SIGNAL("textChanged(QString)"))
)

scanUrlHistory = qt.QLineEdit()
groupLayout.addRow(_("Scan address history:"), scanUrlHistory)
parent.registerProperty(
"RadCoPilot/scanUrlHistory", scanUrlHistory, "text", str(qt.SIGNAL("textChanged(QString)"))
)

fileExtension = qt.QLineEdit()
fileExtension.setText(".nii.gz")
fileExtension.toolTip = _("Default extension for uploading volumes")
Expand Down Expand Up @@ -110,6 +121,7 @@ def __init__(self, parent=None):
self._updatingGUIFromParameterNode = False

self.info = {}
self.current_model = None
self.current_sample = None
self.samples = {}
self.state = {
Expand Down Expand Up @@ -147,15 +159,19 @@ def setup(self):
# in batch mode, without a graphical user interface.
self.tmpdir = slicer.util.tempDirectory("slicer-radcopilot")
self.logic = RadCoPilotLogic()
self.ui.scanComboBox.setEditable(True)
self.ui.serverComboBox.setEditable(True)

# Set icons and tune widget properties
self.ui.serverComboBox.lineEdit().setPlaceholderText("enter server address or leave empty to use default")
self.ui.scanComboBox.lineEdit().setPlaceholderText("enter scan address")
self.ui.fetchServerInfoButton.setIcon(self.icon("refresh-icon.png"))
self.ui.uploadImageButton.setIcon(self.icon("upload.svg"))

# start with button disabled
self.ui.sendPrompt.setEnabled(False)
self.ui.uploadImageButton.setEnabled(False)
self.ui.fetchImageButton.setEnabled(False)
self.ui.outputText.setReadOnly(True)

# Connections
Expand All @@ -164,6 +180,11 @@ def setup(self):
self.ui.sendPrompt.connect("clicked(bool)", self.onClickSendPrompt)
self.ui.cleanOutputButton.connect("clicked(bool)", self.onClickCleanOutputButton)
self.ui.uploadImageButton.connect("clicked(bool)", self.onUploadImage)
self.ui.fetchImageButton.connect("clicked(bool)", self.onFetchImage)


self.updateServerUrlGUIFromSettings()
self.updateScanUrlGUIFromSettings()

def icon(self, name="RadCoPilot.png"):
'''Get the icon for the RadCoPilot module.'''
Expand All @@ -178,13 +199,49 @@ def updateServerSettings(self):
self.logic.setServer(self.serverUrl())
self.saveServerUrl()

def updateScanSettings(self):
'''Update the scan settings based on the current UI state.'''
self.saveScanUrl()

def serverUrl(self):
'''Get the current server URL from the UI.'''
serverUrl = self.ui.serverComboBox.currentText.strip()
if not serverUrl:
serverUrl = "http://localhost:8000"
# return serverUrl.rstrip("/")
return serverUrl

def scanUrl(self):
'''Get the current scan URL from the UI.'''
scanUrl = self.ui.scanComboBox.currentText.strip()
if not scanUrl:
print("Scan address is empty ... ")
return None
return scanUrl

def updateServerUrlGUIFromSettings(self):
'''Save current server URL to the top of history.'''
settings = qt.QSettings()
serverUrlHistory = settings.value("RadCoPilot/serverUrlHistory")

wasBlocked = self.ui.serverComboBox.blockSignals(True)
self.ui.serverComboBox.clear()
if serverUrlHistory:
self.ui.serverComboBox.addItems(serverUrlHistory.split(";"))
self.ui.serverComboBox.setCurrentText(settings.value("RadCoPilot/serverUrl"))
self.ui.serverComboBox.blockSignals(wasBlocked)

def updateScanUrlGUIFromSettings(self):
'''Save current scan URL to the top of history.'''
settings = qt.QSettings()
scanUrlHistory = settings.value("RadCoPilot/scanUrlHistory")

wasBlocked = self.ui.scanComboBox.blockSignals(True)
self.ui.scanComboBox.clear()
if scanUrlHistory:
self.ui.scanComboBox.addItems(scanUrlHistory.split(";"))
self.ui.scanComboBox.setCurrentText(settings.value("RadCoPilot/scanUrl"))
self.ui.scanComboBox.blockSignals(wasBlocked)

def saveServerUrl(self):
'''Save the current server URL to settings and update history.'''
Expand All @@ -210,7 +267,34 @@ def saveServerUrl(self):
serverUrlHistory = serverUrlHistory[:10] # keep up to first 10 elements
settings.setValue("RadCoPilot/serverUrlHistory", ";".join(serverUrlHistory))

# self.updateServerUrlGUIFromSettings()
self.updateServerUrlGUIFromSettings()


def saveScanUrl(self):
'''Save the current scan URL to settings and update history.'''
# self.updateParameterNodeFromGUI()

# Save selected server URL
settings = qt.QSettings()
scanUrl = self.ui.scanComboBox.currentText
settings.setValue("RadCoPilot/scanUrl", scanUrl)

# Save current scan URL to the top of history
scanUrlHistory = settings.value("RadCoPilot/serverUrlHistory")
if scanUrlHistory:
scanUrlHistory = scanUrlHistory.split(";")
else:
scanUrlHistory = []
try:
scanUrlHistory.remove(scanUrl)
except ValueError:
pass

scanUrlHistory.insert(0, scanUrl)
scanUrlHistory = scanUrlHistory[:10] # keep up to first 10 elements
settings.setValue("RadCoPilot/scanUrlHistory", ";".join(scanUrlHistory))

self.updateScanUrlGUIFromSettings()

def show_popup(self, title, message):
'''Display a popup message box with the given title and message.'''
Expand All @@ -225,16 +309,17 @@ def onClickFetchInfo(self):

try:
self.updateServerSettings()
self.updateScanSettings()
info = self.logic.info()
self.current_model = info
self.info = info

print(f"Connected to RadCoPilot Server - Obtained info from server: {self.info}")
self.show_popup("Information", "Connected to RadCoPilot Server")
self.ui.sendPrompt.setEnabled(True)
self.ui.uploadImageButton.setEnabled(True)
self.ui.fetchImageButton.setEnabled(True)
# Updating model name
self.ui.appComboBox.clear()
self.ui.appComboBox.addItem(self.info)
self.ui.appDescriptionLabel.text = self.info

except AttributeError as e:
slicer.util.errorDisplay(
Expand All @@ -250,13 +335,53 @@ def onClickCleanOutputButton(self):
'''Handle the click event for cleaning the output text.'''
self.ui.outputText.clear()

def onUploadImage(self):
'''Gets the volume and sen it to the server.'''
volumeNode = slicer.mrmlScene.GetFirstNodeByClass("vtkMRMLScalarVolumeNode")
image_id = volumeNode.GetName()
def reportProgress2(self, msg, level=None):
'''Print progress in the console.'''
print("Loading... {0}%".format(self.sampleDataLogic.downloadPercent))
# Abort download if cancel is clicked in progress bar
if self.progressWindow.wasCanceled:
raise Exception("download aborted")
# Update progress window
self.progressWindow.show()
self.progressWindow.activateWindow()
self.progressWindow.setValue(int(self.sampleDataLogic.downloadPercent))
self.progressWindow.setLabelText("Downloading...")
# Process events to allow screen to refresh
slicer.app.processEvents()

def onFetchImage(self):
'''Fetch an image from an URL.'''
import SampleData

try:
scan_url = self.ui.scanComboBox.currentText
volumeNode = None
self.progressWindow = slicer.util.createProgressDialog()
self.sampleDataLogic = SampleData.SampleDataLogic()
self.sampleDataLogic.logMessage = self.reportProgress2
loadedNodes = self.sampleDataLogic.downloadFromURL(
nodeNames="CTVolume",
fileNames="ct_liver_0.nii.gz",
uris=scan_url,
#checksums="SHA256:cc211f0dfd9a05ca3841ce1141b292898b2dd2d3f08286affadf823a7e58df93"
)
volumeNode = loadedNodes[0]
self.ui.inputVolumeNodeComboBox.setCurrentNode(volumeNode)
# Sending url to the server
info = self.logic.uploadScan(scan_url)
self.info = info
print(f"Response from the upload text call: {self.info['status']}")

finally:
self.progressWindow.close()

def onUploadImage(self):
'''Gets the text from scanComboBox or the volume from the viewport and sends it to the server.'''
try:
qt.QApplication.setOverrideCursor(qt.Qt.WaitCursor)

volumeNode = self.ui.inputVolumeNodeComboBox.currentNode()
image_id = volumeNode.GetName()
in_file = tempfile.NamedTemporaryFile(suffix=self.file_ext, dir=self.tmpdir).name
self.current_sample = in_file
self.reportProgress(5)
Expand All @@ -279,8 +404,10 @@ def onUploadImage(self):
msg = f"Message: {e.msg}" if hasattr(e, "msg") else ""
self.reportProgress(100)
qt.QApplication.restoreOverrideCursor()
self.show_popup("Error", f"Upload failed: {msg}")
return False


def reportProgress(self, progressPercentage):
'''Reports progress of an event.'''
if not self.progressBar:
Expand All @@ -293,7 +420,14 @@ def reportProgress(self, progressPercentage):
def has_text(self, ui_text):
'''Check if the given UI text element has any content.'''
return len(ui_text.toPlainText()) < 1

def _get_answer_text(self, response):
'''Get the VILA-M3 response.'''

final_response = response['choices'][0]['message']['content']['content'][0]['text']

return final_response

def onClickSendPrompt(self):
'''Handle the click event for sending a prompt to the server.'''
if not self.logic:
Expand All @@ -313,7 +447,10 @@ def onClickSendPrompt(self):
info = self.logic.getAnswer(inputText=inText, volumePath=self.current_sample)
if info is not None:
self.info = info
self.ui.outputText.setText(info['choices'][0]['message']['content'])
if 'VILA-M3' in self.current_model:
self.ui.outputText.setText(self._get_answer_text(info))
else:
self.ui.outputText.setText(info['choices'][0]['message']['content'])
logging.info(f"Time consumed by fetch info: {time.time() - start:3.1f}")


Expand Down
20 changes: 14 additions & 6 deletions plugins/RadCoPilot_Slicer/RadCoPilotLib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ def uploadFile(self, volumePath):
"""
Upload a file to the RadCoPilot server using the fileUploadRouter.

This method sends a file to the '/upload' endpoint of the RadCoPilot server,
This method sends either a local file or a URL to the '/upload' endpoint of the RadCoPilot server,
which stores it as the last received file. This uploaded file can then be
used in subsequent requests to the chat_completions API if no file is
provided in those requests.

Parameters:
-----------
filePath : str
The path to the file that should be uploaded to the server.
volumePath : str
The path to the local file or a URL that should be uploaded to the server.

Returns:
--------
Expand All @@ -103,13 +103,21 @@ def uploadFile(self, volumePath):
print("Uploading file...")

selector = "/upload/"

url = f"{self._server_url}{selector}"

# Check if volumePath is a URL
parsed_url = urlparse(volumePath)
is_url = bool(parsed_url.scheme and parsed_url.netloc)

with open(volumePath, 'rb') as file:
files = {"file": (os.path.basename(volumePath), file, "application/octet-stream")}
if is_url:
# If it's a URL, send it as a form-data field
files = {"file": (None, str(volumePath))}
response = requests.post(url=url, files=files)
else:
# If it's a local file, open and send it
with open(volumePath, 'rb') as file:
files = {"file": (os.path.basename(volumePath), file, "application/octet-stream")}
response = requests.post(url=url, files=files)

if response.status_code == 200:
return response.json()
Expand Down
Loading