diff --git a/plugins/RadCoPilot_Slicer/RadCoPilot.py b/plugins/RadCoPilot_Slicer/RadCoPilot.py index 8649777..144530f 100644 --- a/plugins/RadCoPilot_Slicer/RadCoPilot.py +++ b/plugins/RadCoPilot_Slicer/RadCoPilot.py @@ -70,12 +70,23 @@ def __init__(self, parent): groupLayout.addRow(_("Server address:"), serverUrl) parent.registerProperty("RadCoPilot/serverUrl", serverUrl, "text", str(qt.SIGNAL("textChanged(QString)"))) + scanUrl = qt.QLineEdit() + groupLayout.addRow(_("Scan address:"), scanUrl) + parent.registerProperty("RadCoPilot/scanUrl", scanUrl, "text", str(qt.SIGNAL("textChanged(QString)"))) + + serverUrlHistory = qt.QLineEdit() groupLayout.addRow(_("Server address history:"), serverUrlHistory) parent.registerProperty( "RadCoPilot/serverUrlHistory", serverUrlHistory, "text", str(qt.SIGNAL("textChanged(QString)")) ) + scanUrlHistory = qt.QLineEdit() + groupLayout.addRow(_("Scan address history:"), scanUrlHistory) + parent.registerProperty( + "RadCoPilot/scanUrlHistory", scanUrlHistory, "text", str(qt.SIGNAL("textChanged(QString)")) + ) + fileExtension = qt.QLineEdit() fileExtension.setText(".nii.gz") fileExtension.toolTip = _("Default extension for uploading volumes") @@ -110,6 +121,7 @@ def __init__(self, parent=None): self._updatingGUIFromParameterNode = False self.info = {} + self.current_model = None self.current_sample = None self.samples = {} self.state = { @@ -147,15 +159,19 @@ def setup(self): # in batch mode, without a graphical user interface. self.tmpdir = slicer.util.tempDirectory("slicer-radcopilot") self.logic = RadCoPilotLogic() + self.ui.scanComboBox.setEditable(True) + self.ui.serverComboBox.setEditable(True) # Set icons and tune widget properties self.ui.serverComboBox.lineEdit().setPlaceholderText("enter server address or leave empty to use default") + self.ui.scanComboBox.lineEdit().setPlaceholderText("enter scan address") self.ui.fetchServerInfoButton.setIcon(self.icon("refresh-icon.png")) self.ui.uploadImageButton.setIcon(self.icon("upload.svg")) # start with button disabled self.ui.sendPrompt.setEnabled(False) self.ui.uploadImageButton.setEnabled(False) + self.ui.fetchImageButton.setEnabled(False) self.ui.outputText.setReadOnly(True) # Connections @@ -164,6 +180,11 @@ def setup(self): self.ui.sendPrompt.connect("clicked(bool)", self.onClickSendPrompt) self.ui.cleanOutputButton.connect("clicked(bool)", self.onClickCleanOutputButton) self.ui.uploadImageButton.connect("clicked(bool)", self.onUploadImage) + self.ui.fetchImageButton.connect("clicked(bool)", self.onFetchImage) + + + self.updateServerUrlGUIFromSettings() + self.updateScanUrlGUIFromSettings() def icon(self, name="RadCoPilot.png"): '''Get the icon for the RadCoPilot module.''' @@ -178,6 +199,10 @@ def updateServerSettings(self): self.logic.setServer(self.serverUrl()) self.saveServerUrl() + def updateScanSettings(self): + '''Update the scan settings based on the current UI state.''' + self.saveScanUrl() + def serverUrl(self): '''Get the current server URL from the UI.''' serverUrl = self.ui.serverComboBox.currentText.strip() @@ -185,6 +210,38 @@ def serverUrl(self): serverUrl = "http://localhost:8000" # return serverUrl.rstrip("/") return serverUrl + + def scanUrl(self): + '''Get the current scan URL from the UI.''' + scanUrl = self.ui.scanComboBox.currentText.strip() + if not scanUrl: + print("Scan address is empty ... ") + return None + return scanUrl + + def updateServerUrlGUIFromSettings(self): + '''Save current server URL to the top of history.''' + settings = qt.QSettings() + serverUrlHistory = settings.value("RadCoPilot/serverUrlHistory") + + wasBlocked = self.ui.serverComboBox.blockSignals(True) + self.ui.serverComboBox.clear() + if serverUrlHistory: + self.ui.serverComboBox.addItems(serverUrlHistory.split(";")) + self.ui.serverComboBox.setCurrentText(settings.value("RadCoPilot/serverUrl")) + self.ui.serverComboBox.blockSignals(wasBlocked) + + def updateScanUrlGUIFromSettings(self): + '''Save current scan URL to the top of history.''' + settings = qt.QSettings() + scanUrlHistory = settings.value("RadCoPilot/scanUrlHistory") + + wasBlocked = self.ui.scanComboBox.blockSignals(True) + self.ui.scanComboBox.clear() + if scanUrlHistory: + self.ui.scanComboBox.addItems(scanUrlHistory.split(";")) + self.ui.scanComboBox.setCurrentText(settings.value("RadCoPilot/scanUrl")) + self.ui.scanComboBox.blockSignals(wasBlocked) def saveServerUrl(self): '''Save the current server URL to settings and update history.''' @@ -210,7 +267,34 @@ def saveServerUrl(self): serverUrlHistory = serverUrlHistory[:10] # keep up to first 10 elements settings.setValue("RadCoPilot/serverUrlHistory", ";".join(serverUrlHistory)) - # self.updateServerUrlGUIFromSettings() + self.updateServerUrlGUIFromSettings() + + + def saveScanUrl(self): + '''Save the current scan URL to settings and update history.''' + # self.updateParameterNodeFromGUI() + + # Save selected server URL + settings = qt.QSettings() + scanUrl = self.ui.scanComboBox.currentText + settings.setValue("RadCoPilot/scanUrl", scanUrl) + + # Save current scan URL to the top of history + scanUrlHistory = settings.value("RadCoPilot/serverUrlHistory") + if scanUrlHistory: + scanUrlHistory = scanUrlHistory.split(";") + else: + scanUrlHistory = [] + try: + scanUrlHistory.remove(scanUrl) + except ValueError: + pass + + scanUrlHistory.insert(0, scanUrl) + scanUrlHistory = scanUrlHistory[:10] # keep up to first 10 elements + settings.setValue("RadCoPilot/scanUrlHistory", ";".join(scanUrlHistory)) + + self.updateScanUrlGUIFromSettings() def show_popup(self, title, message): '''Display a popup message box with the given title and message.''' @@ -225,16 +309,17 @@ def onClickFetchInfo(self): try: self.updateServerSettings() + self.updateScanSettings() info = self.logic.info() + self.current_model = info self.info = info print(f"Connected to RadCoPilot Server - Obtained info from server: {self.info}") - self.show_popup("Information", "Connected to RadCoPilot Server") self.ui.sendPrompt.setEnabled(True) self.ui.uploadImageButton.setEnabled(True) + self.ui.fetchImageButton.setEnabled(True) # Updating model name - self.ui.appComboBox.clear() - self.ui.appComboBox.addItem(self.info) + self.ui.appDescriptionLabel.text = self.info except AttributeError as e: slicer.util.errorDisplay( @@ -250,13 +335,53 @@ def onClickCleanOutputButton(self): '''Handle the click event for cleaning the output text.''' self.ui.outputText.clear() - def onUploadImage(self): - '''Gets the volume and sen it to the server.''' - volumeNode = slicer.mrmlScene.GetFirstNodeByClass("vtkMRMLScalarVolumeNode") - image_id = volumeNode.GetName() + def reportProgress2(self, msg, level=None): + '''Print progress in the console.''' + print("Loading... {0}%".format(self.sampleDataLogic.downloadPercent)) + # Abort download if cancel is clicked in progress bar + if self.progressWindow.wasCanceled: + raise Exception("download aborted") + # Update progress window + self.progressWindow.show() + self.progressWindow.activateWindow() + self.progressWindow.setValue(int(self.sampleDataLogic.downloadPercent)) + self.progressWindow.setLabelText("Downloading...") + # Process events to allow screen to refresh + slicer.app.processEvents() + + def onFetchImage(self): + '''Fetch an image from an URL.''' + import SampleData + try: + scan_url = self.ui.scanComboBox.currentText + volumeNode = None + self.progressWindow = slicer.util.createProgressDialog() + self.sampleDataLogic = SampleData.SampleDataLogic() + self.sampleDataLogic.logMessage = self.reportProgress2 + loadedNodes = self.sampleDataLogic.downloadFromURL( + nodeNames="CTVolume", + fileNames="ct_liver_0.nii.gz", + uris=scan_url, + #checksums="SHA256:cc211f0dfd9a05ca3841ce1141b292898b2dd2d3f08286affadf823a7e58df93" + ) + volumeNode = loadedNodes[0] + self.ui.inputVolumeNodeComboBox.setCurrentNode(volumeNode) + # Sending url to the server + info = self.logic.uploadScan(scan_url) + self.info = info + print(f"Response from the upload text call: {self.info['status']}") + + finally: + self.progressWindow.close() + + def onUploadImage(self): + '''Gets the text from scanComboBox or the volume from the viewport and sends it to the server.''' try: qt.QApplication.setOverrideCursor(qt.Qt.WaitCursor) + + volumeNode = self.ui.inputVolumeNodeComboBox.currentNode() + image_id = volumeNode.GetName() in_file = tempfile.NamedTemporaryFile(suffix=self.file_ext, dir=self.tmpdir).name self.current_sample = in_file self.reportProgress(5) @@ -279,8 +404,10 @@ def onUploadImage(self): msg = f"Message: {e.msg}" if hasattr(e, "msg") else "" self.reportProgress(100) qt.QApplication.restoreOverrideCursor() + self.show_popup("Error", f"Upload failed: {msg}") return False + def reportProgress(self, progressPercentage): '''Reports progress of an event.''' if not self.progressBar: @@ -293,7 +420,14 @@ def reportProgress(self, progressPercentage): def has_text(self, ui_text): '''Check if the given UI text element has any content.''' return len(ui_text.toPlainText()) < 1 + + def _get_answer_text(self, response): + '''Get the VILA-M3 response.''' + final_response = response['choices'][0]['message']['content']['content'][0]['text'] + + return final_response + def onClickSendPrompt(self): '''Handle the click event for sending a prompt to the server.''' if not self.logic: @@ -313,7 +447,10 @@ def onClickSendPrompt(self): info = self.logic.getAnswer(inputText=inText, volumePath=self.current_sample) if info is not None: self.info = info - self.ui.outputText.setText(info['choices'][0]['message']['content']) + if 'VILA-M3' in self.current_model: + self.ui.outputText.setText(self._get_answer_text(info)) + else: + self.ui.outputText.setText(info['choices'][0]['message']['content']) logging.info(f"Time consumed by fetch info: {time.time() - start:3.1f}") diff --git a/plugins/RadCoPilot_Slicer/RadCoPilotLib/client.py b/plugins/RadCoPilot_Slicer/RadCoPilotLib/client.py index ab4fc10..fda72a8 100644 --- a/plugins/RadCoPilot_Slicer/RadCoPilotLib/client.py +++ b/plugins/RadCoPilot_Slicer/RadCoPilotLib/client.py @@ -77,15 +77,15 @@ def uploadFile(self, volumePath): """ Upload a file to the RadCoPilot server using the fileUploadRouter. - This method sends a file to the '/upload' endpoint of the RadCoPilot server, + This method sends either a local file or a URL to the '/upload' endpoint of the RadCoPilot server, which stores it as the last received file. This uploaded file can then be used in subsequent requests to the chat_completions API if no file is provided in those requests. Parameters: ----------- - filePath : str - The path to the file that should be uploaded to the server. + volumePath : str + The path to the local file or a URL that should be uploaded to the server. Returns: -------- @@ -103,13 +103,21 @@ def uploadFile(self, volumePath): print("Uploading file...") selector = "/upload/" - url = f"{self._server_url}{selector}" + # Check if volumePath is a URL + parsed_url = urlparse(volumePath) + is_url = bool(parsed_url.scheme and parsed_url.netloc) - with open(volumePath, 'rb') as file: - files = {"file": (os.path.basename(volumePath), file, "application/octet-stream")} + if is_url: + # If it's a URL, send it as a form-data field + files = {"file": (None, str(volumePath))} response = requests.post(url=url, files=files) + else: + # If it's a local file, open and send it + with open(volumePath, 'rb') as file: + files = {"file": (os.path.basename(volumePath), file, "application/octet-stream")} + response = requests.post(url=url, files=files) if response.status_code == 200: return response.json() diff --git a/plugins/RadCoPilot_Slicer/Resources/UI/RadCoPilot.ui b/plugins/RadCoPilot_Slicer/Resources/UI/RadCoPilot.ui index 5987827..86884af 100644 --- a/plugins/RadCoPilot_Slicer/Resources/UI/RadCoPilot.ui +++ b/plugins/RadCoPilot_Slicer/Resources/UI/RadCoPilot.ui @@ -30,116 +30,183 @@ 5 - - - - - - 0 - 0 - - - - true - - - - - - - true - - - Fetch/Refresh models from Server - - - - - - - - - - RadCoPilot Server - - - - - - - Model Name: - - - - - - - Qt::LeftToRight - - - QComboBox::AdjustToContents - - - - - - - - - - - Submit scan - - - - - - - - - Input Prompt - - - - - - - - 0 - 0 - + + + Setup + + + + + RadCoPilot Server: + + + + + + + + + + 0 + 0 + + + + true + + + + + + + + + + + + + + + + Model Name: + + + + + + + + + + + + + + Input volume: + + + + + + + + vtkMRMLScalarVolumeNode + + + + + + + true + + + false + + + false + + + + + + (upload from URL) + + + + + + + + 0 + 0 + + + + URL: + + + Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + + + 10 + + + + + + + + + + + + Fetch volume + + + + + + + Submit volume + + + + + + - - - Send Prompt + + + Input Prompt + + + + + + 0 + 0 + + + + + + + + Send Prompt + + + + - - + + Generated Text - - - - - - - 0 - 0 - - - - - - - - Clean - + + + + + + 0 + 0 + + + + + + + + Clean + + + + @@ -158,11 +225,22 @@ + + ctkCollapsibleGroupBox + QGroupBox +
ctkCollapsibleGroupBox.h
+ 1 +
ctkPushButton QPushButton
ctkPushButton.h
+ + qMRMLNodeComboBox + QWidget +
qMRMLNodeComboBox.h
+
qMRMLWidget QWidget @@ -171,5 +249,70 @@
- + + + inputVolumeNodeComboBox + currentNodeChanged(bool) + scanComboBox + setDisabled(bool) + + + 179 + 97 + + + 201 + 121 + + + + + MONAILabel + mrmlSceneChanged(vtkMRMLScene*) + inputVolumeNodeComboBox + setMRMLScene(vtkMRMLScene*) + + + 55 + 837 + + + 157 + 102 + + + + + inputVolumeNodeComboBox + currentNodeChanged(bool) + fetchImageButton + setHidden(bool) + + + 239 + 96 + + + 143 + 192 + + + + + inputVolumeNodeComboBox + currentNodeChanged(bool) + uploadImageButton + setVisible(bool) + + + 289 + 101 + + + 373 + 195 + + + +