diff --git a/changelogs/Spyder-6.md b/changelogs/Spyder-6.md index e6fd6ac3efc..f688cd4e832 100644 --- a/changelogs/Spyder-6.md +++ b/changelogs/Spyder-6.md @@ -6,6 +6,9 @@ * Add `give_focus` kwarg to the `create_client_for_kernel` method of the IPython console plugin. +* Add `early_return` and `return_awaitable` kwargs to `AsyncDispatcher` API. +* Add `register_api` and `get_api` methods to `RemoteClient` plugin in order to get and register new rest API modules for the remote api. +* Add `get_file_api` method to `RemoteClient` to get the `SpyderRemoteFileServicesAPI` rest API module for remote file systems API. ## Version 6.0.3 (2024/12/10) diff --git a/external-deps/spyder-remote-services/.gitrepo b/external-deps/spyder-remote-services/.gitrepo index 57a5e844ea9..717089b3b72 100644 --- a/external-deps/spyder-remote-services/.gitrepo +++ b/external-deps/spyder-remote-services/.gitrepo @@ -6,7 +6,7 @@ [subrepo] remote = https://github.com/spyder-ide/spyder-remote-services branch = main - commit = fc60fbbb1ab95dc6a78c1c805debb54377381bb3 - parent = c9c71251b9b9f18031ffce1dc335a9720fc0b297 + commit = 2b8dcf2aa3da5764136ee842724086aced71cba7 + parent = 9c058adfa8ef3ed13cf31bc30b8400a36c819491 method = merge cmdver = 0.4.9 diff --git a/external-deps/spyder-remote-services/jupyter-config/spyder_remote_services.json b/external-deps/spyder-remote-services/jupyter-config/spyder_remote_services.json index 56b6f94ad25..979cd34229f 100644 --- a/external-deps/spyder-remote-services/jupyter-config/spyder_remote_services.json +++ b/external-deps/spyder-remote-services/jupyter-config/spyder_remote_services.json @@ -1,7 +1,7 @@ { "ServerApp": { "jpserver_extensions": { - "spyder_remote_services": true + "spyder-services": true } } } \ No newline at end of file diff --git a/external-deps/spyder-remote-services/pyproject.toml b/external-deps/spyder-remote-services/pyproject.toml index 28e4933673e..b6185b334a9 100644 --- a/external-deps/spyder-remote-services/pyproject.toml +++ b/external-deps/spyder-remote-services/pyproject.toml @@ -10,6 +10,7 @@ dependencies = [ "jupyter_server >=2.14.2,<3.0", "jupyter_client >=8.6.2,<9.0", "envs-manager <1.0.0", + "orjson >=3.10.12,<4.0", ] [tool.setuptools.dynamic] diff --git a/external-deps/spyder-remote-services/spyder_remote_services/app.py b/external-deps/spyder-remote-services/spyder_remote_services/app.py index 5e3cbbc673b..3c30d056bcd 100644 --- a/external-deps/spyder-remote-services/spyder_remote_services/app.py +++ b/external-deps/spyder-remote-services/spyder_remote_services/app.py @@ -45,8 +45,20 @@ def _port_default(self): def info_file(self): return str(Path(self.runtime_dir) / self.spyder_server_info_file) - def write_server_info_file(self) -> None: - if os.path.exists(self.info_file): + def write_server_info_file(self, *, __pid_check=True) -> None: + info_file = Path(self.info_file) + if info_file.exists(): + if __pid_check: + with info_file.open(mode="rb") as f: + info = json.load(f) + + # Simple check whether that process is really still running + if ("pid" in info) and not check_pid(info["pid"]): + # If the process has died, try to delete its info file + with suppress(OSError): + info_file.unlink() + self.write_server_info_file(__pid_check=False) + raise FileExistsError( f"Server info file {self.info_file} already exists." "Muliple servers are not supported, please make sure" @@ -64,12 +76,12 @@ def start(self): # The runtime dir might not exist if not runtime_dir.is_dir(): - return None + return conf_file = runtime_dir / SpyderServerApp.spyder_server_info_file if not conf_file.exists(): - return None + return with conf_file.open(mode="rb") as f: info = json.load(f) @@ -88,7 +100,7 @@ class SpyderRemoteServices(ExtensionApp): """A simple jupyter server application.""" # The name of the extension. - name = "spyder_remote_services" + name = "spyder-services" open_browser = False @@ -100,7 +112,7 @@ class SpyderRemoteServices(ExtensionApp): def initialize_handlers(self): """Initialize handlers.""" - self.handlers.extend(handlers) + self.handlers.extend([(rf"/{self.name}{h[0]}", h[1]) for h in handlers]) def initialize(self): super().initialize() diff --git a/external-deps/spyder-remote-services/spyder_remote_services/services/__init__.py b/external-deps/spyder-remote-services/spyder_remote_services/services/__init__.py index bb79e4b06bc..9ce0e6fb673 100644 --- a/external-deps/spyder-remote-services/spyder_remote_services/services/__init__.py +++ b/external-deps/spyder-remote-services/spyder_remote_services/services/__init__.py @@ -1,3 +1,4 @@ from spyder_remote_services.services.envs_manager.handlers import handlers as envs_manager_handlers +from spyder_remote_services.services.files.handlers import handlers as files_handlers -handlers = envs_manager_handlers +handlers = envs_manager_handlers + files_handlers diff --git a/external-deps/spyder-remote-services/spyder_remote_services/services/files/__init__.py b/external-deps/spyder-remote-services/spyder_remote_services/services/files/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/external-deps/spyder-remote-services/spyder_remote_services/services/files/base.py b/external-deps/spyder-remote-services/spyder_remote_services/services/files/base.py new file mode 100644 index 00000000000..364431c12ae --- /dev/null +++ b/external-deps/spyder-remote-services/spyder_remote_services/services/files/base.py @@ -0,0 +1,435 @@ +from __future__ import annotations +import asyncio +import base64 +import datetime +import errno +from http import HTTPStatus +from io import FileIO +import os +from pathlib import Path +from shutil import copy, copy2 +import stat +import threading +import time +import traceback + +import orjson +from tornado.websocket import WebSocketHandler + + +class FileWebSocketHandler(WebSocketHandler): + """ + WebSocket handler for opening files and streaming data. + + The protocol on message receive (JSON messages): + { + "method": "read", # "write", "seek", etc. (required) + "kwargs": {...}, (optional) + "data": "", # all data is base64-encoded (optional) + } + + The protocol for sending data back to the client: + { + "status": 200, # HTTP status code (required) + "data": "", # response data if any (optional) + "error": {"message": "error message", (required) + "traceback": ["line1", "line2", ...] (optional)} # if an error occurred (optional) + } + """ + + LOCK_TIMEOUT = 100 # seconds + + max_message_size = 5 * 1024 * 1024 * 1024 # 5 GB + + __thread_lock = threading.Lock() + + # ---------------------------------------------------------------- + # Tornado WebSocket / Handler Hooks + # ---------------------------------------------------------------- + async def open(self, path): + """Open file.""" + self.mode = self.get_argument("mode", default="r") + self.atomic = self.get_argument("atomic", default="false") == "true" + lock = self.get_argument("lock", default="false") == "true" + self.encoding = self.get_argument("encoding", default="utf-8") + + self.file: FileIO = None + try: + self.path = self._load_path(path) + + if lock and not await self._acquire_lock(path): + self.close( + 1002, + self._parse_json( + HTTPStatus.LOCKED, message="File is locked" + ), + ) + return + + self.file = await self._open_file() + except OSError as e: + self.log.warning("Error opening file", exc_info=e) + self.close(1002, self._parse_os_error(e)) + except Exception as e: + self.log.exception("Error opening file") + self.close(1002, self._parse_error(e)) + else: + await self._send_json(HTTPStatus.OK) + + def on_close(self): + """Close file.""" + if self.file is not None: + self._close_file() + if self.__locked: + self._release_lock() + + async def on_message(self, raw_message): + """Handle incoming messages.""" + self.log.debug("Received message: %s", raw_message) + try: + await self.handle_message(raw_message) + except Exception as e: + self.log.exception("Error handling message") + await self.write_message(self._parse_error(e), binary=True) + + # ---------------------------------------------------------------- + # Internal Helpers + # ---------------------------------------------------------------- + async def handle_message(self, raw_message): + msg = self._decode_json(raw_message) + method, kwargs = await self._parse_message(msg) + await self._run_method(method, kwargs) + + async def _open_file(self): + """Open the file in the requested mode.""" + if self.atomic and ("+" in self.mode or + "a" in self.mode or + "w" in self.mode): + if self.path.exists() and "w" not in self.mode: + copy2(self.path, self.atomic_path) + return self.atomic_path.open(self.mode) + + return self.path.open(self.mode) + + def _close_file(self): + self.file.close() + if self.atomic: + self.atomic_path.replace(self.path) + + async def _run_method(self, method, kwargs): + """Run a method with kwargs.""" + try: + result = await getattr(self, f"_handle_{method}")(**kwargs) + except OSError as e: + self.log.warning("Error handling method: %s", method) + await self.write_message(self._parse_os_error(e), binary=True) + else: + await self._send_result(result) + + async def _parse_message(self, msg): + """Parse a message into method and kwargs.""" + method = msg.pop("method", None) + + if "data" in msg and isinstance(msg["data"], list): + msg["data"] = [self._decode_data(d) for d in msg["data"]] + elif "data" in msg: + msg["data"] = self._decode_data(msg["data"]) + + return method, msg + + async def _acquire_lock(self, __start_time=None): + """Acquire a lock on the file.""" + if __start_time is None: + __start_time = time.time() + + while self.__locked: + await asyncio.sleep(1) + if time.time() - __start_time > self.LOCK_TIMEOUT: + return False + + with self.__thread_lock: + if self.__locked: + return await self._acquire_lock(__start_time=__start_time) + self.lock_path.touch(exist_ok=False) + + return True + + def _release_lock(self): + """Release the lock on the file.""" + with self.__thread_lock: + self.lock_path.unlink(missing_ok=True) + + @property + def atomic_path(self): + """Get the path to the atomic file.""" + return self.path.parent / f".{self.path.name}.spyder.tmp" + + @property + def lock_path(self): + """Get the path to the atomic file.""" + return self.path.parent / f".{self.path.name}.spyder.lck" + + @property + def __locked(self): + return Path(self.lock_path).exists() + + def _decode_json(self, raw_message): + """Decode a JSON message (non-streamed).""" + return orjson.loads(raw_message) + + async def _send_json(self, status: HTTPStatus, **data: dict): + """Send a single JSON message.""" + await self.write_message(self._parse_json(status, **data), binary=True) + + def _parse_json(self, status: HTTPStatus, **data: dict) -> bytes: + """Parse a single JSON message.""" + return orjson.dumps({"status": status.value, **data}) + + def _parse_error(self, error: BaseException) -> bytes: + """Parse an error response to the client.""" + return self._parse_json( + HTTPStatus.INTERNAL_SERVER_ERROR, + message=str(error), + tracebacks=traceback.format_exception( + type(error), error, error.__traceback__ + ), + type=str(type(error)), + ) + + def _parse_os_error(self, e: OSError) -> bytes: + """Parse an OSError response to the client.""" + return self._parse_json( + HTTPStatus.EXPECTATION_FAILED, + strerror=e.strerror, + filename=e.filename, + errno=e.errno, + ) + + async def _send_msg_error(self, message): + await self._send_json( + HTTPStatus.BAD_REQUEST, message=message, + ) + + async def _send_result(self, result): + if result is None: + await self._send_json(HTTPStatus.NO_CONTENT) + elif isinstance(result, list): + await self._send_json( + HTTPStatus.OK, data=[self._encode_data(r) for r in result], + ) + else: + await self._send_json( + HTTPStatus.OK, data=self._encode_data(result), + ) + + def _decode_data(self, data: str | object) -> str | bytes | object: + """Decode data from a message.""" + if not isinstance(data, str): + return data + + if "b" in self.mode: + return base64.b64decode(data) + + return base64.b64decode(data).decode(self.encoding) + + def _encode_data(self, data: bytes | str | object) -> str: + """Encode data for a message.""" + if isinstance(data, bytes): + return base64.b64encode(data).decode("ascii") + if isinstance(data, str): + return base64.b64encode(data.encode(self.encoding)).decode("ascii") + return data + + def _load_path(self, path_str: str) -> Path: + """Convert path string to a Path object.""" + return Path(path_str).expanduser() + + # ---------------------------------------------------------------- + # File Operation + # ---------------------------------------------------------------- + async def _handle_write(self, data: bytes | str) -> int: + """Write data to the file.""" + return self.file.write(data) + + async def _handle_flush(self): + """Flush the file.""" + return self.file.flush() + + async def _handle_read(self, n: int = -1) -> bytes | str: + """Read data from the file.""" + return self.file.read(n) + + async def _handle_seek(self, offset: int, whence: int = 0) -> int: + """Seek to a new position in the file.""" + return self.file.seek(offset, whence) + + async def _handle_tell(self) -> int: + """Get the current file position.""" + return self.file.tell() + + async def _handle_truncate(self, size: int | None = None) -> int: + """Truncate the file to a new size.""" + return self.file.truncate(size) + + async def _handle_fileno(self): + """Flush the file to disk.""" + return self.file.fileno() + + async def _handle_readline(self, size: int = -1) -> bytes | str: + """Read a line from the file.""" + return self.file.readline(size) + + async def _handle_readlines(self, hint: int = -1) -> list[bytes | str]: + """Read lines from the file.""" + return self.file.readlines(hint) + + async def _handle_writelines(self, lines: list[bytes | str]): + """Write lines to the file.""" + return self.file.writelines(lines) + + async def _handle_isatty(self) -> bool: + """Check if the file is a TTY.""" + return self.file.isatty() + + async def _handle_readable(self) -> bool: + """Check if the file is readable.""" + return self.file.readable() + + async def _handle_writable(self) -> bool: + """Check if the file is writable.""" + return self.file.writable() + + +class FilesRESTMixin: + """ + REST handler for fsspec-like filesystem operations, using pathlib.Path. + + Supports: + - fs_ls(path_str, detail=True) + - fs_info(path_str) + - fs_exists(path_str) + - fs_isfile(path_str) + - fs_isdir(path_str) + - fs_mkdir(path_str, create_parents=True, exist_ok=False) + - fs_rmdir(path_str) + - fs_rm_file(path_str, missing_ok=False) + - fs_touch(path_str, truncate=True) + """ + + def _info_for_path(self, path: Path) -> dict: + """Get fsspec-like info about a single path.""" + out = path.stat(follow_symlinks=False) + link = stat.S_ISLNK(out.st_mode) + if link: + # If it's a link, stat the target + out = path.stat(follow_symlinks=True) + size = out.st_size + if stat.S_ISDIR(out.st_mode): + t = "directory" + elif stat.S_ISREG(out.st_mode): + t = "file" + else: + t = "other" + result = { + "name": str(path), + "size": size, + "type": t, + "created": out.st_ctime, + "islink": link, + } + for field in ["mode", "uid", "gid", "mtime", "ino", "nlink"]: + result[field] = getattr(out, f"st_{field}", None) + if link: + result["destination"] = str(path.resolve()) + + return result + + def _load_path(self, path_str: str) -> Path | None: + """Convert a path string to a pathlib.Path object.""" + return Path(path_str).expanduser() + + def fs_ls(self, path_str: str, detail: bool = True): + """List objects at path, like fsspec.ls().""" + path = self._load_path(path_str) + if not path.exists(): + raise FileNotFoundError(errno.ENOENT, + os.strerror(errno.ENOENT), + str(path)) + if path.is_file(): + # fsspec.ls of a file often returns a single entry + if detail: + return [self._info_for_path(path)] + + return [str(path)] + + # Otherwise, it's a directory + results = [] + for p in path.glob("*"): + if detail: + results.append(self._info_for_path(p)) + else: + results.append(str(p)) + return results + + def fs_info(self, path_str: str): + """Get info about a single path, like fsspec.info().""" + path = self._load_path(path_str) + return self._info_for_path(path) + + def fs_exists(self, path_str: str) -> bool: + """Like fsspec.exists().""" + path = self._load_path(path_str) + return path.exists() + + def fs_isfile(self, path_str: str) -> bool: + """Like fsspec.isfile().""" + path = self._load_path(path_str) + return path.is_file() + + def fs_isdir(self, path_str: str) -> bool: + """Like fsspec.isdir().""" + path = self._load_path(path_str) + return path.is_dir() + + def fs_mkdir(self, path_str: str, create_parents: bool = True, exist_ok: bool = False): + """Like fsspec.mkdir().""" + path = self._load_path(path_str) + path.mkdir(parents=create_parents, exist_ok=exist_ok) + return {"success": True} + + def fs_rmdir(self, path_str: str): + """Like fsspec.rmdir() - remove if empty.""" + path = self._load_path(path_str) + path.rmdir() + return {"success": True} + + def fs_rm_file(self, path_str: str, missing_ok: bool = False): + """Like fsspec.rm_file(), remove a single file.""" + path = self._load_path(path_str) + path.unlink(missing_ok=missing_ok) + return {"success": True} + + def fs_touch(self, path_str: str, truncate: bool = True): + """ + Like fsspec.touch(path, truncate=True). + If truncate=True, zero out file if exists. Otherwise just update mtime. + """ + path = self._load_path(path_str) + if path.exists() and not truncate: + now = datetime.datetime.now().timestamp() + os.utime(path, (now, now)) + else: + # create or overwrite + with path.open("wb"): + pass + return {"success": True} + + def fs_copy(self, src_str: str, dst_str: str, metadata: bool=False): + """Like fsspec.copy().""" + src = self._load_path(src_str) + dst = self._load_path(dst_str) + if metadata: + copy2(src, dst) + else: + copy(src, dst) + return {"success": True} diff --git a/external-deps/spyder-remote-services/spyder_remote_services/services/files/handlers.py b/external-deps/spyder-remote-services/spyder_remote_services/services/files/handlers.py new file mode 100644 index 00000000000..9cb3a687b02 --- /dev/null +++ b/external-deps/spyder-remote-services/spyder_remote_services/services/files/handlers.py @@ -0,0 +1,190 @@ +from __future__ import annotations +from http import HTTPStatus +from http.client import responses +import re +from typing import Any +import traceback + +from jupyter_server.auth.decorator import authorized, ws_authenticated +from jupyter_server.base.handlers import JupyterHandler +from jupyter_server.base.websocket import WebSocketMixin +import orjson +from tornado import web + +from spyder_remote_services.services.files.base import ( + FileWebSocketHandler, + FilesRESTMixin, +) + + +class ReadWriteWebsocketHandler( + WebSocketMixin, + FileWebSocketHandler, + JupyterHandler, +): + auth_resource = "spyder-services" + + @ws_authenticated + async def get(self, *args, **kwargs): + """Handle the initial websocket upgrade GET request.""" + await super().get(*args, **kwargs) + + +class BaseFSHandler(FilesRESTMixin, JupyterHandler): + auth_resource = "spyder-services" + + def write_json(self, data, status=200): + self.set_status(status) + self.set_header("Content-Type", "application/json") + self.finish(orjson.dumps(data)) + + def write_error(self, status_code, **kwargs): + """APIHandler errors are JSON, not human pages.""" + self.set_header("Content-Type", "application/json") + reply: dict[str, Any] = {} + exc_info = kwargs.get("exc_info") + if exc_info: + e = exc_info[1] + if isinstance(e, web.HTTPError): + reply["message"] = e.log_message or responses.get(status_code, "Unknown HTTP Error") + reply["reason"] = e.reason + elif isinstance(e, OSError): + self.set_status(HTTPStatus.EXPECTATION_FAILED) + reply["strerror"] = e.strerror + reply["errno"] = e.errno + reply["filename"] = e.filename + else: + self.set_status(HTTPStatus.INTERNAL_SERVER_ERROR) + reply["type"] = str(type(e)) + reply["message"] = str(e) + reply["traceback"] = traceback.format_exception(*exc_info) + else: + reply["message"] = responses.get(status_code, "Unknown HTTP Error") + self.finish(orjson.dumps(reply)) + + def log_exception(self, typ, value, tb): + """Log uncaught exceptions.""" + if isinstance(value, web.HTTPError): + if value.log_message: + format = "%d %s: " + value.log_message + args = [value.status_code, self._request_summary()] + list(value.args) + self.log.warning(format, *args) + elif isinstance(value, OSError): + self.log.debug( + "OSError [Errno %s] %s", + value.errno, + self._request_summary(), + exc_info=(typ, value, tb), # type: ignore + ) + else: + self.log.warning( + "Uncaught exception %s\n%r", + self._request_summary(), + self.request, + exc_info=(typ, value, tb), # type: ignore + ) + + +class LsHandler(BaseFSHandler): + @web.authenticated + @authorized + def get(self, path): + detail_arg = self.get_argument("detail", default="true").lower() + detail = detail_arg == "true" + result = self.fs_ls(path, detail=detail) + self.write_json(result) + + +class InfoHandler(BaseFSHandler): + @web.authenticated + @authorized + def get(self, path): + result = self.fs_info(path) + self.write_json(result) + + +class ExistsHandler(BaseFSHandler): + @web.authenticated + @authorized + def get(self, path): + result = self.fs_exists(path) + self.write_json({"exists": result}) + + +class IsFileHandler(BaseFSHandler): + @web.authenticated + @authorized + def get(self, path): + result = self.fs_isfile(path) + self.write_json({"isfile": result}) + + +class IsDirHandler(BaseFSHandler): + @web.authenticated + @authorized + def get(self, path): + result = self.fs_isdir(path) + self.write_json({"isdir": result}) + + +class MkdirHandler(BaseFSHandler): + @web.authenticated + @authorized + def post(self, path): + create_parents = (self.get_argument("create_parents", "true").lower() == "true") + exist_ok = (self.get_argument("exist_ok", "false").lower() == "true") + result = self.fs_mkdir(path, create_parents=create_parents, exist_ok=exist_ok) + self.write_json(result) + + +class RmdirHandler(BaseFSHandler): + @web.authenticated + @authorized + def delete(self, path): + result = self.fs_rmdir(path) + self.write_json(result) + + +class RemoveFileHandler(BaseFSHandler): + @web.authenticated + @authorized + def delete(self, path): + missing_ok = (self.get_argument("missing_ok", "false").lower() == "true") + result = self.fs_rm_file(path, missing_ok=missing_ok) + self.write_json(result) + + +class TouchHandler(BaseFSHandler): + @web.authenticated + @authorized + def post(self, path): + truncate = (self.get_argument("truncate", "true").lower() == "true") + result = self.fs_touch(path, truncate=truncate) + self.write_json(result) + + +class CopyHandler(BaseFSHandler): + @web.authenticated + @authorized + def post(self, path): + dest = re.match(_path_regex, self.get_argument("dest")).group("path") + metadata = (self.get_argument("metadata", "false").lower() == "true") + result = self.fs_copy(path, dest, metadata=metadata) + self.write_json(result) + + +_path_regex = r"file://(?P.+)" + +handlers = [ + (rf"/fs/open/{_path_regex}", ReadWriteWebsocketHandler), # WebSocket + (rf"/fs/ls/{_path_regex}", LsHandler), # GET + (rf"/fs/info/{_path_regex}", InfoHandler), # GET + (rf"/fs/exists/{_path_regex}", ExistsHandler), # GET + (rf"/fs/isfile/{_path_regex}", IsFileHandler), # GET + (rf"/fs/isdir/{_path_regex}", IsDirHandler), # GET + (rf"/fs/mkdir/{_path_regex}", MkdirHandler), # POST + (rf"/fs/rmdir/{_path_regex}", RmdirHandler), # DELETE + (rf"/fs/file/{_path_regex}", RemoveFileHandler), # DELETE + (rf"/fs/touch/{_path_regex}", TouchHandler), # POST + (rf"/fs/copy/{_path_regex}", CopyHandler), # POST +] diff --git a/spyder/api/asyncdispatcher.py b/spyder/api/asyncdispatcher.py index b874927c561..5a895d2dffd 100644 --- a/spyder/api/asyncdispatcher.py +++ b/spyder/api/asyncdispatcher.py @@ -73,7 +73,8 @@ def __init__(self, typing.Any, typing.Any, typing.Any]], *, loop: LoopID | None = None, - early_return: bool = True): + early_return: bool = True, + return_awaitable: bool = False): """Initialize the decorator. Parameters @@ -82,6 +83,11 @@ def __init__(self, The coroutine to be wrapped. loop : asyncio.AbstractEventLoop, optional The event loop to be used, by default get the current event loop. + early_return : bool, optional + Return the coroutine as a Future object before it is done + or wait for it to finish and return the result. + return_awaitable : bool, optional + Return the coroutine as an awaitable object instead of a Future. """ if not asyncio.iscoroutinefunction(async_func): msg = f"{async_func} is not a coroutine function" @@ -89,11 +95,15 @@ def __init__(self, self._async_func = async_func self._loop = self._ensure_running_loop(loop) self._early_return = early_return + self._return_awaitable = return_awaitable def __call__(self, *args, **kwargs): task = asyncio.run_coroutine_threadsafe( self._async_func(*args, **kwargs), loop=self._loop ) + if self._return_awaitable: + return asyncio.wrap_future(task, loop=asyncio.get_running_loop()) + if self._early_return: AsyncDispatcher._running_tasks.append(task) task.add_done_callback(self._callback_task_done) @@ -104,16 +114,20 @@ def __call__(self, *args, **kwargs): def dispatch(cls, *, loop: LoopID | None = None, - early_return: bool = True): + early_return: bool = True, + return_awaitable: bool = False): """Create a decorator to run the coroutine with a given event loop.""" def decorator( async_func: typing.Callable[P, typing.Coroutine[typing.Any, typing.Any, T]] - ) -> typing.Callable[P, Future[T] | T]: + ) -> typing.Callable[P, asyncio.Future[T] | Future[T] | T]: @functools.wraps(async_func) def wrapper(*args, **kwargs): - return cls(async_func, loop=loop, early_return=early_return)( + return cls(async_func, + loop=loop, + early_return=early_return, + return_awaitable=return_awaitable)( *args, **kwargs ) diff --git a/spyder/api/utils.py b/spyder/api/utils.py index c62c562802b..678d8470520 100644 --- a/spyder/api/utils.py +++ b/spyder/api/utils.py @@ -8,6 +8,7 @@ """ API utilities. """ +from abc import ABCMeta as BaseABCMeta def get_class_values(cls): @@ -64,3 +65,73 @@ class classproperty(property): def __get__(self, cls, owner): return classmethod(self.fget).__get__(None, owner)() + + +class DummyAttribute: + """ + Dummy class to mark abstract attributes. + """ + pass + + +def abstract_attribute(obj=None): + """ + Decorator to mark abstract attributes. Must be used in conjunction with the + ABCMeta metaclass. + """ + if obj is None: + obj = DummyAttribute() + obj.__is_abstract_attribute__ = True + return obj + + +class ABCMeta(BaseABCMeta): + """ + Metaclass to mark abstract classes. + + Adds support for abstract attributes. If a class has abstract attributes + and is instantiated, a NotImplementedError is raised. + + Usage + ----- + class MyABC(metaclass=ABCMeta): + @abstract_attribute + def my_abstract_attribute(self): + pass + + class MyClassOK(MyABC): + def __init__(self): + self.my_abstract_attribute = 1 + + class MyClassNotOK(MyABC): + pass + + Raises + ------ + NotImplementedError + When it's not possible to instantiate an abstract class with abstract + attributes. + """ + + def __call__(cls, *args, **kwargs): + # Collect all abstract-attribute names from the entire MRO + abstract_attr_names = set() + for base in cls.__mro__: + for name, value in base.__dict__.items(): + if getattr(value, '__is_abstract_attribute__', False): + abstract_attr_names.add(name) + + for name, value in cls.__dict__.items(): + if not getattr(value, '__is_abstract_attribute__', False): + abstract_attr_names.discard(name) + + if abstract_attr_names: + raise NotImplementedError( + "Can't instantiate abstract class " + "{} with abstract attributes: {}".format( + cls.__name__, + ", ".join(abstract_attr_names) + ) + ) + + return super().__call__(*args, **kwargs) diff --git a/spyder/plugins/remoteclient/api/__init__.py b/spyder/plugins/remoteclient/api/__init__.py index a290d92b1c3..a702cbd9917 100644 --- a/spyder/plugins/remoteclient/api/__init__.py +++ b/spyder/plugins/remoteclient/api/__init__.py @@ -11,6 +11,7 @@ Remote Client Plugin API. """ +from spyder.plugins.remoteclient.api.manager import SpyderRemoteAPIManager # noqa # ---- Constants # ----------------------------------------------------------------------------- diff --git a/spyder/plugins/remoteclient/api/jupyterhub/auth.py b/spyder/plugins/remoteclient/api/jupyterhub/auth.py deleted file mode 100644 index a2f558930f8..00000000000 --- a/spyder/plugins/remoteclient/api/jupyterhub/auth.py +++ /dev/null @@ -1,58 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Copyright © Spyder Project Contributors -# Licensed under the terms of the MIT License -# (see spyder/__init__.py for details) - -import re - -import aiohttp -import yarl - - -async def token_authentication(api_token, verify_ssl=True): - return aiohttp.ClientSession( - headers={"Authorization": f"token {api_token}"}, - connector=aiohttp.TCPConnector(ssl=None if verify_ssl else False), - ) - - -async def basic_authentication(hub_url, username, password, verify_ssl=True): - session = aiohttp.ClientSession( - headers={"Referer": str(yarl.URL(hub_url) / "hub" / "api")}, - connector=aiohttp.TCPConnector(ssl=None if verify_ssl else False), - ) - - await session.post( - yarl.URL(hub_url) / "hub" / "login", - data={ - "username": username, - "password": password, - }, - ) - - return session - - -async def keycloak_authentication( - hub_url, username, password, verify_ssl=True -): - session = aiohttp.ClientSession( - headers={"Referer": str(yarl.URL(hub_url) / "hub" / "api")}, - connector=aiohttp.TCPConnector(ssl=None if verify_ssl else False), - ) - - response = await session.get(yarl.URL(hub_url) / "hub" / "oauth_login") - content = await response.content.read() - auth_url = re.search('action="([^"]+)"', content.decode("utf8")).group(1) - - response = await session.post( - auth_url.replace("&", "&"), - headers={"Content-Type": "application/x-www-form-urlencoded"}, - data={ - "username": username, - "password": password, - "credentialId": "", - }, - ) - return session diff --git a/spyder/plugins/remoteclient/api/jupyterhub/execute.py b/spyder/plugins/remoteclient/api/jupyterhub/execute.py deleted file mode 100644 index 360bb9fcd8a..00000000000 --- a/spyder/plugins/remoteclient/api/jupyterhub/execute.py +++ /dev/null @@ -1,175 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Copyright © Spyder Project Contributors -# Licensed under the terms of the MIT License -# (see spyder/__init__.py for details) - -import uuid -import difflib -import logging -import textwrap - -from spyder.plugins.remoteclient.api.jupyterhub import JupyterHubAPI -from spyder.plugins.remoteclient.api.jupyterhub.utils import ( - parse_notebook_cells, -) - -logger = logging.getLogger(__name__) - - -DAEMONIZED_STOP_SERVER_HEADER = """ -def _client_stop_server(): - import urllib.request - request = urllib.request.Request(url="{delete_server_endpoint}", method= "DELETE") - request.add_header("Authorization", "token {api_token}") - urllib.request.urlopen(request) - -def custom_exc(shell, etype, evalue, tb, tb_offset=None): - _jupyerhub_client_stop_server() - -get_ipython().set_custom_exc((Exception,), custom_exc) -""" - - -async def determine_username( - hub, - username=None, - user_format="user-{user}-{id}", - service_format="service-{name}-{id}", - temporary_user=False, -): - token = await hub.identify_token(hub.api_token) - - if username is None and not temporary_user: - if token["kind"] == "service": - logger.error( - "cannot execute without specified username or " - "temporary_user=True for service api token" - ) - raise ValueError( - "Service api token cannot execute without specified username " - "or temporary_user=True for" - ) - return token["name"] - elif username is None and temporary_user: - if token["kind"] == "service": - return service_format.format( - id=str(uuid.uuid4()), name=token["name"] - ) - else: - return user_format.format(id=str(uuid.uuid4()), name=token["name"]) - else: - return username - - -async def execute_code( - hub_url, - cells, - username=None, - temporary_user=False, - create_user=False, - delete_user=False, - server_creation_timeout=60, - server_deletion_timeout=60, - kernel_execution_timeout=60, - daemonized=False, - validate=False, - stop_server=True, - user_options=None, - kernel_spec=None, - auth_type="token", - verify_ssl=True, -): - hub = JupyterHubAPI(hub_url, auth_type=auth_type, verify_ssl=verify_ssl) - result_cells = [] - - async with hub: - username = await determine_username( - hub, username, temporary_user=temporary_user - ) - try: - jupyter = await hub.ensure_server( - username, - create_user=create_user, - user_options=user_options, - timeout=server_creation_timeout, - ) - - async with jupyter: - kernel_id, kernel = await jupyter.ensure_kernel( - kernel_spec=kernel_spec - ) - async with kernel: - if daemonized and stop_server: - await kernel.send_code( - username, - DAEMONIZED_STOP_SERVER_HEADER.format( - delete_server_endpoint=hub.api_url - / "users" - / username - / "server", - api_token=hub.api_token, - ), - wait=False, - ) - - for i, (code, expected_result) in enumerate(cells): - kernel_result = await kernel.send_code( - username, - code, - timeout=kernel_execution_timeout, - wait=(not daemonized), - ) - result_cells.append((code, kernel_result)) - if daemonized: - logger.debug( - f"kernel submitted cell={i} " - f'code=\n{textwrap.indent(code, " >>> ")}' - ) - else: - logger.debug( - f"kernel executing cell={i} " - f'code=\n{textwrap.indent(code, " >>> ")}' - ) - logger.debug( - f"kernel result cell={i} result=\n" - f'{textwrap.indent(kernel_result, " | ")}' - ) - if validate and ( - kernel_result.strip() - != expected_result.strip() - ): - diff = "".join( - difflib.unified_diff( - kernel_result, expected_result - ) - ) - logger.error( - f"kernel result did not match expected " - f"result diff={diff}" - ) - raise ValueError( - f"execution of cell={i} did not match " - f"expected result diff={diff}" - ) - - if daemonized and stop_server: - await kernel.send_code( - username, "__client_stop_server()", wait=False - ) - if not daemonized: - await jupyter.delete_kernel(kernel_id) - if not daemonized and stop_server: - await hub.ensure_server_deleted( - username, timeout=server_deletion_timeout - ) - finally: - if delete_user and not daemonized: - await hub.delete_user(username) - - return result_cells - - -async def execute_notebook(hub_url, notebook_path, **kwargs): - cells = parse_notebook_cells(notebook_path) - return await execute_code(hub_url, cells, **kwargs) diff --git a/spyder/plugins/remoteclient/api/jupyterhub/utils.py b/spyder/plugins/remoteclient/api/jupyterhub/utils.py deleted file mode 100644 index 2125428ff14..00000000000 --- a/spyder/plugins/remoteclient/api/jupyterhub/utils.py +++ /dev/null @@ -1,89 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Copyright © Spyder Project Contributors -# Licensed under the terms of the MIT License -# (see spyder/__init__.py for details) - -import json - - -def parse_notebook_cells(notebook_path): - with open(notebook_path) as f: - notebook_data = json.load(f) - - cells = [] - for cell in notebook_data["cells"]: - if cell["cell_type"] == "code": - source = "".join(cell["source"]) - outputs = [] - for output in cell["outputs"]: - if output["output_type"] == "stream": - outputs.append("".join(output["text"])) - elif output["output_type"] == "execute_result": - outputs.append("".join(output["data"]["text/plain"])) - result = "\n".join(outputs) - cells.append((source, result)) - - return cells - - -def render_notebook(cells): - notebook_template = { - "cells": [], - "nbformat": 4, - "nbformat_minor": 4, - "metadata": {}, - } - - for i, (code, result) in enumerate(cells, start=1): - notebook_template["cells"].append( - { - "cell_type": "code", - "execution_count": i, - "metadata": {}, - "outputs": [ - { - "data": {"text/plain": result}, - "execution_count": i, - "metadata": {}, - "output_type": "execute_result", - } - ], - "source": code, - } - ) - - return notebook_template - - -TEMPLATE_SCRIPT_HEADER = """ -import os -import sys -import logging - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('client') - -OUTPUT_FORMAT = '{output_format}' -STDOUT_FILENAME = os.path.expanduser('{stdout_filename}') -STDERR_FILENAME = os.path.expanduser('{stderr_filename}') - -if OUTPUT_FORMAT == 'file': - logger.info('writting output to files stdout={stdout_filename} and stderr={stderr_filename}') - sys.stdout = open(STDOUT_FILENAME, 'w') - sys.stderr = open(STDERR_FILENAME, 'w') - -""" - - -def tangle_cells( - cells, output_format="file", stdout_filename=None, stderr_filename=None -): - # TODO: eventually support writing output to notebook - - tangled_code = [] - for i, (code, expected_result) in enumerate(cells): - tangled_code.append('logger.info("beginning execution cell={i}")') - tangled_code.append(code) - tangled_code.append('logger.info("completed execution cell={i}")') - return TEMPLATE_SCRIPT_HEADER + "\n".join(tangled_code) diff --git a/spyder/plugins/remoteclient/api/client.py b/spyder/plugins/remoteclient/api/manager.py similarity index 82% rename from spyder/plugins/remoteclient/api/client.py rename to spyder/plugins/remoteclient/api/manager.py index f52ed3b907a..f42459e056f 100644 --- a/spyder/plugins/remoteclient/api/client.py +++ b/spyder/plugins/remoteclient/api/manager.py @@ -4,11 +4,20 @@ # Licensed under the terms of the MIT License # (see spyder/__init__.py for details) +""" +spyder.plugins.remoteclient.api.manager +======================================= + +Remote Client Plugin API Manager. +""" + from __future__ import annotations import asyncio +from functools import partial import json import logging import socket +import typing import asyncssh from packaging.version import Version @@ -19,7 +28,7 @@ SPYDER_REMOTE_MAX_VERSION, SPYDER_REMOTE_MIN_VERSION, ) -from spyder.plugins.remoteclient.api.jupyterhub import JupyterAPI +from spyder.plugins.remoteclient.api.modules.base import JupyterAPI from spyder.plugins.remoteclient.api.protocol import ( ConnectionInfo, ConnectionStatus, @@ -35,16 +44,19 @@ SERVER_ENV, ) +if typing.TYPE_CHECKING: + from spyder.plugins.remoteclient.api.modules.base import ( + SpyderBaseJupyterAPIType, + ) + -class SpyderRemoteClientLoggerHandler(logging.Handler): +class SpyderRemoteAPILoggerHandler(logging.Handler): def __init__(self, client, *args, **kwargs): self._client = client super().__init__(*args, **kwargs) log_format = "%(message)s — %(asctime)s" - formatter = logging.Formatter( - log_format, datefmt="%H:%M:%S %d/%m/%Y" - ) + formatter = logging.Formatter(log_format, datefmt="%H:%M:%S %d/%m/%Y") self.setFormatter(formatter) def emit(self, record): @@ -58,15 +70,23 @@ def emit(self, record): ) -class SpyderRemoteClient: - """Class to manage a remote server and its kernels.""" +class SpyderRemoteAPIManager: + """Class to manage a remote server and its APIs.""" + + REGISTERED_MODULE_APIS: typing.ClassVar[ + dict[str,type[SpyderBaseJupyterAPIType]]] = {} JUPYTER_SERVER_TIMEOUT = 5 # seconds _extra_options = ["platform", "id"] - START_SERVER_COMMAND = f"/${{HOME}}/.local/bin/micromamba run -n {SERVER_ENV} spyder-server" - GET_SERVER_INFO_COMMAND = f"/${{HOME}}/.local/bin/micromamba run -n {SERVER_ENV} spyder-server info" + START_SERVER_COMMAND = ( + f"/${{HOME}}/.local/bin/micromamba run -n {SERVER_ENV} spyder-server" + ) + GET_SERVER_INFO_COMMAND = ( + f"/${{HOME}}/.local/bin/micromamba run" + f" -n {SERVER_ENV} spyder-server info" + ) def __init__(self, conf_id, options: SSHClientOptions, _plugin=None): self._config_id = conf_id @@ -84,17 +104,18 @@ def __init__(self, conf_id, options: SSHClientOptions, _plugin=None): self._remote_server_process: asyncssh.SSHClientProcess = None self._port_forwarder: asyncssh.SSHListener = None self._server_info = {} + self._local_port = None # For logging - self._logger = logging.getLogger( + self.logger = logging.getLogger( f"{__name__}.{self.__class__.__name__}({self.config_id})" ) if not get_debug_level(): - self._logger.setLevel(logging.DEBUG) + self.logger.setLevel(logging.DEBUG) if self._plugin is not None: - self._logger.addHandler(SpyderRemoteClientLoggerHandler(self)) + self.logger.addHandler(SpyderRemoteAPILoggerHandler(self)) def __emit_connection_status(self, status, message): if self._plugin is not None: @@ -159,9 +180,9 @@ def server_url(self): ValueError If the local port is not set. """ - if not self.local_port: + if not self._local_port: raise ValueError("Local port is not set") - return f"http://127.0.0.1:{self.local_port}" + return f"http://127.0.0.1:{self._local_port}" @property def api_token(self): @@ -211,7 +232,7 @@ def _handle_connection_lost(self, exc: Exception | None = None): self.__server_started.clear() self._port_forwarder = None if exc: - self._logger.error( + self.logger.error( f"Connection to {self.peer_host} was lost", exc_info=exc, ) @@ -223,7 +244,7 @@ def _handle_connection_lost(self, exc: Exception | None = None): async def get_server_info(self): """Check if the remote server is running.""" if self._ssh_connection is None: - self._logger.debug("ssh connection was not established") + self.logger.debug("ssh connection was not established") return None try: @@ -231,29 +252,27 @@ async def get_server_info(self): self.GET_SERVER_INFO_COMMAND, check=True ) except asyncssh.TimeoutError: - self._logger.error("Getting server info timed out") + self.logger.error("Getting server info timed out") return None except asyncssh.misc.ChannelOpenError: - self._logger.error( + self.logger.error( "The connection is closed, so it's not possible to get the " "server info" ) return None except asyncssh.ProcessError as err: - self._logger.debug(f"Error getting server info: {err.stderr}") + self.logger.debug(f"Error getting server info: {err.stderr}") return None try: info = json.loads(output.stdout.splitlines()[-1]) except (json.JSONDecodeError, IndexError): - self._logger.debug( - f"Issue parsing server info: {output.stdout}" - ) + self.logger.debug(f"Issue parsing server info: {output.stdout}") return None return info - # -- Connection and server management + # ---- Connection and server management async def connect_and_install_remote_server(self) -> bool: """Connect to the remote server and install the server.""" if await self.create_new_connection(): @@ -329,7 +348,7 @@ async def start_remote_server(self): async def __start_remote_server(self): """Start remote server.""" if not self.ssh_is_connected: - self._logger.error("SSH connection is not open") + self.logger.error("SSH connection is not open") self.__emit_connection_status( ConnectionStatus.Error, _("The SSH connection is not open"), @@ -337,14 +356,14 @@ async def __start_remote_server(self): return False if info := await self.get_server_info(): - self._logger.warning( + self.logger.warning( f"Remote server is already running for {self.peer_host}" ) - self._logger.debug("Checking server info") + self.logger.debug("Checking server info") if self._server_info != info: self._server_info = info - self._logger.info( + self.logger.info( "Different server info, updating info " f"for {self.peer_host}" ) @@ -355,7 +374,7 @@ async def __start_remote_server(self): ) return True - self._logger.error( + self.logger.error( "Error forwarding local port, server might not be " "reachable" ) @@ -371,7 +390,7 @@ async def __start_remote_server(self): return True - self._logger.debug(f"Starting remote server for {self.peer_host}") + self.logger.debug(f"Starting remote server for {self.peer_host}") try: self._remote_server_process = ( await self._ssh_connection.create_process( @@ -380,7 +399,7 @@ async def __start_remote_server(self): ) ) except (OSError, asyncssh.Error, ValueError) as e: - self._logger.error(f"Error starting remote server: {e}") + self.logger.error(f"Error starting remote server: {e}") self._remote_server_process = None self.__emit_connection_status( ConnectionStatus.Error, _("Error starting the remote server") @@ -393,7 +412,7 @@ async def __start_remote_server(self): _time += 1 if info is None: - self._logger.error("Faield to get server info") + self.logger.error("Faield to get server info") self.__emit_connection_status( ConnectionStatus.Error, _( @@ -405,7 +424,7 @@ async def __start_remote_server(self): self._server_info = info - self._logger.info( + self.logger.info( f"Remote server started for {self.peer_host} at port " f"{self.server_port}" ) @@ -417,7 +436,7 @@ async def __start_remote_server(self): ) return True - self._logger.error("Error forwarding local port.") + self.logger.error("Error forwarding local port.") self.__emit_connection_status( ConnectionStatus.Error, _("It was not possible to forward the local port"), @@ -427,7 +446,7 @@ async def __start_remote_server(self): async def ensure_server_installed(self) -> bool: """Check remote server version.""" if not self.ssh_is_connected: - self._logger.error("SSH connection is not open") + self.logger.error("SSH connection is not open") self.__emit_connection_status( ConnectionStatus.Error, _("The SSH connection is not open"), @@ -437,17 +456,13 @@ async def ensure_server_installed(self) -> bool: commnad = get_server_version_command(self.options["platform"]) try: - output = await self._ssh_connection.run( - commnad, check=True - ) + output = await self._ssh_connection.run(commnad, check=True) except asyncssh.ProcessError as err: # Server is not installed - self._logger.warning( - f"Issue checking server version: {err.stderr}" - ) + self.logger.warning(f"Issue checking server version: {err.stderr}") return await self.install_remote_server() except asyncssh.TimeoutError: - self._logger.error("Checking server version timed out") + self.logger.error("Checking server version timed out") self.__emit_connection_status( ConnectionStatus.Error, _("The server version check timed out"), @@ -457,7 +472,7 @@ async def ensure_server_installed(self) -> bool: version = output.stdout.splitlines()[-1].strip() if Version(version) >= Version(SPYDER_REMOTE_MAX_VERSION): - self._logger.error( + self.logger.error( f"Server version mismatch: {version} is greater than " f"the maximum supported version {SPYDER_REMOTE_MAX_VERSION}" ) @@ -469,14 +484,14 @@ async def ensure_server_installed(self) -> bool: return False if Version(version) < Version(SPYDER_REMOTE_MIN_VERSION): - self._logger.warning( + self.logger.warning( f"Server version mismatch: {version} is lower than " f"the minimum supported version {SPYDER_REMOTE_MIN_VERSION}. " f"A more recent version will be installed." ) return await self.install_remote_server() - self._logger.info(f"Supported Server version: {version}") + self.logger.info(f"Supported Server version: {version}") return True @@ -500,21 +515,21 @@ async def install_remote_server(self) -> bool: async def __install_remote_server(self): """Install remote server.""" if not self.ssh_is_connected: - self._logger.error("SSH connection is not open") + self.logger.error("SSH connection is not open") self.__emit_connection_status( ConnectionStatus.Error, _("The SSH connection is not open"), ) return False - self._logger.debug( + self.logger.debug( f"Installing spyder-remote-server on {self.peer_host}" ) try: command = get_installer_command(self.options["platform"]) except NotImplementedError: - self._logger.error( + self.logger.error( f"Cannot install spyder-remote-server on " f"{self.options['platform']} automatically. Please install it " f"manually." @@ -528,21 +543,21 @@ async def __install_remote_server(self): try: await self._ssh_connection.run(command, check=True) except asyncssh.ProcessError as err: - self._logger.error(f"Installation script failed: {err.stderr}") + self.logger.error(f"Installation script failed: {err.stderr}") self.__emit_connection_status( status=ConnectionStatus.Error, message=_("There was an error installing the remote server"), ) return False except asyncssh.TimeoutError: - self._logger.error("Installation script timed out") + self.logger.error("Installation script timed out") self.__emit_connection_status( status=ConnectionStatus.Error, message=_("There was an error installing the remote server"), ) return False - self._logger.info( + self.logger.info( f"Successfully installed spyder-remote-server on {self.peer_host}" ) @@ -578,7 +593,7 @@ async def __create_new_connection(self) -> bool: True if the connection was successful, False otherwise. """ if self.ssh_is_connected: - self._logger.debug( + self.logger.debug( f"Atempting to create a new connection with an existing for " f"{self.peer_host}" ) @@ -594,27 +609,27 @@ async def __create_new_connection(self) -> bool: for k, v in self.options.items() if k not in self._extra_options } - self._logger.debug("Opening SSH connection") + self.logger.debug("Opening SSH connection") try: self._ssh_connection = await asyncssh.connect( **connect_kwargs, client_factory=self.client_factory ) except (OSError, asyncssh.Error) as e: - self._logger.error(f"Failed to open ssh connection: {e}") + self.logger.error(f"Failed to open ssh connection: {e}") self.__emit_connection_status( ConnectionStatus.Error, _("It was not possible to open a connection to this machine"), ) return False - self._logger.info(f"SSH connection opened for {self.peer_host}") + self.logger.info(f"SSH connection opened for {self.peer_host}") return True async def forward_local_port(self): """Forward local port.""" if not self.server_port: - self._logger.error("Server port is not set") + self.logger.error("Server port is not set") self.__emit_connection_status( status=ConnectionStatus.Error, message=_("The server port is not set"), @@ -622,21 +637,21 @@ async def forward_local_port(self): return False if not self.ssh_is_connected: - self._logger.error("SSH connection is not open") + self.logger.error("SSH connection is not open") self.__emit_connection_status( status=ConnectionStatus.Error, message=_("The SSH connection is not open"), ) return False - self._logger.debug( + self.logger.debug( f"Forwarding an free local port to remote port {self.server_port}" ) if self._port_forwarder: - self._logger.warning( + self.logger.warning( f"Port forwarder is already open for host {self.peer_host} " - f"with local port {self.local_port} and remote port " + f"with local port {self._local_port} and remote port " f"{self.server_port}" ) await self.close_port_forwarder() @@ -652,9 +667,9 @@ async def forward_local_port(self): self.server_port, ) - self.local_port = local_port + self._local_port = local_port - self._logger.debug( + self.logger.debug( f"Forwarded local port {local_port} to remote server at " f"{server_host}:{self.server_port}" ) @@ -664,28 +679,28 @@ async def forward_local_port(self): async def close_port_forwarder(self): """Close port forwarder.""" if self.port_is_forwarded: - self._logger.debug( + self.logger.debug( f"Closing port forwarder for host {self.peer_host} with local " - f"port {self.local_port}" + f"port {self._local_port}" ) self._port_forwarder.close() await self._port_forwarder.wait_closed() self._port_forwarder = None - self._logger.debug( + self.logger.debug( f"Port forwarder closed for host {self.peer_host} with local " - f"port {self.local_port}" + f"port {self._local_port}" ) async def stop_remote_server(self): """Close remote server.""" if not self.server_started: - self._logger.warning( + self.logger.warning( f"Remote server is not running for {self.peer_host}" ) return False if not self.ssh_is_connected: - self._logger.error("SSH connection is not open") + self.logger.error("SSH connection is not open") self.__emit_connection_status( ConnectionStatus.Error, _("The SSH connection is not open"), @@ -693,7 +708,7 @@ async def stop_remote_server(self): return False # bug in jupyterhub, need to send SIGINT twice - self._logger.debug( + self.logger.debug( f"Stopping remote server for {self.peer_host} with pid " f"{self._server_info['pid']}" ) @@ -703,9 +718,7 @@ async def stop_remote_server(self): ) as jupyter: await jupyter.shutdown_server() except Exception as err: - self._logger.exception( - "Error stopping remote server", exc_info=err - ) + self.logger.exception("Error stopping remote server", exc_info=err) if ( self._remote_server_process @@ -716,27 +729,56 @@ async def stop_remote_server(self): self.__server_started.clear() self._remote_server_process = None - self._logger.info(f"Remote server process closed for {self.peer_host}") + self.logger.info(f"Remote server process closed for {self.peer_host}") return True async def close_ssh_connection(self): """Close SSH connection.""" if not self.ssh_is_connected: - self._logger.debug("SSH connection is not open") + self.logger.debug("SSH connection is not open") return - self._logger.debug(f"Closing SSH connection for {self.peer_host}") + self.logger.debug(f"Closing SSH connection for {self.peer_host}") self._ssh_connection.close() await self._ssh_connection.wait_closed() self._ssh_connection = None - self._logger.info("SSH connection closed") + self.logger.info("SSH connection closed") self.__connection_established.clear() self.__emit_connection_status( ConnectionStatus.Inactive, _("The connection was closed successfully"), ) - # --- Kernel Management + @staticmethod + def get_free_port(): + """Request a free port from the OS.""" + with socket.socket() as s: + s.bind(("", 0)) + return s.getsockname()[1] + + # ---- API Management + @classmethod + def register_api( + cls, kclass: type[SpyderBaseJupyterAPIType] + ) -> type[SpyderBaseJupyterAPIType]: + """Register a REST API class.""" + cls.REGISTERED_MODULE_APIS[kclass.__qualname__] = kclass + return kclass + + def get_api( + self, api: str | type[SpyderBaseJupyterAPIType] + ) -> typing.Callable[..., SpyderBaseJupyterAPIType]: + """Get a registered REST API class.""" + if isinstance(api, type): + api = api.__qualname__ + + api_class = self.REGISTERED_MODULE_APIS.get(api) + if api_class is None: + raise ValueError(f"API {api} is not registered") + + return partial(api_class, manager=self) + + # ---- Kernel Management async def start_new_kernel_ensure_server( self, _retries=5 ) -> KernelConnectionInfo: @@ -753,7 +795,7 @@ async def start_new_kernel_ensure_server( The kernel connection information. """ if not await self.ensure_connection_and_server(): - self._logger.error( + self.logger.error( "Cannot launch kernel, remote server is not running" ) return {} @@ -767,7 +809,7 @@ async def start_new_kernel_ensure_server( while not kernel_id and retries < _retries: await asyncio.sleep(1) kernel_id = await self.start_new_kernel() - self._logger.debug( + self.logger.debug( f"Server might not be ready yet, retrying kernel launch " f"({retries + 1}/{_retries})" ) @@ -791,7 +833,7 @@ async def get_kernel_info_ensure_server( The kernel connection information. """ if not await self.ensure_connection_and_server(): - self._logger.error( + self.logger.error( "Cannot launch kernel, remote server is not running" ) return {} @@ -805,7 +847,7 @@ async def get_kernel_info_ensure_server( while not kernel_info and retries < _retries: await asyncio.sleep(1) kernel_info = await self.get_kernel_info(kernel_id) - self._logger.debug( + self.logger.debug( f"Server might not be ready yet, retrying kernel launch " f"({retries + 1}/{_retries})" ) @@ -819,7 +861,7 @@ async def start_new_kernel(self, kernel_spec=None) -> KernelInfo: self.server_url, api_token=self.api_token ) as jupyter: response = await jupyter.create_kernel(kernel_spec=kernel_spec) - self._logger.info(f"Kernel started with ID {response['id']}") + self.logger.info(f"Kernel started with ID {response['id']}") return response async def list_kernels(self) -> list[KernelInfo]: @@ -829,7 +871,7 @@ async def list_kernels(self) -> list[KernelInfo]: ) as jupyter: response = await jupyter.list_kernels() - self._logger.info(f"Kernels listed for {self.peer_host}") + self.logger.info(f"Kernels listed for {self.peer_host}") return response async def get_kernel_info(self, kernel_id) -> KernelInfo: @@ -839,7 +881,7 @@ async def get_kernel_info(self, kernel_id) -> KernelInfo: ) as jupyter: response = await jupyter.get_kernel(kernel_id=kernel_id) - self._logger.info(f"Kernel info retrieved for ID {kernel_id}") + self.logger.info(f"Kernel info retrieved for ID {kernel_id}") return response async def terminate_kernel(self, kernel_id) -> bool: @@ -849,7 +891,7 @@ async def terminate_kernel(self, kernel_id) -> bool: ) as jupyter: response = await jupyter.delete_kernel(kernel_id=kernel_id) - self._logger.info(f"Kernel terminated for ID {kernel_id}") + self.logger.info(f"Kernel terminated for ID {kernel_id}") return response async def interrupt_kernel(self, kernel_id) -> bool: @@ -859,7 +901,7 @@ async def interrupt_kernel(self, kernel_id) -> bool: ) as jupyter: response = await jupyter.interrupt_kernel(kernel_id=kernel_id) - self._logger.info(f"Kernel interrupted for ID {kernel_id}") + self.logger.info(f"Kernel interrupted for ID {kernel_id}") return response async def restart_kernel(self, kernel_id) -> bool: @@ -869,12 +911,5 @@ async def restart_kernel(self, kernel_id) -> bool: ) as jupyter: response = await jupyter.restart_kernel(kernel_id=kernel_id) - self._logger.info(f"Kernel restarted for ID {kernel_id}") + self.logger.info(f"Kernel restarted for ID {kernel_id}") return response - - @staticmethod - def get_free_port(): - """Request a free port from the OS.""" - with socket.socket() as s: - s.bind(("", 0)) - return s.getsockname()[1] diff --git a/spyder/plugins/remoteclient/api/modules/__init__.py b/spyder/plugins/remoteclient/api/modules/__init__.py new file mode 100644 index 00000000000..8ad9e2aaa6c --- /dev/null +++ b/spyder/plugins/remoteclient/api/modules/__init__.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +# +# Copyright © Spyder Project Contributors +# Licensed under the terms of the MIT License +# (see spyder/__init__.py for details) + +""" +spyder.plugins.remoteclient.api.modules +======================================= + +Remote Client Plugin Modules API. + +This package contains the API for the Remote Client Plugin +to interact with the Jupyter Server. + +These modules provides an interface to control the Jupyter Server that sits +in the remote machine. +""" diff --git a/spyder/plugins/remoteclient/api/jupyterhub/__init__.py b/spyder/plugins/remoteclient/api/modules/base.py similarity index 78% rename from spyder/plugins/remoteclient/api/jupyterhub/__init__.py rename to spyder/plugins/remoteclient/api/modules/base.py index 020779df554..9bb7557b70b 100644 --- a/spyder/plugins/remoteclient/api/jupyterhub/__init__.py +++ b/spyder/plugins/remoteclient/api/modules/base.py @@ -4,15 +4,28 @@ # Licensed under the terms of the MIT License # (see spyder/__init__.py for details) +from __future__ import annotations +from abc import abstractmethod import uuid import logging import time +import typing import asyncio +import re import yarl import aiohttp -from spyder.plugins.remoteclient.api.jupyterhub import auth +from spyder.api.asyncdispatcher import AsyncDispatcher +from spyder.api.utils import ABCMeta, abstract_attribute + +if typing.TYPE_CHECKING: + from spyder.plugins.remoteclient.api import SpyderRemoteAPIManager + + +SpyderBaseJupyterAPIType = typing.TypeVar( + "SpyderBaseJupyterAPIType", bound="SpyderBaseJupyterAPI" +) logger = logging.getLogger(__name__) @@ -21,6 +34,54 @@ REQUEST_TIMEOUT = 5 # seconds +async def token_authentication(api_token, verify_ssl=True): + return aiohttp.ClientSession( + headers={"Authorization": f"token {api_token}"}, + connector=aiohttp.TCPConnector(ssl=None if verify_ssl else False), + ) + + +async def basic_authentication(hub_url, username, password, verify_ssl=True): + session = aiohttp.ClientSession( + headers={"Referer": str(yarl.URL(hub_url) / "hub" / "api")}, + connector=aiohttp.TCPConnector(ssl=None if verify_ssl else False), + ) + + await session.post( + yarl.URL(hub_url) / "hub" / "login", + data={ + "username": username, + "password": password, + }, + ) + + return session + + +async def keycloak_authentication( + hub_url, username, password, verify_ssl=True +): + session = aiohttp.ClientSession( + headers={"Referer": str(yarl.URL(hub_url) / "hub" / "api")}, + connector=aiohttp.TCPConnector(ssl=None if verify_ssl else False), + ) + + response = await session.get(yarl.URL(hub_url) / "hub" / "oauth_login") + content = await response.content.read() + auth_url = re.search('action="([^"]+)"', content.decode("utf8")).group(1) + + response = await session.post( + auth_url.replace("&", "&"), + headers={"Content-Type": "application/x-www-form-urlencoded"}, + data={ + "username": username, + "password": password, + "credentialId": "", + }, + ) + return session + + class JupyterHubAPI: def __init__(self, hub_url, auth_type="token", verify_ssl=True, **kwargs): self.hub_url = yarl.URL(hub_url) @@ -36,11 +97,11 @@ def __init__(self, hub_url, auth_type="token", verify_ssl=True, **kwargs): async def __aenter__(self): if self.auth_type == "token": - self.session = await auth.token_authentication( + self.session = await token_authentication( self.api_token, verify_ssl=self.verify_ssl ) elif self.auth_type == "basic": - self.session = await auth.basic_authentication( + self.session = await basic_authentication( self.hub_url, self.username, self.password, @@ -51,11 +112,11 @@ async def __aenter__(self): logger.debug( "upgrading basic authentication to token authentication" ) - self.session = await auth.token_authentication( + self.session = await token_authentication( self.api_token, verify_ssl=self.verify_ssl ) elif self.auth_type == "keycloak": - self.session = await auth.keycloak_authentication( + self.session = await keycloak_authentication( self.hub_url, self.username, self.password, @@ -66,7 +127,7 @@ async def __aenter__(self): logger.debug( "upgrading keycloak authentication to token authentication" ) - self.session = await auth.token_authentication( + self.session = await token_authentication( self.api_token, verify_ssl=self.verify_ssl ) return self @@ -299,12 +360,8 @@ async def create_kernel(self, kernel_spec=None): self.api_url / "kernels", json=data ) as response: if response.status != 201: - logger.error( - f"failed to create kernel_spec={kernel_spec}" - ) - raise ValueError( - await response.text() - ) + logger.error(f"failed to create kernel_spec={kernel_spec}") + raise ValueError(await response.text()) return await response.json() async def list_kernel_specs(self): @@ -386,9 +443,7 @@ async def restart_kernel(self, kernel_id): return False async def shutdown_server(self): - async with self.session.post( - self.api_url / "shutdown" - ) as response: + async with self.session.post(self.api_url / "shutdown") as response: if response.status == 200: logger.info(f"Server for jupyter has been shutdown") return True @@ -470,3 +525,78 @@ async def send_code(self, username, code, wait=True, timeout=None): # cell did not produce output elif msg["content"].get("execution_state") == "idle": return "" + + +class SpyderBaseJupyterAPI(metaclass=ABCMeta): + """ + Base class for Jupyter API plugins. + + This class must be subclassed to implement the API for a specific + Jupyter extension. Provides a context manager for the API session. + + Class Attributes + ---------------- + base_url: str + The base URL for the Jupyter Extension's rest API. + + Attributes + ---------- + api_url: yarl.URL + The full URL for the rest API. + + api_token: str + The API token for the Jupyter API. + + verify_ssl: bool + Whether to verify SSL certificates. + + session: aiohttp.ClientSession + The session for the Jupyter API requests. + """ + + @abstract_attribute + def base_url(self): ... + + def __init__(self, manager: SpyderRemoteAPIManager): + self.manager = manager + self.session = None + + @property + def api_url(self): + return yarl.URL(self.manager.server_url) / self.base_url + + async def connect(self): + if not await AsyncDispatcher( + self.manager.ensure_connection_and_server, + loop="asyncssh", + return_awaitable=True, + )(): + raise RuntimeError("Failed to connect to Jupyter server") + + if self.session is not None and not self.session.closed: + return + + self.session = aiohttp.ClientSession( + headers={"Authorization": f"token {self.manager.api_token}"}, + connector=aiohttp.TCPConnector(ssl=None), + raise_for_status=self._raise_for_status, + ) + + async def __aenter__(self): + await self.connect() + return self + + async def close(self): + await self.session.close() + + async def __aexit__(self, exc_type, exc, tb): + await self.close() + + @property + def closed(self): + if self.session is None: + return True + return self.session.closed + + @abstractmethod + async def _raise_for_status(self, response: aiohttp.ClientResponse): ... diff --git a/spyder/plugins/remoteclient/api/modules/file_services.py b/spyder/plugins/remoteclient/api/modules/file_services.py new file mode 100644 index 00000000000..9280afe59d1 --- /dev/null +++ b/spyder/plugins/remoteclient/api/modules/file_services.py @@ -0,0 +1,463 @@ +# -*- coding: utf-8 -*- +# +# Copyright © Spyder Project Contributors +# Licensed under the terms of the MIT License +# (see spyder/__init__.py for details) + +from __future__ import annotations +import base64 +from http import HTTPStatus +from io import RawIOBase +import json +from pathlib import Path + +import aiohttp + +from spyder.plugins.remoteclient.api.modules.base import SpyderBaseJupyterAPI +from spyder.plugins.remoteclient.api import SpyderRemoteAPIManager + +# jupyter server's extension name for spyder-remote-services +SPYDER_PLUGIN_NAME = "spyder-services" + + +class SpyderServicesError(Exception): + """ + Exception for errors related to Spyder services. + """ + ... + + +class RemoteFileServicesError(SpyderServicesError): + """ + Exception for errors related to remote file services. + """ + def __init__(self, type, message, url, tracebacks): + self.type = type + self.message = message + self.url = url + self.tracebacks = tracebacks + + def __str__(self): + return ( + f"(type='{self.type}', message='{self.message}', url='{self.url}')" + ) + + +class RemoteOSError(OSError, RemoteFileServicesError): + """ + Exception for OSErrors raised on the remote server. + """ + def __init__(self, errno, strerror, filename, url): + super().__init__(errno, strerror, filename) + super(OSError, self).__init__(OSError, super().__str__(), url, []) + + @classmethod + def from_json(cls, data, url): + return cls(data["errno"], data["strerror"], data["filename"], url) + + def __str__(self): + return super(OSError, self).__str__() + + +@SpyderRemoteAPIManager.register_api +class SpyderRemoteFileIOAPI(SpyderBaseJupyterAPI, RawIOBase): + """API for remote file I/O. + + This API is a RawIOBase subclass that allows reading and writing files + on a remote server. + + The file is open upon the websocket connection and closed when the + connection is closed. + + If lock is True, the file will be locked on the remote server. + And any other attempts to open the file will wait until the lock is + released. + + If atomic is True, any operations on the file will be done on a temporary + copy of the file, and then the file will be replaced with the copy upon + closing. + + Parameters + ---------- + file : str + The path to the file to open. + mode : str, optional + The mode to open the file in, by default "r". + atomic : bool, optional + Whether to open the file atomically, by default False. + lock : bool, optional + Whether to lock the file, by default False. + encoding : str, optional + The encoding to use when reading and writing the file, by default "utf-8". + + Raises + ------ + RemoteFileServicesError + If an error occurs when opening the file. + RemoteOSError + If an OSError occured on the remote server. + """ + base_url = SPYDER_PLUGIN_NAME + "/fs/open" + + def __init__( + self, + file, + mode="r", + atomic=False, + lock=False, + encoding="utf-8", + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.name = file + self.mode = mode + self.encoding = encoding + self.atomic = atomic + self.lock = lock + + self._websocket: aiohttp.ClientWebSocketResponse = None + + async def _raise_for_status(self, response): + response.raise_for_status() + + async def connect(self): + await super().connect() + + if self._websocket is not None and not self._websocket.closed: + return + + self._websocket = await self.session.ws_connect( + self.api_url / f"file://{self.name}", + params={ + "mode": self.mode, + "atomic": str(self.atomic).lower(), + "lock": str(self.lock).lower(), + "encoding": self.encoding, + }, + ) + + try: + await self._check_connection() + except Exception as e: + self._websocket = None + raise e + + async def _check_connection(self): + status = await self._websocket.receive() + + if status.type == aiohttp.WSMsgType.CLOSE: + await self._websocket.close() + if status.data == 1002: + data = json.loads(status.extra) + if data["status"] in ( + HTTPStatus.LOCKED, + HTTPStatus.EXPECTATION_FAILED, + ): + raise RemoteOSError.from_json( + data, url=self._websocket._response.url + ) + + raise RemoteFileServicesError( + data.get("type", "UnknownError"), + data.get("message", "Unknown error"), + self._websocket._response.url, + data.get("tracebacks", []), + ) + else: + raise RemoteFileServicesError( + "UnknownError", + "Failed to open file", + self._websocket._response.url, + [], + ) + + async def close(self): + await self._websocket.close() + try: + await self._websocket.receive() + except Exception: + pass + await super().close() + + @property + def closed(self): + if self._websocket is None: + return super().closed + return self._websocket.closed and super().closed + + def _decode_data(self, data: str | object) -> str | bytes | object: + """Decode data from a message.""" + if not isinstance(data, str): + return data + + if "b" in self.mode: + return base64.b64decode(data) + + return base64.b64decode(data).decode(self.encoding) + + def _encode_data(self, data: bytes | str | object) -> str: + """Encode data for a message.""" + if isinstance(data, bytes): + return base64.b64encode(data).decode("ascii") + if isinstance(data, str): + return base64.b64encode(data.encode(self.encoding)).decode("ascii") + return data + + async def _send_request(self, method: str, **args): + await self._websocket.send_json({"method": method, **args}) + + async def _get_response(self, timeout=None): + message = json.loads( + await self._websocket.receive_bytes(timeout=timeout) + ) + + if message["status"] > 400: + if message["status"] == HTTPStatus.EXPECTATION_FAILED: + raise RemoteOSError.from_json( + message, url=self._websocket._response.url + ) + + raise RemoteFileServicesError( + message.get("type", "UnknownError"), + message.get("message", "Unknown error"), + self._websocket._response.url, + message.get("tracebacks", []), + ) + + data = message.get("data") + if data is None: + return None + + if isinstance(data, list): + return [self._decode_data(d) for d in data] + + return self._decode_data(data) + + @property + def closefd(self): + return True + + async def __iter__(self): + while response := await self.readline(): + yield response + + async def __next__(self): + response = await self.readline() + if not response: + raise StopIteration + return response + + async def write(self, s: bytes | str) -> int: + """Write data to the file.""" + await self._send_request("write", data=self._encode_data(s)) + return await self._get_response() + + async def flush(self): + """Flush the file.""" + await self._send_request("flush") + return await self._get_response() + + async def read(self, size: int = -1) -> bytes | str: + """Read data from the file.""" + await self._send_request("read", n=size) + return await self._get_response() + + async def readall(self): + """Read all data from the file.""" + return await self.read(size=-1) + + async def readinto(self, b) -> int: + """Read data into a buffer.""" + raise NotImplementedError( + "readinto() is not supported by the remote file API" + ) + + async def seek(self, pos: int, whence: int = 0) -> int: + """Seek to a new position in the file.""" + await self._send_request("seek", offset=pos, whence=whence) + return await self._get_response() + + async def tell(self) -> int: + """Get the current file position.""" + await self._send_request("tell") + return await self._get_response() + + async def truncate(self, size: int | None = None) -> int: + """Truncate the file to a new size.""" + await self._send_request("truncate", size=size) + return await self._get_response() + + async def fileno(self): + """Flush the file to disk.""" + await self._send_request("fileno") + return await self._get_response() + + async def readline(self, size: int = -1) -> bytes | str: + """Read a line from the file.""" + await self._send_request("readline", size=size) + return await self._get_response() + + async def readlines(self, hint: int = -1) -> list[bytes | str]: + """Read lines from the file.""" + await self._send_request("readlines", hint=hint) + return await self._get_response() + + async def writelines(self, lines: list[bytes | str]): + """Write lines to the file.""" + await self._send_request( + "writelines", lines=list(map(self._encode_data, lines)) + ) + return await self._get_response() + + async def isatty(self) -> bool: + """Check if the file is a TTY.""" + await self._send_request("isatty") + return await self._get_response() + + async def readable(self) -> bool: + """Check if the file is readable.""" + await self._send_request("readable") + return await self._get_response() + + async def writable(self) -> bool: + """Check if the file is writable.""" + await self._send_request("writable") + return await self._get_response() + + +@SpyderRemoteAPIManager.register_api +class SpyderRemoteFileServicesAPI(SpyderBaseJupyterAPI): + """API for remote file services. + + This API allows for interacting with files on a remote server. + + Raises + ------ + RemoteFileServicesError + If an error occurs when interacting with the file services. + RemoteOSError + If an OSError occured on the remote server. + """ + + base_url = SPYDER_PLUGIN_NAME + "/fs" + + async def _raise_for_status(self, response: aiohttp.ClientResponse): + if response.status not in ( + HTTPStatus.INTERNAL_SERVER_ERROR, + HTTPStatus.EXPECTATION_FAILED, + ): + return response.raise_for_status() + + try: + data = await response.json() + except json.JSONDecodeError: + data = {} + + # If we're in a context we can rely on __aexit__() to release as the + # exception propagates. + if not response._in_context: + response.release() + + if response.status == HTTPStatus.EXPECTATION_FAILED: + raise RemoteOSError.from_json(data, response.url) + + raise RemoteFileServicesError( + data.get("type", "UnknownError"), + data.get("message", "Unknown error"), + response.url, + data.get("tracebacks", []), + ) + + async def ls(self, path: Path, detail: bool = True): + async with self.session.get( + self.api_url / "ls" / f"file://{path}", + params={"detail": str(detail).lower()}, + ) as response: + return await response.json() + + async def info(self, path: Path): + async with self.session.get( + self.api_url / "info" / f"file://{path}" + ) as response: + return await response.json() + + async def exists(self, path: Path): + async with self.session.get( + self.api_url / "exists" / f"file://{path}" + ) as response: + return await response.json() + + async def is_file(self, path: Path): + async with self.session.get( + self.api_url / "isfile" / f"file://{path}" + ) as response: + return await response.json() + + async def is_dir(self, path: Path): + async with self.session.get( + self.api_url / "isdir" / f"file://{path}" + ) as response: + return await response.json() + + async def mkdir( + self, path: Path, create_parents: bool = True, exist_ok: bool = False + ): + async with self.session.post( + self.api_url / "mkdir" / f"file://{path}", + params={ + "create_parents": str(create_parents).lower(), + "exist_ok": str(exist_ok).lower(), + }, + ) as response: + return await response.json() + + async def rmdir(self, path: Path): + async with self.session.delete( + self.api_url / "rmdir" / f"file://{path}" + ) as response: + return await response.json() + + async def unlink(self, path: Path, missing_ok: bool = False): + async with self.session.delete( + self.api_url / "file" / f"file://{path}", + params={"missing_ok": str(missing_ok).lower()}, + ) as response: + return await response.json() + + async def copy(self, path1: Path, path2: Path): + async with self.session.post( + self.api_url / "copy" / f"file://{path1}", + params={"dest": f"file://{path2}"}, + ) as response: + return await response.json() + + async def copy2(self, path1: Path, path2: Path): + async with self.session.post( + self.api_url / "copy" / f"file://{path1}", + params={"dest": f"file://{path2}", "metadata": "true"}, + ) as response: + return await response.json() + + async def replace(self, path1: Path, path2: Path): + async with self.session.post( + self.api_url / "move" / f"file://{path1}", + params={"dest": f"file://{path2}"}, + ) as response: + return await response.json() + + async def touch(self, path: Path, truncate: bool = True): + async with self.session.post( + self.api_url / "touch" / f"file://{path}", + params={"truncate": str(truncate).lower()}, + ) as response: + return await response.json() + + async def open( + self, path, mode="r", atomic=False, lock=False, encoding="utf-8" + ): + file = SpyderRemoteFileIOAPI( + path, mode, atomic, lock, encoding, manager=self.manager + ) + await file.connect() + return file diff --git a/spyder/plugins/remoteclient/plugin.py b/spyder/plugins/remoteclient/plugin.py index fd85bd4c257..dc0b0d8cde6 100644 --- a/spyder/plugins/remoteclient/plugin.py +++ b/spyder/plugins/remoteclient/plugin.py @@ -10,8 +10,10 @@ """ # Standard library imports +from __future__ import annotations import logging import contextlib +import typing # Third-party imports from qtpy.QtCore import Signal, Slot @@ -34,13 +36,21 @@ RemoteClientActions, RemoteClientMenus, ) -from spyder.plugins.remoteclient.api.client import SpyderRemoteClient +from spyder.plugins.remoteclient.api import SpyderRemoteAPIManager from spyder.plugins.remoteclient.api.protocol import ( SSHClientOptions, ConnectionStatus, ) +from spyder.plugins.remoteclient.api.modules.base import SpyderBaseJupyterAPI +from spyder.plugins.remoteclient.api.modules.file_services import ( + SpyderRemoteFileServicesAPI, +) from spyder.plugins.remoteclient.widgets.container import RemoteClientContainer +if typing.TYPE_CHECKING: + from spyder.plugins.remoteclient.api.modules.base import SpyderBaseJupyterAPIType + + _logger = logging.getLogger(__name__) @@ -73,7 +83,7 @@ class RemoteClient(SpyderPluginV2): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._remote_clients: dict[str, SpyderRemoteClient] = {} + self._remote_clients: dict[str, SpyderRemoteAPIManager] = {} # ---- SpyderPluginV2 API # ------------------------------------------------------------------------- @@ -114,9 +124,7 @@ def on_initialize(self): self.sig_client_message_logged.connect( container.sig_client_message_logged ) - self.sig_version_mismatch.connect( - container.on_server_version_mismatch - ) + self.sig_version_mismatch.connect(container.on_server_version_mismatch) self._sig_kernel_started.connect(container.on_kernel_started) def on_first_registration(self): @@ -221,7 +229,7 @@ def load_client_from_id(self, config_id): def load_client(self, config_id: str, options: SSHClientOptions): """Load remote server.""" - client = SpyderRemoteClient(config_id, options, _plugin=self) + client = SpyderRemoteAPIManager(config_id, options, _plugin=self) self._remote_clients[config_id] = client def load_conf(self, config_id): @@ -300,6 +308,53 @@ def create_ipyclient_for_server(self, config_id): ) ) + @staticmethod + def register_api(kclass: typing.Type[SpyderBaseJupyterAPIType]): + """Register Remote Client API. + + This method is used to register a new API class that will be used to + interact with the remote server. + + Can be used as a decorator. + + Parameters + ---------- + kclass: Type[SpyderBaseJupyterAPI] + Class to be registered. + + Returns + ------- + Type[SpyderBaseJupyterAPI] + Class that was registered. + """ + return SpyderRemoteAPIManager.register_api(kclass) + + def get_api( + self, config_id: str, api: str | typing.Type[SpyderBaseJupyterAPIType] + ): + """Get the API for a remote server. + + Get the registered API class for a given remote server. + + Parameters + ---------- + config_id: str + Configuration id of the remote server. + api: str | Type[SpyderBaseJupyterAPI] + API class to be retrieved. + + Returns + ------- + SpyderBaseJupyterAPI + API class instance. + """ + if config_id not in self._remote_clients: + self.load_client_from_id(config_id) + + client = self._remote_clients[config_id] + + return client.get_api(api) + # ---- Private API # ------------------------------------------------------------------------- # --- Remote Server Kernel Methods @@ -375,3 +430,12 @@ def _add_remote_consoles_menu(self): ) self._is_consoles_menu_added = True + + def get_file_api(self, config_id): + """Get file API.""" + if config_id not in self._remote_clients: + self.load_client_from_id(config_id) + + client = self._remote_clients[config_id] + + return client.get_api(SpyderRemoteFileServicesAPI) diff --git a/spyder/plugins/remoteclient/tests/conftest.py b/spyder/plugins/remoteclient/tests/conftest.py index cdf69208b1d..295715eaa5a 100644 --- a/spyder/plugins/remoteclient/tests/conftest.py +++ b/spyder/plugins/remoteclient/tests/conftest.py @@ -184,9 +184,9 @@ def ipyconsole( @pytest.fixture(scope="session") -def ipyconsole_and_remoteclient(qapp) -> ( - typing.Iterator[typing.Tuple[IPythonConsole, RemoteClient]] -): +def ipyconsole_and_remoteclient( + qapp, +) -> typing.Iterator[typing.Tuple[IPythonConsole, RemoteClient]]: """ Start the Spyder Remote Client plugin with IPython Console. diff --git a/spyder/plugins/remoteclient/tests/test_files.py b/spyder/plugins/remoteclient/tests/test_files.py new file mode 100644 index 00000000000..010c2a9bdc0 --- /dev/null +++ b/spyder/plugins/remoteclient/tests/test_files.py @@ -0,0 +1,155 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2009- Spyder Project Contributors +# +# Distributed under the terms of the MIT License +# (see spyder/__init__.py for details) +# ----------------------------------------------------------------------------- + +"""Tests for the remote files API.""" + +# Third party imports +import pytest + +from spyder.api.asyncdispatcher import AsyncDispatcher +from spyder.plugins.remoteclient.plugin import RemoteClient +from spyder.plugins.remoteclient.api.modules.file_services import RemoteOSError + + +class TestRemoteFilesAPI: + remote_temp_dir = "/tmp/spyder-remote-tests" + + @AsyncDispatcher.dispatch(early_return=False) + async def test_create_dir( + self, + remote_client: RemoteClient, + remote_client_id: str, + ): + """Test that a directory can be created on the remote server.""" + file_api_class = remote_client.get_file_api(remote_client_id) + assert file_api_class is not None + + async with file_api_class() as file_api: + assert await file_api.mkdir(self.remote_temp_dir) == { + "success": True + } + + @AsyncDispatcher.dispatch(early_return=False) + async def test_write_file( + self, + remote_client: RemoteClient, + remote_client_id: str, + ): + """Test that a file can be written to the remote server.""" + file_api_class = remote_client.get_file_api(remote_client_id) + assert file_api_class is not None + + async with file_api_class() as file_api: + async with await file_api.open( + self.remote_temp_dir + "/test.txt", "w+" + ) as f: + await f.write("Hello, world!") + await f.flush() + await f.seek(0) + assert await f.read() == "Hello, world!" + + @AsyncDispatcher.dispatch(early_return=False) + async def test_list_directories( + self, + remote_client: RemoteClient, + remote_client_id: str, + ): + """Test that a directory can be listed on the remote server.""" + file_api_class = remote_client.get_file_api(remote_client_id) + assert file_api_class is not None + + async with file_api_class() as file_api: + ls_content = await file_api.ls(self.remote_temp_dir) + assert len(ls_content) == 1 + assert ls_content[0]["name"] == self.remote_temp_dir + "/test.txt" + assert ls_content[0]["size"] == 13 + assert ls_content[0]["type"] == "file" + assert not ls_content[0]["islink"] + assert ls_content[0]["created"] > 0 + assert ls_content[0]["mode"] == 0o100644 + assert ls_content[0]["uid"] > 0 + assert ls_content[0]["gid"] >= 0 + assert ls_content[0]["mtime"] > 0 + assert ls_content[0]["ino"] > 0 + assert ls_content[0]["nlink"] == 1 + + @AsyncDispatcher.dispatch(early_return=False) + async def test_copy_file( + self, + remote_client: RemoteClient, + remote_client_id: str, + ): + """Test that a file can be copied on the remote server.""" + file_api_class = remote_client.get_file_api(remote_client_id) + assert file_api_class is not None + + async with file_api_class() as file_api: + assert await file_api.copy( + self.remote_temp_dir + "/test.txt", + self.remote_temp_dir + "/test2.txt", + ) == {"success": True} + + async with file_api_class() as file_api: + ls_content = await file_api.ls(self.remote_temp_dir) + assert len(ls_content) == 2 + idx = [ + item["name"] for item in ls_content + ].index(self.remote_temp_dir + "/test.txt") + assert ls_content[not idx]["name"] == self.remote_temp_dir + "/test2.txt" + assert ls_content[0]["size"] == ls_content[1]["size"] + + @AsyncDispatcher.dispatch(early_return=False) + async def test_rm_file( + self, + remote_client: RemoteClient, + remote_client_id: str, + ): + """Test that a file can be removed from the remote server.""" + file_api_class = remote_client.get_file_api(remote_client_id) + assert file_api_class is not None + + async with file_api_class() as file_api: + assert await file_api.unlink( + self.remote_temp_dir + "/test.txt" + ) == {"success": True} + assert await file_api.unlink( + self.remote_temp_dir + "/test2.txt" + ) == {"success": True} + + @AsyncDispatcher.dispatch(early_return=False) + async def test_rm_dir( + self, + remote_client: RemoteClient, + remote_client_id: str, + ): + """Test that a directory can be removed from the remote server.""" + file_api_class = remote_client.get_file_api(remote_client_id) + assert file_api_class is not None + + async with file_api_class() as file_api: + assert await file_api.rmdir(self.remote_temp_dir) == { + "success": True + } + + @AsyncDispatcher.dispatch(early_return=False) + async def test_ls_nonexistent_dir( + self, + remote_client: RemoteClient, + remote_client_id: str, + ): + """Test that listing a nonexistent directory raises an error.""" + file_api_class = remote_client.get_file_api(remote_client_id) + assert file_api_class is not None + + async with file_api_class() as file_api: + with pytest.raises(RemoteOSError) as exc_info: + await file_api.ls(self.remote_temp_dir) + + assert exc_info.value.errno == 2 # ENOENT: No such file or directory + +if __name__ == "__main__": + pytest.main() diff --git a/spyder/plugins/remoteclient/tests/test_plugin.py b/spyder/plugins/remoteclient/tests/test_plugin.py index 8f99095056b..f04cba91917 100644 --- a/spyder/plugins/remoteclient/tests/test_plugin.py +++ b/spyder/plugins/remoteclient/tests/test_plugin.py @@ -176,7 +176,7 @@ def test_wrong_version( qtbot, ): monkeypatch.setattr( - "spyder.plugins.remoteclient.api.client.SPYDER_REMOTE_MAX_VERSION", + "spyder.plugins.remoteclient.api.SPYDER_REMOTE_MAX_VERSION", "0.0.1", ) monkeypatch.setattr( diff --git a/spyder/plugins/remoteclient/utils/installation.py b/spyder/plugins/remoteclient/utils/installation.py index c871d7bc795..e0f970c0b2f 100644 --- a/spyder/plugins/remoteclient/utils/installation.py +++ b/spyder/plugins/remoteclient/utils/installation.py @@ -12,16 +12,19 @@ SERVER_ENV = "spyder-remote" PACKAGE_NAME = "spyder-remote-services" SCRIPT_URL = ( - f"https://raw.githubusercontent.com/spyder-ide/{PACKAGE_NAME}/master/scripts" + f"https://raw.githubusercontent.com/spyder-ide/" + f"{PACKAGE_NAME}/master/scripts" ) def get_installer_command(platform: str) -> str: if platform == "win": raise NotImplementedError("Windows is not supported yet") - + if running_remoteclient_tests(): - return '\n' # server should be aready installed in the test environment + return ( + "\n" # server should be aready installed in the test environment + ) return ( f'"${{SHELL}}" <(curl -L {SCRIPT_URL}/installer.sh) ' diff --git a/spyder/plugins/remoteclient/widgets/connectiondialog.py b/spyder/plugins/remoteclient/widgets/connectiondialog.py index 7deea0efd5a..3775910ba13 100644 --- a/spyder/plugins/remoteclient/widgets/connectiondialog.py +++ b/spyder/plugins/remoteclient/widgets/connectiondialog.py @@ -121,10 +121,7 @@ def set_text(self, reasons: ValidationReasons): ) if reasons.get("missing_info"): - text += ( - prefix - + _("There are missing fields on this page.") - ) + text += prefix + _("There are missing fields on this page.") self.setAlignment(Qt.AlignCenter if n_reasons == 1 else Qt.AlignLeft) self.setText(text) @@ -235,8 +232,8 @@ def create_connection_info_widget(self): ) intro_tip = TipWidget( tip_text=intro_tip_text, - icon=ima.icon('info_tip'), - hover_icon=ima.icon('info_tip_hover'), + icon=ima.icon("info_tip"), + hover_icon=ima.icon("info_tip_hover"), size=AppStyle.ConfigPageIconSize + 2, wrap_text=True, ) @@ -258,15 +255,13 @@ def create_connection_info_widget(self): # TODO: The config file method is not implemented yet, so we need to # disable it for now. methods = ( - (_('Password'), AuthenticationMethod.Password), - (_('Key file'), AuthenticationMethod.KeyFile), + (_("Password"), AuthenticationMethod.Password), + (_("Key file"), AuthenticationMethod.KeyFile), # (_('Configuration file'), AuthenticationMethod.ConfigFile), ) self._auth_methods = self.create_combobox( - _("Authentication method:"), - methods, - f"{self.host_id}/auth_method" + _("Authentication method:"), methods, f"{self.host_id}/auth_method" ) # Subpages @@ -333,7 +328,7 @@ def _create_common_elements(self, auth_method): suffix="", option=f"{self.host_id}/{auth_method}/port", min_=1, - max_=65535 + max_=65535, ) port.spinbox.setStyleSheet("margin-left: 5px") @@ -392,7 +387,7 @@ def _create_password_subpage(self): else _("Your password is saved securely by Spyder") ), status_icon=ima.icon("error"), - password=True + password=True, ) validation_label = ValidationLabel(self) @@ -402,9 +397,9 @@ def _create_password_subpage(self): password ) - self._validation_labels[ - AuthenticationMethod.Password - ] = validation_label + self._validation_labels[AuthenticationMethod.Password] = ( + validation_label + ) # Layout password_layout = QVBoxLayout() @@ -447,7 +442,7 @@ def _create_keyfile_subpage(self): if self.NEW_CONNECTION else _("Your passphrase is saved securely by Spyder") ), - password=True + password=True, ) validation_label = ValidationLabel(self) @@ -457,9 +452,9 @@ def _create_keyfile_subpage(self): keyfile ) - self._validation_labels[ - AuthenticationMethod.KeyFile - ] = validation_label + self._validation_labels[AuthenticationMethod.KeyFile] = ( + validation_label + ) # Layout keyfile_layout = QVBoxLayout() @@ -507,9 +502,9 @@ def _create_configfile_subpage(self): name, configfile, ] - self._validation_labels[ - AuthenticationMethod.ConfigFile - ] = validation_label + self._validation_labels[AuthenticationMethod.ConfigFile] = ( + validation_label + ) # Layout configfile_layout = QVBoxLayout() @@ -548,18 +543,18 @@ def _validate_address(self, address): """Validate if address introduced by users is correct.""" # Regex pattern for a valid domain name (simplified version) domain_pattern = ( - r'^([a-zA-Z0-9][a-zA-Z0-9-]{0,61}[a-zA-Z0-9]\.){1,}[a-zA-Z]{2,}$' + r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,61}[a-zA-Z0-9]\.){1,}[a-zA-Z]{2,}$" ) # Regex pattern for a valid IPv4 address - ipv4_pattern = r'^(\d{1,3}\.){3}\d{1,3}$' + ipv4_pattern = r"^(\d{1,3}\.){3}\d{1,3}$" # Regex pattern for a valid IPv6 address (simplified version) - ipv6_pattern = r'^([\da-fA-F]{1,4}:){7}[\da-fA-F]{1,4}$' + ipv6_pattern = r"^([\da-fA-F]{1,4}:){7}[\da-fA-F]{1,4}$" # Combined pattern to check all three formats combined_pattern = ( - f'({domain_pattern})|({ipv4_pattern})|({ipv6_pattern})' + f"({domain_pattern})|({ipv4_pattern})|({ipv6_pattern})" ) address_re = re.compile(combined_pattern) @@ -737,7 +732,7 @@ class ConnectionDialog(SidebarDialog): sig_connections_changed = Signal() def __init__(self, parent=None): - self.ICON = ima.icon('remote_server') + self.ICON = ima.icon("remote_server") super().__init__(parent) self._container = parent diff --git a/spyder/plugins/remoteclient/widgets/connectionstatus.py b/spyder/plugins/remoteclient/widgets/connectionstatus.py index bb51aa630c9..3703d6f7e6b 100644 --- a/spyder/plugins/remoteclient/widgets/connectionstatus.py +++ b/spyder/plugins/remoteclient/widgets/connectionstatus.py @@ -69,15 +69,15 @@ logging.INFO: "Info:", logging.WARNING: ( f'' - f'Warning:' + f"Warning:" ), logging.ERROR: ( f'' - f'Error:' + f"Error:" ), logging.CRITICAL: ( f'' - f'Critical:' + f"Critical:" ), } @@ -220,7 +220,7 @@ def _set_stylesheet(self): ) # Remove automatic indent added by Qt - important_labels_css.setValues(**{'qproperty-indent': '0'}) + important_labels_css.setValues(**{"qproperty-indent": "0"}) for label in [self._connection_label, self._message_label]: label.setStyleSheet(important_labels_css.toString()) @@ -242,8 +242,8 @@ def _set_stylesheet(self): # the same amount of attention to it. backgroundColor=SpyderPalette.COLOR_BACKGROUND_3, # Remove bottom rounded borders - borderBottomLeftRadius='0px', - borderBottomRightRadius='0px', + borderBottomLeftRadius="0px", + borderBottomRightRadius="0px", # This is necessary to align the label to the text above it marginLeft="2px", ) @@ -252,8 +252,8 @@ def _set_stylesheet(self): self._log_widget.css.QPlainTextEdit.setValues( # Remove these borders to make it appear attached to the top label borderTop="0px", - borderTopLeftRadius='0px', - borderTopRightRadius='0px', + borderTopLeftRadius="0px", + borderTopRightRadius="0px", # Match border color with the top label one and avoid to change # that color when the widget is given focus borderLeft=f"1px solid {SpyderPalette.COLOR_BACKGROUND_3}",