From 81b98d1d46a511b565da14d51d6dcfa2ad8c01ba Mon Sep 17 00:00:00 2001 From: Arjun Purushothaman Date: Tue, 25 Feb 2025 16:36:25 +0000 Subject: [PATCH] add type annotations and docstrings to devlib Most of the files are covered, but some of the instruments and unused platforms are not augmented --- .gitignore | 4 + devlib/_target_runner.py | 129 +- devlib/collector/__init__.py | 71 +- devlib/collector/dmesg.py | 138 +- devlib/collector/ftrace.py | 336 +- devlib/collector/logcat.py | 60 +- devlib/collector/perf.py | 265 +- devlib/collector/perfetto.py | 72 +- devlib/collector/screencapture.py | 83 +- devlib/collector/serial_trace.py | 87 +- devlib/collector/systrace.py | 78 +- devlib/connection.py | 594 +++- devlib/exception.py | 30 +- devlib/host.py | 283 +- devlib/instrument/__init__.py | 459 ++- devlib/instrument/acmecape.py | 9 +- devlib/instrument/arm_energy_probe.py | 19 +- devlib/instrument/daq.py | 129 +- devlib/instrument/frames.py | 78 +- devlib/instrument/hwmon.py | 33 +- devlib/module/__init__.py | 229 +- devlib/module/android.py | 64 +- devlib/module/biglittle.py | 262 +- devlib/module/cgroups.py | 373 ++- devlib/module/cgroups2.py | 377 ++- devlib/module/cooling.py | 44 +- devlib/module/cpufreq.py | 269 +- devlib/module/cpuidle.py | 120 +- devlib/module/devfreq.py | 102 +- devlib/module/gpufreq.py | 35 +- devlib/module/hotplug.py | 83 +- devlib/module/hwmon.py | 124 +- devlib/module/sched.py | 231 +- devlib/module/thermal.py | 135 +- devlib/module/vexpress.py | 169 +- devlib/platform/__init__.py | 51 +- devlib/platform/arm.py | 173 +- devlib/target.py | 4233 +++++++++++++++++++------ devlib/utils/android.py | 1171 +++++-- devlib/utils/annotation_helpers.py | 72 + devlib/utils/asyn.py | 618 +++- devlib/utils/misc.py | 767 ++++- devlib/utils/rendering.py | 37 +- devlib/utils/serial_port.py | 46 +- devlib/utils/ssh.py | 1207 +++++-- devlib/utils/types.py | 25 +- devlib/utils/version.py | 17 +- mypy.ini | 6 + py.typed | 0 setup.py | 20 +- tests/test_target.py | 6 +- 51 files changed, 10203 insertions(+), 3820 deletions(-) create mode 100644 devlib/utils/annotation_helpers.py create mode 100644 mypy.ini create mode 100644 py.typed diff --git a/.gitignore b/.gitignore index 291b5354d..6ae2b075d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,7 @@ devlib/bin/scripts/shutils doc/_build/ build/ dist/ +.venv/ +.vscode/ +venv/ +.history/ \ No newline at end of file diff --git a/devlib/_target_runner.py b/devlib/_target_runner.py index a45612354..77723ca42 100644 --- a/devlib/_target_runner.py +++ b/devlib/_target_runner.py @@ -1,4 +1,4 @@ -# Copyright 2024 ARM Limited +# Copyright 2024-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,12 +20,22 @@ import logging import os import time + from platform import machine +from typing import Optional, cast, Protocol, TYPE_CHECKING, TypedDict, Union +from typing_extensions import NotRequired, LiteralString +if TYPE_CHECKING: + from _typeshed import StrPath, BytesPath + from devlib.platform import Platform +else: + StrPath = str + BytesPath = bytes from devlib.exception import (TargetStableError, HostError) -from devlib.target import LinuxTarget +from devlib.target import LinuxTarget, Target from devlib.utils.misc import get_subprocess, which from devlib.utils.ssh import SshConnection +from devlib.utils.annotation_helpers import SubprocessCommand, SshUserConnectionSettings class TargetRunner: @@ -40,15 +50,22 @@ class TargetRunner: """ def __init__(self, - target): + target: Target) -> None: self.target = target - self.logger = logging.getLogger(self.__class__.__name__) def __enter__(self): + """ + Enter the context for this runner. + :return: This runner instance. + :rtype: TargetRunner + """ return self def __exit__(self, *_): + """ + Exit the context for this runner. + """ pass @@ -77,20 +94,20 @@ class SubprocessTargetRunner(TargetRunner): """ def __init__(self, - runner_cmd, - target, - connect=True, - boot_timeout=60): + runner_cmd: SubprocessCommand, + target: Target, + connect: bool = True, + boot_timeout: int = 60): super().__init__(target=target) - self.boot_timeout = boot_timeout + self.boot_timeout: int = boot_timeout self.logger.info('runner_cmd: %s', runner_cmd) try: self.runner_process = get_subprocess(runner_cmd) except Exception as ex: - raise HostError(f'Error while running "{runner_cmd}": {ex}') from ex + raise HostError(f'Error while running "{runner_cmd!r}": {ex}') from ex if connect: self.wait_boot_complete() @@ -107,16 +124,16 @@ def __exit__(self, *_): self.terminate() - def wait_boot_complete(self): + def wait_boot_complete(self) -> None: """ - Wait for target OS to finish boot up and become accessible over SSH in at most - ``SubprocessTargetRunner.boot_timeout`` seconds. + Wait for the target OS to finish booting and become accessible within + :attr:`boot_timeout` seconds. - :raises TargetStableError: In case of timeout. + :raises TargetStableError: If the target is inaccessible after the timeout. """ start_time = time.time() - elapsed = 0 + elapsed: float = 0.0 while self.boot_timeout >= elapsed: try: self.target.connect(timeout=self.boot_timeout - elapsed) @@ -132,9 +149,9 @@ def wait_boot_complete(self): self.terminate() raise TargetStableError(f'Target is inaccessible for {self.boot_timeout} seconds!') - def terminate(self): + def terminate(self) -> None: """ - Terminate ``SubprocessTargetRunner.runner_process``. + Terminate the subprocess associated with this runner. """ self.logger.debug('Killing target runner...') @@ -150,7 +167,7 @@ class NOPTargetRunner(TargetRunner): :type target: Target """ - def __init__(self, target): + def __init__(self, target: Target) -> None: super().__init__(target=target) def __enter__(self): @@ -159,11 +176,63 @@ def __enter__(self): def __exit__(self, *_): pass - def terminate(self): + def terminate(self) -> None: """ Nothing to terminate for NOP target runners. Defined to be compliant with other runners (e.g., ``SubprocessTargetRunner``). """ + pass + + +QEMUTargetUserSettings = TypedDict("QEMUTargetUserSettings", { + 'kernel_image': str, + 'arch': NotRequired[str], + 'cpu_type': NotRequired[str], + 'initrd_image': str, + 'mem_size': NotRequired[int], + 'num_cores': NotRequired[int], + 'num_threads': NotRequired[int], + 'cmdline': NotRequired[str], + 'enable_kvm': NotRequired[bool], +}) + +QEMUTargetRunnerSettings = TypedDict("QEMUTargetRunnerSettings", { + 'kernel_image': str, + 'arch': str, + 'cpu_type': str, + 'initrd_image': str, + 'mem_size': int, + 'num_cores': int, + 'num_threads': int, + 'cmdline': str, + 'enable_kvm': bool, +}) + + +SshConnectionSettings = TypedDict("SshConnectionSettings", { + 'username': str, + 'password': str, + 'keyfile': Optional[Union[LiteralString, StrPath, BytesPath]], + 'host': str, + 'port': int, + 'timeout': float, + 'platform': 'Platform', + 'sudo_cmd': str, + 'strict_host_check': bool, + 'use_scp': bool, + 'poll_transfers': bool, + 'start_transfer_poll_delay': int, + 'total_transfer_timeout': int, + 'transfer_poll_period': int, +}) + + +class QEMUTargetRunnerTargetFactory(Protocol): + """ + Protocol for Lambda function for creating :class:`Target` based object. + """ + def __call__(self, *, connect: bool, conn_cls, connection_settings: SshConnectionSettings) -> Target: + ... class QEMUTargetRunner(SubprocessTargetRunner): @@ -177,7 +246,7 @@ class QEMUTargetRunner(SubprocessTargetRunner): * ``arch``: Architecture type. Defaults to ``aarch64``. - * ``cpu_types``: List of CPU ids for QEMU. The list only contains ``cortex-a72`` by + * ``cpu_type``: List of CPU ids for QEMU. The list only contains ``cortex-a72`` by default. This parameter is valid for Arm architectures only. * ``initrd_image``: This points to the location of initrd image (e.g., @@ -212,21 +281,25 @@ class QEMUTargetRunner(SubprocessTargetRunner): """ def __init__(self, - qemu_settings, - connection_settings=None, - make_target=LinuxTarget, - **args): + qemu_settings: QEMUTargetUserSettings, + connection_settings: Optional[SshUserConnectionSettings] = None, + make_target: QEMUTargetRunnerTargetFactory = cast(QEMUTargetRunnerTargetFactory, LinuxTarget), + **args) -> None: - self.connection_settings = { + default_connection_settings = { 'host': '127.0.0.1', 'port': 8022, 'username': 'root', 'password': 'root', 'strict_host_check': False, } - self.connection_settings = {**self.connection_settings, **(connection_settings or {})} - qemu_args = { + self.connection_settings: SshConnectionSettings = cast(SshConnectionSettings, { + **default_connection_settings, + **(connection_settings or {}) + }) + + qemu_default_args = { 'arch': 'aarch64', 'cpu_type': 'cortex-a72', 'mem_size': 512, @@ -235,7 +308,7 @@ def __init__(self, 'cmdline': 'console=ttyAMA0', 'enable_kvm': True, } - qemu_args = {**qemu_args, **qemu_settings} + qemu_args: QEMUTargetRunnerSettings = cast(QEMUTargetRunnerSettings, {**qemu_default_args, **qemu_settings}) qemu_executable = f'qemu-system-{qemu_args["arch"]}' qemu_path = which(qemu_executable) diff --git a/devlib/collector/__init__.py b/devlib/collector/__init__.py index 0bc22ff07..fafbb7d3f 100644 --- a/devlib/collector/__init__.py +++ b/devlib/collector/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2015 ARM Limited +# Copyright 2015-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,27 +16,62 @@ import logging from devlib.utils.types import caseless_string +from typing import TYPE_CHECKING, Optional, List +if TYPE_CHECKING: + from devlib.target import Target + class CollectorBase(object): + """ + The ``Collector`` API provide a consistent way of collecting arbitrary data from + a target. Data is collected via an instance of a class derived from :class:`CollectorBase`. - def __init__(self, target): + :param target: The devlib Target from which data will be collected. + """ + def __init__(self, target: 'Target'): self.target = target - self.logger = logging.getLogger(self.__class__.__name__) - self.output_path = None - - def reset(self): + self.logger: logging.Logger = logging.getLogger(self.__class__.__name__) + self.output_path: Optional[str] = None + + def reset(self) -> None: + """ + This can be used to configure a collector for collection. This must be invoked + before :meth:`start()` is called to begin collection. + """ pass - def start(self): + def start(self) -> None: + """ + Starts collecting from the target. + """ pass def stop(self): + """ + Stops collecting from target. Must be called after + :func:`start()`. + """ pass - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: + """ + Configure the output path for the particular collector. This will be either + a directory or file path which will be used when storing the data. Please see + the individual Collector documentation for more information. + + :param output_path: The path (file or directory) to which data will be saved. + """ self.output_path = output_path - def get_data(self): + def get_data(self) -> 'CollectorOutput': + """ + The collected data will be return via the previously specified output_path. + This method will return a :class:`CollectorOutput` object which is a subclassed + list object containing individual ``CollectorOutputEntry`` objects with details + about the individual output entry. + + :raises RuntimeError: If ``output_path`` has not been set. + """ return CollectorOutput() def __enter__(self): @@ -47,11 +82,25 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.stop() + class CollectorOutputEntry(object): + """ + This object is designed to allow for the output of a collector to be processed + generically. The object will behave as a regular string containing the path to + underlying output path and can be used directly in ``os.path`` operations. + + .. attribute:: CollectorOutputEntry.path + + The file path for the corresponding output item. + + .. attribute:: CollectorOutputEntry.path_kind - path_kinds = ['file', 'directory'] + :param path: The file path of the collected output data. + :param path_kind: The type of output. Must be one of ``file`` or ``directory``. + """ + path_kinds: List[str] = ['file', 'directory'] - def __init__(self, path, path_kind): + def __init__(self, path: str, path_kind: str): self.path = path path_kind = caseless_string(path_kind) diff --git a/devlib/collector/dmesg.py b/devlib/collector/dmesg.py index 06676aaa6..82d32d556 100644 --- a/devlib/collector/dmesg.py +++ b/devlib/collector/dmesg.py @@ -1,4 +1,4 @@ -# Copyright 2024 ARM Limited +# Copyright 2024-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,8 +23,12 @@ from devlib.exception import TargetStableError from devlib.utils.misc import memoized +from typing import (Pattern, Optional, Match, Tuple, List, + Union, Any, Generator, TYPE_CHECKING) +if TYPE_CHECKING: + from devlib.target import Target -_LOGGER = logging.getLogger('dmesg') +_LOGGER: logging.Logger = logging.getLogger('dmesg') class KernelLogEntry(object): @@ -49,11 +53,12 @@ class KernelLogEntry(object): :type line_nr: int """ - _TIMESTAMP_MSG_REGEX = re.compile(r'\[(.*?)\] (.*)') - _RAW_LEVEL_REGEX = re.compile(r'<([0-9]+)>(.*)') - _PRETTY_LEVEL_REGEX = re.compile(r'\s*([a-z]+)\s*:([a-z]+)\s*:\s*(.*)') + _TIMESTAMP_MSG_REGEX: Pattern[str] = re.compile(r'\[(.*?)\] (.*)') + _RAW_LEVEL_REGEX: Pattern[str] = re.compile(r'<([0-9]+)>(.*)') + _PRETTY_LEVEL_REGEX: Pattern[str] = re.compile(r'\s*([a-z]+)\s*:([a-z]+)\s*:\s*(.*)') - def __init__(self, facility, level, timestamp, msg, line_nr=0): + def __init__(self, facility: Optional[str], level: str, + timestamp: timedelta, msg: str, line_nr: int = 0): self.facility = facility self.level = level self.timestamp = timestamp @@ -61,7 +66,7 @@ def __init__(self, facility, level, timestamp, msg, line_nr=0): self.line_nr = line_nr @classmethod - def from_str(cls, line, line_nr=0): + def from_str(cls, line: str, line_nr: int = 0) -> 'KernelLogEntry': """ Parses a "dmesg --decode" output line, formatted as following: kern :err : [3618282.310743] nouveau 0000:01:00.0: systemd-logind[988]: nv50cal_space: -16 @@ -69,10 +74,18 @@ def from_str(cls, line, line_nr=0): Or the more basic output given by "dmesg -r": <3>[3618282.310743] nouveau 0000:01:00.0: systemd-logind[988]: nv50cal_space: -16 + :param line: A string from dmesg. + :type line: str + :param line_nr: The line number in the overall log. + :type line_nr: int + :raises ValueError: If the line format is invalid. + :return: A constructed :class:`KernelLogEntry`. + :rtype: KernelLogEntry + """ - def parse_raw_level(line): - match = cls._RAW_LEVEL_REGEX.match(line) + def parse_raw_level(line: str) -> Tuple[str, Union[str, Any]]: + match: Optional[Match[str]] = cls._RAW_LEVEL_REGEX.match(line) if not match: raise ValueError(f'dmesg entry format not recognized: {line}') level, remainder = match.groups() @@ -81,15 +94,15 @@ def parse_raw_level(line): level = levels[int(level) % len(levels)] return level, remainder - def parse_pretty_level(line): - match = cls._PRETTY_LEVEL_REGEX.match(line) + def parse_pretty_level(line: str) -> Tuple[str, str, str]: + match: Optional[Match[str]] = cls._PRETTY_LEVEL_REGEX.match(line) if not match: raise ValueError(f'dmesg entry pretty format not recognized: {line}') facility, level, remainder = match.groups() return facility, level, remainder - def parse_timestamp_msg(line): - match = cls._TIMESTAMP_MSG_REGEX.match(line) + def parse_timestamp_msg(line: str) -> Tuple[timedelta, str]: + match: Optional[Match[str]] = cls._TIMESTAMP_MSG_REGEX.match(line) if not match: raise ValueError(f'dmesg entry timestamp format not recognized: {line}') timestamp, msg = match.groups() @@ -101,7 +114,7 @@ def parse_timestamp_msg(line): # If we can parse the raw prio directly, that is a basic line try: level, remainder = parse_raw_level(line) - facility = None + facility: Optional[str] = None except ValueError: facility, level, remainder = parse_pretty_level(line) @@ -116,21 +129,26 @@ def parse_timestamp_msg(line): ) @classmethod - def from_dmesg_output(cls, dmesg_out, error=None): + def from_dmesg_output(cls, dmesg_out: Optional[str], error: Optional[str] = None) -> Generator['KernelLogEntry', Any, None]: """ Return a generator of :class:`KernelLogEntry` for each line of the output of dmesg command. + :param dmesg_out: The dmesg output to parse. + :type dmesg_out - str + :param error: If ``"raise"`` or ``None``, an exception will be raised if a parsing error occurs. If ``"warn"``, it will be logged at WARNING level. If ``"ignore"``, it will be ignored. If a callable is passed, the exception will be passed to it. :type error: str or None or typing.Callable[[BaseException], None] + :return: A generator of parsed :class:`KernelLogEntry` objects. + :rtype: Generator[KernelLogEntry, Any, None] .. note:: The same restrictions on the dmesg output format as for :meth:`from_str` apply. """ - for i, line in enumerate(dmesg_out.splitlines()): + for i, line in enumerate(dmesg_out.splitlines() if dmesg_out else ''): if line.strip(): try: yield cls.from_str(line, line_nr=i) @@ -160,6 +178,9 @@ class DmesgCollector(CollectorBase): """ Dmesg output collector. + :param target: The devlib Target (must be rooted). + :type target:Target + :param level: Minimum log level to enable. All levels that are more critical will be collected as well. :type level: str @@ -172,13 +193,15 @@ class DmesgCollector(CollectorBase): so it's not recommended unless it's really necessary. :type empty_buffer: bool + :param parse_error: A string to be appended to error lines if parse fails. + :type parse_error : str .. warning:: If BusyBox dmesg is used, facility and level will be ignored, and the parsed entries will also lack that information. """ # taken from "dmesg --help" # This list needs to be ordered by priority - LOG_LEVELS = [ + LOG_LEVELS: List[str] = [ "emerg", # system is unusable "alert", # action must be taken immediately "crit", # critical conditions @@ -189,13 +212,15 @@ class DmesgCollector(CollectorBase): "debug", # debug-level messages ] - def __init__(self, target, level=LOG_LEVELS[-1], facility='kern', empty_buffer=False, parse_error=None): + def __init__(self, target: 'Target', level: str = LOG_LEVELS[-1], + facility: str = 'kern', empty_buffer: bool = False, + parse_error: Optional[str] = None): super(DmesgCollector, self).__init__(target) if not target.is_rooted: raise TargetStableError('Cannot collect dmesg on non-rooted target') - self.output_path = None + self.output_path: Optional[str] = None if level not in self.LOG_LEVELS: raise ValueError('level needs to be one of: {}'.format( @@ -207,42 +232,48 @@ def __init__(self, target, level=LOG_LEVELS[-1], facility='kern', empty_buffer=F # e.g. busybox's dmesg or the one shipped on some Android versions # (toybox). Note: BusyBox dmesg does not support -h, but will still # print the help with an exit code of 1 - help_ = self.target.execute('dmesg -h', check_exit_code=False) - self.basic_dmesg = not all( + help_: str = self.target.execute('dmesg -h', check_exit_code=False) + self.basic_dmesg: bool = not all( opt in help_ for opt in ('--facility', '--force-prefix', '--decode', '--level') ) self.facility = facility try: - needs_root = target.read_sysctl('kernel.dmesg_restrict') + needs_root: bool = target.read_sysctl('kernel.dmesg_restrict') except ValueError: needs_root = True else: needs_root = bool(int(needs_root)) self.needs_root = needs_root - self._begin_timestamp = None - self.empty_buffer = empty_buffer - self._dmesg_out = None - self._parse_error = parse_error + self._begin_timestamp: Optional[timedelta] = None + self.empty_buffer: bool = empty_buffer + self._dmesg_out: Optional[str] = None + self._parse_error: Optional[str] = parse_error @property - def dmesg_out(self): - out = self._dmesg_out + def dmesg_out(self) -> Optional[str]: + """ + get the dmesg output + """ + out: Optional[str] = self._dmesg_out if out is None: return None else: try: - entry = self.entries[0] + entry: KernelLogEntry = self.entries[0] except IndexError: return '' else: - i = entry.line_nr + i: int = entry.line_nr return '\n'.join(out.splitlines()[i:]) @property - def entries(self): + def entries(self) -> List[KernelLogEntry]: + """ + get the entries as a list of class:KernelLogEntry + """ return self._get_entries( self._dmesg_out, self._begin_timestamp, @@ -250,14 +281,15 @@ def entries(self): ) @memoized - def _get_entries(self, dmesg_out, timestamp, error): - entries = KernelLogEntry.from_dmesg_output(dmesg_out, error=error) - entries = list(entries) + def _get_entries(self, dmesg_out: Optional[str], timestamp: Optional[timedelta], + error: Optional[str]) -> List[KernelLogEntry]: + entry_ = KernelLogEntry.from_dmesg_output(dmesg_out, error=error) + entries = list(entry_) if timestamp is None: return entries else: try: - first = entries[0] + first: KernelLogEntry = entries[0] except IndexError: pass else: @@ -273,14 +305,17 @@ def _get_entries(self, dmesg_out, timestamp, error): if entry.timestamp > timestamp ] - def _get_output(self): - levels_list = list(takewhile( + def _get_output(self) -> None: + """ + get the dmesg collector output into _dmesg_out local variable + """ + levels_list: List[str] = list(takewhile( lambda level: level != self.level, self.LOG_LEVELS )) levels_list.append(self.level) if self.basic_dmesg: - cmd = 'dmesg -r' + cmd: str = 'dmesg -r' else: cmd = 'dmesg --facility={facility} --force-prefix --decode --level={levels}'.format( levels=','.join(levels_list), @@ -289,10 +324,17 @@ def _get_output(self): self._dmesg_out = self.target.execute(cmd, as_root=self.needs_root) - def reset(self): + def reset(self) -> None: + """ + Reset the collector's internal state (e.g., cached dmesg output). + """ self._dmesg_out = None - def start(self): + def start(self) -> None: + """ + Start collecting dmesg logs. If ``empty_buffer`` is true, clear them first. + :raises TargetStableError: If the target is not rooted. + """ # If the buffer is emptied on start(), it does not matter as we will # not end up with entries dating from before start() if self.empty_buffer: @@ -307,13 +349,23 @@ def start(self): else: self._begin_timestamp = entry.timestamp - def stop(self): + def stop(self) -> None: + """ + Stop collecting logs and retrieve the latest dmesg output. + """ self._get_output() - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: self.output_path = output_path - def get_data(self): + def get_data(self) -> CollectorOutput: + """ + Write the dmesg output to :attr:`output_path` and return a :class:`CollectorOutput`. + + :raises RuntimeError: If :attr:`output_path` is unset. + :return: A collector output referencing the saved dmesg file. + :rtype: CollectorOutput + """ if self.output_path is None: raise RuntimeError("Output path was not set.") with open(self.output_path, 'wt') as f: diff --git a/devlib/collector/ftrace.py b/devlib/collector/ftrace.py index 9a887efdf..abe9081b2 100644 --- a/devlib/collector/ftrace.py +++ b/devlib/collector/ftrace.py @@ -1,4 +1,4 @@ -# Copyright 2015-2018 ARM Limited +# Copyright 2015-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,12 +29,18 @@ from devlib.utils.misc import check_output, which, memoized from devlib.utils.asyn import asyncf - -TRACE_MARKER_START = 'TRACE_MARKER_START' -TRACE_MARKER_STOP = 'TRACE_MARKER_STOP' -OUTPUT_TRACE_FILE = 'trace.dat' -OUTPUT_PROFILE_FILE = 'trace_stat.dat' -DEFAULT_EVENTS = [ +from devlib.module.cpufreq import CpufreqModule +from devlib.module.cpuidle import Cpuidle +from typing import (cast, List, Pattern, TYPE_CHECKING, Optional, + Dict, Union, Match) +if TYPE_CHECKING: + from devlib.target import Target + +TRACE_MARKER_START: str = 'TRACE_MARKER_START' +TRACE_MARKER_STOP: str = 'TRACE_MARKER_STOP' +OUTPUT_TRACE_FILE: str = 'trace.dat' +OUTPUT_PROFILE_FILE: str = 'trace_stat.dat' +DEFAULT_EVENTS: List[str] = [ 'cpu_frequency', 'cpu_idle', 'sched_migrate_task', @@ -45,32 +51,70 @@ 'sched_wakeup', 'sched_wakeup_new', ] -TIMEOUT = 180 +TIMEOUT: int = 180 # Regexps for parsing of function profiling data -CPU_RE = re.compile(r' Function \(CPU([0-9]+)\)') -STATS_RE = re.compile(r'([^ ]*) +([0-9]+) +([0-9.]+) us +([0-9.]+) us +([0-9.]+) us') +CPU_RE: Pattern[str] = re.compile(r' Function \(CPU([0-9]+)\)') +STATS_RE: Pattern[str] = re.compile(r'([^ ]*) +([0-9]+) +([0-9.]+) us +([0-9.]+) us +([0-9.]+) us') -class FtraceCollector(CollectorBase): +class FtraceCollector(CollectorBase): + """ + Collector using ftrace to trace kernel events and functions. + + :param target: The devlib Target (must be rooted). + :type target: Target + :param events: A list of events to trace (defaults to `DEFAULT_EVENTS`). + :type events: List(str) + :param functions: A list of functions to trace, if function tracing is used. + :type functions: List(str) + :param tracer: The tracer to use (e.g., 'function_graph'), or ``None``. + :type tracer: str + :param trace_children_functions: If ``True``, trace child functions as well. + :type trace_children_functions: bool + :param buffer_size: The size of the trace buffer in KB. + :type buffer_size: int + :param top_buffer_size: The top-level buffer size in KB, if different. + :type top_buffer_size: int + :param buffer_size_step: The step size for increasing the buffer. + :type buffer_size_step: int + :param tracing_path: The path to the tracefs mount point, if not auto-detected. + :type tracing_path: str + :param automark: If ``True``, automatically mark start and stop in the trace. + :type automark: bool + :param autoreport: If ``True``, generate a textual trace report automatically. + :type autoreport: bool + :param autoview: If ``True``, open KernelShark for a graphical view of the trace. + :type autoview: bool + :param no_install: If ``True``, assume trace-cmd is already installed on target. + :type no_install: bool + :param strict: If ``True``, raise errors if requested events/functions are not available. + :type strict: bool + :param report_on_target: If ``True``, generate the trace report on the target side. + :type report_on_target: bool + :param trace_clock: The clock source for the trace. + :type trace_clock: bool + :param saved_cmdlines_nr: The number of cmdlines to save in the trace buffer. + :type saved_cmdlines_nr: int + """ # pylint: disable=too-many-locals,too-many-branches,too-many-statements - def __init__(self, target, - events=None, - functions=None, - tracer=None, - trace_children_functions=False, - buffer_size=None, - top_buffer_size=None, - buffer_size_step=1000, - tracing_path=None, - automark=True, - autoreport=True, - autoview=False, - no_install=False, - strict=False, - report_on_target=False, - trace_clock='local', - saved_cmdlines_nr=4096, + def __init__(self, target: 'Target', + events: Optional[List[str]] = None, + functions: Optional[List[str]] = None, + tracer: Optional[str] = None, + trace_children_functions: bool = False, + buffer_size: Optional[int] = None, + top_buffer_size: Optional[int] = None, + buffer_size_step: int = 1000, + tracing_path: Optional[str] = None, + automark: bool = True, + autoreport: bool = True, + autoview: bool = False, + no_install: bool = False, + strict: bool = False, + report_on_target: bool = False, + trace_clock: str = 'local', + saved_cmdlines_nr: int = 4096, ): super(FtraceCollector, self).__init__(target) self.events = events if events is not None else DEFAULT_EVENTS @@ -79,39 +123,39 @@ def __init__(self, target, self.trace_children_functions = trace_children_functions self.buffer_size = buffer_size self.top_buffer_size = top_buffer_size - self.tracing_path = self._resolve_tracing_path(target, tracing_path) + self.tracing_path: str = self._resolve_tracing_path(target, tracing_path) self.automark = automark self.autoreport = autoreport self.autoview = autoview self.strict = strict self.report_on_target = report_on_target - self.target_output_file = target.path.join(self.target.working_directory, OUTPUT_TRACE_FILE) - text_file_name = target.path.splitext(OUTPUT_TRACE_FILE)[0] + '.txt' - self.target_text_file = target.path.join(self.target.working_directory, text_file_name) - self.output_path = None - self.target_binary = None - self.host_binary = None - self.start_time = None - self.stop_time = None - self.event_string = None - self.function_string = None + self.target_output_file: str = target.path.join(self.target.working_directory, OUTPUT_TRACE_FILE) if target.path else '' + text_file_name: str = target.path.splitext(OUTPUT_TRACE_FILE)[0] + '.txt' if target.path else '' + self.target_text_file: str = target.path.join(self.target.working_directory, text_file_name) if target.path else '' + self.output_path: Optional[str] = None + self.target_binary: Optional[str] = None + self.host_binary: Optional[str] = None + self.start_time: Optional[float] = None + self.stop_time: Optional[float] = None + self.event_string: Optional[str] = None + self.function_string: Optional[str] = None self.trace_clock = trace_clock self.saved_cmdlines_nr = saved_cmdlines_nr - self._reset_needed = True + self._reset_needed: bool = True # pylint: disable=bad-whitespace # Setup tracing paths - self.available_events_file = self.target.path.join(self.tracing_path, 'available_events') - self.available_functions_file = self.target.path.join(self.tracing_path, 'available_filter_functions') - self.current_tracer_file = self.target.path.join(self.tracing_path, 'current_tracer') - self.function_profile_file = self.target.path.join(self.tracing_path, 'function_profile_enabled') - self.marker_file = self.target.path.join(self.tracing_path, 'trace_marker') - self.ftrace_filter_file = self.target.path.join(self.tracing_path, 'set_ftrace_filter') - self.available_tracers_file = self.target.path.join(self.tracing_path, 'available_tracers') - self.kprobe_events_file = self.target.path.join(self.tracing_path, 'kprobe_events') + self.available_events_file: str = self.target.path.join(self.tracing_path, 'available_events') if self.target.path else '' + self.available_functions_file: str = self.target.path.join(self.tracing_path, 'available_filter_functions') if self.target.path else '' + self.current_tracer_file: str = self.target.path.join(self.tracing_path, 'current_tracer') if self.target.path else '' + self.function_profile_file: str = self.target.path.join(self.tracing_path, 'function_profile_enabled') if self.target.path else '' + self.marker_file: str = self.target.path.join(self.tracing_path, 'trace_marker') if self.target.path else '' + self.ftrace_filter_file: str = self.target.path.join(self.tracing_path, 'set_ftrace_filter') if self.target.path else '' + self.available_tracers_file: str = self.target.path.join(self.tracing_path, 'available_tracers') if self.target.path else '' + self.kprobe_events_file: str = self.target.path.join(self.tracing_path, 'kprobe_events') if self.target.path else '' self.host_binary = which('trace-cmd') - self.kernelshark = which('kernelshark') + self.kernelshark: Optional[str] = which('kernelshark') if not self.target.is_rooted: raise TargetStableError('trace-cmd instrument cannot be used on an unrooted device.') @@ -120,7 +164,7 @@ def __init__(self, target, if self.autoview and self.kernelshark is None: raise HostError('kernelshark binary must be installed on the host if autoview=True.') if not no_install: - host_file = os.path.join(PACKAGE_BIN_DIRECTORY, self.target.abi, 'trace-cmd') + host_file = os.path.join(PACKAGE_BIN_DIRECTORY, self.target.abi or '', 'trace-cmd') self.target_binary = self.target.install(host_file) else: if not self.target.is_installed('trace-cmd'): @@ -128,26 +172,32 @@ def __init__(self, target, self.target_binary = 'trace-cmd' # Validate required events to be traced - def event_to_regex(event): + def event_to_regex(event: str) -> Pattern[str]: + """ + to use for finding events to be traced + """ if not event.startswith('*'): event = '*' + event return re.compile(event.replace('*', '.*')) - def event_is_in_list(event, events): + def event_is_in_list(event: str, events: List[str]) -> bool: + """ + true if event is in the list of events + """ return any( event_to_regex(event).match(_event) for _event in events ) - available_events = self.available_events - unavailable_events = [ + available_events: List[str] = self.available_events + unavailable_events: List[str] = [ event for event in self.events if not event_is_in_list(event, available_events) ] if unavailable_events: - message = 'Events not available for tracing: {}'.format( + message: str = 'Events not available for tracing: {}'.format( ', '.join(unavailable_events) ) if self.strict: @@ -155,7 +205,7 @@ def event_is_in_list(event, events): else: self.target.logger.warning(message) - selected_events = sorted(set(self.events) - set(unavailable_events)) + selected_events: List[str] = sorted(set(self.events) - set(unavailable_events)) if self.tracer and self.tracer not in self.available_tracers: raise TargetStableError('Unsupported tracer "{}". Available tracers: {}'.format( @@ -164,7 +214,7 @@ def event_is_in_list(event, events): # Check for function tracing support if self.functions: # Validate required functions to be traced - selected_functions = [] + selected_functions: List[str] = [] for function in self.functions: if function not in self.available_functions: message = 'Function [{}] not available for tracing/profiling'.format(function) @@ -177,7 +227,7 @@ def event_is_in_list(event, events): # Function profiling if self.tracer is None: if not self.target.file_exists(self.function_profile_file): - raise TargetStableError('Function profiling not supported. '\ + raise TargetStableError('Function profiling not supported. ' 'A kernel build with CONFIG_FUNCTION_PROFILER enable is required') self.function_string = _build_trace_functions(selected_functions) # If function profiling is enabled we always need at least one event. @@ -198,14 +248,20 @@ def event_is_in_list(event, events): self.event_string = _build_trace_events(selected_events) @classmethod - def _resolve_tracing_path(cls, target, path): + def _resolve_tracing_path(cls, target: 'Target', path: Optional[str]) -> str: + """ + Find path for tracefs + """ if path is None: return cls.find_tracing_path(target) else: return path @classmethod - def find_tracing_path(cls, target): + def find_tracing_path(cls, target: 'Target') -> str: + """ + get tracefs path from mount point + """ fs_list = [ fs.mount_point for fs in target.list_file_systems() @@ -219,14 +275,14 @@ def find_tracing_path(cls, target): @property @memoized - def available_tracers(self): + def available_tracers(self) -> List[str]: """ List of ftrace tracers supported by the target's kernel. """ return self.target.read_value(self.available_tracers_file).split(' ') @property - def available_events(self): + def available_events(self) -> List[str]: """ List of ftrace events supported by the target's kernel. """ @@ -234,16 +290,16 @@ def available_events(self): @property @memoized - def available_functions(self): + def available_functions(self) -> List[str]: """ List of functions whose tracing/profiling is supported by the target's kernel. """ return self.target.read_value(self.available_functions_file).splitlines() - def reset(self): + def reset(self) -> None: # Save kprobe events try: - kprobe_events = self.target.read_value(self.kprobe_events_file) + kprobe_events: Optional[str] = self.target.read_value(self.kprobe_events_file) except TargetStableError: kprobe_events = None @@ -254,10 +310,10 @@ def reset(self): # parameter, but unfortunately some events still end up there (e.g. # print event). So we still need to set that size, otherwise the buffer # might be too small and some event lost. - top_buffer_size = self.top_buffer_size if self.top_buffer_size else self.buffer_size + top_buffer_size: Optional[int] = self.top_buffer_size if self.top_buffer_size else self.buffer_size if top_buffer_size: self.target.write_value( - self.target.path.join(self.tracing_path, 'buffer_size_kb'), + self.target.path.join(self.tracing_path, 'buffer_size_kb') if self.target.path else '', top_buffer_size, verify=False ) @@ -271,17 +327,22 @@ def reset(self): self._reset_needed = False @asyncf - async def start(self): + async def start(self) -> None: + """ + Start capturing ftrace events according to the selected events/functions. + + :raises TargetStableError: If the target is unrooted or tracing setup fails. + """ self.start_time = time.time() if self._reset_needed: self.reset() if self.tracer is not None and 'function' in self.tracer: - tracecmd_functions = self.function_string + tracecmd_functions: Optional[str] = self.function_string else: tracecmd_functions = '' - tracer_string = '-p {}'.format(self.tracer) if self.tracer else '' + tracer_string: str = '-p {}'.format(self.tracer) if self.tracer else '' # Ensure kallsyms contains addresses if possible, so that function the # collected trace contains enough data for pretty printing @@ -304,33 +365,35 @@ async def start(self): self.mark_start() if 'cpufreq' in self.target.modules: self.logger.debug('Trace CPUFreq frequencies') - self.target.cpufreq.trace_frequencies() + cast(CpufreqModule, self.target.cpufreq).trace_frequencies() if 'cpuidle' in self.target.modules: self.logger.debug('Trace CPUIdle states') - self.target.cpuidle.perturb_cpus() + cast(Cpuidle, self.target.cpuidle).perturb_cpus() # Enable kernel function profiling if self.functions and self.tracer is None: target = self.target await target.async_manager.concurrently( - execute.asyn('echo nop > {}'.format(self.current_tracer_file), + target.execute.asyn('echo nop > {}'.format(self.current_tracer_file), as_root=True), - execute.asyn('echo 0 > {}'.format(self.function_profile_file), + target.execute.asyn('echo 0 > {}'.format(self.function_profile_file), + as_root=True), # type: ignore + target.execute.asyn('echo {} > {}'.format(self.function_string, self.ftrace_filter_file), as_root=True), - execute.asyn('echo {} > {}'.format(self.function_string, self.ftrace_filter_file), - as_root=True), - execute.asyn('echo 1 > {}'.format(self.function_profile_file), + target.execute.asyn('echo 1 > {}'.format(self.function_profile_file), as_root=True), ) - - def stop(self): + def stop(self) -> None: + """ + Stop capturing ftrace events. + """ # Disable kernel function profiling if self.functions and self.tracer is None: self.target.execute('echo 0 > {}'.format(self.function_profile_file), as_root=True) if 'cpufreq' in self.target.modules: self.logger.debug('Trace CPUFreq frequencies') - self.target.cpufreq.trace_frequencies() + cast(CpufreqModule, self.target.cpufreq).trace_frequencies() self.stop_time = time.time() if self.automark: self.mark_stop() @@ -338,22 +401,30 @@ def stop(self): timeout=TIMEOUT, as_root=True) self._reset_needed = True - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: if os.path.isdir(output_path): output_path = os.path.join(output_path, os.path.basename(self.target_output_file)) self.output_path = output_path - def get_data(self): + def get_data(self) -> CollectorOutput: + """ + Pull the captured trace data from the target, optionally generate a report, + and return a :class:`CollectorOutput`. + + :raises RuntimeError: If :attr:`output_path` is unset. + :return: A collector output referencing ftrace data. + :rtype: CollectorOutput + """ if self.output_path is None: raise RuntimeError("Output path was not set.") self.target.execute('{0} extract -B devlib -o {1}; chmod 666 {1}'.format(self.target_binary, - self.target_output_file), + self.target_output_file), timeout=TIMEOUT, as_root=True) # The size of trace.dat will depend on how long trace-cmd was running. # Therefore timout for the pull command must also be adjusted # accordingly. - pull_timeout = 10 * (self.stop_time - self.start_time) + pull_timeout: float = 10 * (cast(float, self.stop_time) - cast(float, self.start_time)) self.target.pull(self.target_output_file, self.output_path, timeout=pull_timeout) output = CollectorOutput() if not os.path.isfile(self.output_path): @@ -361,7 +432,7 @@ def get_data(self): else: output.append(CollectorOutputEntry(self.output_path, 'file')) if self.autoreport: - textfile = os.path.splitext(self.output_path)[0] + '.txt' + textfile: str = os.path.splitext(self.output_path)[0] + '.txt' if self.report_on_target: self.generate_report_on_target() self.target.pull(self.target_text_file, @@ -373,20 +444,27 @@ def get_data(self): self.view(self.output_path) return output - def get_stats(self, outfile): + def get_stats(self, outfile: str) -> Optional[Dict[int, + Dict[str, + Dict[str, Union[int, float]]]]]: + """ + get the processing statistics for the cpu + :param outfile: path to the output file + :type outfile: str + """ if not (self.functions and self.tracer is None): - return + return None if os.path.isdir(outfile): outfile = os.path.join(outfile, OUTPUT_PROFILE_FILE) # pylint: disable=protected-access - output = self.target._execute_util('ftrace_get_function_stats', - as_root=True) + output: str = self.target._execute_util('ftrace_get_function_stats', + as_root=True) - function_stats = {} + function_stats: Dict[int, Dict[str, Dict[str, Union[int, float]]]] = {} for line in output.splitlines(): # Match a new CPU dataset - match = CPU_RE.search(line) + match: Optional[Match[str]] = CPU_RE.search(line) if match: cpu_id = int(match.group(1)) function_stats[cpu_id] = {} @@ -397,13 +475,13 @@ def get_stats(self, outfile): if match: fname = match.group(1) function_stats[cpu_id][fname] = { - 'hits' : int(match.group(2)), - 'time' : float(match.group(3)), - 'avg' : float(match.group(4)), - 's_2' : float(match.group(5)), - } + 'hits': int(match.group(2)), + 'time': float(match.group(3)), + 'avg': float(match.group(4)), + 's_2': float(match.group(5)), + } self.logger.debug(" %s: %s", - fname, function_stats[cpu_id][fname]) + fname, function_stats[cpu_id][fname]) self.logger.debug("FTrace stats output [%s]...", outfile) with open(outfile, 'w') as fh: @@ -412,15 +490,25 @@ def get_stats(self, outfile): return function_stats - def report(self, binfile, destfile): + def report(self, binfile: str, destfile: str) -> None: + """ + Generate a textual report from a captured trace.dat file on the host. + + :param binfile: The path to the binary trace file. + :type binfile: str + :param destfile: The path to write the report. + :type destfile: str + :raises TargetStableError: If trace-cmd returns a non-zero exit code. + :raises HostError: If trace-cmd is not found on the host. + """ # To get the output of trace.dat, trace-cmd must be installed # This is done host-side because the generated file is very large try: - command = '{} report {} > {}'.format(self.host_binary, binfile, destfile) + command: str = '{} report {} > {}'.format(self.host_binary, binfile, destfile) self.logger.debug(command) process = subprocess.Popen(command, stderr=subprocess.PIPE, shell=True) - _, error = process.communicate() - error = error.decode(sys.stdout.encoding or 'utf-8', 'replace') + _, error_b = process.communicate() + error = error_b.decode(sys.stdout.encoding or 'utf-8', 'replace') if process.returncode: raise TargetStableError('trace-cmd returned non-zero exit code {}'.format(process.returncode)) if error: @@ -441,34 +529,52 @@ def report(self, binfile, destfile): except OSError: raise HostError('Could not find trace-cmd. Please make sure it is installed and is in PATH.') - def generate_report_on_target(self): - command = '{} report {} > {}'.format(self.target_binary, - self.target_output_file, - self.target_text_file) + def generate_report_on_target(self) -> None: + """ + generate report on target + """ + command: str = '{} report {} > {}'.format(self.target_binary, + self.target_output_file, + self.target_text_file) self.target.execute(command, timeout=TIMEOUT) - def view(self, binfile): + def view(self, binfile: str) -> None: + """ + KernelShark is a graphical front-end tool for visualizing trace data collected by trace-cmd. + It allows users to view and analyze kernel tracing data in a more intuitive and interactive way. + """ check_output('{} {}'.format(self.kernelshark, binfile), shell=True) - def teardown(self): - self.target.remove(self.target.path.join(self.target.working_directory, OUTPUT_TRACE_FILE)) + def teardown(self) -> None: + """ + Remove the trace.dat file from the target, cleaning up after data collection. + """ + self.target.remove(self.target.path.join(self.target.working_directory, OUTPUT_TRACE_FILE) if self.target.path else '') - def mark_start(self): + def mark_start(self) -> None: + """ + Write a start marker into the ftrace marker file. + """ self.target.write_value(self.marker_file, TRACE_MARKER_START, verify=False) - def mark_stop(self): + def mark_stop(self) -> None: + """ + Write a stop marker into the ftrace marker file. + """ self.target.write_value(self.marker_file, TRACE_MARKER_STOP, verify=False) -def _build_trace_events(events): - event_string = ' '.join(['-e {}'.format(e) for e in events]) +def _build_trace_events(events: List[str]) -> str: + event_string: str = ' '.join(['-e {}'.format(e) for e in events]) return event_string -def _build_trace_functions(functions): - function_string = " ".join(functions) + +def _build_trace_functions(functions: List[str]) -> str: + function_string: str = " ".join(functions) return function_string -def _build_graph_functions(functions, trace_children_functions): + +def _build_graph_functions(functions: List[str], trace_children_functions: bool) -> str: opt = 'g' if trace_children_functions else 'l' return ' '.join( '-{} {}'.format(opt, quote(f)) diff --git a/devlib/collector/logcat.py b/devlib/collector/logcat.py index 770c9054b..74314bca3 100644 --- a/devlib/collector/logcat.py +++ b/devlib/collector/logcat.py @@ -1,4 +1,4 @@ -# Copyright 2024 ARM Limited +# Copyright 2024-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,19 +18,36 @@ from devlib.collector import (CollectorBase, CollectorOutput, CollectorOutputEntry) from devlib.utils.android import LogcatMonitor +from typing import (cast, TYPE_CHECKING, List, Optional, + Union) +from io import TextIOWrapper +from tempfile import _TemporaryFileWrapper +if TYPE_CHECKING: + from devlib.target import AndroidTarget, Target + class LogcatCollector(CollectorBase): + """ + A collector that retrieves logs via `adb logcat` from an Android target. - def __init__(self, target, regexps=None, logcat_format=None): + :param target: The devlib Target (must be Android). + :type target: Target + :param regexps: A list of regular expressions to filter log lines (optional). + :type regexps: List(str) + :param logcat_format: The desired logcat output format (optional). + :type logcat_format: str + """ + def __init__(self, target: 'Target', regexps: Optional[List[str]] = None, + logcat_format: Optional[str] = None): super(LogcatCollector, self).__init__(target) self.regexps = regexps self.logcat_format = logcat_format - self.output_path = None - self._collecting = False - self._prev_log = None - self._monitor = None + self.output_path: Optional[str] = None + self._collecting: bool = False + self._prev_log: Optional[Union[TextIOWrapper, _TemporaryFileWrapper[str]]] = None + self._monitor: Optional[LogcatMonitor] = None - def reset(self): + def reset(self) -> None: """ Clear Collector data but do not interrupt collection """ @@ -40,39 +57,46 @@ def reset(self): if self._collecting: self._monitor.clear_log() elif self._prev_log: - os.remove(self._prev_log) + os.remove(cast(str, self._prev_log)) self._prev_log = None - def start(self): + def start(self) -> None: """ - Start collecting logcat lines + Start capturing logcat output. Raises RuntimeError if no output path is set. """ if self.output_path is None: raise RuntimeError("Output path was not set.") - self._monitor = LogcatMonitor(self.target, self.regexps, logcat_format=self.logcat_format) + self._monitor = LogcatMonitor(cast('AndroidTarget', self.target), self.regexps, logcat_format=self.logcat_format) if self._prev_log: # Append new data collection to previous collection - self._monitor.start(self._prev_log) + self._monitor.start(cast(str, self._prev_log)) else: self._monitor.start(self.output_path) self._collecting = True - def stop(self): + def stop(self) -> None: """ Stop collecting logcat lines """ if not self._collecting: raise RuntimeError('Logcat monitor not running, nothing to stop') - - self._monitor.stop() + if self._monitor: + self._monitor.stop() self._collecting = False - self._prev_log = self._monitor.logfile + self._prev_log = self._monitor.logfile if self._monitor else None - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: self.output_path = output_path - def get_data(self): + def get_data(self) -> CollectorOutput: + """ + Return a :class:`CollectorOutput` for the captured logcat data. + + :raises RuntimeError: If :attr:`output_path` is unset. + :return: A collector output referencing the logcat file. + :rtype: CollectorOutput + """ if self.output_path is None: raise RuntimeError("No data collected.") return CollectorOutput([CollectorOutputEntry(self.output_path, 'file')]) diff --git a/devlib/collector/perf.py b/devlib/collector/perf.py index a1389967a..03ac66581 100644 --- a/devlib/collector/perf.py +++ b/devlib/collector/perf.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,25 +16,30 @@ import os import re import time -from past.builtins import basestring, zip +from past.builtins import zip from devlib.host import PACKAGE_BIN_DIRECTORY from devlib.collector import (CollectorBase, CollectorOutput, CollectorOutputEntry) from devlib.utils.misc import ensure_file_directory_exists as _f +from typing import (cast, List, Dict, TYPE_CHECKING, Optional, + Union, Pattern) +from signal import Signals +if TYPE_CHECKING: + from devlib.target import Target -PERF_STAT_COMMAND_TEMPLATE = '{binary} {command} {options} {events} {sleep_cmd} > {outfile} 2>&1 ' -PERF_REPORT_COMMAND_TEMPLATE= '{binary} report {options} -i {datafile} > {outfile} 2>&1 ' -PERF_REPORT_SAMPLE_COMMAND_TEMPLATE= '{binary} report-sample {options} -i {datafile} > {outfile} ' -PERF_RECORD_COMMAND_TEMPLATE= '{binary} record {options} {events} -o {outfile}' +PERF_STAT_COMMAND_TEMPLATE: str = '{binary} {command} {options} {events} {sleep_cmd} > {outfile} 2>&1 ' +PERF_REPORT_COMMAND_TEMPLATE: str = '{binary} report {options} -i {datafile} > {outfile} 2>&1 ' +PERF_REPORT_SAMPLE_COMMAND_TEMPLATE: str = '{binary} report-sample {options} -i {datafile} > {outfile} ' +PERF_RECORD_COMMAND_TEMPLATE: str = '{binary} record {options} {events} -o {outfile}' -PERF_DEFAULT_EVENTS = [ +PERF_DEFAULT_EVENTS: List[str] = [ 'cpu-migrations', 'context-switches', ] -SIMPLEPERF_DEFAULT_EVENTS = [ +SIMPLEPERF_DEFAULT_EVENTS: List[str] = [ 'raw-cpu-cycles', 'raw-l1-dcache', 'raw-l1-dcache-refill', @@ -42,7 +47,8 @@ 'raw-instruction-retired', ] -DEFAULT_EVENTS = {'perf':PERF_DEFAULT_EVENTS, 'simpleperf':SIMPLEPERF_DEFAULT_EVENTS} +DEFAULT_EVENTS: Dict[str, List[str]] = {'perf': PERF_DEFAULT_EVENTS, 'simpleperf': SIMPLEPERF_DEFAULT_EVENTS} + class PerfCollector(CollectorBase): """ @@ -82,43 +88,66 @@ class PerfCollector(CollectorBase): Options can be obtained by running the following in the command line :: man perf-stat + + :param target: The devlib Target (rooted if on Android). + :type target: Target + :param perf_type: Either 'perf' or 'simpleperf'. + :type perf_type: str + :param command: The perf command to run (e.g. 'stat' or 'record'). + :type command: str + :param events: A list of events to collect. Defaults to built-in sets. + :type events: list(str) or None + :param optionstring: Extra CLI options (a string or list of strings). + :type optionstring: str or list(str) or None + :param report_options: Additional options for ``perf report``. + :type report_options: str or None + :param run_report_sample: If True, run the ``report-sample`` subcommand. + :type run_report_sample: bool + :param report_sample_options: Additional options for ``report-sample``. + :type report_sample_options: str or None + :param labels: Unique labels for each command or option set. + :type labels: list(str) or None + :param force_install: If True, reinstall perf even if it's already on the target. + :type force_install: bool + :param validate_events: If True, verify that requested events are available. + :type validate_events: bool """ def __init__(self, - target, - perf_type='perf', - command='stat', - events=None, - optionstring=None, - report_options=None, - run_report_sample=False, - report_sample_options=None, - labels=None, - force_install=False, - validate_events=True): + target: 'Target', + perf_type: str = 'perf', + command: str = 'stat', + events: Optional[List[str]] = None, + optionstring: Optional[Union[str, List[str]]] = None, + report_options: Optional[str] = None, + run_report_sample: bool = False, + report_sample_options: Optional[str] = None, + labels: Optional[List[str]] = None, + force_install: bool = False, + validate_events: bool = True): super(PerfCollector, self).__init__(target) self.force_install = force_install self.labels = labels self.report_options = report_options self.run_report_sample = run_report_sample self.report_sample_options = report_sample_options - self.output_path = None + self.output_path: Optional[str] = None self.validate_events = validate_events # Validate parameters if isinstance(optionstring, list): - self.optionstrings = optionstring + self.optionstrings: List[str] = optionstring else: - self.optionstrings = [optionstring] + self.optionstrings = [optionstring] if optionstring else [] if perf_type in ['perf', 'simpleperf']: - self.perf_type = perf_type + self.perf_type: str = perf_type else: raise ValueError('Invalid perf type: {}, must be perf or simpleperf'.format(perf_type)) if not events: - self.events = DEFAULT_EVENTS[self.perf_type] + self.events: List[str] = DEFAULT_EVENTS[self.perf_type] else: self.events = events - if isinstance(self.events, basestring): + if isinstance(self.events, str): self.events = [self.events] if not self.labels: self.labels = ['perf_{}'.format(i) for i in range(len(self.optionstrings))] @@ -133,51 +162,73 @@ def __init__(self, if report_sample_options and (command != 'record'): raise ValueError('report_sample_options specified, but command is not record') - self.binary = self.target.get_installed(self.perf_type) + self.binary: str = self.target.get_installed(self.perf_type) if self.force_install or not self.binary: self.binary = self._deploy_perf() if self.validate_events: self._validate_events(self.events) - self.commands = self._build_commands() + self.commands: List[str] = self._build_commands() - def reset(self): + def reset(self) -> None: self.target.killall(self.perf_type, as_root=self.target.is_rooted) self.target.remove(self.target.get_workpath('TemporaryFile*')) - for label in self.labels: - filepath = self._get_target_file(label, 'data') - self.target.remove(filepath) - filepath = self._get_target_file(label, 'rpt') - self.target.remove(filepath) - filepath = self._get_target_file(label, 'rptsamples') - self.target.remove(filepath) - - def start(self): + if self.labels: + for label in self.labels: + filepath = self._get_target_file(label, 'data') + self.target.remove(filepath) + filepath = self._get_target_file(label, 'rpt') + self.target.remove(filepath) + filepath = self._get_target_file(label, 'rptsamples') + self.target.remove(filepath) + + def start(self) -> None: + """ + Start the perf command(s) in the background on the target. + """ for command in self.commands: self.target.background(command, as_root=self.target.is_rooted) - def stop(self): - self.target.killall(self.perf_type, signal='SIGINT', + def stop(self) -> None: + """ + Send SIGINT to terminate the perf tool, finalizing any data files. + """ + self.target.killall(self.perf_type, signal=cast(Signals, 'SIGINT'), as_root=self.target.is_rooted) if self.perf_type == "perf" and self.command == "stat": # perf doesn't transmit the signal to its sleep call so handled here: self.target.killall('sleep', as_root=self.target.is_rooted) # NB: we hope that no other "important" sleep is on-going - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: + """ + Define where perf data or reports will be stored on the host. + + :param output_path: A directory or file path for storing perf results. + :type output_path: str + """ self.output_path = output_path - def get_data(self): + def get_data(self) -> CollectorOutput: + """ + Pull the perf data from the target to the host and optionally generate + textual reports. + + :raises RuntimeError: If :attr:`output_path` is unset. + :return: A collector output referencing the saved perf files. + :rtype: CollectorOutput + """ if self.output_path is None: raise RuntimeError("Output path was not set.") output = CollectorOutput() - + if self.labels is None: + raise RuntimeError("labels not set") for label in self.labels: if self.command == 'record': self._wait_for_data_file_write(label, self.output_path) - path = self._pull_target_file_to_host(label, 'rpt', self.output_path) + path: str = self._pull_target_file_to_host(label, 'rpt', self.output_path) output.append(CollectorOutputEntry(path, 'file')) if self.run_report_sample: report_samples_path = self._pull_target_file_to_host(label, 'rptsamples', self.output_path) @@ -187,16 +238,25 @@ def get_data(self): output.append(CollectorOutputEntry(path, 'file')) return output - def _deploy_perf(self): - host_executable = os.path.join(PACKAGE_BIN_DIRECTORY, - self.target.abi, self.perf_type) + def _deploy_perf(self) -> str: + """ + install perf on target + """ + host_executable: str = os.path.join(PACKAGE_BIN_DIRECTORY, + cast(str, self.target.abi), self.perf_type) return self.target.install(host_executable) - def _get_target_file(self, label, extension): + def _get_target_file(self, label: str, extension: str) -> Optional[str]: + """ + get file path on target + """ return self.target.get_workpath('{}.{}'.format(label, extension)) - def _build_commands(self): - commands = [] + def _build_commands(self) -> List[str]: + """ + build perf commands + """ + commands: List[str] = [] for opts, label in zip(self.optionstrings, self.labels): if self.command == 'stat': commands.append(self._build_perf_stat_command(opts, self.events, label)) @@ -204,50 +264,86 @@ def _build_commands(self): commands.append(self._build_perf_record_command(opts, label)) return commands - def _build_perf_stat_command(self, options, events, label): - event_string = ' '.join(['-e {}'.format(e) for e in events]) - sleep_cmd = 'sleep 1000' if self.perf_type == 'perf' else '' - command = PERF_STAT_COMMAND_TEMPLATE.format(binary = self.binary, - command = self.command, - options = options or '', - events = event_string, - sleep_cmd = sleep_cmd, - outfile = self._get_target_file(label, 'out')) + def _build_perf_stat_command(self, options: str, events: List[str], label) -> str: + """ + Construct a perf stat command string. + + :param options: Additional perf stat options. + :type options: str + :param events: The list of events to measure. + :type events: list(str) + :param label: A label to identify this command/run. + :type label: str + :return: A command string suitable for running on the target. + :rtype: str + """ + event_string: str = ' '.join(['-e {}'.format(e) for e in events]) + sleep_cmd: str = 'sleep 1000' if self.perf_type == 'perf' else '' + command: str = PERF_STAT_COMMAND_TEMPLATE.format(binary=self.binary, + command=self.command, + options=options or '', + events=event_string, + sleep_cmd=sleep_cmd, + outfile=self._get_target_file(label, 'out')) return command - def _build_perf_report_command(self, report_options, label): + def _build_perf_report_command(self, report_options: Optional[str], label: str) -> str: + """ + Construct a perf stat command string. + + :param options: Additional perf stat options. + :type options: str + :param events: The list of events to measure. + :type events: list(str) + :param label: A label to identify this command/run. + :type label: str + :return: A command string suitable for running on the target. + :rtype: str + """ command = PERF_REPORT_COMMAND_TEMPLATE.format(binary=self.binary, options=report_options or '', datafile=self._get_target_file(label, 'data'), outfile=self._get_target_file(label, 'rpt')) return command - def _build_perf_report_sample_command(self, label): + def _build_perf_report_sample_command(self, label: str) -> str: + """ + build perf report sample command + """ command = PERF_REPORT_SAMPLE_COMMAND_TEMPLATE.format(binary=self.binary, - options=self.report_sample_options or '', - datafile=self._get_target_file(label, 'data'), - outfile=self._get_target_file(label, 'rptsamples')) + options=self.report_sample_options or '', + datafile=self._get_target_file(label, 'data'), + outfile=self._get_target_file(label, 'rptsamples')) return command - def _build_perf_record_command(self, options, label): - event_string = ' '.join(['-e {}'.format(e) for e in self.events]) - command = PERF_RECORD_COMMAND_TEMPLATE.format(binary=self.binary, - options=options or '', - events=event_string, - outfile=self._get_target_file(label, 'data')) + def _build_perf_record_command(self, options: Optional[str], label: str) -> str: + """ + build perf record command + """ + event_string: str = ' '.join(['-e {}'.format(e) for e in self.events]) + command: str = PERF_RECORD_COMMAND_TEMPLATE.format(binary=self.binary, + options=options or '', + events=event_string, + outfile=self._get_target_file(label, 'data')) return command - def _pull_target_file_to_host(self, label, extension, output_path): - target_file = self._get_target_file(label, extension) - host_relpath = os.path.basename(target_file) - host_file = _f(os.path.join(output_path, host_relpath)) + def _pull_target_file_to_host(self, label: str, extension: str, output_path: str) -> str: + """ + pull a file from target to host + """ + target_file: Optional[str] = self._get_target_file(label, extension) + host_relpath: str = os.path.basename(target_file or '') + host_file: str = _f(os.path.join(output_path, host_relpath)) self.target.pull(target_file, host_file) return host_file - def _wait_for_data_file_write(self, label, output_path): - data_file_finished_writing = False - max_tries = 80 - current_tries = 0 + def _wait_for_data_file_write(self, label: str, output_path: str) -> None: + """ + wait for file write operation by perf + """ + data_file_finished_writing: bool = False + max_tries: int = 80 + current_tries: int = 0 while not data_file_finished_writing: files = self.target.execute('cd {} && ls'.format(self.target.get_workpath(''))) # Perf stores data in tempory files whilst writing to data output file. Check if they have been removed. @@ -259,15 +355,18 @@ def _wait_for_data_file_write(self, label, output_path): self.logger.warning('''writing {}.data file took longer than expected, file may not have written correctly'''.format(label)) data_file_finished_writing = True - report_command = self._build_perf_report_command(self.report_options, label) + report_command: str = self._build_perf_report_command(self.report_options, label) self.target.execute(report_command) if self.run_report_sample: report_sample_command = self._build_perf_report_sample_command(label) self.target.execute(report_sample_command) - def _validate_events(self, events): - available_events_string = self.target.execute('{} list | {} cat'.format(self.perf_type, self.target.busybox)) - available_events = available_events_string.splitlines() + def _validate_events(self, events: List[str]) -> None: + """ + validate events against available perf events on target + """ + available_events_string: str = self.target.execute('{} list | {} cat'.format(self.perf_type, self.target.busybox)) + available_events: List[str] = available_events_string.splitlines() for available_event in available_events: if available_event == '': continue @@ -275,7 +374,7 @@ def _validate_events(self, events): available_events.append(available_event.split('OR')[1]) available_events[available_events.index(available_event)] = available_event.split()[0].strip() # Raw hex event codes can also be passed in that do not appear on perf/simpleperf list, prefixed with 'r' - raw_event_code_regex = re.compile(r"^r(0x|0X)?[A-Fa-f0-9]+$") + raw_event_code_regex: Pattern[str] = re.compile(r"^r(0x|0X)?[A-Fa-f0-9]+$") for event in events: if event in available_events or re.match(raw_event_code_regex, event): continue diff --git a/devlib/collector/perfetto.py b/devlib/collector/perfetto.py index c5070e03a..a30848337 100644 --- a/devlib/collector/perfetto.py +++ b/devlib/collector/perfetto.py @@ -1,4 +1,4 @@ -# Copyright 2023 ARM Limited +# Copyright 2023-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,6 +21,9 @@ from devlib.collector import (CollectorBase, CollectorOutput, CollectorOutputEntry) from devlib.exception import TargetStableError, HostError +from typing import TYPE_CHECKING, Optional +if TYPE_CHECKING: + from devlib.target import Target, BackgroundCommand OUTPUT_PERFETTO_TRACE = 'devlib-trace.perfetto-trace' @@ -53,29 +56,36 @@ class PerfettoCollector(CollectorBase): For more information consult the official documentation: https://perfetto.dev/docs/ + + :param target: The devlib Target. + :type target: Target + :param config: Path to a Perfetto text config (proto) if any. + :type config: str or None + :param force_tracebox: If True, force usage of tracebox instead of native Perfetto. + :type force_tracebox: bool """ - def __init__(self, target, config=None, force_tracebox=False): + def __init__(self, target: 'Target', config: Optional[str] = None, force_tracebox: bool = False): super().__init__(target) - self.bg_cmd = None + self.bg_cmd: Optional['BackgroundCommand'] = None self.config = config - self.target_binary = 'perfetto' - target_output_path = self.target.working_directory + self.target_binary: str = 'perfetto' + target_output_path: Optional[str] = self.target.working_directory - install_tracebox = force_tracebox or (target.os in ['linux', 'android'] and not target.is_running('traced')) + install_tracebox: bool = force_tracebox or (target.os in ['linux', 'android'] and not target.is_running('traced')) # Install Perfetto through tracebox if install_tracebox: self.target_binary = 'tracebox' if not self.target.get_installed(self.target_binary): - host_executable = os.path.join(PACKAGE_BIN_DIRECTORY, - self.target.abi, self.target_binary) + host_executable: str = os.path.join(PACKAGE_BIN_DIRECTORY, + self.target.abi or '', self.target_binary) if not os.path.exists(host_executable): raise HostError("{} not found on the host".format(self.target_binary)) self.target.install(host_executable) # Use Android's built-in Perfetto elif target.os == 'android': - os_version = target.os_version['release'] + os_version: str = target.os_version['release'] if int(os_version) >= 9: # Android requires built-in Perfetto to write to this directory target_output_path = '/data/misc/perfetto-traces' @@ -83,11 +93,16 @@ def __init__(self, target, config=None, force_tracebox=False): if int(os_version) <= 10: target.execute('setprop persist.traced.enable 1') - self.target_output_file = target.path.join(target_output_path, OUTPUT_PERFETTO_TRACE) + self.target_output_file = target.path.join(target_output_path, OUTPUT_PERFETTO_TRACE) if target.path else '' + + def start(self) -> None: + """ + Start Perfetto tracing by feeding the config to the perfetto (or tracebox) binary. - def start(self): - cmd = "{} cat {} | {} --txt -c - -o {}".format( - quote(self.target.busybox), quote(self.config), quote(self.target_binary), quote(self.target_output_file) + :raises TargetStableError: If perfetto/tracebox cannot be started on the target. + """ + cmd: str = "{} cat {} | {} --txt -c - -o {}".format( + quote(self.target.busybox or ''), quote(self.config or ''), quote(self.target_binary), quote(self.target_output_file) ) # start tracing if self.bg_cmd is None: @@ -95,17 +110,34 @@ def start(self): else: raise TargetStableError('Perfetto collector is not re-entrant') - def stop(self): - # stop tracing - self.bg_cmd.cancel() - self.bg_cmd = None - - def set_output(self, output_path): + def stop(self) -> None: + """ + Stop Perfetto tracing and finalize the trace file. + """ + if self.bg_cmd: + # stop tracing + self.bg_cmd.cancel() + self.bg_cmd = None + + def set_output(self, output_path: str) -> None: + """ + Specify where the trace file will be pulled on the host. + + :param output_path: The file path or directory on the host. + :type output_path: str + """ if os.path.isdir(output_path): output_path = os.path.join(output_path, os.path.basename(self.target_output_file)) self.output_path = output_path - def get_data(self): + def get_data(self) -> CollectorOutput: + """ + Pull the trace file from the target and return a :class:`CollectorOutput`. + + :raises RuntimeError: If :attr:`output_path` is unset or if no trace file exists. + :return: A collector output referencing the Perfetto trace file. + :rtype: CollectorOutput + """ if self.output_path is None: raise RuntimeError("Output path was not set.") if not self.target.file_exists(self.target_output_file): diff --git a/devlib/collector/screencapture.py b/devlib/collector/screencapture.py index 399227fc8..d80698a92 100644 --- a/devlib/collector/screencapture.py +++ b/devlib/collector/screencapture.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,27 +22,42 @@ from devlib.collector import (CollectorBase, CollectorOutput, CollectorOutputEntry) from devlib.exception import WorkerThreadError +from typing import TYPE_CHECKING, Optional, cast +if TYPE_CHECKING: + from devlib.target import Target class ScreenCapturePoller(threading.Thread): - - def __init__(self, target, period, timeout=30): + """ + A background thread that periodically captures screenshots from the target. + + :param target: The devlib Target. + :type target: Target + :param period: Interval in seconds between captures. If None, the logic may differ. + :type period: float or None + :param timeout: Maximum time to wait for the poller thread to stop. + :type timeout: int + """ + def __init__(self, target: 'Target', period: Optional[float], timeout: int = 30): super(ScreenCapturePoller, self).__init__() self.target = target - self.logger = logging.getLogger('screencapture') + self.logger: logging.Logger = logging.getLogger('screencapture') self.period = period self.timeout = timeout self.stop_signal = threading.Event() self.lock = threading.Lock() - self.last_poll = 0 - self.daemon = True - self.exc = None - self.output_path = None + self.last_poll: float = 0 + self.daemon: bool = True + self.exc: Optional[Exception] = None + self.output_path: Optional[str] = None - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: self.output_path = output_path - def run(self): + def run(self) -> None: + """ + Continuously capture screenshots at the specified interval until stopped. + """ self.logger.debug('Starting screen capture polling') try: if self.output_path is None: @@ -52,13 +67,16 @@ def run(self): break with self.lock: current_time = time.time() - if (current_time - self.last_poll) >= self.period: + if (current_time - self.last_poll) >= cast(float, self.period): self.poll() time.sleep(0.5) except Exception: # pylint: disable=W0703 self.exc = WorkerThreadError(self.name, sys.exc_info()) - def stop(self): + def stop(self) -> None: + """ + Signal the thread to stop and wait for it to exit, up to :attr:`timeout`. + """ self.logger.debug('Stopping screen capture polling') self.stop_signal.set() self.join(self.timeout) @@ -67,34 +85,49 @@ def stop(self): if self.exc: raise self.exc # pylint: disable=E0702 - def poll(self): + def poll(self) -> None: self.last_poll = time.time() - self.target.capture_screen(os.path.join(self.output_path, "screencap_{ts}.png")) + self.target.capture_screen(os.path.join(self.output_path or '', "screencap_{ts}.png")) class ScreenCaptureCollector(CollectorBase): - - def __init__(self, target, period=None): + """ + A collector that periodically captures screenshots from a target device. + + :param target: The devlib Target. + :type target: Target + :param period: Interval in seconds between captures. + :type period: float or None + """ + def __init__(self, target: 'Target', period: Optional[float] = None): super(ScreenCaptureCollector, self).__init__(target) - self._collecting = False - self.output_path = None + self._collecting: bool = False + self.output_path: Optional[str] = None self.period = period self.target = target - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: self.output_path = output_path - def reset(self): + def reset(self) -> None: self._poller = ScreenCapturePoller(self.target, self.period) - def get_data(self): + def get_data(self) -> CollectorOutput: + """ + Return a :class:`CollectorOutput` referencing the directory of captured screenshots. + + :return: A collector output referencing the screenshot directory. + :rtype: CollectorOutput + """ if self.output_path is None: raise RuntimeError("No data collected.") return CollectorOutput([CollectorOutputEntry(self.output_path, 'directory')]) - def start(self): + def start(self) -> None: """ - Start collecting the screenshots + Start the screen capture poller thread. + + :raises RuntimeError: If :attr:`output_path` is unset. """ if self.output_path is None: raise RuntimeError("Output path was not set.") @@ -102,9 +135,9 @@ def start(self): self._poller.start() self._collecting = True - def stop(self): + def stop(self) -> None: """ - Stop collecting the screenshots + Stop the screen capture poller thread. """ if not self._collecting: raise RuntimeError('Screen capture collector is not running, nothing to stop') diff --git a/devlib/collector/serial_trace.py b/devlib/collector/serial_trace.py index 7df9ab3ff..35cbcb308 100644 --- a/devlib/collector/serial_trace.py +++ b/devlib/collector/serial_trace.py @@ -1,4 +1,4 @@ -# Copyright 2024 ARM Limited +# Copyright 2024-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,27 +18,45 @@ from devlib.collector import (CollectorBase, CollectorOutput, CollectorOutputEntry) from devlib.utils.serial_port import get_connection +from typing import TextIO, cast, TYPE_CHECKING, Optional +from pexpect import fdpexpect +from serial import Serial +from io import BufferedWriter +if TYPE_CHECKING: + from devlib.target import Target class SerialTraceCollector(CollectorBase): - + """ + A collector that reads serial output and saves it to a file. + + :param target: The devlib Target. + :type target: Target + :param serial_port: The serial port to open. + :type serial_port: int + :param baudrate: The baud rate (bits per second). + :type baudrate: int + :param timeout: A timeout for serial reads, in seconds. + :type timeout: int + """ @property - def collecting(self): + def collecting(self) -> bool: return self._collecting - def __init__(self, target, serial_port, baudrate, timeout=20): + def __init__(self, target: 'Target', serial_port: int, + baudrate: int, timeout: int = 20): super(SerialTraceCollector, self).__init__(target) self.serial_port = serial_port self.baudrate = baudrate self.timeout = timeout - self.output_path = None + self.output_path: Optional[str] = None - self._serial_target = None - self._conn = None - self._outfile_fh = None - self._collecting = False + self._serial_target: Optional[fdpexpect.fdspawn] = None + self._conn: Optional[Serial] = None + self._outfile_fh: Optional[BufferedWriter] = None + self._collecting: bool = False - def reset(self): + def reset(self) -> None: if self._collecting: raise RuntimeError("reset was called whilst collecting") @@ -46,24 +64,34 @@ def reset(self): self._outfile_fh.close() self._outfile_fh = None - def start(self): + def start(self) -> None: + """ + Open the serial connection and write all data to :attr:`output_path`. + + :raises RuntimeError: If already collecting or :attr:`output_path` is unset. + """ if self._collecting: raise RuntimeError("start was called whilst collecting") if self.output_path is None: raise RuntimeError("Output path was not set.") self._outfile_fh = open(self.output_path, 'wb') - start_marker = "-------- Starting serial logging --------\n" + start_marker: str = "-------- Starting serial logging --------\n" self._outfile_fh.write(start_marker.encode('utf-8')) self._serial_target, self._conn = get_connection(port=self.serial_port, baudrate=self.baudrate, timeout=self.timeout, - logfile=self._outfile_fh, - init_dtr=0) + logfile=cast(TextIO, self._outfile_fh), + init_dtr=False) self._collecting = True - def stop(self): + def stop(self) -> None: + """ + Close the serial connection and finalize the log file. + + :raises RuntimeError: If not currently collecting. + """ if not self._collecting: raise RuntimeError("stop was called whilst not collecting") @@ -71,25 +99,34 @@ def stop(self): # do something so that it interacts with the serial device, # and hence updates the logfile. try: - self._serial_target.expect(".", timeout=1) + if self._serial_target: + self._serial_target.expect(".", timeout=1) except TIMEOUT: pass - - self._serial_target.close() + if self._serial_target: + self._serial_target.close() del self._conn - stop_marker = "-------- Stopping serial logging --------\n" - self._outfile_fh.write(stop_marker.encode('utf-8')) - self._outfile_fh.flush() - self._outfile_fh.close() - self._outfile_fh = None + stop_marker: str = "-------- Stopping serial logging --------\n" + if self._outfile_fh: + self._outfile_fh.write(stop_marker.encode('utf-8')) + self._outfile_fh.flush() + self._outfile_fh.close() + self._outfile_fh = None self._collecting = False - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: self.output_path = output_path - def get_data(self): + def get_data(self) -> CollectorOutput: + """ + Return a :class:`CollectorOutput` referencing the saved serial log file. + + :raises RuntimeError: If :attr:`output_path` is unset. + :return: A collector output referencing the serial log file. + :rtype: CollectorOutput + """ if self._collecting: raise RuntimeError("get_data was called whilst collecting") if self.output_path is None: diff --git a/devlib/collector/systrace.py b/devlib/collector/systrace.py index 4e29cf11a..b4871dd90 100644 --- a/devlib/collector/systrace.py +++ b/devlib/collector/systrace.py @@ -1,4 +1,4 @@ -# Copyright 2024 ARM Limited +# Copyright 2024-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,9 +21,12 @@ from devlib.exception import TargetStableError, HostError import devlib.utils.android from devlib.utils.misc import memoized +from typing import TYPE_CHECKING, List, Optional, Union, TextIO +from subprocess import Popen +if TYPE_CHECKING: + from devlib.target import AndroidTarget - -DEFAULT_CATEGORIES = [ +DEFAULT_CATEGORIES: List[str] = [ 'gfx', 'view', 'sched', @@ -31,6 +34,7 @@ 'idle' ] + class SystraceCollector(CollectorBase): """ A trace collector based on Systrace @@ -56,32 +60,35 @@ class SystraceCollector(CollectorBase): @property @memoized - def available_categories(self): - lines = subprocess.check_output( - [self.systrace_binary, '-l'], universal_newlines=True + def available_categories(self) -> List[str]: + """ + list of available categories + """ + lines: List[str] = subprocess.check_output( + [self.systrace_binary or '', '-l'], universal_newlines=True ).splitlines() return [line.split()[0] for line in lines if line] - def __init__(self, target, - categories=None, - buffer_size=None, - strict=False): + def __init__(self, target: 'AndroidTarget', + categories: Optional[str] = None, + buffer_size: Optional[int] = None, + strict: bool = False): super(SystraceCollector, self).__init__(target) - self.categories = categories or DEFAULT_CATEGORIES + self.categories: Union[str, List[str]] = categories or DEFAULT_CATEGORIES self.buffer_size = buffer_size - self.output_path = None + self.output_path: Optional[str] = None - self._systrace_process = None - self._outfile_fh = None + self._systrace_process: Optional[Popen] = None + self._outfile_fh: Optional[TextIO] = None # Try to find a systrace binary - self.systrace_binary = None + self.systrace_binary: Optional[str] = None - platform_tools = devlib.utils.android.platform_tools - systrace_binary_path = os.path.join(platform_tools, 'systrace', 'systrace.py') + platform_tools: str = devlib.utils.android.platform_tools # type: ignore + systrace_binary_path: str = os.path.join(platform_tools, 'systrace', 'systrace.py') if not os.path.isfile(systrace_binary_path): raise HostError('Could not find any systrace binary under {}'.format(platform_tools)) @@ -90,7 +97,7 @@ def __init__(self, target, # Filter the requested categories for category in self.categories: if category not in self.available_categories: - message = 'Category [{}] not available for tracing'.format(category) + message: str = 'Category [{}] not available for tracing'.format(category) if strict: raise TargetStableError(message) self.logger.warning(message) @@ -102,11 +109,14 @@ def __init__(self, target, def __del__(self): self.reset() - def _build_cmd(self): - self._outfile_fh = open(self.output_path, 'w') + def _build_cmd(self) -> None: + """ + build command + """ + self._outfile_fh = open(self.output_path or '', 'w') # pylint: disable=attribute-defined-outside-init - self.systrace_cmd = 'python2 -u {} -o {} -e {}'.format( + self.systrace_cmd: str = 'python2 -u {} -o {} -e {}'.format( self.systrace_binary, self._outfile_fh.name, self.target.adb_name @@ -117,11 +127,14 @@ def _build_cmd(self): self.systrace_cmd += ' {}'.format(' '.join(self.categories)) - def reset(self): + def reset(self) -> None: if self._systrace_process: self.stop() - def start(self): + def start(self) -> None: + """ + Start systrace, typically running a systrace command in the background. + """ if self._systrace_process: raise RuntimeError("Tracing is already underway, call stop() first") if self.output_path is None: @@ -138,9 +151,13 @@ def start(self): shell=True, universal_newlines=True ) - self._systrace_process.stdout.read(1) + if self._systrace_process.stdout: + self._systrace_process.stdout.read(1) - def stop(self): + def stop(self) -> None: + """ + Stop systrace and finalize the trace file. + """ if not self._systrace_process: raise RuntimeError("No tracing to stop, call start() first") @@ -152,10 +169,17 @@ def stop(self): self._outfile_fh.close() self._outfile_fh = None - def set_output(self, output_path): + def set_output(self, output_path: str) -> None: self.output_path = output_path - def get_data(self): + def get_data(self) -> CollectorOutput: + """ + Pull the trace HTML (or raw data) from the target and return a + :class:`CollectorOutput`. + + :return: A collector output referencing the systrace file. + :rtype: CollectorOutput + """ if self._systrace_process: raise RuntimeError("Tracing is underway, call stop() first") if self.output_path is None: diff --git a/devlib/connection.py b/devlib/connection.py index 460997580..d43e1f0b1 100644 --- a/devlib/connection.py +++ b/devlib/connection.py @@ -1,4 +1,4 @@ -# Copyright 2024 ARM Limited +# Copyright 2024-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,15 +26,88 @@ import fcntl from devlib.utils.misc import InitCheckpoint +from devlib.utils.annotation_helpers import SubprocessCommand +from typing import (Optional, TYPE_CHECKING, Set, + Tuple, IO, Dict, List, Union, + Generator, Callable) +from typing_extensions import Protocol, Literal + +if TYPE_CHECKING: + from signal import Signals + from subprocess import Popen + from threading import Lock, Thread, Event + from logging import Logger + from paramiko.channel import Channel + from paramiko.sftp_client import SFTPClient + from scp import SCPClient + + +class HasInitialized(Protocol): + """ + Protocol indicating that the object includes an ``initialized`` property + and a ``close()`` method. Used to ensure safe clean-up in destructors. + + :ivar initialized: ``True`` if the object finished initializing successfully, + otherwise ``False`` if initialization failed or is incomplete. + :vartype initialized: bool + """ + initialized: bool + + # other functions referred by the object with the initialized property + def close(self) -> None: + """ + Close method expected on objects that provide ``initialized``. + """ + ... -_KILL_TIMEOUT = 3 +_KILL_TIMEOUT: int = 3 +""" +int: The default time (in seconds) to wait between sending SIGTERM and SIGKILL +during process cancellation (see :meth:`BackgroundCommand.cancel`). +""" -def _kill_pgid_cmd(pgid, sig, busybox): + +def _kill_pgid_cmd(pgid: int, sig: 'Signals', busybox: Optional[str]) -> str: + """ + Construct a shell command string that sends a specified signal to a given + process group. + + :param pgid: The process group ID (PGID) to signal. + :type pgid: int + :param sig: The signal to send (e.g., SIGTERM, SIGKILL). + :type sig: signal.Signals + :param busybox: Path to a busybox binary on the target, if any. If None, + the command may assume `kill` is already in PATH. + :type busybox: str or None + :return: A complete shell command that, when run, kills the PGID with the given signal. + :rtype: str + """ return '{} kill -{} -{}'.format(busybox, sig.value, pgid) -def _popen_communicate(bg, popen, input, timeout): + +def _popen_communicate(bg: 'BackgroundCommand', popen: 'Popen', input: bytes, + timeout: Optional[int]) -> Tuple[Optional[bytes], Optional[bytes]]: + """ + Wrapper around ``popen.communicate(...)`` to handle timeouts and + cancellation of a background command. + + :param bg: The associated :class:`BackgroundCommand` object that may be canceled. + :type bg: BackgroundCommand + :param popen: The :class:`subprocess.Popen` instance to communicate with. + :type popen: subprocess.Popen + :param input: Bytes to send to stdin. + :type input: bytes + :param timeout: The timeout in seconds or None for no timeout. + :type timeout: int or None + :return: A tuple (stdout, stderr) if the command completes successfully. + :rtype: (bytes or None, bytes or None) + :raises subprocess.TimeoutExpired: If the command doesn't complete in time. + :raises subprocess.CalledProcessError: If the command exits with a non-zero return code. + """ try: + stdout: Optional[bytes] + stderr: Optional[bytes] stdout, stderr = popen.communicate(input=input, timeout=timeout) except subprocess.TimeoutExpired: bg.cancel() @@ -55,19 +128,43 @@ def _popen_communicate(bg, popen, input, timeout): class ConnectionBase(InitCheckpoint): """ Base class for all connections. + A :class:`Connection` abstracts an actual physical connection to a device. The + first connection is created when :func:`Target.connect` method is called. If a + :class:`~devlib.target.Target` is used in a multi-threaded environment, it will + maintain a connection for each thread in which it is invoked. This allows + the same target object to be used in parallel in multiple threads. + + :class:`Connection` s will be automatically created and managed by + :class:`~devlib.target.Target` s, so there is usually no reason to create one + manually. Instead, configuration for a :class:`Connection` is passed as + `connection_settings` parameter when creating a + :class:`~devlib.target.Target`. The connection to be used target is also + specified on instantiation by `conn_cls` parameter, though all concrete + :class:`~devlib.target.Target` implementations will set an appropriate + default, so there is typically no need to specify this explicitly. + + :param poll_transfers: If True, manage file transfers by polling for progress. + :type poll_transfers: bool + :param start_transfer_poll_delay: Delay in seconds before first checking a + file transfer's progress. + :type start_transfer_poll_delay: int + :param total_transfer_timeout: Cancel transfers if they exceed this many seconds. + :type total_transfer_timeout: int + :param transfer_poll_period: Interval (seconds) between transfer progress checks. + :type transfer_poll_period: int """ def __init__( self, - poll_transfers=False, - start_transfer_poll_delay=30, - total_transfer_timeout=3600, - transfer_poll_period=30, + poll_transfers: bool = False, + start_transfer_poll_delay: int = 30, + total_transfer_timeout: int = 3600, + transfer_poll_period: int = 30, ): - self._current_bg_cmds = set() - self._closed = False - self._close_lock = threading.Lock() - self.busybox = None - self.logger = logging.getLogger('Connection') + self._current_bg_cmds: Set['BackgroundCommand'] = set() + self._closed: bool = False + self._close_lock: Lock = threading.Lock() + self.busybox: Optional[str] = None + self.logger: Logger = logging.getLogger('Connection') self.transfer_manager = TransferManager( self, @@ -76,14 +173,17 @@ def __init__( transfer_poll_period=transfer_poll_period, ) if poll_transfers else NoopTransferManager() - - def cancel_running_command(self): - bg_cmds = set(self._current_bg_cmds) + def cancel_running_command(self) -> Optional[bool]: + """ + Cancel all active background commands tracked by this connection. + """ + bg_cmds: Set['BackgroundCommand'] = set(self._current_bg_cmds) for bg_cmd in bg_cmds: bg_cmd.cancel() + return None @abstractmethod - def _close(self): + def _close(self) -> None: """ Close the connection. @@ -92,11 +192,14 @@ def _close(self): be called from multiple threads at once. """ - def close(self): - - def finish_bg(): - bg_cmds = set(self._current_bg_cmds) - n = len(bg_cmds) + def close(self) -> None: + """ + Cancel any ongoing commands and finalize the connection. Safe to call multiple times, + does nothing after the first invocation. + """ + def finish_bg() -> None: + bg_cmds: Set['BackgroundCommand'] = set(self._current_bg_cmds) + n: int = len(bg_cmds) if n: self.logger.debug(f'Canceling {n} background commands before closing connection') for bg_cmd in bg_cmds: @@ -113,13 +216,42 @@ def finish_bg(): # Ideally, that should not be relied upon but that will improve the chances # of the connection being properly cleaned up when it's not in use anymore. - def __del__(self): + def __del__(self: HasInitialized): + """ + Destructor ensuring the connection is closed if not already. Only runs + if object initialization succeeded (initialized=True). + """ # Since __del__ will be called if an exception is raised in __init__ # (e.g. we cannot connect), we only run close() when we are sure # __init__ has completed successfully. if self.initialized: self.close() + @abstractmethod + def execute(self, command: 'SubprocessCommand', timeout: Optional[int] = None, + check_exit_code: bool = True, as_root: Optional[bool] = False, + strip_colors: bool = True, will_succeed: bool = False) -> str: + """ + Execute a shell command and return the combined stdout/stderr. + + :param command: Command string or SubprocessCommand detailing the command to run. + :type command: SubprocessCommand + :param timeout: Timeout in seconds (None for no limit). + :type timeout: int or None + :param check_exit_code: If True, raise an error if exit code is non-zero. + :type check_exit_code: bool + :param as_root: If True, attempt to run with elevated privileges. + :type as_root: bool or None + :param strip_colors: Remove ANSI color codes from output if True. + :type strip_colors: bool + :param will_succeed: If True, interpret a failing command as a transient environment error. + :type will_succeed: bool + :returns: The command's combined stdout and stderr. + :rtype: str + :raises DevlibTransientError: If the command fails and is considered transient (will_succeed=True). + :raises DevlibStableError: If the command fails in a stable way (exit code != 0, or other error). + """ + class BackgroundCommand(ABC): """ @@ -128,9 +260,12 @@ class BackgroundCommand(ABC): Instances of this class can be used as context managers, with the same semantic as :class:`subprocess.Popen`. + + :param conn: The connection that owns this background command. + :type conn: ConnectionBase """ - def __init__(self, conn): + def __init__(self, conn: 'ConnectionBase'): self.conn = conn # Poll currently opened background commands on that connection to make @@ -147,17 +282,23 @@ def __init__(self, conn): conn._current_bg_cmds.add(self) - def _deregister(self): + def _deregister(self) -> None: + """ + deregister the background command + """ try: self.conn._current_bg_cmds.remove(self) except KeyError: pass @abstractmethod - def _send_signal(self, sig): + def _send_signal(self, sig: 'Signals') -> None: + """ + Subclass-specific implementation to send a signal (e.g., SIGTERM) to the process group. + """ pass - def send_signal(self, sig): + def send_signal(self, sig: 'Signals') -> None: """ Send a POSIX signal to the background command's process group ID (PGID). @@ -171,16 +312,19 @@ def send_signal(self, sig): # Deregister if the command has finished self.poll() - def kill(self): + def kill(self) -> None: """ Send SIGKILL to the background command. """ self.send_signal(signal.SIGKILL) - def cancel(self, kill_timeout=_KILL_TIMEOUT): + def cancel(self, kill_timeout: int = _KILL_TIMEOUT) -> None: """ Try to gracefully terminate the process by sending ``SIGTERM``, then waiting for ``kill_timeout`` to send ``SIGKILL``. + + :param kill_timeout: Seconds to wait between SIGTERM and SIGKILL. + :type kill_timeout: int """ try: if self.poll() is None: @@ -189,30 +333,44 @@ def cancel(self, kill_timeout=_KILL_TIMEOUT): self._deregister() @abstractmethod - def _cancel(self, kill_timeout): + def _cancel(self, kill_timeout: int) -> None: """ - Method to override in subclasses to implement :meth:`cancel`. + Subclass-specific logic for :meth:`cancel`. Usually sends SIGTERM, waits, + then sends SIGKILL if needed. """ pass @abstractmethod - def _wait(self): + def _wait(self) -> int: + """ + Wait for the command to complete. Return its exit code. + """ pass - def wait(self): + def wait(self) -> int: """ - Block until the background command completes, and return its exit code. + Block until the command completes, returning the exit code. + + :returns: The exit code of the command. + :rtype: int """ try: return self._wait() finally: self._deregister() - def communicate(self, input=b'', timeout=None): + def communicate(self, input: bytes = b'', timeout: Optional[int] = None) -> Tuple[Optional[bytes], Optional[bytes]]: """ - Block until the background command completes while reading stdout and stderr. - Return ``tuple(stdout, stderr)``. If the return code is non-zero, - raises a :exc:`subprocess.CalledProcessError` exception. + Write to stdin and read all data from stdout/stderr until the command exits. + + :param input: Bytes to send to stdin. + :type input: bytes + :param timeout: Max time to wait for the command to exit, or None if indefinite. + :type timeout: int or None + :returns: A tuple of (stdout, stderr) if the command exits cleanly. + :rtype: Tuple[Optional[bytes], Optional[bytes]] + :raises subprocess.TimeoutExpired: If the process runs past the timeout. + :raises subprocess.CalledProcessError: If the process exits with a non-zero code. """ try: return self._communicate(input=input, timeout=timeout) @@ -220,16 +378,26 @@ def communicate(self, input=b'', timeout=None): self.close() @abstractmethod - def _communicate(self, input, timeout): + def _communicate(self, input: bytes, timeout: Optional[int]) -> Tuple[Optional[bytes], Optional[bytes]]: + """ + Method to override in subclasses to implement :meth:`communicate`. + """ pass @abstractmethod - def _poll(self): + def _poll(self) -> Optional[int]: + """ + Method to override in subclasses to implement :meth:`poll`. + """ pass - def poll(self): + def poll(self) -> Optional[int]: """ - Return exit code if the command has exited, None otherwise. + Return the exit code if the command has finished, otherwise None. + Deregisters if the command is done. + + :returns: Exit code or None if ongoing. + :rtype: int or None """ retcode = self._poll() if retcode is not None: @@ -238,28 +406,28 @@ def poll(self): @property @abstractmethod - def stdin(self): + def stdin(self) -> Optional[IO]: """ - File-like object connected to the background's command stdin. + A file-like object representing this command's standard input. May be None if unsupported. """ @property @abstractmethod - def stdout(self): + def stdout(self) -> Optional[IO]: """ - File-like object connected to the background's command stdout. + A file-like object representing this command's standard output. May be None. """ @property @abstractmethod - def stderr(self): + def stderr(self) -> Optional[IO]: """ - File-like object connected to the background's command stderr. + A file-like object representing this command's standard error. May be None. """ @property @abstractmethod - def pid(self): + def pid(self) -> int: """ Process Group ID (PGID) of the background command. @@ -271,14 +439,18 @@ def pid(self): """ @abstractmethod - def _close(self): + def _close(self) -> int: + """ + Subclass hook for final cleanup: close streams, wait for exit, return exit code. + """ pass - def close(self): + def close(self) -> int: """ - Close all opened streams and then wait for command completion. + Close any open streams and finalize the command. Return exit code. - :returns: Exit code of the command. + :returns: The command's final exit code. + :rtype: int .. note:: If the command is writing to its stdout/stderr, it might be blocked on that and die when the streams are closed. @@ -297,42 +469,51 @@ def __exit__(self, *args, **kwargs): class PopenBackgroundCommand(BackgroundCommand): """ - :class:`subprocess.Popen`-based background command. + Runs a command via ``subprocess.Popen`` in the background. Signals are sent + to the process group. Streams are accessible via ``stdin``, ``stdout``, and ``stderr``. + + :param conn: The parent connection. + :type conn: ConnectionBase + :param popen: The Popen object controlling the shell command. + :type popen: Popen """ - def __init__(self, conn, popen): + def __init__(self, conn: 'ConnectionBase', popen: 'Popen'): super().__init__(conn=conn) self.popen = popen - def _send_signal(self, sig): + def _send_signal(self, sig: 'Signals') -> None: + """ + Send a signal to the process group + """ return os.killpg(self.popen.pid, sig) @property - def stdin(self): + def stdin(self) -> Optional[IO]: return self.popen.stdin @property - def stdout(self): + def stdout(self) -> Optional[IO]: return self.popen.stdout @property - def stderr(self): + def stderr(self) -> Optional[IO]: return self.popen.stderr @property - def pid(self): + def pid(self) -> int: return self.popen.pid - def _wait(self): + def _wait(self) -> int: return self.popen.wait() - def _communicate(self, input, timeout): + def _communicate(self, input: bytes, timeout: Optional[int]) -> Tuple[Optional[bytes], Optional[bytes]]: return _popen_communicate(self, self.popen, input, timeout) - def _poll(self): + def _poll(self) -> Optional[int]: return self.popen.poll() - def _cancel(self, kill_timeout): + def _cancel(self, kill_timeout: int) -> None: popen = self.popen os.killpg(os.getpgid(popen.pid), signal.SIGTERM) try: @@ -340,7 +521,7 @@ def _cancel(self, kill_timeout): except subprocess.TimeoutExpired: os.killpg(os.getpgid(popen.pid), signal.SIGKILL) - def _close(self): + def _close(self) -> int: self.popen.__exit__(None, None, None) return self.popen.returncode @@ -352,9 +533,32 @@ def __enter__(self): class ParamikoBackgroundCommand(BackgroundCommand): """ - :mod:`paramiko`-based background command. + Background command using a Paramiko :class:`Channel` for remote SSH-based execution. + Handles signals by running kill commands on the remote, using the PGID. + + :param conn: The SSH-based connection. + :type conn: ConnectionBase + :param chan: The Paramiko channel running the remote command. + :type chan: Channel + :param pid: Remote process group ID for signaling. + :type pid: int + :param as_root: True if run with elevated privileges. + :type as_root: bool or None + :param cmd: The shell command executed (for reference). + :type cmd: SubprocessCommand + :param stdin: A file-like object to write into the remote stdin. + :type stdin: IO + :param stdout: A file-like object for reading from the remote stdout. + :type stdout: IO + :param stderr: A file-like object for reading from the remote stderr. + :type stderr: IO + :param redirect_thread: A thread that captures data from the channel and writes to + stdout/stderr pipes. + :type redirect_thread: Thread """ - def __init__(self, conn, chan, pid, as_root, cmd, stdin, stdout, stderr, redirect_thread): + def __init__(self, conn: 'ConnectionBase', chan: 'Channel', pid: int, + as_root: Optional[bool], cmd: 'SubprocessCommand', stdin: IO, + stdout: IO, stderr: IO, redirect_thread: 'Thread'): super().__init__(conn=conn) self.chan = chan self.as_root = as_root @@ -365,7 +569,7 @@ def __init__(self, conn, chan, pid, as_root, cmd, stdin, stdout, stderr, redirec self.redirect_thread = redirect_thread self.cmd = cmd - def _send_signal(self, sig): + def _send_signal(self, sig: 'Signals') -> None: # If the command has already completed, we don't want to send a signal # to another process that might have gotten that PID in the meantime. if self.poll() is not None: @@ -376,20 +580,23 @@ def _send_signal(self, sig): self.conn.execute(cmd, as_root=self.as_root) @property - def pid(self): + def pid(self) -> int: return self._pid - def _wait(self): + def _wait(self) -> int: status = self.chan.recv_exit_status() # Ensure that the redirection thread is finished copying the content # from paramiko to the pipe. self.redirect_thread.join() return status - def _communicate(self, input, timeout): + def _communicate(self, input: bytes, timeout: Optional[int]) -> Tuple[Optional[bytes], Optional[bytes]]: + """ + Implementation for reading from stdout/stderr, writing to stdin, + handling timeouts, etc. Raise an error if non-zero exit or timeout. + """ stdout = self._stdout stderr = self._stderr - stdin = self._stdin chan = self.chan # For some reason, file descriptors in the read-list of select() can @@ -400,21 +607,21 @@ def _communicate(self, input, timeout): for s in (stdout, stderr): fcntl.fcntl(s.fileno(), fcntl.F_SETFL, os.O_NONBLOCK) - out = {stdout: [], stderr: []} - ret = None - can_send = True + out: Dict[IO, List[bytes]] = {stdout: [], stderr: []} + ret: Optional[int] = None + can_send: bool = True - select_timeout = 1 + select_timeout: int = 1 if timeout is not None: select_timeout = min(select_timeout, 1) - def create_out(): + def create_out() -> Tuple[bytes, bytes]: return ( b''.join(out[stdout]), b''.join(out[stderr]) ) - start = time.monotonic() + start: float = time.monotonic() while ret is None: # Even if ret is not None anymore, we need to drain the streams @@ -426,11 +633,11 @@ def create_out(): raise subprocess.TimeoutExpired(self.cmd, timeout, _stdout, _stderr) can_send &= (not chan.closed) & bool(input) - wlist = [chan] if can_send else [] + wlist: List[Channel] = [chan] if can_send else [] if can_send and chan.send_ready(): try: - n = chan.send(input) + n: int = chan.send(input) # stdin might have been closed already except OSError: can_send = False @@ -440,7 +647,8 @@ def create_out(): if not input: # Send EOF on stdin chan.shutdown_write() - + rs: List[IO] + ws: List[IO] rs, ws, _ = select.select( [x for x in (stdout, stderr) if not x.closed], wlist, @@ -449,7 +657,7 @@ def create_out(): ) for r in rs: - chunk = r.read() + chunk: bytes = r.read() if chunk: out[r].append(chunk) @@ -465,7 +673,7 @@ def create_out(): else: return (_stdout, _stderr) - def _poll(self): + def _poll(self) -> Optional[int]: # Wait for the redirection thread to finish, otherwise we would # indicate the caller that the command is finished and that the streams # are safe to drain, but actually the redirection thread is not @@ -477,7 +685,7 @@ def _poll(self): else: return None - def _cancel(self, kill_timeout): + def _cancel(self, kill_timeout: int) -> None: self.send_signal(signal.SIGTERM) # Check if the command terminated quickly time.sleep(10e-3) @@ -488,24 +696,24 @@ def _cancel(self, kill_timeout): self.wait() @property - def stdin(self): + def stdin(self) -> Optional[IO]: return self._stdin @property - def stdout(self): + def stdout(self) -> Optional[IO]: return self._stdout @property - def stderr(self): + def stderr(self) -> Optional[IO]: return self._stderr - def _close(self): + def _close(self) -> int: for x in (self.stdin, self.stdout, self.stderr): if x is not None: x.close() - exit_code = self.wait() - thread = self.redirect_thread + exit_code: int = self.wait() + thread: Thread = self.redirect_thread if thread: thread.join() @@ -514,47 +722,59 @@ def _close(self): class AdbBackgroundCommand(BackgroundCommand): """ - ``adb``-based background command. + A background command launched through ADB. Manages signals by sending + kill commands on the remote Android device. + + :param conn: The ADB-based connection. + :type conn: ConnectionBase + :param adb_popen: A subprocess.Popen object representing 'adb shell' or similar. + :type adb_popen: Popen + :param pid: Remote process group ID used for signals. + :type pid: int + :param as_root: If True, signals are sent as root. + :type as_root: bool or None """ - def __init__(self, conn, adb_popen, pid, as_root): + def __init__(self, conn: 'ConnectionBase', adb_popen: 'Popen', + pid: int, as_root: Optional[bool]): super().__init__(conn=conn) self.as_root = as_root self.adb_popen = adb_popen self._pid = pid - def _send_signal(self, sig): + def _send_signal(self, sig: 'Signals') -> None: self.conn.execute( _kill_pgid_cmd(self.pid, sig, self.conn.busybox), as_root=self.as_root, ) @property - def stdin(self): + def stdin(self) -> Optional[IO]: return self.adb_popen.stdin @property - def stdout(self): + def stdout(self) -> Optional[IO]: return self.adb_popen.stdout @property - def stderr(self): + def stderr(self) -> Optional[IO]: return self.adb_popen.stderr @property - def pid(self): + def pid(self) -> int: return self._pid - def _wait(self): + def _wait(self) -> int: return self.adb_popen.wait() - def _communicate(self, input, timeout): + def _communicate(self, input: bytes, + timeout: Optional[int]) -> Tuple[Optional[bytes], Optional[bytes]]: return _popen_communicate(self, self.adb_popen, input, timeout) - def _poll(self): + def _poll(self) -> Optional[int]: return self.adb_popen.poll() - def _cancel(self, kill_timeout): + def _cancel(self, kill_timeout: int) -> None: self.send_signal(signal.SIGTERM) try: self.adb_popen.wait(timeout=kill_timeout) @@ -562,7 +782,7 @@ def _cancel(self, kill_timeout): self.send_signal(signal.SIGKILL) self.adb_popen.kill() - def _close(self): + def _close(self) -> int: self.adb_popen.__exit__(None, None, None) return self.adb_popen.returncode @@ -573,7 +793,21 @@ def __enter__(self): class TransferManager: - def __init__(self, conn, transfer_poll_period=30, start_transfer_poll_delay=30, total_transfer_timeout=3600): + """ + Monitors active file transfers (push or pull) in a background thread + and aborts them if they exceed a time limit or appear inactive. + + :param conn: The ConnectionBase owning this manager. + :type conn: ConnectionBase + :param transfer_poll_period: Interval (seconds) between checks for activity. + :type transfer_poll_period: int + :param start_transfer_poll_delay: Delay (seconds) before starting to poll a new transfer. + :type start_transfer_poll_delay: int + :param total_transfer_timeout: Cancel the transfer if it exceeds this duration. + :type total_transfer_timeout: int + """ + def __init__(self, conn: 'ConnectionBase', transfer_poll_period: int = 30, + start_transfer_poll_delay: int = 30, total_transfer_timeout: int = 3600): self.conn = conn self.transfer_poll_period = transfer_poll_period self.total_transfer_timeout = total_transfer_timeout @@ -582,14 +816,36 @@ def __init__(self, conn, transfer_poll_period=30, start_transfer_poll_delay=30, self.logger = logging.getLogger('FileTransfer') @contextmanager - def manage(self, sources, dest, direction, handle): - excep = None - stop_thread = threading.Event() + def manage(self, sources: Tuple[str, ...], dest: str, + direction: Union[Literal['push'], Literal['pull']], + handle: 'TransferHandleBase') -> Generator: + """ + A context manager that spawns a thread to monitor file transfer progress. + If the transfer stalls or times out, it cancels the operation. + + :param sources: Paths being transferred. + :type sources: Tuple[str, ...] + :param dest: Destination path. + :type dest: str + :param direction: 'push' or 'pull' for transfer direction. + :type direction: Literal['push', 'pull'] + :param handle: A TransferHandleBase for polling/canceling. + :type handle: TransferHandleBase + :raises TimeoutError: If the transfer times out. + """ + excep: Optional[TimeoutError] = None + stop_thread: Event = threading.Event() - def monitor(): + def monitor() -> None: + """ + thread to monitor the file transfer + """ nonlocal excep - def cancel(reason): + def cancel(reason: str) -> None: + """ + cancel the file transfer + """ self.logger.warning( f'Cancelling file transfer {sources} -> {dest} due to: {reason}' ) @@ -604,7 +860,7 @@ def cancel(reason): cancel(reason='transfer timed out') excep = TimeoutError(f'{direction}: {sources} -> {dest}') - m_thread = threading.Thread(target=monitor, daemon=True) + m_thread: Thread = threading.Thread(target=monitor, daemon=True) try: m_thread.start() yield self @@ -616,33 +872,64 @@ def cancel(reason): class NoopTransferManager: - def manage(self, *args, **kwargs): + """ + A manager that does nothing for transfers. Used if polling is disabled. + """ + def manage(self, *args, **kwargs) -> nullcontext: return nullcontext(self) class TransferHandleBase(ABC): - def __init__(self, manager): + """ + Abstract base for objects tracking a file transfer's progress and allowing cancellations. + + :param manager: The TransferManager that created this handle. + :type manager: TransferManager + """ + def __init__(self, manager: 'TransferManager'): self.manager = manager @property def logger(self): + """ + get the logger for transfer manager + """ return self.manager.logger @abstractmethod - def isactive(self): + def isactive(self) -> bool: + """ + Check if the transfer still appears to be making progress (return True) + or if it is idle/complete (return False). + """ pass @abstractmethod - def cancel(self): + def cancel(self) -> None: + """ + cancel ongoing file transfer + """ pass class PopenTransferHandle(TransferHandleBase): - def __init__(self, bg_cmd, dest, direction, *args, **kwargs): + """ + File transfer handle implemented using a background command (e.g., scp/rsync). + It regularly checks the destination size to see if it is increasing. + + :param bg_cmd: The BackgroundCommand driving the file transfer. + :type bg_cmd: BackgroundCommand + :param dest: Destination path (local or remote). + :type dest: str + :param direction: 'push' or 'pull'. + :type direction: Literal['push', 'pull'] + """ + def __init__(self, bg_cmd: 'BackgroundCommand', dest: str, + direction: Union[Literal['push'], Literal['pull']], *args, **kwargs): super().__init__(*args, **kwargs) if direction == 'push': - sample_size = self._push_dest_size + sample_size: Callable[[str], Optional[int]] = self._push_dest_size elif direction == 'pull': sample_size = self._pull_dest_size else: @@ -651,38 +938,55 @@ def __init__(self, bg_cmd, dest, direction, *args, **kwargs): self.sample_size = lambda: sample_size(dest) self.bg_cmd = bg_cmd - self.last_sample = 0 + self.last_sample: int = 0 @staticmethod - def _pull_dest_size(dest): + def _pull_dest_size(dest: str) -> Optional[int]: + """ + Compute total size of a directory or file at the local ``dest`` path. + Returns None if it does not exist. + """ if os.path.isdir(dest): return sum( os.stat(os.path.join(dirpath, f)).st_size - for dirpath, _, fnames in os.walk(dest) - for f in fnames + for dirpath, _, fnames in os.walk(dest) + for f in fnames ) else: return os.stat(dest).st_size - def _push_dest_size(self, dest): - conn = self.manager.conn - cmd = '{} du -s -- {}'.format(quote(conn.busybox), quote(dest)) - out = conn.execute(cmd) - return int(out.split()[0]) - - def cancel(self): + def _push_dest_size(self, dest: str) -> Optional[int]: + """ + Compute total size of a directory or file on the remote device, + using busybox du if available. + """ + conn: 'ConnectionBase' = self.manager.conn + if conn.busybox: + cmd: str = '{} du -s -- {}'.format(quote(conn.busybox), quote(dest)) + out: str = conn.execute(cmd) + return int(out.split()[0]) + return None + + def cancel(self) -> None: + """ + Cancel the underlying background command, aborting the file transfer. + """ self.bg_cmd.cancel() - def isactive(self): + def isactive(self) -> bool: + """ + Check if the file size at the destination has grown since the last poll. + Returns True if so, otherwise might still be True if we can't read size. + """ try: - curr_size = self.sample_size() + curr_size: Optional[int] = self.sample_size() except Exception as e: self.logger.debug(f'File size polling failed: {e}') return True else: self.logger.debug(f'Polled file transfer, destination size: {curr_size}') if curr_size: - active = curr_size > self.last_sample + active: bool = curr_size > self.last_sample self.last_sample = curr_size return active # If the file is empty it will never grow in size, so we assume @@ -692,21 +996,33 @@ def isactive(self): class SSHTransferHandle(TransferHandleBase): + """ + SCP or SFTP-based file transfer handle that uses a callback to track progress. + + :param handle: The SCPClient or SFTPClient controlling the file transfer. + :type handle: SCPClient or SFTPClient + """ - def __init__(self, handle, *args, **kwargs): + def __init__(self, handle: Union['SCPClient', 'SFTPClient'], *args, **kwargs): super().__init__(*args, **kwargs) # SFTPClient or SSHClient self.handle = handle - self.progressed = False - self.transferred = 0 - self.to_transfer = 0 + self.progressed: bool = False + self.transferred: int = 0 + self.to_transfer: int = 0 - def cancel(self): + def cancel(self) -> None: + """ + Close the underlying SCP or SFTP client, presumably aborting the transfer. + """ self.handle.close() def isactive(self): + """ + Return True if we've seen progress since last poll, otherwise False. + """ progressed = self.progressed if progressed: self.progressed = False @@ -716,7 +1032,15 @@ def isactive(self): ) return progressed - def progress_cb(self, transferred, to_transfer): + def progress_cb(self, transferred: int, to_transfer: int) -> None: + """ + Callback to be called by the SCP/SFTP library on each progress update. + + :param transferred: Bytes transferred so far. + :type transferred: int + :param to_transfer: Total bytes to transfer, or 0 if unknown. + :type to_transfer: int + """ self.progressed = True self.transferred = transferred self.to_transfer = to_transfer diff --git a/devlib/exception.py b/devlib/exception.py index 33ef3c099..403fe2c19 100644 --- a/devlib/exception.py +++ b/devlib/exception.py @@ -1,4 +1,4 @@ -# Copyright 2013-2018 ARM Limited +# Copyright 2013-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,12 +14,16 @@ # import subprocess +from typing import cast, Optional, List + +from devlib.utils.annotation_helpers import SubprocessCommand + class DevlibError(Exception): """Base class for all Devlib exceptions.""" - def __init__(self, *args): - message = args[0] if args else None + def __init__(self, *args) -> None: + message: Optional[object] = args[0] if args else None self._message = message @property @@ -73,18 +77,20 @@ class TargetStableError(TargetError, DevlibStableError): class TargetCalledProcessError(subprocess.CalledProcessError, TargetError): """Exception raised when a command executed on the target fails.""" - def __str__(self): + + def __str__(self) -> str: msg = super().__str__() - def decode(s): + + def decode(s: bytes) -> str: try: - s = s.decode() + st = s.decode() except AttributeError: - s = str(s) + st = str(s) - return s.strip() + return st.strip() if self.stdout is not None and self.stderr is None: - out = ['OUTPUT: {}'.format(decode(self.output))] + out: List[str] = ['OUTPUT: {}'.format(decode(self.output))] else: out = [ 'STDOUT: {}'.format(decode(self.output)) if self.output is not None else '', @@ -124,13 +130,13 @@ class TimeoutError(DevlibTransientError): programming error (e.g. not setting long enough timers), it is often due to some failure in the environment, and there fore should be classed as a "user error".""" - def __init__(self, command, output): - super(TimeoutError, self).__init__('Timed out: {}'.format(command)) + def __init__(self, command: Optional[SubprocessCommand], output: Optional[str]): + super(TimeoutError, self).__init__('Timed out: {}'.format(cast(str, command))) self.command = command self.output = output def __str__(self): - return '\n'.join([self.message, 'OUTPUT:', self.output or '']) + return '\n'.join([cast(str, self.message), 'OUTPUT:', self.output or '']) class WorkerThreadError(DevlibError): diff --git a/devlib/host.py b/devlib/host.py index a9958a349..923f182f0 100644 --- a/devlib/host.py +++ b/devlib/host.py @@ -1,4 +1,4 @@ -# Copyright 2015-2024 ARM Limited +# Copyright 2015-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,11 +26,30 @@ ) from devlib.utils.misc import check_output from devlib.connection import ConnectionBase, PopenBackgroundCommand - +from typing import Optional, TYPE_CHECKING, cast, Union, List, Tuple +from typing_extensions import Literal +if TYPE_CHECKING: + from devlib.platform import Platform + from devlib.utils.annotation_helpers import SubprocessCommand + from signal import Signals + from logging import Logger + from subprocess import Popen if sys.version_info >= (3, 8): - def copy_tree(src, dst): - from shutils import copy, copytree + def copy_tree(src: str, dst: str) -> None: + """ + Recursively copy an entire directory tree from ``src`` to ``dst``, + preserving the directory structure but **not** file metadata + (modification times, modes, etc.). If ``dst`` already exists, this + overwrites matching files. + + :param src: The source directory path. + :type src: str + :param dst: The destination directory path. + :type dst: str + :raises OSError: If any file or directory within ``src`` cannot be copied. + """ + from shutil import copy, copytree copytree( src, dst, @@ -42,17 +61,44 @@ def copy_tree(src, dst): ) else: def copy_tree(src, dst): + """ + Recursively copy an entire directory tree from ``src`` to ``dst``, + preserving the directory structure but **not** file metadata + (modification times, modes, etc.). If ``dst`` already exists, this + overwrites matching files. + + :param src: The source directory path. + :type src: str + :param dst: The destination directory path. + :type dst: str + :raises OSError: If any file or directory within ``src`` cannot be copied. + + .. note:: + This uses :func:`distutils.dir_util.copy_tree` under Python < 3.8, which + does not support ``dirs_exist_ok=True``. The behavior is effectively the same + for overwriting existing paths. + """ from distutils.dir_util import copy_tree # Mirror the behavior of all other targets which only copy the # content without metadata copy_tree(src, dst, preserve_mode=False, preserve_times=False) -PACKAGE_BIN_DIRECTORY = os.path.join(os.path.dirname(__file__), 'bin') +PACKAGE_BIN_DIRECTORY: str = os.path.join(os.path.dirname(__file__), 'bin') # pylint: disable=redefined-outer-name -def kill_children(pid, signal=signal.SIGKILL): +def kill_children(pid: int, signal: 'Signals' = signal.SIGKILL) -> None: + """ + Recursively kill all child processes of the specified process ID, then kill + the process itself with the given signal. + + :param pid: The process ID whose children (and itself) will be killed. + :type pid: int + :param signal_: The signal to send (defaults to SIGKILL). + :type signal_: signal.Signals + :raises ProcessLookupError: If any child process does not exist (e.g., race conditions). + """ with open('/proc/{0}/task/{0}/children'.format(pid), 'r') as fd: for cpid in map(int, fd.read().strip().split()): kill_children(cpid, signal) @@ -60,62 +106,173 @@ def kill_children(pid, signal=signal.SIGKILL): class LocalConnection(ConnectionBase): + """ + A connection to the local host, allowing the local system to be treated as a + devlib Target. Commands are run directly via :mod:`subprocess`, rather than + an SSH or ADB connection. + :param platform: A devlib Platform object for describing this local system + (e.g., CPU topology). If None, defaults may be used. + :type platform: Platform, optional + :param keep_password: If ``True``, cache the user’s sudo password in memory + after prompting. Defaults to True. + :type keep_password: bool + :param unrooted: If ``True``, assume the local system is non-root and do not + attempt root commands. This avoids prompting for a password. + :type unrooted: bool + :param password: Password for sudo. If provided, will not prompt the user. + :type password: str, optional + :param timeout: A default timeout (in seconds) for connection-based operations. + :type timeout: int, optional + """ name = 'local' host = 'localhost' + # pylint: disable=unused-argument + def __init__(self, platform: Optional['Platform'] = None, + keep_password: bool = True, unrooted: bool = False, + password: Optional[str] = None, timeout: Optional[int] = None): + """ + Initialize the LocalConnection instance. + """ + super().__init__() + self._connected_as_root: Optional[bool] = None + self.logger: Logger = logging.getLogger('local_connection') + self.keep_password: bool = keep_password + self.unrooted: bool = unrooted + self.password: Optional[str] = password + @property - def connected_as_root(self): + def connected_as_root(self) -> Optional[bool]: + """ + Indicate whether the current user context is effectively 'root' (uid=0). + + :return: + - True if root + - False if not root + - None if undetermined + :rtype: bool or None + """ if self._connected_as_root is None: - result = self.execute('id', as_root=False) + result: str = self.execute('id', as_root=False) self._connected_as_root = 'uid=0(' in result return self._connected_as_root @connected_as_root.setter - def connected_as_root(self, state): - self._connected_as_root = state + def connected_as_root(self, state: Optional[bool]) -> None: + """ + Override the known 'connected_as_root' state, if needed. - # pylint: disable=unused-argument - def __init__(self, platform=None, keep_password=True, unrooted=False, - password=None, timeout=None): - super().__init__() - self._connected_as_root = None - self.logger = logging.getLogger('local_connection') - self.keep_password = keep_password - self.unrooted = unrooted - self.password = password + :param state: True if effectively root, False if not, or None if unknown. + :type state: bool or None + """ + self._connected_as_root = state + def _copy_path(self, source: str, dest: str) -> None: + """ + Copy a single file or directory from ``source`` to ``dest``. If ``source`` + is a directory, it is copied recursively. - def _copy_path(self, source, dest): + :param source: The path to the file or directory on the local system. + :type source: str + :param dest: Destination path. + :type dest: str + :raises OSError: If any part of the copy operation fails. + """ self.logger.debug('copying {} to {}'.format(source, dest)) if os.path.isdir(source): copy_tree(source, dest) else: shutil.copy(source, dest) - def _copy_paths(self, sources, dest): + def _copy_paths(self, sources: Tuple[str, ...], dest: str) -> None: + """ + Copy multiple paths (files or directories) to the same destination. + + :param sources: A tuple of file or directory paths to copy. + :type sources: tuple of str + :param dest: The destination path, which may be a directory. + :type dest: str + :raises OSError: If any part of a copy operation fails. + """ for source in sources: self._copy_path(source, dest) - def push(self, sources, dest, timeout=None, as_root=False): # pylint: disable=unused-argument + def push(self, sources: Tuple[str, ...], dest: str, timeout: Optional[int] = None, + as_root: bool = False) -> None: # pylint: disable=unused-argument + """ + Transfer a list of files **from the local system** to itself (no-op in some contexts). + In practice, this copies each file in ``sources`` to ``dest``. + + :param sources: Tuple of file or directory paths on the local system. + :type sources: tuple of str + :param dest: Destination path on the local system. + :type dest: str + :param timeout: Timeout in seconds for each file copy; unused here (local copy). + :type timeout: int, optional + :param as_root: If True, tries to escalate with sudo. Typically a no-op locally. + :type as_root: bool + :raises TargetStableError: If the system is set to unrooted but as_root=True is used. + :raises OSError: If copying fails at any point. + """ self._copy_paths(sources, dest) - def pull(self, sources, dest, timeout=None, as_root=False): # pylint: disable=unused-argument + def pull(self, sources: Tuple[str, ...], dest: str, timeout: Optional[int] = None, + as_root: bool = False) -> None: # pylint: disable=unused-argument + """ + Transfer a list of files **from the local system** to the local system (similar to :meth:`push`). + + :param sources: Tuple of paths on the local system. + :type sources: tuple of str + :param dest: Destination directory or file path on local system. + :type dest: str + :param timeout: Timeout in seconds; typically unused. + :type timeout: int, optional + :param as_root: If True, attempts to use sudo for the copy, if not already root. + :type as_root: bool + :raises TargetStableError: If the system is set to unrooted but as_root=True is used. + :raises OSError: If copying fails. + """ self._copy_paths(sources, dest) # pylint: disable=unused-argument - def execute(self, command, timeout=None, check_exit_code=True, - as_root=False, strip_colors=True, will_succeed=False): + def execute(self, command: 'SubprocessCommand', timeout: Optional[int] = None, + check_exit_code: bool = True, as_root: Optional[bool] = False, + strip_colors: bool = True, will_succeed: bool = False) -> str: + """ + Execute a command locally (via :func:`subprocess.check_output`), returning + combined stdout+stderr output. Optionally escalates privileges with sudo. + + :param command: The command to execute (string or SubprocessCommand). + :type command: SubprocessCommand or str + :param timeout: Time in seconds after which the command is forcibly terminated. + :type timeout: int, optional + :param check_exit_code: If True, raise an error on nonzero exit codes. + :type check_exit_code: bool + :param as_root: If True, attempt sudo unless already root. Fails if ``unrooted=True``. + :type as_root: bool + :param strip_colors: If True, attempt to remove ANSI color codes from output. + (Not used in this local example.) + :type strip_colors: bool + :param will_succeed: If True, treat a failing command as a transient error + rather than stable. + :type will_succeed: bool + :return: The combined stdout+stderr of the command. + :rtype: str + :raises TargetTransientCalledProcessError: If the command fails but is considered transient. + :raises TargetStableCalledProcessError: If the command fails and is considered stable. + :raises TargetStableError: If run as root is requested but unrooted is True. + """ self.logger.debug(command) - use_sudo = as_root and not self.connected_as_root + use_sudo: Optional[bool] = as_root and not self.connected_as_root if use_sudo: if self.unrooted: raise TargetStableError('unrooted') - password = self._get_password() + password: str = self._get_password() # Empty prompt with -p '' to avoid adding a leading space to the # output. - command = "echo {} | sudo -k -p '' -S -- sh -c {}".format(quote(password), quote(command)) - ignore = None if check_exit_code else 'all' + command = "echo {} | sudo -k -p '' -S -- sh -c {}".format(quote(password), quote(cast(str, command))) + ignore: Optional[Union[int, List[int], Literal['all']]] = None if check_exit_code else 'all' try: stdout, stderr = check_output(command, shell=True, timeout=timeout, ignore=ignore) except subprocess.CalledProcessError as e: @@ -133,21 +290,40 @@ def execute(self, command, timeout=None, check_exit_code=True, return stdout + stderr - def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False): + def background(self, command: 'SubprocessCommand', stdout: int = subprocess.PIPE, + stderr: int = subprocess.PIPE, as_root: Optional[bool] = False) -> PopenBackgroundCommand: + """ + Launch a command on the local system in the background, returning + a handle to manage its execution via :class:`PopenBackgroundCommand`. + + :param command: The command or SubprocessCommand to run. + :type command: SubprocessCommand or str + :param stdout: File handle or constant (e.g. subprocess.PIPE) for capturing stdout. + :type stdout: int + :param stderr: File handle or constant for capturing stderr. + :type stderr: int + :param as_root: If True, attempt to run with sudo unless already root. + :type as_root: bool + :return: A background command object that can be polled, waited on, or killed. + :rtype: PopenBackgroundCommand + :raises TargetStableError: If unrooted is True but as_root is requested. + + .. note:: This **will block the connection** until the command completes. + """ if as_root and not self.connected_as_root: if self.unrooted: raise TargetStableError('unrooted') - password = self._get_password() + password: str = self._get_password() # Empty prompt with -p '' to avoid adding a leading space to the # output. - command = "echo {} | sudo -k -p '' -S -- sh -c {}".format(quote(password), quote(command)) + command = "echo {} | sudo -k -p '' -S -- sh -c {}".format(quote(password), quote(cast(str, command))) # Make sure to get a new PGID so PopenBackgroundCommand() can kill # all sub processes that could be started without troubles. def preexec_fn(): os.setpgrp() - popen = subprocess.Popen( + popen: Popen[bytes] = subprocess.Popen( command, stdout=stdout, stderr=stderr, @@ -158,22 +334,51 @@ def preexec_fn(): bg_cmd = PopenBackgroundCommand(self, popen) return bg_cmd - def _close(self): + def _close(self) -> None: + """ + Close the connection to the device. The :class:`Connection` object should not + be used after this method is called. There is no way to reopen a previously + closed connection, a new connection object should be created instead. + """ pass - def cancel_running_command(self): + def cancel_running_command(self) -> None: + """ + Cancel a running command (previously started with :func:`background`) and free up the connection. + It is valid to call this if the command has already terminated (or if no + command was issued), in which case this is a no-op. + """ pass - def wait_for_device(self, timeout=30): + def wait_for_device(self, timeout: int = 30) -> None: + """ + Wait for the local system to be 'available'. In practice, this is always a no-op + since we are already local. + :param timeout: Ignored. + :type timeout: int + """ return - def reboot_bootloader(self, timeout=30): + def reboot_bootloader(self, timeout: int = 30) -> None: + """ + Attempt to reboot into a bootloader mode. Not implemented for local usage. + + :param timeout: Time in seconds to wait for the operation to complete. + :type timeout: int + :raises NotImplementedError: Always, as local usage does not support bootloader reboots. + """ raise NotImplementedError() - def _get_password(self): + def _get_password(self) -> str: + """ + Prompt for the user's sudo password if not already cached. + + :return: The password string, either from cache or via user input. + :rtype: str + """ if self.password: return self.password - password = getpass('sudo password:') + password: str = getpass('sudo password:') if self.keep_password: self.password = password return password diff --git a/devlib/instrument/__init__.py b/devlib/instrument/__init__.py index 6dca81cbe..850d5ea18 100644 --- a/devlib/instrument/__init__.py +++ b/devlib/instrument/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,29 +14,69 @@ # import logging import collections - -from past.builtins import basestring - +from abc import abstractmethod from devlib.utils.csvutil import csvreader from devlib.utils.types import numeric from devlib.utils.types import identifier +from typing import (Dict, Optional, List, OrderedDict, + TYPE_CHECKING, Union, Callable, Generator, + Any, Tuple) +if TYPE_CHECKING: + from devlib.target import Target # Channel modes describe what sort of measurement the instrument supports. # Values must be powers of 2 INSTANTANEOUS = 1 CONTINUOUS = 2 -MEASUREMENT_TYPES = {} # populated further down +MEASUREMENT_TYPES: Dict[str, 'MeasurementType'] = {} # populated further down class MeasurementType(object): - - def __init__(self, name, units, category=None, conversions=None): + """ + In order to make instruments easer to use, and to make it easier to swap them + out when necessary (e.g. change method of collecting power), a number of + standard measurement types are defined. This way, for example, power will + always be reported as "power" in Watts, and never as "pwr" in milliWatts. + Currently defined measurement types are + + + +-------------+-------------+---------------+ + | name | units | category | + +=============+=============+===============+ + | count | count | | + +-------------+-------------+---------------+ + | percent | percent | | + +-------------+-------------+---------------+ + | time_us | microseconds| time | + +-------------+-------------+---------------+ + | time_ms | milliseconds| time | + +-------------+-------------+---------------+ + | temperature | degrees | thermal | + +-------------+-------------+---------------+ + | power | watts | power/energy | + +-------------+-------------+---------------+ + | voltage | volts | power/energy | + +-------------+-------------+---------------+ + | current | amps | power/energy | + +-------------+-------------+---------------+ + | energy | joules | power/energy | + +-------------+-------------+---------------+ + | tx | bytes | data transfer | + +-------------+-------------+---------------+ + | rx | bytes | data transfer | + +-------------+-------------+---------------+ + | tx/rx | bytes | data transfer | + +-------------+-------------+---------------+ + + """ + def __init__(self, name: str, units: Optional[str], + category: Optional[str] = None, conversions: Optional[Dict[str, Callable]] = None): self.name = name self.units = units self.category = category - self.conversions = {} + self.conversions: Dict[str, Callable] = {} if conversions is not None: for key, value in conversions.items(): if not callable(value): @@ -44,24 +84,48 @@ def __init__(self, name, units, category=None, conversions=None): raise ValueError(msg.format(type(value), value)) self.conversions[key] = value - def convert(self, value, to): - if isinstance(to, basestring) and to in MEASUREMENT_TYPES: + def convert(self, value: str, to: Union[str, 'MeasurementType']) -> Union[str, 'MeasurementType']: + if isinstance(to, str) and to in MEASUREMENT_TYPES: to = MEASUREMENT_TYPES[to] if not isinstance(to, MeasurementType): - msg = 'Unexpected conversion target: "{}"' + msg: str = 'Unexpected conversion target: "{}"' raise ValueError(msg.format(to)) if to.name == self.name: return value - if not to.name in self.conversions: + if to.name not in self.conversions: msg = 'No conversion from {} to {} available' raise ValueError(msg.format(self.name, to.name)) return self.conversions[to.name](value) - # pylint: disable=undefined-variable - def __cmp__(self, other): + def __lt__(self, other): + if isinstance(other, MeasurementType): + return self.name < other.name + return self.name < other + + def __le__(self, other): + if isinstance(other, MeasurementType): + return self.name <= other.name + return self.name <= other + + def __eq__(self, other): + if isinstance(other, MeasurementType): + return self.name == other.name + return self.name == other + + def __ne__(self, other): if isinstance(other, MeasurementType): - other = other.name - return cmp(self.name, other) + return self.name != other.name + return self.name != other + + def __gt__(self, other): + if isinstance(other, MeasurementType): + return self.name > other.name + return self.name > other + + def __ge__(self, other): + if isinstance(other, MeasurementType): + return self.name >= other.name + return self.name >= other def __str__(self): return self.name @@ -79,7 +143,7 @@ def __repr__(self): # to particular insturments (e.g. a particular method of mearuing power), instruments # must, where possible, resport their measurments formatted as on of the standard types # defined here. -_measurement_types = [ +_measurement_types: List[MeasurementType] = [ # For whatever reason, the type of measurement could not be established. MeasurementType('unknown', None), @@ -95,33 +159,33 @@ def __repr__(self): # processors that expect all times time be at a particular scale can automatically # covert without being familar with individual instruments. MeasurementType('time', 'seconds', 'time', - conversions={ - 'time_us': lambda x: x * 1e6, - 'time_ms': lambda x: x * 1e3, - 'time_ns': lambda x: x * 1e9, - } - ), + conversions={ + 'time_us': lambda x: x * 1e6, + 'time_ms': lambda x: x * 1e3, + 'time_ns': lambda x: x * 1e9, + } + ), MeasurementType('time_us', 'microseconds', 'time', - conversions={ - 'time': lambda x: x / 1e6, - 'time_ms': lambda x: x / 1e3, - 'time_ns': lambda x: x * 1e3, - } - ), + conversions={ + 'time': lambda x: x / 1e6, + 'time_ms': lambda x: x / 1e3, + 'time_ns': lambda x: x * 1e3, + } + ), MeasurementType('time_ms', 'milliseconds', 'time', - conversions={ - 'time': lambda x: x / 1e3, - 'time_us': lambda x: x * 1e3, - 'time_ns': lambda x: x * 1e6, - } - ), + conversions={ + 'time': lambda x: x / 1e3, + 'time_us': lambda x: x * 1e3, + 'time_ns': lambda x: x * 1e6, + } + ), MeasurementType('time_ns', 'nanoseconds', 'time', - conversions={ - 'time': lambda x: x / 1e9, - 'time_ms': lambda x: x / 1e6, - 'time_us': lambda x: x / 1e3, - } - ), + conversions={ + 'time': lambda x: x / 1e9, + 'time_ms': lambda x: x / 1e6, + 'time_us': lambda x: x / 1e3, + } + ), # Measurements related to thermals. MeasurementType('temperature', 'degrees', 'thermal'), @@ -150,23 +214,52 @@ class Measurement(object): __slots__ = ['value', 'channel'] @property - def name(self): + def name(self) -> str: + """ + name of the measurement + """ return '{}_{}'.format(self.channel.site, self.channel.kind) @property - def units(self): + def units(self) -> Optional[str]: + """ + Units in which measurement will be reported. + """ return self.channel.units - def __init__(self, value, channel): + def __init__(self, value: Union[int, float], channel: 'InstrumentChannel'): self.value = value self.channel = channel - # pylint: disable=undefined-variable - def __cmp__(self, other): + def __lt__(self, other): if hasattr(other, 'value'): - return cmp(self.value, other.value) - else: - return cmp(self.value, other) + return self.value < other.value + return self.value < other + + def __eq__(self, other): + if hasattr(other, 'value'): + return self.value == other.value + return self.value == other + + def __le__(self, other): + if hasattr(other, 'value'): + return self.value <= other.value + return self.value <= other + + def __ne__(self, other): + if hasattr(other, 'value'): + return self.value != other.value + return self.value != other + + def __gt__(self, other): + if hasattr(other, 'value'): + return self.value > other.value + return self.value > other + + def __ge__(self, other): + if hasattr(other, 'value'): + return self.value >= other.value + return self.value >= other def __str__(self): if self.units: @@ -179,44 +272,47 @@ def __str__(self): class MeasurementsCsv(object): - def __init__(self, path, channels=None, sample_rate_hz=None): + def __init__(self, path, channels: Optional[List['InstrumentChannel']] = None, + sample_rate_hz: Optional[float] = None): self.path = path self.channels = channels self.sample_rate_hz = sample_rate_hz if self.channels is None: self._load_channels() - headings = [chan.label for chan in self.channels] - self.data_tuple = collections.namedtuple('csv_entry', + headings = [chan.label for chan in self.channels] if self.channels else [] + + self.data_tuple = collections.namedtuple('csv_entry', # type:ignore map(identifier, headings)) - def measurements(self): + def measurements(self) -> List[List['Measurement']]: return list(self.iter_measurements()) - def iter_measurements(self): + def iter_measurements(self) -> Generator[List['Measurement'], Any, None]: for row in self._iter_rows(): values = map(numeric, row) - yield [Measurement(v, c) for (v, c) in zip(values, self.channels)] + if self.channels: + yield [Measurement(v, c) for (v, c) in zip(values, self.channels)] - def values(self): + def values(self) -> List: return list(self.iter_values()) - def iter_values(self): + def iter_values(self) -> Generator[Tuple[Any], Any, None]: for row in self._iter_rows(): values = list(map(numeric, row)) yield self.data_tuple(*values) - def _load_channels(self): - header = [] + def _load_channels(self) -> None: + header: List[str] = [] with csvreader(self.path) as reader: header = next(reader) self.channels = [] for entry in header: for mt in MEASUREMENT_TYPES: - suffix = '_{}'.format(mt) + suffix: str = '_{}'.format(mt) if entry.endswith(suffix): - site = entry[:-len(suffix)] - measure = mt + site: Optional[str] = entry[:-len(suffix)] + measure: str = mt break else: if entry in MEASUREMENT_TYPES: @@ -225,12 +321,12 @@ def _load_channels(self): else: site = entry measure = 'unknown' - - chan = InstrumentChannel(site, measure) - self.channels.append(chan) + if site: + chan = InstrumentChannel(site, measure) + self.channels.append(chan) # pylint: disable=stop-iteration-return - def _iter_rows(self): + def _iter_rows(self) -> Generator[List[str], Any, None]: with csvreader(self.path) as reader: next(reader) # headings for row in reader: @@ -238,9 +334,41 @@ def _iter_rows(self): class InstrumentChannel(object): + """ + An :class:`InstrumentChannel` describes a single type of measurement that may + be collected by an :class:`~devlib.instrument.Instrument`. A channel is + primarily defined by a ``site`` and a ``measurement_type``. + + A ``site`` indicates where on the target a measurement is collected from + (e.g. a voltage rail or location of a sensor). + + A ``measurement_type`` is an instance of :class:`MeasurmentType` that + describes what sort of measurement this is (power, temperature, etc). Each + measurement type has a standard unit it is reported in, regardless of an + instrument used to collect it. + + A channel (i.e. site/measurement_type combination) is unique per instrument, + however there may be more than one channel associated with one site (e.g. for + both voltage and power). + It should not be assumed that any site/measurement_type combination is valid. + The list of available channels can queried with + :func:`Instrument.list_channels()`. + + .. attribute:: InstrumentChannel.site + + The name of the "site" from which the measurements are collected (e.g. voltage + rail, sensor, etc). + + """ @property - def label(self): + def label(self) -> str: + """ + A label that can be attached to measurements associated with with channel. + This is constructed with :: + + '{}_{}'.format(self.site, self.kind) + """ if self.site is not None: return '{}_{}'.format(self.site, self.kind) return self.kind @@ -248,14 +376,22 @@ def label(self): name = label @property - def kind(self): + def kind(self) -> str: + """ + A string indicating the type of measurement that will be collected. This is + the ``name`` of the :class:`MeasurmentType` associated with this channel. + """ return self.measurement_type.name @property - def units(self): + def units(self) -> Optional[str]: + """ + Units in which measurement will be reported. this is determined by the + underlying :class:`MeasurmentType`. + """ return self.measurement_type.units - def __init__(self, site, measurement_type, **attrs): + def __init__(self, site: str, measurement_type: Union[str, MeasurementType], **attrs): self.site = site if isinstance(measurement_type, MeasurementType): self.measurement_type = measurement_type @@ -277,39 +413,117 @@ def __str__(self): class Instrument(object): + """ + The ``Instrument`` API provide a consistent way of collecting measurements from + a target. Measurements are collected via an instance of a class derived from + :class:`~devlib.instrument.Instrument`. An ``Instrument`` allows collection of + measurement from one or more channels. An ``Instrument`` may support + ``INSTANTANEOUS`` or ``CONTINUOUS`` collection, or both. + + .. attribute:: Instrument.mode + + A bit mask that indicates collection modes that are supported by this + instrument. Possible values are: - mode = 0 + :INSTANTANEOUS: The instrument supports taking a single sample via + ``take_measurement()``. + :CONTINUOUS: The instrument supports collecting measurements over a + period of time via ``start()``, ``stop()``, ``get_data()``, + and (optionally) ``get_raw`` methods. - def __init__(self, target): + .. note:: It's possible for one instrument to support more than a single + mode. + + .. attribute:: Instrument.active_channels + + Channels that have been activated via ``reset()``. Measurements will only be + collected for these channels. + .. attribute:: Instrument.sample_rate_hz + + Sample rate of the instrument in Hz. Assumed to be the same for all channels. + + .. note:: This attribute is only provided by + :class:`~devlib.instrument.Instrument` s that + support ``CONTINUOUS`` measurement. + """ + mode: int = 0 + + def __init__(self, target: 'Target'): self.target = target self.logger = logging.getLogger(self.__class__.__name__) - self.channels = collections.OrderedDict() - self.active_channels = [] - self.sample_rate_hz = None + self.channels: OrderedDict[str, InstrumentChannel] = collections.OrderedDict() + self.active_channels: List[InstrumentChannel] = [] + self.sample_rate_hz: Optional[float] = None # channel management - def list_channels(self): + def list_channels(self) -> List[InstrumentChannel]: + """ + Returns a list of :class:`InstrumentChannel` instances that describe what + this instrument can measure on the current target. A channel is a combination + of a ``kind`` of measurement (power, temperature, etc) and a ``site`` that + indicates where on the target the measurement will be collected from. + """ return list(self.channels.values()) - def get_channels(self, measure): - if hasattr(measure, 'name'): - measure = measure.name + def get_channels(self, measure: Union[str, MeasurementType]): + """ + Returns channels for a particular ``measure`` type. A ``measure`` can be + either a string (e.g. ``"power"``) or a :class:`MeasurmentType` instance. + """ + if isinstance(measure, MeasurementType): + if hasattr(measure, 'name'): + measure = measure.name return [c for c in self.list_channels() if c.kind == measure] - def add_channel(self, site, measure, **attrs): + def add_channel(self, site: str, measure: Union[str, MeasurementType], **attrs) -> None: + """ + add channel to channels dict + """ chan = InstrumentChannel(site, measure, **attrs) self.channels[chan.label] = chan # initialization and teardown - def setup(self, *args, **kwargs): + def setup(self, *args, **kwargs) -> None: + """ + This will set up the instrument on the target. Parameters this method takes + are particular to subclasses (see documentation for specific instruments + below). What actions are performed by this method are also + instrument-specific. Usually these will be things like installing + executables, starting services, deploying assets, etc. Typically, this method + needs to be invoked at most once per reboot of the target (unless + ``teardown()`` has been called), but see documentation for the instrument + you're interested in. + """ pass - def teardown(self): + def teardown(self) -> None: + """ + Performs any required clean up of the instrument. This usually includes + removing temporary and raw files (if ``keep_raw`` is set to ``False`` on relevant + instruments), stopping services etc. + """ pass - def reset(self, sites=None, kinds=None, channels=None): + def reset(self, sites: Optional[List[str]] = None, + kinds: Optional[List[str]] = None, + channels: Optional[OrderedDict[str, InstrumentChannel]] = None) -> None: + """ + This is used to configure an instrument for collection. This must be invoked + before ``start()`` is called to begin collection. This methods sets the + ``active_channels`` attribute of the ``Instrument``. + + If ``channels`` is provided, it is a list of names of channels to enable and + ``sites`` and ``kinds`` must both be ``None``. + + Otherwise, if one of ``sites`` or ``kinds`` is provided, all channels + matching the given sites or kinds are enabled. If both are provided then all + channels of the given kinds at the given sites are enabled. + + If none of ``sites``, ``kinds`` or ``channels`` are provided then all + available channels are enabled. + """ if channels is not None: if sites is not None or kinds is not None: raise ValueError('sites and kinds should not be set if channels is set') @@ -317,36 +531,93 @@ def reset(self, sites=None, kinds=None, channels=None): try: self.active_channels = [self.channels[ch] for ch in channels] except KeyError as e: - msg = 'Unexpected channel "{}"; must be in {}' + msg: str = 'Unexpected channel "{}"; must be in {}' raise ValueError(msg.format(e, self.channels.keys())) elif sites is None and kinds is None: self.active_channels = sorted(self.channels.values(), key=lambda x: x.label) else: - if isinstance(sites, basestring): + if isinstance(sites, str): sites = [sites] - if isinstance(kinds, basestring): + if isinstance(kinds, str): kinds = [kinds] wanted = lambda ch: ((kinds is None or ch.kind in kinds) and - (sites is None or ch.site in sites)) + (sites is None or ch.site in sites)) self.active_channels = list(filter(wanted, self.channels.values())) # instantaneous - - def take_measurement(self): + @abstractmethod + def take_measurement(self) -> List[Measurement]: + """ + Take a single measurement from ``active_channels``. Returns a list of + :class:`Measurement` objects (one for each active channel). + + .. note:: This method is only implemented by + :class:`~devlib.instrument.Instrument's that + support ``INSTANTANEOUS`` measurement. + """ pass # continuous - def start(self): + def start(self) -> None: + """ + Starts collecting measurements from ``active_channels``. + + .. note:: This method is only implemented by + :class:`~devlib.instrument.Instrument` s that + support ``CONTINUOUS`` measurement. + """ pass - def stop(self): + def stop(self) -> None: + """ + Stops collecting measurements from ``active_channels``. Must be called after + :func:`start()`. + + .. note:: This method is only implemented by + :class:`~devlib.instrument.Instrument` s that + support ``CONTINUOUS`` measurement. + """ pass - # pylint: disable=no-self-use - def get_data(self, outfile): + @abstractmethod + def get_data(self, outfile: str) -> MeasurementsCsv: + """ + Write collected data into ``outfile``. Must be called after :func:`stop()`. + Data will be written in CSV format with a column for each channel and a row + for each sample. Column heading will be channel, labels in the form + ``_`` (see :class:`InstrumentChannel`). The order of the columns + will be the same as the order of channels in ``Instrument.active_channels``. + + If reporting timestamps, one channel must have a ``site`` named + ``"timestamp"`` and a ``kind`` of a :class:`MeasurmentType` of an appropriate + time unit which will be used, if appropriate, during any post processing. + + .. note:: Currently supported time units are seconds, milliseconds and + microseconds, other units can also be used if an appropriate + conversion is provided. + + This returns a :class:`MeasurementCsv` instance associated with the outfile + that can be used to stream :class:`Measurement` s lists (similar to what is + returned by ``take_measurement()``. + + .. note:: This method is only implemented by + :class:`~devlib.instrument.Instrument` s that + support ``CONTINUOUS`` measurement. + """ pass - def get_raw(self): + def get_raw(self) -> List[str]: + """ + Returns a list of paths to files containing raw output from the underlying + source(s) that is used to produce the data CSV. If no raw output is + generated or saved, an empty list will be returned. The format of the + contents of the raw files is entirely source-dependent. + + .. note:: This method is not guaranteed to return valid filepaths after the + :meth:`teardown` method has been invoked as the raw files may have + been deleted. Please ensure that copies are created manually + prior to calling :meth:`teardown` if the files are to be retained. + """ return [] diff --git a/devlib/instrument/acmecape.py b/devlib/instrument/acmecape.py index cfbcbe071..1d2f13197 100644 --- a/devlib/instrument/acmecape.py +++ b/devlib/instrument/acmecape.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ # limitations under the License. # -#pylint: disable=attribute-defined-outside-init +# pylint: disable=attribute-defined-outside-init import os import sys import time @@ -34,6 +34,7 @@ ${iio_capture} -n ${host} -b ${buffer_size} -c -f ${outfile} ${iio_device} """) + def _read_nonblock(pipe, size=1024): fd = pipe.fileno() flags = fcntl(fd, F_GETFL) @@ -93,7 +94,7 @@ def reset(self, sites=None, kinds=None, channels=None): iio_device=self.iio_device, outfile=self.raw_data_file ) - params = {k: quote(v) for k, v in params.items()} + params = {k: quote(v or '') for k, v in params.items()} self.command = IIOCAP_CMD_TEMPLATE.substitute(**params) self.logger.debug('ACME cape command: {}'.format(self.command)) @@ -115,7 +116,7 @@ def stop(self): if self.process.poll() is None: msg = 'Could not terminate iio-capture:\n{}' raise HostError(msg.format(output)) - if self.process.returncode != 15: # iio-capture exits with 15 when killed + if self.process.returncode != 15: # iio-capture exits with 15 when killed output += self.process.stdout.read().decode(sys.stdout.encoding or 'utf-8', 'replace') self.logger.info('ACME instrument encountered an error, ' 'you may want to try rebooting the ACME device:\n' diff --git a/devlib/instrument/arm_energy_probe.py b/devlib/instrument/arm_energy_probe.py index 80ef643da..1e697ec28 100644 --- a/devlib/instrument/arm_energy_probe.py +++ b/devlib/instrument/arm_energy_probe.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -45,6 +45,7 @@ from devlib.utils.parse_aep import AepParser + class ArmEnergyProbeInstrument(Instrument): """ Collects power traces using the ARM Energy Probe. @@ -68,23 +69,23 @@ class ArmEnergyProbeInstrument(Instrument): mode = CONTINUOUS - MAX_CHANNELS = 12 # 4 Arm Energy Probes + MAX_CHANNELS = 12 # 4 Arm Energy Probes def __init__(self, target, config_file='./config-aep', keep_raw=False): super(ArmEnergyProbeInstrument, self).__init__(target) self.arm_probe = which('arm-probe') if self.arm_probe is None: raise HostError('arm-probe must be installed on the host') - #todo detect is config file exist + # todo detect is config file exist self.attributes = ['power', 'voltage', 'current'] self.sample_rate_hz = 10000 self.config_file = config_file self.keep_raw = keep_raw self.parser = AepParser() - #TODO make it generic + # TODO make it generic topo = self.parser.topology_from_config(self.config_file) - for item in topo: + for item in topo or []: if item == 'time': self.add_channel('timestamp', 'time') else: @@ -103,9 +104,9 @@ def reset(self, sites=None, kinds=None, channels=None): def start(self): self.logger.debug(self.command) self.armprobe = subprocess.Popen(self.command, - stderr=self.output_fd_error, - preexec_fn=os.setpgrp, - shell=True) + stderr=self.output_fd_error, + preexec_fn=os.setpgrp, + shell=True) def stop(self): self.logger.debug("kill running arm-probe") @@ -132,7 +133,7 @@ def get_data(self, outfile): # pylint: disable=R0914 if len(row) < len(active_channels): continue # all data are in micro (seconds/watt) - new = [float(row[i])/1000000 for i in active_indexes] + new = [float(row[i]) / 1000000 for i in active_indexes] writer.writerow(new) self.output_fd_error.close() diff --git a/devlib/instrument/daq.py b/devlib/instrument/daq.py index 97c638fd8..da85d0740 100644 --- a/devlib/instrument/daq.py +++ b/devlib/instrument/daq.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,11 +20,11 @@ from itertools import chain, zip_longest from devlib.host import PACKAGE_BIN_DIRECTORY -from devlib.instrument import Instrument, MeasurementsCsv, CONTINUOUS +from devlib.instrument import Instrument, MeasurementsCsv, CONTINUOUS, InstrumentChannel from devlib.exception import HostError from devlib.utils.csvutil import csvwriter, create_reader from devlib.utils.misc import unique - +import daqpower.server try: from daqpower.client import DaqClient from daqpower.config import DeviceConfiguration @@ -32,31 +32,35 @@ DaqClient = None DeviceConfiguration = None import_error_mesg = e.args[0] if e.args else str(e) +from typing import (TYPE_CHECKING, List, Union, Optional, Tuple, + cast, Dict, TextIO, Any, OrderedDict) +if TYPE_CHECKING: + from devlib.target import Target class DaqInstrument(Instrument): mode = CONTINUOUS - def __init__(self, target, resistor_values, # pylint: disable=R0914 - labels=None, - host='localhost', - port=45677, - device_id='Dev1', - v_range=2.5, - dv_range=0.2, - sample_rate_hz=10000, - channel_map=(0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23), - keep_raw=False, - time_as_clock_boottime=True + def __init__(self, target: 'Target', resistor_values: List[Union[int, str]], # pylint: disable=R0914 + labels: Optional[List[str]] = None, + host: str = 'localhost', + port: int = 45677, + device_id: str = 'Dev1', + v_range: float = 2.5, + dv_range: float = 0.2, + sample_rate_hz: int = 10000, + channel_map: Tuple = (0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23), + keep_raw: bool = False, + time_as_clock_boottime: bool = True ): # pylint: disable=no-member super(DaqInstrument, self).__init__(target) self.keep_raw = keep_raw - self._need_reset = True - self._raw_files = [] - self.tempdir = None - self.target_boottime_clock_at_start = 0.0 + self._need_reset: bool = True + self._raw_files: List[str] = [] + self.tempdir: Optional[str] = None + self.target_boottime_clock_at_start: float = 0.0 if DaqClient is None: raise HostError('Could not import "daqpower": {}'.format(import_error_mesg)) if labels is None: @@ -65,20 +69,20 @@ def __init__(self, target, resistor_values, # pylint: disable=R0914 raise ValueError('"labels" and "resistor_values" must be of the same length') self.daq_client = DaqClient(host, port) try: - devices = self.daq_client.list_devices() + devices: List[str] = cast(daqpower.server.DaqServer, self.daq_client).list_devices() if device_id not in devices: msg = 'Device "{}" is not found on the DAQ server. Available devices are: "{}"' raise ValueError(msg.format(device_id, ', '.join(devices))) except Exception as e: raise HostError('Problem querying DAQ server: {}'.format(e)) - - self.device_config = DeviceConfiguration(device_id=device_id, - v_range=v_range, - dv_range=dv_range, - sampling_rate=sample_rate_hz, - resistor_values=resistor_values, - channel_map=channel_map, - labels=labels) + if DeviceConfiguration: + self.device_config = DeviceConfiguration(device_id=device_id, + v_range=v_range, + dv_range=dv_range, + sampling_rate=sample_rate_hz, + resistor_values=resistor_values, + channel_map=channel_map, + labels=labels) self.sample_rate_hz = sample_rate_hz self.time_as_clock_boottime = time_as_clock_boottime @@ -88,62 +92,67 @@ def __init__(self, target, resistor_values, # pylint: disable=R0914 self.add_channel(label, kind) if time_as_clock_boottime: - host_path = os.path.join(PACKAGE_BIN_DIRECTORY, self.target.abi, - 'get_clock_boottime') + host_path: str = os.path.join(PACKAGE_BIN_DIRECTORY, self.target.abi or '', + 'get_clock_boottime') self.clock_boottime_cmd = self.target.install_if_needed(host_path, search_system_binaries=False) - def calculate_boottime_offset(self): - time_before = time.time() - out = self.target.execute(self.clock_boottime_cmd) - time_after = time.time() + def calculate_boottime_offset(self) -> float: + """ + calculate boot time offset + """ + time_before: float = time.time() + out: str = self.target.execute(self.clock_boottime_cmd) + time_after: float = time.time() remote_clock_boottime = float(out) - propagation_delay = (time_after - time_before) / 2 - boottime_at_end = remote_clock_boottime + propagation_delay + propagation_delay: float = (time_after - time_before) / 2 + boottime_at_end: float = remote_clock_boottime + propagation_delay return time_after - boottime_at_end - def reset(self, sites=None, kinds=None, channels=None): + def reset(self, sites: Optional[List[str]] = None, + kinds: Optional[List[str]] = None, + channels: Optional[OrderedDict[str, InstrumentChannel]] = None) -> None: super(DaqInstrument, self).reset(sites, kinds, channels) - self.daq_client.close() - self.daq_client.configure(self.device_config) + cast(daqpower.server.DaqServer, self.daq_client).close() + cast(daqpower.server.DaqServer, self.daq_client).configure(self.device_config) self._need_reset = False self._raw_files = [] - def start(self): + def start(self) -> None: if self._need_reset: # Preserve channel order - self.reset(channels=self.channels.keys()) + self.reset(channels=cast(OrderedDict[str, InstrumentChannel], self.channels.keys())) if self.time_as_clock_boottime: target_boottime_offset = self.calculate_boottime_offset() time_start = time.time() - self.daq_client.start() + cast(daqpower.server.DaqServer, self.daq_client).start() if self.time_as_clock_boottime: - time_end = time.time() + time_end: float = time.time() self.target_boottime_clock_at_start = (time_start + time_end) / 2 - target_boottime_offset - def stop(self): - self.daq_client.stop() + def stop(self) -> None: + cast(daqpower.server.DaqServer, self.daq_client).stop() self._need_reset = True - def get_data(self, outfile): # pylint: disable=R0914 + def get_data(self, outfile: str) -> MeasurementsCsv: # pylint: disable=R0914 self.tempdir = tempfile.mkdtemp(prefix='daq-raw-') self.daq_client.get_data(self.tempdir) - raw_file_map = {} + raw_file_map: Dict[str, str] = {} for entry in os.listdir(self.tempdir): - site = os.path.splitext(entry)[0] - path = os.path.join(self.tempdir, entry) + site: str = os.path.splitext(entry)[0] + path: str = os.path.join(self.tempdir, entry) raw_file_map[site] = path self._raw_files.append(path) - active_sites = unique([c.site for c in self.active_channels]) - file_handles = [] + active_sites: List[str] = unique([c.site for c in self.active_channels]) + file_handles: List[TextIO] = [] try: - site_readers = {} + site_readers: Dict[str, Any] = {} for site in active_sites: try: site_file = raw_file_map[site] @@ -152,11 +161,11 @@ def get_data(self, outfile): # pylint: disable=R0914 file_handles.append(fh) except KeyError: if not site.startswith("Time"): - message = 'Could not get DAQ trace for {}; Obtained traces are in {}' + message: str = 'Could not get DAQ trace for {}; Obtained traces are in {}' raise HostError(message.format(site, self.tempdir)) # The first row is the headers - channel_order = ['Time_time'] + channel_order: List[str] = ['Time_time'] for site, reader in site_readers.items(): channel_order.extend(['{}_{}'.format(site, kind) for kind in next(reader)]) @@ -167,15 +176,15 @@ def _read_rows(): raw_row = list(chain.from_iterable(raw_row)) raw_row.insert(0, _read_rows.row_time_s) yield raw_row - _read_rows.row_time_s += 1.0 / self.sample_rate_hz + _read_rows.row_time_s += 1.0 / cast(float, self.sample_rate_hz) - _read_rows.row_time_s = self.target_boottime_clock_at_start + _read_rows.row_time_s = self.target_boottime_clock_at_start # type:ignore with csvwriter(outfile) as writer: - field_names = [c.label for c in self.active_channels] + field_names: List[str] = [c.label for c in self.active_channels] writer.writerow(field_names) for raw_row in _read_rows(): - row = [raw_row[channel_order.index(f)] for f in field_names] + row: List[str] = [raw_row[channel_order.index(f)] for f in field_names] writer.writerow(row) return MeasurementsCsv(outfile, self.active_channels, self.sample_rate_hz) @@ -183,11 +192,11 @@ def _read_rows(): for fh in file_handles: fh.close() - def get_raw(self): + def get_raw(self) -> List[str]: return self._raw_files - def teardown(self): - self.daq_client.close() + def teardown(self) -> None: + cast(daqpower.server.DaqServer, self.daq_client).close() if not self.keep_raw: if self.tempdir and os.path.isdir(self.tempdir): shutil.rmtree(self.tempdir) diff --git a/devlib/instrument/frames.py b/devlib/instrument/frames.py index 402c48194..a5cd2c8ce 100644 --- a/devlib/instrument/frames.py +++ b/devlib/instrument/frames.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,74 +16,88 @@ import os from devlib.instrument import (Instrument, CONTINUOUS, - MeasurementsCsv, MeasurementType) + MeasurementsCsv, MeasurementType, + InstrumentChannel) from devlib.utils.rendering import (GfxinfoFrameCollector, SurfaceFlingerFrameCollector, SurfaceFlingerFrame, - read_gfxinfo_columns) + read_gfxinfo_columns, + FrameCollector) +from typing import (TYPE_CHECKING, Optional, Type, + OrderedDict, Any, List) +if TYPE_CHECKING: + from devlib.target import Target class FramesInstrument(Instrument): mode = CONTINUOUS - collector_cls = None + collector_cls: Optional[Type[FrameCollector]] = None - def __init__(self, target, collector_target, period=2, keep_raw=True): + def __init__(self, target: 'Target', collector_target: Any, + period: int = 2, keep_raw: bool = True): super(FramesInstrument, self).__init__(target) self.collector_target = collector_target self.period = period self.keep_raw = keep_raw - self.sample_rate_hz = 1 / self.period - self.collector = None - self.header = None - self._need_reset = True - self._raw_file = None + self.sample_rate_hz: float = 1 / self.period + self.collector: Optional[FrameCollector] = None + self.header: Optional[List[str]] = None + self._need_reset: bool = True + self._raw_file: Optional[str] = None self._init_channels() - def reset(self, sites=None, kinds=None, channels=None): + def reset(self, sites: Optional[List[str]] = None, + kinds: Optional[List[str]] = None, + channels: Optional[OrderedDict[str, InstrumentChannel]] = None) -> None: super(FramesInstrument, self).reset(sites, kinds, channels) - # pylint: disable=not-callable - self.collector = self.collector_cls(self.target, self.period, - self.collector_target, self.header) + if self.collector_cls: + # pylint: disable=not-callable + self.collector = self.collector_cls(self.target, self.period, + self.collector_target, self.header) # type: ignore self._need_reset = False self._raw_file = None - def start(self): + def start(self) -> None: if self._need_reset: self.reset() - self.collector.start() + if self.collector: + self.collector.start() - def stop(self): - self.collector.stop() + def stop(self) -> None: + if self.collector: + self.collector.stop() self._need_reset = True - def get_data(self, outfile): + def get_data(self, outfile: str) -> MeasurementsCsv: if self.keep_raw: self._raw_file = outfile + '.raw' - self.collector.process_frames(self._raw_file) - active_sites = [chan.label for chan in self.active_channels] - self.collector.write_frames(outfile, columns=active_sites) + if self.collector: + self.collector.process_frames(self._raw_file) + active_sites: List[str] = [chan.label for chan in self.active_channels] + if self.collector: + self.collector.write_frames(outfile, columns=active_sites) return MeasurementsCsv(outfile, self.active_channels, self.sample_rate_hz) - def get_raw(self): + def get_raw(self) -> List[str]: return [self._raw_file] if self._raw_file else [] - def _init_channels(self): + def _init_channels(self) -> None: raise NotImplementedError() - def teardown(self): + def teardown(self) -> None: if not self.keep_raw: - if os.path.isfile(self._raw_file): - os.remove(self._raw_file) + if os.path.isfile(self._raw_file or ''): + os.remove(self._raw_file or '') class GfxInfoFramesInstrument(FramesInstrument): - mode = CONTINUOUS + mode: int = CONTINUOUS collector_cls = GfxinfoFrameCollector - def _init_channels(self): - columns = read_gfxinfo_columns(self.target) + def _init_channels(self) -> None: + columns: List[str] = read_gfxinfo_columns(self.target) for entry in columns: if entry == 'Flags': self.add_channel('Flags', MeasurementType('flags', 'flags')) @@ -94,10 +108,10 @@ def _init_channels(self): class SurfaceFlingerFramesInstrument(FramesInstrument): - mode = CONTINUOUS + mode: int = CONTINUOUS collector_cls = SurfaceFlingerFrameCollector - def _init_channels(self): + def _init_channels(self) -> None: for field in SurfaceFlingerFrame._fields: # remove the "_time" from filed names to avoid duplication self.add_channel(field[:-5], 'time_us') diff --git a/devlib/instrument/hwmon.py b/devlib/instrument/hwmon.py index 7c1cb7d1a..16795eaf2 100644 --- a/devlib/instrument/hwmon.py +++ b/devlib/instrument/hwmon.py @@ -1,4 +1,4 @@ -# Copyright 2015-2017 ARM Limited +# Copyright 2015-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,15 +16,20 @@ from devlib.instrument import Instrument, Measurement, INSTANTANEOUS from devlib.exception import TargetStableError +from typing import (Dict, Tuple, Callable, Union, TYPE_CHECKING, + cast, List) +from devlib.module.hwmon import HwmonModule, HwmonSensor +if TYPE_CHECKING: + from devlib.target import Target class HwmonInstrument(Instrument): - name = 'hwmon' - mode = INSTANTANEOUS + name: str = 'hwmon' + mode: int = INSTANTANEOUS # sensor kind --> (meaure, standard unit conversion) - measure_map = { + measure_map: Dict[str, Tuple[str, Callable[[Union[int, float]], float]]] = { 'temp': ('temperature', lambda x: x / 1000), 'in': ('voltage', lambda x: x / 1000), 'curr': ('current', lambda x: x / 1000), @@ -32,16 +37,18 @@ class HwmonInstrument(Instrument): 'energy': ('energy', lambda x: x / 1000000), } - def __init__(self, target): + def __init__(self, target: 'Target'): if not hasattr(target, 'hwmon'): raise TargetStableError('Target does not support HWMON') super(HwmonInstrument, self).__init__(target) self.logger.debug('Discovering available HWMON sensors...') - for ts in self.target.hwmon.sensors: + for ts in cast(HwmonModule, self.target.hwmon).sensors: try: ts.get_file('input') - measure = self.measure_map.get(ts.kind)[0] + measure_map = self.measure_map.get(ts.kind) + if measure_map: + measure: str = measure_map[0] if measure: self.logger.debug('\tAdding sensor {}'.format(ts.name)) self.add_channel(_guess_site(ts), measure, sensor=ts) @@ -52,16 +59,16 @@ def __init__(self, target): self.logger.debug(message.format(ts.name)) continue - def take_measurement(self): - result = [] + def take_measurement(self) -> List[Measurement]: + result: List[Measurement] = [] for chan in self.active_channels: - convert = self.measure_map[chan.sensor.kind][1] - value = convert(chan.sensor.get('input')) + convert = self.measure_map[chan.sensor.kind][1] # type: ignore + value = convert(chan.sensor.get('input')) # type: ignore result.append(Measurement(value, chan)) return result -def _guess_site(sensor): +def _guess_site(sensor: HwmonSensor): """ HWMON does not specify a standard for labeling its sensors, or for device/item split (the implication is that each hwmon device a separate chip @@ -74,7 +81,7 @@ def _guess_site(sensor): # If no label has been specified for the sensor (in which case, it # defaults to the sensor's name), assume that the "site" of the sensor # is identified by the HWMON device - text = sensor.device.name + text: str = sensor.device.name else: # If a label has been specified, assume multiple sensors controlled by # the same device and the label identifies the site. diff --git a/devlib/module/__init__.py b/devlib/module/__init__.py index c450ba17e..0c6c6dadc 100644 --- a/devlib/module/__init__.py +++ b/devlib/module/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2014-2018 ARM Limited +# Copyright 2014-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,10 +18,45 @@ from devlib.exception import TargetStableError from devlib.utils.types import identifier from devlib.utils.misc import walk_modules +from typing import (Optional, Dict, Union, Type, + TYPE_CHECKING, Any) +if TYPE_CHECKING: + from devlib.target import Target +_module_registry: Dict[str, Type['Module']] = {} -_module_registry = {} -def register_module(mod): +def register_module(mod: Type['Module']) -> None: + """ + Modules are specified on :class:`~devlib.target.Target` or + :class:`~devlib.platform.Platform` creation by name. In order to find the class + associated with the name, the module needs to be registered with ``devlib``. + This is accomplished by passing the module class into :func:`register_module` + method once it is defined. + + .. note:: If you're wiring a module to be included as part of ``devlib`` code + base, you can place the file with the module class under + ``devlib/modules/`` in the source and it will be automatically + enumerated. There is no need to explicitly register it in that case. + + The code snippet below illustrates an implementation of a hard reset function + for an "Acme" device. + + .. code:: python + + import os + from devlib import HardResetModule, register_module + + + class AcmeHardReset(HardResetModule): + + name = 'acme_hard_reset' + + def __call__(self): + # Assuming Acme board comes with a "reset-acme-board" utility + os.system('reset-acme-board {}'.format(self.target.name)) + + register_module(AcmeHardReset) + """ if not issubclass(mod, Module): raise ValueError('A module must subclass devlib.Module') @@ -39,34 +74,87 @@ def register_module(mod): class Module: - - name = None - kind = None - # This is the stage at which the module will be installed. Current valid - # stages are: - # 'early' -- installed when the Target is first created. This should be - # used for modules that do not rely on the main connection - # being established (usually because the commumnitcate with the - # target through some sorto of secondary connection, e.g. via - # serial). - # 'connected' -- installed when a connection to to the target has been - # established. This is the default. - # 'setup' -- installed after initial setup of the device has been performed. - # This allows the module to utilize assets deployed during the - # setup stage for example 'Busybox'. - stage = 'connected' + """ + Modules add additional functionality to the core :class:`~devlib.target.Target` + interface. Usually, it is support for specific subsystems on the target. Modules + are instantiated as attributes of the :class:`~devlib.target.Target` instance. + + Modules implement discrete, optional pieces of functionality ("optional" in the + sense that the functionality may or may not be present on the target device, or + that it may or may not be necessary for a particular application). + + Every module (ultimately) derives from :class:`devlib.module.Module` class. A + module must define the following class attributes: + + :name: A unique name for the module. This cannot clash with any of the existing + names and must be a valid Python identifier, but is otherwise free-form. + :kind: This identifies the type of functionality a module implements, which in + turn determines the interface implemented by the module (all modules of + the same kind must expose a consistent interface). This must be a valid + Python identifier, but is otherwise free-form, though, where possible, + one should try to stick to an already-defined kind/interface, lest we end + up with a bunch of modules implementing similar functionality but + exposing slightly different interfaces. + + .. note:: It is possible to omit ``kind`` when defining a module, in + which case the module's ``name`` will be treated as its + ``kind`` as well. + + :stage: This defines when the module will be installed into a + :class:`~devlib.target.Target`. Currently, the following values are + allowed: + + :connected: The module is installed after a connection to the target has + been established. This is the default. + :early: The module will be installed when a + :class:`~devlib.target.Target` is first created. This should be + used for modules that do not rely on a live connection to the + target. + :setup: The module will be installed after initial setup of the device + has been performed. This allows the module to utilize assets + deployed during the setup stage for example 'Busybox'. + + Additionally, a module must implement a static (or class) method :func:`probe`: + """ + name: Optional[str] = None + kind: Optional[str] = None + attr_name: Optional[str] = None + stage: str = 'connected' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: + """ + This method takes a :class:`~devlib.target.Target` instance and returns + ``True`` if this module is supported by that target, or ``False`` otherwise. + + .. note:: If the module ``stage`` is ``"early"``, this method cannot assume + that a connection has been established (i.e. it can only access + attributes of the Target that do not rely on a connection). + """ raise NotImplementedError() @classmethod - def install(cls, target, **params): - attr_name = cls.attr_name - installed = target._installed_modules + def install(cls, target: 'Target', **params: Type['Module']): + """ + The default installation method will create an instance of a module (the + :class:`~devlib.target.Target` instance being the sole argument) and assign it + to the target instance attribute named after the module's ``kind`` (or + ``name`` if ``kind`` is ``None``). + + It is possible to change the installation procedure for a module by overriding + the default :func:`install` method. The method must have the following + signature: + + .. method:: Module.install(cls, target, **kwargs) + + Install the module into the target instance. + """ + attr_name: Optional[str] = cls.attr_name + installed: Dict[str, 'Module'] = target._installed_modules try: - mod = installed[attr_name] + if attr_name: + mod: 'Module' = installed[attr_name] except KeyError: mod = cls(target, **params) mod.logger.debug(f'Installing module {cls.name}') @@ -79,8 +167,8 @@ def install(cls, target, **params): ): if name is not None: installed[name] = mod - - target._modules[cls.name] = params + if cls.name: + target._modules[cls.name] = params return mod else: raise TargetStableError(f'Module "{cls.name}" is not supported by the target') @@ -89,15 +177,14 @@ def install(cls, target, **params): f'Attempting to install module "{cls.name}" but a module is already installed as attribute "{attr_name}": {mod}' ) - def __init__(self, target): + def __init__(self, target: 'Target'): self.target = target self.logger = logging.getLogger(self.name) - - def __init_subclass__(cls, *args, **kwargs): + def __init_subclass__(cls, *args, **kwargs) -> None: super().__init_subclass__(*args, **kwargs) - attr_name = cls.kind or cls.name + attr_name: Optional[str] = cls.kind or cls.name cls.attr_name = identifier(attr_name) if attr_name else None if cls.name is not None: @@ -105,21 +192,60 @@ def __init_subclass__(cls, *args, **kwargs): class HardRestModule(Module): + """ + .. attribute:: HardResetModule.kind - kind = 'hard_reset' + "hard_reset" + """ + + kind: str = 'hard_reset' def __call__(self): + """ + .. method:: HardResetModule.__call__() + + Must be implemented by derived classes. + + Implements hard reset for a target devices. The equivalent of physically + power cycling the device. This may be used by client code in situations + where the target becomes unresponsive and/or a regular reboot is not + possible. + """ raise NotImplementedError() class BootModule(Module): + """ + .. attribute:: BootModule.kind - kind = 'boot' + "boot" + """ + + kind: str = 'boot' def __call__(self): + """ + .. method:: BootModule.__call__() + + Must be implemented by derived classes. + + Implements a boot procedure. This takes the device from (hard or soft) + reset to a booted state where the device is ready to accept connections. For + a lot of commercial devices the process is entirely automatic, however some + devices (e.g. development boards), my require additional steps, such as + interactions with the bootloader, in order to boot into the OS. + """ raise NotImplementedError() - def update(self, **kwargs): + def update(self, **kwargs) -> None: + """ + .. method:: Bootmodule.update(**kwargs) + + Update the boot settings. Some boot sequences allow specifying settings + that will be utilized during boot (e.g. linux kernel boot command line). The + default implementation will set each setting in ``kwargs`` as an attribute of + the boot module (or update the existing attribute). + """ for name, value in kwargs.items(): if not hasattr(self, name): raise ValueError('Unknown parameter "{}" for {}'.format(name, self.name)) @@ -128,15 +254,40 @@ def update(self, **kwargs): class FlashModule(Module): - - kind = 'flash' - - def __call__(self, image_bundle=None, images=None, boot_config=None, connect=True): + """ + .. attribute:: FlashModule.kind + + "flash" + """ + kind: str = 'flash' + + def __call__(self, image_bundle: Optional[str] = None, + images: Optional[Dict[str, str]] = None, + boot_config: Any = None, connect: bool = True) -> None: + """ + .. method:: __call__(image_bundle=None, images=None, boot_config=None, connect=True) + + Must be implemented by derived classes. + + Flash the target platform with the specified images. + + :param image_bundle: A compressed bundle of image files with any associated + metadata. The format of the bundle is specific to a + particular implementation. + :param images: A dict mapping image names/identifiers to the path on the + host file system of the corresponding image file. If both + this and ``image_bundle`` are specified, individual images + will override those in the bundle. + :param boot_config: Some platforms require specifying boot arguments at the + time of flashing the images, rather than during each + reboot. For other platforms, this will be ignored. + :connect: Specifiy whether to try and connect to the target after flashing. + """ raise NotImplementedError() -def get_module(mod): - def from_registry(mod): +def get_module(mod: Union[str, Type[Module]]) -> Type[Module]: + def from_registry(mod: str): try: return _module_registry[mod] except KeyError: diff --git a/devlib/module/android.py b/devlib/module/android.py index 70564fd05..c9c75ec11 100644 --- a/devlib/module/android.py +++ b/devlib/module/android.py @@ -1,4 +1,4 @@ -# Copyright 2014-2018 ARM Limited +# Copyright 2014-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,15 +23,18 @@ from devlib.exception import HostError from devlib.utils.android import fastboot_flash_partition, fastboot_command from devlib.utils.misc import merge_dicts, safe_extract +from typing import (TYPE_CHECKING, Any, Optional, Dict, List, cast) +if TYPE_CHECKING: + from devlib.target import Target, AndroidTarget class FastbootFlashModule(FlashModule): - name = 'fastboot' - description = """ + name: str = 'fastboot' + description: str = """ Enables automated flashing of images using the fastboot utility. - To use this flasher, a set of image files to be flused are required. + To use this flasher, a set of image files to be flashed are required. In addition a mapping between partitions and image file is required. There are two ways to specify those requirements: @@ -47,59 +50,68 @@ class FastbootFlashModule(FlashModule): """ - delay = 0.5 - partitions_file_name = 'partitions.txt' + delay: float = 0.5 + partitions_file_name: str = 'partitions.txt' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: return target.os == 'android' - def __call__(self, image_bundle=None, images=None, bootargs=None, connect=True): + def __call__(self, image_bundle: Optional[str] = None, + images: Optional[Dict[str, str]] = None, + bootargs: Any = None, connect: bool = True) -> None: if bootargs: raise ValueError('{} does not support boot configuration'.format(self.name)) - self.prelude_done = False - to_flash = {} + self.prelude_done: bool = False + to_flash: Dict[str, str] = {} if image_bundle: # pylint: disable=access-member-before-definition image_bundle = expand_path(image_bundle) to_flash = self._bundle_to_images(image_bundle) to_flash = merge_dicts(to_flash, images or {}, should_normalize=False) for partition, image_path in to_flash.items(): self.logger.debug('flashing {}'.format(partition)) - self._flash_image(self.target, partition, expand_path(image_path)) + self._flash_image(cast('AndroidTarget', self.target), partition, expand_path(image_path)) fastboot_command('reboot') if connect: self.target.connect(timeout=180) - def _validate_image_bundle(self, image_bundle): + def _validate_image_bundle(self, image_bundle: str) -> None: + """ + make sure the image bundle is a tarfile and it can be opened and it contains the + required partition file + """ if not tarfile.is_tarfile(image_bundle): raise HostError('File {} is not a tarfile'.format(image_bundle)) with tarfile.open(image_bundle) as tar: - files = [tf.name for tf in tar.getmembers()] + files: List[str] = [tf.name for tf in tar.getmembers()] if not any(pf in files for pf in (self.partitions_file_name, '{}/{}'.format(files[0], self.partitions_file_name))): HostError('Image bundle does not contain the required partition file (see documentation)') - def _bundle_to_images(self, image_bundle): + def _bundle_to_images(self, image_bundle: str) -> Dict[str, str]: """ Extracts the bundle to a temporary location and creates a mapping between the contents of the bundle - and images to be flushed. + and images to be flashed. """ self._validate_image_bundle(image_bundle) - extract_dir = tempfile.mkdtemp() + extract_dir: str = tempfile.mkdtemp() with tarfile.open(image_bundle) as tar: safe_extract(tar, path=extract_dir) - files = [tf.name for tf in tar.getmembers()] + files: List[str] = [tf.name for tf in tar.getmembers()] if self.partitions_file_name not in files: extract_dir = os.path.join(extract_dir, files[0]) - partition_file = os.path.join(extract_dir, self.partitions_file_name) + partition_file: str = os.path.join(extract_dir, self.partitions_file_name) return get_mapping(extract_dir, partition_file) - def _flash_image(self, target, partition, image_path): + def _flash_image(self, target: 'AndroidTarget', partition: str, image_path: str) -> None: + """ + flash the image into the partition using fastboot + """ if not self.prelude_done: self._fastboot_prelude(target) fastboot_flash_partition(partition, image_path) time.sleep(self.delay) - def _fastboot_prelude(self, target): + def _fastboot_prelude(self, target: 'AndroidTarget') -> None: target.reset(fastboot=True) time.sleep(self.delay) self.prelude_done = True @@ -107,15 +119,21 @@ def _fastboot_prelude(self, target): # utility functions -def expand_path(original_path): +def expand_path(original_path: str) -> str: + """ + expand ~ and ~user in the path + """ path = os.path.abspath(os.path.expanduser(original_path)) if not os.path.exists(path): raise HostError('{} does not exist.'.format(path)) return path -def get_mapping(base_dir, partition_file): - mapping = {} +def get_mapping(base_dir: str, partition_file: str) -> Dict[str, str]: + """ + get the image and partition mapping info from partition txt file + """ + mapping: Dict[str, str] = {} with open(partition_file) as pf: for line in pf: pair = line.split() diff --git a/devlib/module/biglittle.py b/devlib/module/biglittle.py index 7124f65a5..f3e60258c 100644 --- a/devlib/module/biglittle.py +++ b/devlib/module/biglittle.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,12 @@ # from devlib.module import Module +from devlib.module.hotplug import HotplugModule +from devlib.module.cpufreq import CpufreqModule +from typing import (TYPE_CHECKING, cast, List, + Optional, Dict) +if TYPE_CHECKING: + from devlib.target import Target class BigLittleModule(Module): @@ -21,189 +27,307 @@ class BigLittleModule(Module): name = 'bl' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: return target.big_core is not None @property - def bigs(self): + def bigs(self) -> List[int]: + """ + get the list of big cores + """ return [i for i, c in enumerate(self.target.platform.core_names) if c == self.target.platform.big_core] @property - def littles(self): + def littles(self) -> List[int]: + """ + get the list of little cores + """ return [i for i, c in enumerate(self.target.platform.core_names) if c == self.target.platform.little_core] @property - def bigs_online(self): + def bigs_online(self) -> List[int]: + """ + get the list of big cores which are online + """ return list(sorted(set(self.bigs).intersection(self.target.list_online_cpus()))) @property - def littles_online(self): + def littles_online(self) -> List[int]: + """ + get the list of little cores which are online + """ return list(sorted(set(self.littles).intersection(self.target.list_online_cpus()))) # hotplug - def online_all_bigs(self): - self.target.hotplug.online(*self.bigs) - - def offline_all_bigs(self): - self.target.hotplug.offline(*self.bigs) - - def online_all_littles(self): - self.target.hotplug.online(*self.littles) - - def offline_all_littles(self): - self.target.hotplug.offline(*self.littles) + def online_all_bigs(self) -> None: + """ + make all big cores go online + """ + cast(HotplugModule, self.target.hotplug).online(*self.bigs) + + def offline_all_bigs(self) -> None: + """ + make all big cores go offline + """ + cast(HotplugModule, self.target.hotplug).offline(*self.bigs) + + def online_all_littles(self) -> None: + """ + make all little cores go online + """ + cast(HotplugModule, self.target.hotplug).online(*self.littles) + + def offline_all_littles(self) -> None: + """ + make all little cores go offline + """ + cast(HotplugModule, self.target.hotplug).offline(*self.littles) # cpufreq - def list_bigs_frequencies(self): + def list_bigs_frequencies(self) -> Optional[List[int]]: + """ + get the big cores frequencies + """ bigs_online = self.bigs_online if bigs_online: - return self.target.cpufreq.list_frequencies(bigs_online[0]) + return cast(CpufreqModule, self.target.cpufreq).list_frequencies(bigs_online[0]) + return None - def list_bigs_governors(self): + def list_bigs_governors(self) -> Optional[List[str]]: + """ + get the governors supported for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - return self.target.cpufreq.list_governors(bigs_online[0]) + return cast(CpufreqModule, self.target.cpufreq).list_governors(bigs_online[0]) + return None - def list_bigs_governor_tunables(self): + def list_bigs_governor_tunables(self) -> Optional[List[str]]: + """ + get the tunable governors supported for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - return self.target.cpufreq.list_governor_tunables(bigs_online[0]) + return cast(CpufreqModule, self.target.cpufreq).list_governor_tunables(bigs_online[0]) + return None - def list_littles_frequencies(self): + def list_littles_frequencies(self) -> Optional[List[int]]: + """ + get the little cores frequencies + """ littles_online = self.littles_online if littles_online: - return self.target.cpufreq.list_frequencies(littles_online[0]) + return cast(CpufreqModule, self.target.cpufreq).list_frequencies(littles_online[0]) + return None - def list_littles_governors(self): + def list_littles_governors(self) -> Optional[List[str]]: + """ + get the governors supported for the first little core that is online + """ littles_online = self.littles_online if littles_online: - return self.target.cpufreq.list_governors(littles_online[0]) + return cast(CpufreqModule, self.target.cpufreq).list_governors(littles_online[0]) + return None - def list_littles_governor_tunables(self): + def list_littles_governor_tunables(self) -> Optional[List[str]]: + """ + get the tunable governors supported for the first little core that is online + """ littles_online = self.littles_online if littles_online: - return self.target.cpufreq.list_governor_tunables(littles_online[0]) + return cast(CpufreqModule, self.target.cpufreq).list_governor_tunables(littles_online[0]) + return None - def get_bigs_governor(self): + def get_bigs_governor(self) -> Optional[str]: + """ + get the current governor set for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - return self.target.cpufreq.get_governor(bigs_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_governor(bigs_online[0]) + return None - def get_bigs_governor_tunables(self): + def get_bigs_governor_tunables(self) -> Optional[Dict[str, str]]: + """ + get the current governor tunables set for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - return self.target.cpufreq.get_governor_tunables(bigs_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_governor_tunables(bigs_online[0]) + return None - def get_bigs_frequency(self): + def get_bigs_frequency(self) -> Optional[int]: + """ + get the current frequency that is set for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - return self.target.cpufreq.get_frequency(bigs_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_frequency(bigs_online[0]) + return None - def get_bigs_min_frequency(self): + def get_bigs_min_frequency(self) -> Optional[int]: + """ + get the current minimum frequency that is set for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - return self.target.cpufreq.get_min_frequency(bigs_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_min_frequency(bigs_online[0]) + return None - def get_bigs_max_frequency(self): + def get_bigs_max_frequency(self) -> Optional[int]: + """ + get the current maximum frequency that is set for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - return self.target.cpufreq.get_max_frequency(bigs_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_max_frequency(bigs_online[0]) + return None - def get_littles_governor(self): + def get_littles_governor(self) -> Optional[str]: + """ + get the current governor set for the first little core that is online + """ littles_online = self.littles_online if littles_online: - return self.target.cpufreq.get_governor(littles_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_governor(littles_online[0]) + return None - def get_littles_governor_tunables(self): + def get_littles_governor_tunables(self) -> Optional[Dict[str, str]]: + """ + get the current governor tunables set for the first little core that is online + """ littles_online = self.littles_online if littles_online: - return self.target.cpufreq.get_governor_tunables(littles_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_governor_tunables(littles_online[0]) + return None - def get_littles_frequency(self): + def get_littles_frequency(self) -> Optional[int]: + """ + get the current frequency that is set for the first little core that is online + """ littles_online = self.littles_online if littles_online: - return self.target.cpufreq.get_frequency(littles_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_frequency(littles_online[0]) + return None - def get_littles_min_frequency(self): + def get_littles_min_frequency(self) -> Optional[int]: + """ + get the current minimum frequency that is set for the first little core that is online + """ littles_online = self.littles_online if littles_online: - return self.target.cpufreq.get_min_frequency(littles_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_min_frequency(littles_online[0]) + return None - def get_littles_max_frequency(self): + def get_littles_max_frequency(self) -> Optional[int]: + """ + get the current maximum frequency that is set for the first little core that is online + """ littles_online = self.littles_online if littles_online: - return self.target.cpufreq.get_max_frequency(littles_online[0]) + return cast(CpufreqModule, self.target.cpufreq).get_max_frequency(littles_online[0]) + return None - def set_bigs_governor(self, governor, **kwargs): + def set_bigs_governor(self, governor: str, **kwargs) -> None: + """ + set governor for the first online big core + """ bigs_online = self.bigs_online if bigs_online: - self.target.cpufreq.set_governor(bigs_online[0], governor, **kwargs) + cast(CpufreqModule, self.target.cpufreq).set_governor(bigs_online[0], governor, **kwargs) else: raise ValueError("All bigs appear to be offline") - def set_bigs_governor_tunables(self, governor, **kwargs): + def set_bigs_governor_tunables(self, governor: str, **kwargs) -> None: + """ + set governor tunables for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - self.target.cpufreq.set_governor_tunables(bigs_online[0], governor, **kwargs) + cast(CpufreqModule, self.target.cpufreq).set_governor_tunables(bigs_online[0], governor, **kwargs) else: raise ValueError("All bigs appear to be offline") - def set_bigs_frequency(self, frequency, exact=True): + def set_bigs_frequency(self, frequency: int, exact: bool = True) -> None: + """ + set the frequency for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - self.target.cpufreq.set_frequency(bigs_online[0], frequency, exact) + cast(CpufreqModule, self.target.cpufreq).set_frequency(bigs_online[0], frequency, exact) else: raise ValueError("All bigs appear to be offline") - def set_bigs_min_frequency(self, frequency, exact=True): + def set_bigs_min_frequency(self, frequency: int, exact: bool = True) -> None: + """ + set the minimum value for the cpu frequency for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - self.target.cpufreq.set_min_frequency(bigs_online[0], frequency, exact) + cast(CpufreqModule, self.target.cpufreq).set_min_frequency(bigs_online[0], frequency, exact) else: raise ValueError("All bigs appear to be offline") - def set_bigs_max_frequency(self, frequency, exact=True): + def set_bigs_max_frequency(self, frequency: int, exact: bool = True) -> None: + """ + set the minimum value for the cpu frequency for the first big core that is online + """ bigs_online = self.bigs_online if bigs_online: - self.target.cpufreq.set_max_frequency(bigs_online[0], frequency, exact) + cast(CpufreqModule, self.target.cpufreq).set_max_frequency(bigs_online[0], frequency, exact) else: raise ValueError("All bigs appear to be offline") - def set_littles_governor(self, governor, **kwargs): + def set_littles_governor(self, governor: str, **kwargs) -> None: + """ + set governor for the first online little core + """ littles_online = self.littles_online if littles_online: - self.target.cpufreq.set_governor(littles_online[0], governor, **kwargs) + cast(CpufreqModule, self.target.cpufreq).set_governor(littles_online[0], governor, **kwargs) else: raise ValueError("All littles appear to be offline") - def set_littles_governor_tunables(self, governor, **kwargs): + def set_littles_governor_tunables(self, governor: str, **kwargs) -> None: + """ + set governor tunables for the first little core that is online + """ littles_online = self.littles_online if littles_online: - self.target.cpufreq.set_governor_tunables(littles_online[0], governor, **kwargs) + cast(CpufreqModule, self.target.cpufreq).set_governor_tunables(littles_online[0], governor, **kwargs) else: raise ValueError("All littles appear to be offline") - def set_littles_frequency(self, frequency, exact=True): + def set_littles_frequency(self, frequency: int, exact: bool = True) -> None: + """ + set the frequency for the first little core that is online + """ littles_online = self.littles_online if littles_online: - self.target.cpufreq.set_frequency(littles_online[0], frequency, exact) + cast(CpufreqModule, self.target.cpufreq).set_frequency(littles_online[0], frequency, exact) else: raise ValueError("All littles appear to be offline") - def set_littles_min_frequency(self, frequency, exact=True): + def set_littles_min_frequency(self, frequency: int, exact: bool = True) -> None: + """ + set the minimum value for the cpu frequency for the first little core that is online + """ littles_online = self.littles_online if littles_online: - self.target.cpufreq.set_min_frequency(littles_online[0], frequency, exact) + cast(CpufreqModule, self.target.cpufreq).set_min_frequency(littles_online[0], frequency, exact) else: raise ValueError("All littles appear to be offline") - def set_littles_max_frequency(self, frequency, exact=True): + def set_littles_max_frequency(self, frequency: int, exact: bool = True) -> None: + """ + set the maximum value for the cpu frequency for the first little core that is online + """ littles_online = self.littles_online if littles_online: - self.target.cpufreq.set_max_frequency(littles_online[0], frequency, exact) + cast(CpufreqModule, self.target.cpufreq).set_max_frequency(littles_online[0], frequency, exact) else: raise ValueError("All littles appear to be offline") diff --git a/devlib/module/cgroups.py b/devlib/module/cgroups.py index a7edf879c..3bcf6f6dc 100644 --- a/devlib/module/cgroups.py +++ b/devlib/module/cgroups.py @@ -1,4 +1,4 @@ -# Copyright 2014-2018 ARM Limited +# Copyright 2014-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,11 +25,16 @@ from devlib.utils.misc import list_to_ranges, isiterable from devlib.utils.types import boolean from devlib.utils.asyn import asyncf, run +from typing import (TYPE_CHECKING, Optional, List, Dict, + Union, Tuple, cast, Set) + +if TYPE_CHECKING: + from devlib.target import Target, FstabEntry class Controller(object): - def __init__(self, kind, hid, clist): + def __init__(self, kind: str, hid: int, clist: List[str]): """ Initialize a controller given the hierarchy it belongs to. @@ -42,51 +47,54 @@ def __init__(self, kind, hid, clist): :param clist: the list of controller mounted in the same hierarchy :type clist: list(str) """ - self.mount_name = 'devlib_cgh{}'.format(hid) - self.kind = kind - self.hid = hid - self.clist = clist - self.target = None - self._noprefix = False - - self.logger = logging.getLogger('CGroup.'+self.kind) + self.mount_name: str = 'devlib_cgh{}'.format(hid) + self.kind: str = kind + self.hid: int = hid + self.clist: List[str] = clist + self.target: Optional['Target'] = None + self._noprefix: bool = False + + self.logger: logging.Logger = logging.getLogger('CGroup.' + self.kind) self.logger.debug('Initialized [%s, %d, %s]', self.kind, self.hid, self.clist) - self.mount_point = None - self._cgroups = {} + self.mount_point: Optional[str] = None + self._cgroups: Dict[str, 'CGroup'] = {} @asyncf - async def mount(self, target, mount_root): - - mounted = target.list_file_systems() + async def mount(self, target: 'Target', mount_root: str) -> None: + """ + mount the controller in mount point + """ + mounted: List[FstabEntry] = target.list_file_systems() if self.mount_name in [e.device for e in mounted]: # Identify mount point if controller is already in use self.mount_point = [ - fs.mount_point - for fs in mounted - if fs.device == self.mount_name - ][0] + fs.mount_point + for fs in mounted + if fs.device == self.mount_name + ][0] else: # Mount the controller if not already in use - self.mount_point = target.path.join(mount_root, self.mount_name) - await target.execute.asyn('mkdir -p {} 2>/dev/null'\ - .format(self.mount_point), as_root=True) - await target.execute.asyn('mount -t cgroup -o {} {} {}'\ - .format(','.join(self.clist), - self.mount_name, - self.mount_point), - as_root=True) + if target.path: + self.mount_point = target.path.join(mount_root, self.mount_name) + await target.execute.asyn('mkdir -p {} 2>/dev/null' + .format(self.mount_point), as_root=True) + await target.execute.asyn('mount -t cgroup -o {} {} {}' + .format(','.join(self.clist), + self.mount_name, + self.mount_point), + as_root=True) # Check if this controller uses "noprefix" option - output = await target.execute.asyn('mount | grep "{} "'.format(self.mount_name)) + output: str = await target.execute.asyn('mount | grep "{} "'.format(self.mount_name)) if 'noprefix' in output: self._noprefix = True # self.logger.debug('Controller %s using "noprefix" option', # self.kind) self.logger.debug('Controller %s mounted under: %s (noprefix=%s)', - self.kind, self.mount_point, self._noprefix) + self.kind, self.mount_point, self._noprefix) # Mark this contoller as available self.target = target @@ -94,39 +102,51 @@ async def mount(self, target, mount_root): # Create root control group self.cgroup('/') - def cgroup(self, name): + def cgroup(self, name: str) -> 'CGroup': + """ + get the control group with the name + """ if not self.target: - raise RuntimeError('CGroup creation failed: {} controller not mounted'\ - .format(self.kind)) + raise RuntimeError('CGroup creation failed: {} controller not mounted' + .format(self.kind)) if name not in self._cgroups: self._cgroups[name] = CGroup(self, name) return self._cgroups[name] - def exists(self, name): + def exists(self, name: str) -> bool: + """ + returns True if the control group with this name exists + """ if not self.target: - raise RuntimeError('CGroup creation failed: {} controller not mounted'\ - .format(self.kind)) + raise RuntimeError('CGroup creation failed: {} controller not mounted' + .format(self.kind)) if name not in self._cgroups: self._cgroups[name] = CGroup(self, name, create=False) return self._cgroups[name].exists() - def list_all(self): + def list_all(self) -> List[str]: + """ + List all control groups for this controller + """ self.logger.debug('Listing groups for %s controller', self.kind) - output = self.target.execute('{} find {} -type d'\ - .format(self.target.busybox, self.mount_point), - as_root=True) - cgroups = [] - for cg in output.splitlines(): - cg = cg.replace(self.mount_point + '/', '/') - cg = cg.replace(self.mount_point, '/') - cg = cg.strip() - if cg == '': - continue - self.logger.debug('Populate %s cgroup: %s', self.kind, cg) - cgroups.append(cg) + if self.target: + output: str = self.target.execute('{} find {} -type d' + .format(self.target.busybox, self.mount_point), + as_root=True) + cgroups: List[str] = [] + for cg in output.splitlines(): + if self.mount_point: + cg = cg.replace(self.mount_point + '/', '/') + cg = cg.replace(self.mount_point, '/') + cg = cg.strip() + if cg == '': + continue + self.logger.debug('Populate %s cgroup: %s', self.kind, cg) + cgroups.append(cg) return cgroups - def move_tasks(self, source, dest, exclude=None): + def move_tasks(self, source: str, dest: str, + exclude: Optional[Union[str, List[str]]] = None) -> None: if isinstance(exclude, str): warnings.warn("Controller.move_tasks() takes needs a _list_ of exclude patterns, not a string", DeprecationWarning) exclude = [exclude] @@ -143,17 +163,17 @@ def move_tasks(self, source, dest, exclude=None): srcg = self.cgroup(source) dstg = self.cgroup(dest) + if self.target and srcg.directory and dstg.directory: + self.target._execute_util( # pylint: disable=protected-access + 'cgroups_tasks_move {src} {dst} {exclude}'.format( + src=quote(srcg.directory), + dst=quote(dstg.directory), + exclude=exclude, + ), + as_root=True, + ) - self.target._execute_util( # pylint: disable=protected-access - 'cgroups_tasks_move {src} {dst} {exclude}'.format( - src=quote(srcg.directory), - dst=quote(dstg.directory), - exclude=exclude, - ), - as_root=True, - ) - - def move_all_tasks_to(self, dest, exclude=None): + def move_all_tasks_to(self, dest: str, exclude: Optional[Union[str, List[str]]] = None) -> None: """ Move all the tasks to the specified CGroup @@ -187,10 +207,10 @@ def move_all_tasks_to(self, dest, exclude=None): self.move_tasks(cgroup, dest, exclude) # pylint: disable=too-many-locals - def tasks(self, cgroup, - filter_tid='', - filter_tname='', - filter_tcmdline=''): + def tasks(self, cgroup: str, + filter_tid: str = '', + filter_tname: str = '', + filter_tcmdline: str = '') -> Dict[int, Tuple[str, str]]: """ Report the tasks that are included in a cgroup. The tasks can be filtered by their tid, tname or tcmdline if filter_tid, filter_tname or @@ -222,14 +242,16 @@ def tasks(self, cgroup, cg = self._cgroups[cgroup] except KeyError as e: raise ValueError('Unknown group: {}'.format(e)) - output = self.target._execute_util( # pylint: disable=protected-access - 'cgroups_tasks_in {}'.format(cg.directory), - as_root=True) - entries = output.splitlines() - tasks = {} + if self.target is None: + raise ValueError("Target is None") + output: str = self.target._execute_util( # pylint: disable=protected-access + 'cgroups_tasks_in {}'.format(cg.directory), + as_root=True) + entries: List[str] = output.splitlines() + tasks: Dict[int, Tuple[str, str]] = {} for task in entries: - fields = task.split(',', 2) - nr_fields = len(fields) + fields: List[str] = task.split(',', 2) + nr_fields: int = len(fields) if nr_fields < 2: continue elif nr_fields == 2: @@ -248,65 +270,86 @@ def tasks(self, cgroup, tasks[int(tid_str)] = (tname, tcmdline) return tasks - def tasks_count(self, cgroup): + def tasks_count(self, cgroup: str) -> int: + """ + count of the number of tasks in the cgroup + """ try: cg = self._cgroups[cgroup] except KeyError as e: raise ValueError('Unknown group: {}'.format(e)) + if self.target is None: + raise ValueError("Target is None") output = self.target.execute( - '{} wc -l {}/tasks'.format( - self.target.busybox, cg.directory), - as_root=True) + '{} wc -l {}/tasks'.format( + self.target.busybox, cg.directory), + as_root=True) return int(output.split()[0]) - def tasks_per_group(self): - tasks = {} + def tasks_per_group(self) -> Dict[str, int]: + """ + tasks in all cgroups + """ + tasks: Dict[str, int] = {} for cg in self.list_all(): tasks[cg] = self.tasks_count(cg) return tasks + class CGroup(object): - def __init__(self, controller, name, create=True): - self.logger = logging.getLogger('cgroups.' + controller.kind) - self.target = controller.target - self.controller = controller - self.name = name + def __init__(self, controller: 'Controller', name: str, create: bool = True): + self.logger: logging.Logger = logging.getLogger('cgroups.' + controller.kind) + self.target: Optional['Target'] = controller.target + self.controller: Controller = controller + self.name: str = name # Control cgroup path - self.directory = controller.mount_point - + self.directory: Optional[str] = controller.mount_point + if self.target is None: + raise ValueError("Target is None") + if self.target.path is None: + raise ValueError("Target.path is None") if name != '/': self.directory = self.target.path.join(controller.mount_point, name.strip('/')) # Setup path for tasks file - self.tasks_file = self.target.path.join(self.directory, 'tasks') - self.procs_file = self.target.path.join(self.directory, 'cgroup.procs') + self.tasks_file: str = self.target.path.join(self.directory, 'tasks') + self.procs_file: str = self.target.path.join(self.directory, 'cgroup.procs') if not create: return self.logger.debug('Creating cgroup %s', self.directory) - self.target.execute('[ -d {0} ] || mkdir -p {0}'\ - .format(self.directory), as_root=True) + self.target.execute('[ -d {0} ] || mkdir -p {0}' + .format(self.directory), as_root=True) - def exists(self): + def exists(self) -> bool: + """ + return true if the directory of the control group exists + """ try: - self.target.execute('[ -d {0} ]'\ - .format(self.directory), as_root=True) + if self.target is None: + raise TargetStableError + self.target.execute('[ -d {0} ]' + .format(self.directory), as_root=True) return True except TargetStableError: return False - def get(self): - conf = {} - + def get(self) -> Dict[str, str]: + """ + get attributes and associated value from control groups + """ + conf: Dict[str, str] = {} + if self.target is None: + raise ValueError("Target is None") self.logger.debug('Reading %s attributes from:', self.controller.kind) self.logger.debug(' %s', self.directory) - output = self.target._execute_util( # pylint: disable=protected-access - 'cgroups_get_attributes {} {}'.format( - self.directory, self.controller.kind), - as_root=True) + output: str = self.target._execute_util( # pylint: disable=protected-access + 'cgroups_get_attributes {} {}'.format( + self.directory, self.controller.kind), + as_root=True) for res in output.splitlines(): attr = res.split(':')[0] value = res.split(':')[1] @@ -314,78 +357,98 @@ def get(self): return conf - def set(self, **attrs): + def set(self, **attrs: Union[str, List[int], int]) -> None: + """ + set attributes to the control group + """ for idx in attrs: if isiterable(attrs[idx]): - attrs[idx] = list_to_ranges(attrs[idx]) + attrs[idx] = list_to_ranges(cast(List, attrs[idx])) # Build attribute path if self.controller._noprefix: # pylint: disable=protected-access attr_name = '{}'.format(idx) else: attr_name = '{}.{}'.format(self.controller.kind, idx) - path = self.target.path.join(self.directory, attr_name) + path: str = self.target.path.join(self.directory, attr_name) if self.target and self.target.path else '' self.logger.debug('Set attribute [%s] to: %s"', - path, attrs[idx]) + path, attrs[idx]) # Set the attribute value try: - self.target.write_value(path, attrs[idx]) + if self.target: + self.target.write_value(path, attrs[idx]) except TargetStableError: # Check if the error is due to a non-existing attribute - attrs = self.get() - if idx not in attrs: - raise ValueError('Controller [{}] does not provide attribute [{}]'\ + attrs_int = self.get() + if idx not in attrs_int: + raise ValueError('Controller [{}] does not provide attribute [{}]' .format(self.controller.kind, attr_name)) raise - def get_tasks(self): - task_ids = self.target.read_value(self.tasks_file).split() + def get_tasks(self) -> List[int]: + """ + get the ids of tasks in the control group + """ + task_ids: List[str] = self.target.read_value(self.tasks_file).split() if self.target else [] self.logger.debug('Tasks: %s', task_ids) return list(map(int, task_ids)) - def add_task(self, tid): - self.target.write_value(self.tasks_file, tid, verify=False) + def add_task(self, tid: int) -> None: + """ + add task to the control group + """ + if self.target: + self.target.write_value(self.tasks_file, tid, verify=False) - def add_tasks(self, tasks): + def add_tasks(self, tasks: List[int]) -> None: + """ + add multiple tasks to the control group + """ for tid in tasks: self.add_task(tid) - def add_proc(self, pid): - self.target.write_value(self.procs_file, pid, verify=False) + def add_proc(self, pid: int) -> None: + """ + add process to the control group + """ + if self.target: + self.target.write_value(self.procs_file, pid, verify=False) + CgroupSubsystemEntry = namedtuple('CgroupSubsystemEntry', 'name hierarchy num_cgroups enabled') + class CgroupsModule(Module): - name = 'cgroups' - stage = 'setup' + name: str = 'cgroups' + stage: str = 'setup' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: if not target.is_rooted: return False if target.file_exists('/proc/cgroups'): return True return target.config.has('cgroups') - def __init__(self, target): + def __init__(self, target: 'Target'): super(CgroupsModule, self).__init__(target) - self.logger = logging.getLogger('CGroups') + self.logger: logging.Logger = logging.getLogger('CGroups') # Set Devlib's CGroups mount point - self.cgroup_root = target.path.join( - target.working_directory, 'cgroups') + self.cgroup_root: str = target.path.join( + target.working_directory, 'cgroups') if target.path else '' # Get the list of the available controllers - subsys = self.list_subsystems() + subsys: List['CgroupSubsystemEntry'] = self.list_subsystems() if not subsys: self.logger.warning('No CGroups controller available') return # Map hierarchy IDs into a list of controllers - hierarchy = {} + hierarchy: Dict[int, List[str]] = {} for ss in subsys: try: hierarchy[ss.hierarchy].append(ss.name) @@ -395,10 +458,13 @@ def __init__(self, target): # Initialize controllers self.logger.info('Available controllers:') - self.controllers = {} + self.controllers: Dict[str, Controller] = {} - async def register_controller(ss): - hid = ss.hierarchy + async def register_controller(ss: 'CgroupSubsystemEntry') -> None: + """ + register controller to control group module + """ + hid: int = ss.hierarchy controller = Controller(ss.name, hid, hierarchy[hid]) try: await controller.mount.asyn(self.target, self.cgroup_root) @@ -416,11 +482,13 @@ async def register_controller(ss): ) ) - - def list_subsystems(self): - subsystems = [] - for line in self.target.execute('{} cat /proc/cgroups'\ - .format(self.target.busybox), as_root=self.target.is_rooted).splitlines()[1:]: + def list_subsystems(self) -> List['CgroupSubsystemEntry']: + """ + get the list of subsystems as a list of class:CgroupSubsystemEntry objects + """ + subsystems: List['CgroupSubsystemEntry'] = [] + for line in self.target.execute('{} cat /proc/cgroups' + .format(self.target.busybox), as_root=self.target.is_rooted).splitlines()[1:]: line = line.strip() if not line or line.startswith('#') or line.endswith('0'): continue @@ -431,14 +499,16 @@ def list_subsystems(self): boolean(enabled))) return subsystems - - def controller(self, kind): + def controller(self, kind: str) -> Optional[Controller]: + """ + get the controller of the specified kind + """ if kind not in self.controllers: self.logger.warning('Controller %s not available', kind) return None return self.controllers[kind] - def run_into_cmd(self, cgroup, cmdline): + def run_into_cmd(self, cgroup: str, cmdline: str) -> str: """ Get the command to run a command into a given cgroup @@ -450,10 +520,10 @@ def run_into_cmd(self, cgroup, cmdline): message = 'cgroup name "{}" must start with "/"'.format(cgroup) raise ValueError(message) return 'CGMOUNT={} {} cgroups_run_into {} {}'\ - .format(self.cgroup_root, self.target.shutils, - cgroup, cmdline) + .format(self.cgroup_root, self.target.shutils, + cgroup, cmdline) - def run_into(self, cgroup, cmdline, as_root=None): + def run_into(self, cgroup: str, cmdline: str, as_root: Optional[bool] = None) -> str: """ Run the specified command into the specified CGroup @@ -465,13 +535,13 @@ def run_into(self, cgroup, cmdline, as_root=None): """ if as_root is None: as_root = self.target.is_rooted - cmd = self.run_into_cmd(cgroup, cmdline) - raw_output = self.target.execute(cmd, as_root=as_root) + cmd: str = self.run_into_cmd(cgroup, cmdline) + raw_output: str = self.target.execute(cmd, as_root=as_root) # First line of output comes from shutils; strip it out. return raw_output.split('\n', 1)[1] - def cgroups_tasks_move(self, srcg, dstg, exclude=''): + def cgroups_tasks_move(self, srcg: str, dstg: str, exclude: str = '') -> str: """ Move all the tasks from the srcg CGroup to the dstg one. A regexps of tasks names can be used to defined tasks which should not @@ -481,7 +551,7 @@ def cgroups_tasks_move(self, srcg, dstg, exclude=''): 'cgroups_tasks_move {} {} {}'.format(srcg, dstg, exclude), as_root=True) - def isolate(self, cpus, exclude=None): + def isolate(self, cpus: List[int], exclude: Optional[List[str]] = None) -> Tuple[CGroup, CGroup]: """ Remove all userspace tasks from specified CPUs. @@ -500,12 +570,14 @@ def isolate(self, cpus, exclude=None): """ if exclude is None: exclude = [] - all_cpus = set(range(self.target.number_of_cpus)) - sbox_cpus = list(all_cpus - set(cpus)) - isol_cpus = list(all_cpus - set(sbox_cpus)) + all_cpus: Set[int] = set(range(self.target.number_of_cpus)) + sbox_cpus: List[int] = list(all_cpus - set(cpus)) + isol_cpus: List[int] = list(all_cpus - set(sbox_cpus)) # Create Sandbox and Isolated cpuset CGroups - cpuset = self.controller('cpuset') + cpuset: Optional[Controller] = self.controller('cpuset') + if cpuset is None: + raise ValueError("cpuset is None") sbox_cg = cpuset.cgroup('/DEVLIB_SBOX') isol_cg = cpuset.cgroup('/DEVLIB_ISOL') @@ -518,7 +590,8 @@ def isolate(self, cpus, exclude=None): return sbox_cg, isol_cg - def freeze(self, exclude=None, thaw=False): + def freeze(self, exclude: Optional[List[str]] = None, + thaw: bool = False) -> Optional[Dict[int, Tuple[str, str]]]: """ Freeze all user-space tasks but the specified ones @@ -549,10 +622,11 @@ def freeze(self, exclude=None, thaw=False): if thaw: # Restart frozen tasks # pylint: disable=protected-access - freezer.target._execute_util(cmd.format('THAWED'), as_root=True) - # Remove all tasks from freezer - freezer.move_all_tasks_to('/') - return + if freezer.target: + freezer.target._execute_util(cmd.format('THAWED'), as_root=True) + # Remove all tasks from freezer + freezer.move_all_tasks_to('/') + return None # Move all tasks into the freezer group freezer.move_all_tasks_to('/DEVLIB_FREEZER', exclude) @@ -562,6 +636,7 @@ def freeze(self, exclude=None, thaw=False): # Freeze all tasks # pylint: disable=protected-access - freezer.target._execute_util(cmd.format('FROZEN'), as_root=True) + if freezer.target: + freezer.target._execute_util(cmd.format('FROZEN'), as_root=True) return tasks diff --git a/devlib/module/cgroups2.py b/devlib/module/cgroups2.py index 83cbf3948..bf4f937e1 100644 --- a/devlib/module/cgroups2.py +++ b/devlib/module/cgroups2.py @@ -1,4 +1,4 @@ -# Copyright 2022 ARM Limited +# Copyright 2022-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -101,7 +101,10 @@ from abc import ABC, abstractmethod from contextlib import ExitStack, contextmanager from shlex import quote -from typing import Dict, Set, List, Union, Any +from typing import (Dict, Set, List, Union, Any, + Tuple, cast, Callable, Optional, + Pattern, Generator, Match) +from contextlib import _GeneratorContextManager from uuid import uuid4 from devlib import LinuxTarget @@ -112,6 +115,9 @@ from devlib.target import FstabEntry from devlib.utils.misc import memoized +# dictionary type frequently being used in this module +ControllerDict = Dict[str, Dict[str, Union[str, int]]] + def _is_systemd_online(target: LinuxTarget): """ @@ -132,7 +138,7 @@ def _is_systemd_online(target: LinuxTarget): return True -def _read_lines(target: LinuxTarget, path: str): +def _read_lines(target: LinuxTarget, path: str) -> List[str]: """ Reads the lines of a file stored on the target device. @@ -169,7 +175,10 @@ def _add_controller_versions(controllers: Dict[str, Dict[str, int]]): # https://man7.org/linux/man-pages/man7/cgroups.7.html # (Under NOTES) [Dated 12/08/2022] - def infer_version(config): + def infer_version(config: Dict[str, int]) -> Optional[int]: + """ + determine the controller version + """ if config["hierarchy"] != 0: return 1 elif config["hierarchy"] == 0 and config["num_cgroups"] > 1: @@ -188,7 +197,7 @@ def infer_version(config): def _add_controller_mounts( controllers: Dict[str, Dict[str, int]], target_fs_list: List[FstabEntry] -): +) -> Dict[str, Dict[str, Union[str, int]]]: """ Find the CGroup controller's mount point and adds it as ``mount_point`` key. @@ -208,11 +217,14 @@ def _add_controller_mounts( """ # Filter the mounted filesystems on the target device, obtaining the respective V1/V2 FstabEntries. - v1_mounts = [fs for fs in target_fs_list if fs.fs_type == "cgroup"] - v2_mounts = [fs for fs in target_fs_list if fs.fs_type == "cgroup2"] + v1_mounts: List[FstabEntry] = [fs for fs in target_fs_list if fs.fs_type == "cgroup"] + v2_mounts: List[FstabEntry] = [fs for fs in target_fs_list if fs.fs_type == "cgroup2"] - def _infer_mount(controller: str, configuration: Dict): - controller_version = configuration.get("version") + def _infer_mount(controller: str, configuration: Dict[str, int]) -> Optional[str]: + """ + determine the controller mount point + """ + controller_version: Optional[int] = configuration.get("version") if controller_version == 1: for mount in v1_mounts: if controller in mount.options.strip().split(","): @@ -225,7 +237,7 @@ def _infer_mount(controller: str, configuration: Dict): return None - return { + return cast(Dict[str, Dict[str, Union[str, int]]], { controller: {**config, "mount_point": path if path is not None else config} for (controller, config, path) in ( ( @@ -235,10 +247,10 @@ def _infer_mount(controller: str, configuration: Dict): ) for (controller, config) in controllers.items() ) - } + }) -def _get_cgroup_controllers(target: LinuxTarget): +def _get_cgroup_controllers(target: LinuxTarget) -> ControllerDict: """ Returns the CGroup controllers that are currently enabled on the target device, alongside their appropriate configurations. @@ -257,19 +269,23 @@ def _get_cgroup_controllers(target: LinuxTarget): # #subsys_name hierarchy num_cgroups enabled # cpuset 3 1 1 - PROC_MOUNT_REGEX = re.compile( + PROC_MOUNT_REGEX: Pattern[str] = re.compile( r"^(?!#)(?P.+)\t(?P.+)\t(?P.+)\t(?P.+)" ) - proc_cgroup_file = _read_lines(target=target, path="/proc/cgroups") + proc_cgroup_file: List[str] = _read_lines(target=target, path="/proc/cgroups") - def _parse_controllers(controller): + def _parse_controllers(controller: str) -> Union[Tuple[str, Dict[str, int]], + Tuple[None, None]]: + """ + parse the controllers information from cgroups file + """ match = PROC_MOUNT_REGEX.match(controller.strip()) if match: - name = match.group("name") - enabled = int(match.group("enabled")) - hierarchy = int(match.group("hierarchy")) - num_cgroups = int(match.group("num_cgroups")) + name: str = match.group("name") + enabled: int = int(match.group("enabled")) + hierarchy: int = int(match.group("hierarchy")) + num_cgroups: int = int(match.group("num_cgroups")) # We should ignore disabled controllers. if enabled != 0: config = { @@ -279,9 +295,9 @@ def _parse_controllers(controller): return (name, config) return (None, None) - controllers = dict(map(_parse_controllers, proc_cgroup_file)) - controllers.pop(None) - controllers = _add_controller_versions(controllers=controllers) + controllers_temp = dict(map(_parse_controllers, proc_cgroup_file)) + controllers_temp.pop(None) + controllers = _add_controller_versions(controllers=cast(Dict[str, Dict[str, int]], controllers_temp)) controllers = _add_controller_mounts( controllers=controllers, target_fs_list=target.list_file_systems(), @@ -291,7 +307,7 @@ def _parse_controllers(controller): @contextmanager -def _request_delegation(target: LinuxTarget): +def _request_delegation(target: LinuxTarget) -> Generator[int]: """ Requests systemd to delegate a subtree CGroup hierarchy to our transient service unit. @@ -304,7 +320,7 @@ def _request_delegation(target: LinuxTarget): try: target.execute( 'systemd-run --no-block --property Delegate="yes" --unit {name} --quiet {busybox} sh -c "while true; do sleep 1d; done"'.format( - name=quote(service_name), busybox=quote(target.busybox) + name=quote(service_name), busybox=quote(target.busybox or '') ), as_root=True, ) @@ -326,7 +342,7 @@ def _request_delegation(target: LinuxTarget): @contextmanager -def _mount_v2_controllers(target: LinuxTarget): +def _mount_v2_controllers(target: LinuxTarget) -> Generator[str]: """ Mounts the V2 unified CGroup controller hierarchy. @@ -335,23 +351,22 @@ def _mount_v2_controllers(target: LinuxTarget): :yield: The path to the root of the mounted V2 controller hierarchy. :rtype: str - - :raises TargetStableError: Occurs in the case where the root directory of the requested CGroup V2 Controller hierarchy + + :raises TargetStableError: Occurs in the case where the root directory of the requested CGroup V2 Controller hierarchy is unable to be created up on the target system. """ - path = target.tempfile() - + path: str = target.tempfile() + try: target.makedirs(path, as_root=True) except TargetStableCalledProcessError: raise TargetStableError("Un-able to create the root directory of the requested CGroup V2 hierarchy") - - + try: target.execute( "{busybox} mount -t cgroup2 none {path}".format( - busybox=quote(target.busybox), path=quote(path) + busybox=quote(target.busybox or ''), path=quote(path) ), as_root=True, ) @@ -359,7 +374,7 @@ def _mount_v2_controllers(target: LinuxTarget): finally: target.execute( "{busybox} umount {path} && {busybox} rmdir -- {path}".format( - busybox=quote(target.busybox), + busybox=quote(target.busybox or ''), path=quote(path), ), as_root=True, @@ -367,7 +382,7 @@ def _mount_v2_controllers(target: LinuxTarget): @contextmanager -def _mount_v1_controllers(target: LinuxTarget, controllers: Set[str]): +def _mount_v1_controllers(target: LinuxTarget, controllers: Set[str]) -> Generator[Dict[str, str]]: """ Mounts the V1 split CGroup controller hierarchies. @@ -379,27 +394,27 @@ def _mount_v1_controllers(target: LinuxTarget, controllers: Set[str]): :yield: A dictionary mapping CGroup controller names to the paths that they're currently mounted at. :rtype: Dict[str,str] - - :raises TargetStableError: Occurs in the case where the root directory of a requested CGroup V1 Controller hierarchy + + :raises TargetStableError: Occurs in the case where the root directory of a requested CGroup V1 Controller hierarchy is unable to be created up on the target system. """ # Internal helper function which mounts a single V1 controller hierarchy and returns # its mount path. @contextmanager - def _mount_controller(controller): + def _mount_controller(controller: str) -> Generator[str]: + + path: str = target.tempfile() - path = target.tempfile() - try: target.makedirs(path, as_root=True) - except TargetStableCalledProcessError as err: - raise TargetStableError("Un-able to create the root directory of the {controller} CGroup V1 hierarchy".format(controller = controller)) + except TargetStableCalledProcessError: + raise TargetStableError("Un-able to create the root directory of the {controller} CGroup V1 hierarchy".format(controller=controller)) try: target.execute( "{busybox} mount -t cgroup -o {controller} none {path}".format( - busybox=quote(target.busybox), + busybox=quote(target.busybox or ''), controller=quote(controller), path=quote(path), ), @@ -410,7 +425,7 @@ def _mount_controller(controller): finally: target.execute( "{busybox} umount {path} && {busybox} rmdir -- {path}".format( - busybox=quote(target.busybox), + busybox=quote(target.busybox or ''), path=quote(path), ), as_root=True, @@ -424,8 +439,8 @@ def _mount_controller(controller): def _validate_requested_hierarchy( - requested_controllers: Set[str], available_controllers: Dict -): + requested_controllers: Set[str], available_controllers: Dict[str, Any] +) -> None: """ Validates that the requested hierarchy is valid using the controllers available on the target system. @@ -443,7 +458,7 @@ def _validate_requested_hierarchy( # Will determine if there are any controllers present within the requested controllers # and not within the available controllers - diff = set(requested_controllers) - available_controllers.keys() + diff: Set[str] = set(requested_controllers) - available_controllers.keys() if diff: raise TargetStableError( @@ -474,7 +489,7 @@ def __init__( self, name: str, parent_path: str, - active_controllers: Dict[str, Dict[str, str]], + active_controllers: ControllerDict, target: LinuxTarget, ): self.name = name @@ -483,12 +498,12 @@ def __init__( self._parent_path = parent_path @property - def group_path(self): + def group_path(self) -> str: return self.target.path.join(self._parent_path, self.name) def _set_controller_attribute( self, controller: str, attribute: str, value: Union[int, str], verify=False - ): + ) -> None: """ Writes the specified ``value`` into the interface file specified by the ``controller`` and ``attribute`` parameters. In the case where no ``controller`` name is specified, the ``attribute`` argument is assumed to be the name of the @@ -510,13 +525,13 @@ def _set_controller_attribute( str_value = str(value) # Some CGroup interface files don't have a controller name prefix, we accommodate that here. - interface_file = controller + "." + attribute if controller else attribute + interface_file: str = controller + "." + attribute if controller else attribute - full_path = self.target.path.join(self.group_path, interface_file) + full_path: str = self.target.path.join(self.group_path, interface_file) self.target.write_value(full_path, str_value, verify=verify) - def _create_directory(self, path: str): + def _create_directory(self, path: str) -> None: """ Creates a new directory at the given path, creating the parent directories if required. If the directory already exists, no exception is thrown. @@ -527,7 +542,7 @@ def _create_directory(self, path: str): self.target.makedirs(path, as_root=True) - def _delete_directory(self, path: str): + def _delete_directory(self, path: str) -> None: """ Removes the directory at the given path. @@ -539,12 +554,12 @@ def _delete_directory(self, path: str): # tries to delete the interface/controller files as well which isn't needed nor permitted. self.target.execute( "{busybox} rmdir -- {path}".format( - busybox=quote(self.target.busybox), path=quote(path) + busybox=quote(self.target.busybox or ''), path=quote(path) ), as_root=True, ) - def _add_process(self, pid: Union[str, int]): + def _add_process(self, pid: Union[str, int]) -> Optional[TargetStableError]: """ Adds the process associated with the ``pid`` to the CGroup, only if the process is not already a member of the CGroup. @@ -554,6 +569,7 @@ def _add_process(self, pid: Union[str, int]): """ if not self.target.file_exists(filepath="/proc/{pid}/status".format(pid=pid)): + # FIXME - is this return of the error intentional or was it meant to be raised return TargetStableError( "The Process ID: {pid} does not exists.".format(pid=pid) ) @@ -569,12 +585,13 @@ def _add_process(self, pid: Union[str, int]): ) except TargetStableError: self._set_controller_attribute("cgroup", "procs", pid) - + else: if str(pid) not in member_processes: self._set_controller_attribute("cgroup", "procs", pid) + return None - def _get_pid_from_tid(self, tid: int): + def _get_pid_from_tid(self, tid: int) -> int: """ Retrieves the ``pid`` (Process ID) that the ``tid`` (Thread ID) is a part of. @@ -584,15 +601,15 @@ def _get_pid_from_tid(self, tid: int): :return: The ``pid`` (Process ID) associated with the ``tid`` (Thread ID). :rtype: int """ - status = _read_lines( + status: List[str] = _read_lines( target=self.target, path="/proc/{tid}/status".format(tid=tid) ) for line in status: # the Tgid entry contains the thread group ID, which is the PID of # the process this thread belongs to. - match = re.match(r"\s*Tgid:\s*(\d+)\s*", line) + match: Optional[Match] = re.match(r"\s*Tgid:\s*(\d+)\s*", line) if match: - pid = match.group(1) + pid: str = match.group(1) break else: raise TargetStableError( @@ -602,7 +619,7 @@ def _get_pid_from_tid(self, tid: int): return int(pid) @abstractmethod - def _add_thread(self, tid: int, threaded_domain): + def _add_thread(self, tid: int, threaded_domain: Union['ResponseTree', '_TreeBase']): """ Ensures all sub-classes have the ability to add threads to their CGroups where their differences dont allow for a common approach. @@ -610,7 +627,7 @@ def _add_thread(self, tid: int, threaded_domain): pass @abstractmethod - def _init_cgroup(self): + def _init_cgroup(self) -> None: """ Ensures all sub-classes are able to initialise their respective CGroup directories as per defined by their user configurations. @@ -662,8 +679,8 @@ def __init__( self, name: str, parent_path: str, - active_controllers: Dict[str, Dict[str, str]], - subtree_controllers: set, + active_controllers: ControllerDict, + subtree_controllers: set[str], is_threaded: bool, target: LinuxTarget, ): @@ -701,7 +718,7 @@ def __enter__(self): def __exit__(self, *exc): self._delete_directory(path=self.group_path) - def _init_cgroup(self): + def _init_cgroup(self) -> None: """ Performs the required steps in order to initialize the CGroup to the user specified configuration: @@ -742,7 +759,7 @@ def _init_cgroup(self): value="+{cont}".format(cont=controller), ) - def _add_thread(self, tid: int, threaded_domain): + def _add_thread(self, tid: int, threaded_domain: Union['ResponseTree', '_TreeBase']) -> None: """ Attempts to add the thread associated with ``tid`` to the CGroup. Due to the requirements imposed by the kernel regarding thread management within a V2 CGroup hierarchy, @@ -760,9 +777,9 @@ def _add_thread(self, tid: int, threaded_domain): :type threaded_domain: :class:`ResponseTree` """ - pid_of_tid = self._get_pid_from_tid(tid=tid) + pid_of_tid: int = self._get_pid_from_tid(tid=tid) - for low_level in threaded_domain.low_levels.values(): + for low_level in cast(ResponseTree, threaded_domain).low_levels.values(): low_level._add_process(pid_of_tid) self._set_controller_attribute( @@ -788,8 +805,8 @@ class _CGroupV2Root(_CGroupV2): @classmethod def _v2_controller_translation( - cls, controllers: Dict[str, Dict[str, Union[str, int]]] - ): + cls, controllers: ControllerDict + ) -> ControllerDict: """ Given the new controller names within V2, rename the controllers to provide CGroupV2 compatibility. At this point in time, the ``blkio`` controller has been renamed to ``io`` in V2, while the V2 ``cpu`` controller @@ -809,7 +826,7 @@ def _v2_controller_translation( :rtype: Dict[str, Dict[str, Union[str,int]]] """ - translation = {} + translation: ControllerDict = {} if "blkio" in controllers: translation["io"] = controllers["blkio"] @@ -839,7 +856,7 @@ def _v2_controller_translation( } @classmethod - def _get_delegated_sub_path(cls, delegated_pid: int, target: LinuxTarget): + def _get_delegated_sub_path(cls, delegated_pid: int, target: LinuxTarget) -> Optional[str]: """ Returns the relative sub-path the delegated root of the V2 hierarchy is mounted on, via the parsing of the /proc//cgroup file of the delegated process associated with ``delegated_pid``. @@ -854,26 +871,27 @@ def _get_delegated_sub_path(cls, delegated_pid: int, target: LinuxTarget): :rtype: str """ - relative_delegated_mount_paths = _read_lines( + relative_delegated_mount_paths: List[str] = _read_lines( target=target, path="/proc/{pid}/cgroup".format(pid=delegated_pid) ) # Following Regex matches the line that contains the relative sub path. - REL_PATH_REGEX = re.compile(r"0::\/(?P.+)") + REL_PATH_REGEX: Pattern[str] = re.compile(r"0::\/(?P.+)") for mount_path in relative_delegated_mount_paths: - m = REL_PATH_REGEX.match(mount_path) + m: Optional[Match[str]] = REL_PATH_REGEX.match(mount_path) if m: return m.group("path") else: raise TargetStableError( "A V2 CGroup hierarchy was not delegated by systemd." ) + return None @classmethod def _get_available_controllers( - cls, controllers: Dict[str, Dict[str, Union[str, int]]] - ): + cls, controllers: ControllerDict + ) -> ControllerDict: """ Returns the CGroup controllers that are currently not in use on the target device, which can be taken control over and used in a manually mounted V2 hierarchy. @@ -892,7 +910,7 @@ def _get_available_controllers( """ # Filters the controllers dict to entries where the version is == 2. - mounted_v2_controllers = { + mounted_v2_controllers: Set[str] = { controller for controller, configuration in controllers.items() if (configuration.get("version") == 2) @@ -911,8 +929,8 @@ def _get_available_controllers( @classmethod def _path_to_delegated_root( - cls, controllers: Dict[str, Dict[str, Union[int, str]]], sub_path: str - ): + cls, controllers: ControllerDict, sub_path: str + ) -> str: """ Return the full path to the delegated root. This occurs in 2 stages: @@ -938,7 +956,7 @@ def _path_to_delegated_root( """ # Filter out non v2 controller mounts and append the "mount_point" to a set - v2_mount_point = { + v2_mount_point: Set[Union[str, int]] = { configuration["mount_point"] for configuration in controllers.values() if configuration.get("version") == 2 @@ -949,17 +967,17 @@ def _path_to_delegated_root( ) else: # Since there can only be a single V2 hierarchy (ignoring bind mounts), this should be totally legal. - mount_path_to_unified_hierarchy = v2_mount_point.pop() - return str(os.path.join(mount_path_to_unified_hierarchy, sub_path)) + mount_path_to_unified_hierarchy: Union[str, int] = v2_mount_point.pop() + return str(os.path.join(cast(str, mount_path_to_unified_hierarchy), sub_path)) @classmethod @contextmanager def _systemd_offline_mount( cls, target: LinuxTarget, - all_controllers: Dict[str, Dict[str, Union[str, int]]], + all_controllers: ControllerDict, requested_controllers: Set[str], - ): + ) -> Generator[str]: """ Manually mounts the V2 hierarchy on the target device. Occurs in the absence of systemd. @@ -978,7 +996,7 @@ def _systemd_offline_mount( :rtype: str """ - unused_controllers = _CGroupV2Root._get_available_controllers( + unused_controllers: ControllerDict = _CGroupV2Root._get_available_controllers( controllers=all_controllers ) _validate_requested_hierarchy( @@ -994,9 +1012,9 @@ def _systemd_offline_mount( def _systemd_online_setup( cls, target: LinuxTarget, - all_controllers: Dict[str, Dict[str, int]], + all_controllers: ControllerDict, requested_controllers: Set[str], - ): + ) -> Generator[str]: """ Sets up the required V2 hierarchy on the target device. Occurs in the presence of systemd. @@ -1015,15 +1033,15 @@ def _systemd_online_setup( :rtype: str """ with _request_delegation(target=target) as main_pid: - delegated_sub_path = _CGroupV2Root._get_delegated_sub_path( + delegated_sub_path: Optional[str] = _CGroupV2Root._get_delegated_sub_path( delegated_pid=main_pid, target=target ) - delegated_path = _CGroupV2Root._path_to_delegated_root( + delegated_path: str = _CGroupV2Root._path_to_delegated_root( controllers=all_controllers, - sub_path=delegated_sub_path, + sub_path=cast(str, delegated_sub_path), ) - delegated_controllers_path = "{path}/cgroup.controllers".format( + delegated_controllers_path: str = "{path}/cgroup.controllers".format( path=delegated_path ) @@ -1033,7 +1051,7 @@ def _systemd_online_setup( # by _read_file and splitting said element (str) using the white space character # as the delimiter. # (The _validate_requested_hierarchy requires the available_controllers argument to be a dict, necessitating this dict structure.) - delegated_controllers = { + delegated_controllers: Dict[str, None] = { controller: None for controller in _read_lines( target=target, path=delegated_controllers_path @@ -1048,7 +1066,7 @@ def _systemd_online_setup( @classmethod @contextmanager - def _mount_filesystem(cls, target: LinuxTarget, requested_controllers: Set[str]): + def _mount_filesystem(cls, target: LinuxTarget, requested_controllers: Set[str]) -> Generator[str]: """ Mounts/Sets-up a V2 hierarchy on the target device, covering contexts where systemd is both present and absent. @@ -1063,13 +1081,13 @@ def _mount_filesystem(cls, target: LinuxTarget, requested_controllers: Set[str]) :rtype: str """ - systemd_online = _is_systemd_online(target=target) - controllers = _CGroupV2Root._v2_controller_translation( + systemd_online: bool = _is_systemd_online(target=target) + controllers: ControllerDict = _CGroupV2Root._v2_controller_translation( _get_cgroup_controllers(target=target) ) if systemd_online: - cm = _CGroupV2Root._systemd_online_setup( + cm: _GeneratorContextManager[str] = _CGroupV2Root._systemd_online_setup( target=target, all_controllers=controllers, requested_controllers=requested_controllers, @@ -1089,7 +1107,7 @@ def _mount_filesystem(cls, target: LinuxTarget, requested_controllers: Set[str]) def __init__( self, mount_point: str, - subtree_controllers: set, + subtree_controllers: set[str], target: LinuxTarget, ): @@ -1103,7 +1121,7 @@ def __init__( is_threaded=False, target=target, ) - self.target = target + self.target: LinuxTarget = target def __enter__(self): """ @@ -1128,7 +1146,7 @@ def __enter__(self): def __exit__(self, *exc): pass - def _init_root_cgroup(self): + def _init_root_cgroup(self) -> None: """ Performs the required actions in order to initialise a Root V2 CGroup. In the case where systemd is active, there is a required need to create a leaf CGroup from the Root, where the PIDs @@ -1139,11 +1157,11 @@ def _init_root_cgroup(self): if _is_systemd_online(target=self.target): # Create the leaf CGroup directory - group_name = "devlib-" + str(uuid4().hex) - full_path = self.target.path.join(self.group_path, group_name) + group_name: str = "devlib-" + str(uuid4().hex) + full_path: str = self.target.path.join(self.group_path, group_name) self._create_directory(full_path) - delegated_pids = _read_lines( + delegated_pids: List[str] = _read_lines( target=self.target, path="{path}/cgroup.procs".format(path=self.group_path), ) @@ -1224,7 +1242,7 @@ def _init_cgroup(self): controller=controller, attribute=attr, value=val, verify=True ) - def _add_thread(self, tid: int, threaded_domain): + def _add_thread(self, tid: int, threaded_domain: Union['ResponseTree', '_TreeBase']) -> None: """ Adds the thread associated with ``tid`` to the CGroup. While thread level management suffers from no restrictions within a V1 hierarchy, @@ -1243,9 +1261,9 @@ def _add_thread(self, tid: int, threaded_domain): :type threaded_domain: :class:`ResponseTree` """ - pid_of_tid = self._get_pid_from_tid(tid=tid) + pid_of_tid: int = self._get_pid_from_tid(tid=tid) - for low_level in threaded_domain.low_levels.values(): + for low_level in cast(ResponseTree, threaded_domain).low_levels.values(): low_level._add_process(pid_of_tid) self._set_controller_attribute("", "tasks", tid) @@ -1267,10 +1285,10 @@ class _CGroupV1Root(_CGroupV1): @classmethod def _get_delegated_paths( cls, - controllers: Dict[str, Dict[str, Union[str, int]]], + controllers: ControllerDict, delegated_pid: int, target: LinuxTarget, - ): + ) -> Dict[str, str]: """ Returns the relative sub-paths the delegated roots of the V1 hierarchies, via the parsing of the /proc//cgroup file of the delegated PID. @@ -1292,7 +1310,7 @@ def _get_delegated_paths( :rtype: Dict[str, str] """ - delegated_mount_paths = _read_lines( + delegated_mount_paths: List[str] = _read_lines( target=target, path="/proc/{pid}/cgroup".format(pid=delegated_pid) ) @@ -1303,22 +1321,22 @@ def _get_delegated_paths( # # The regex is structured to only match V1 controller hierarchies. - REL_PATH_REGEX = re.compile( + REL_PATH_REGEX: Pattern[str] = re.compile( r"\d+:(?P.+):\/(?P.*)" ) - delegated_controllers = {} + delegated_controllers: Dict[str, str] = {} for mount_path in delegated_mount_paths: - regex_match = REL_PATH_REGEX.match(mount_path) + regex_match: Optional[Match[str]] = REL_PATH_REGEX.match(mount_path) if regex_match: - con = regex_match.group("controllers") - path = regex_match.group("path_to_delegated_service_root") + con: str = regex_match.group("controllers") + path: str = regex_match.group("path_to_delegated_service_root") # Multiple v1 controllers can be co-mounted on a single folder hierarchy. - co_mounted_controllers = con.strip().split(",") + co_mounted_controllers: List[str] = con.strip().split(",") for controller in co_mounted_controllers: try: - configuration = controllers[controller] + configuration: Dict[str, Union[str, int]] = controllers[controller] except KeyError: pass else: @@ -1338,7 +1356,7 @@ def _get_delegated_paths( def _systemd_offline_mount( cls, requested_controllers: Set[str], - all_controllers: Dict[str, Dict[str, Union[str, int]]], + all_controllers: ControllerDict, target: LinuxTarget, ): """ @@ -1359,7 +1377,7 @@ def _systemd_offline_mount( :rtype: Dict[str,str] """ - available_controllers = _CGroupV1Root._get_available_v1_controllers( + available_controllers: ControllerDict = _CGroupV1Root._get_available_v1_controllers( controllers=all_controllers ) _validate_requested_hierarchy( @@ -1373,10 +1391,13 @@ def _systemd_offline_mount( @classmethod def _get_available_v1_controllers( - cls, controllers: Dict[str, Dict[str, Union[int, str]]] - ): + cls, controllers: ControllerDict + ) -> ControllerDict: + """ + helper function to get the available v1 controllers + """ - unused_controllers = { + unused_controllers: ControllerDict = { controller: configuration for controller, configuration in controllers.items() if configuration.get("version") is None @@ -1393,8 +1414,8 @@ def _systemd_online_setup( cls, target: LinuxTarget, requested_controllers: Set[str], - all_controllers: Dict[str, Dict[str, str]], - ): + all_controllers: ControllerDict, + ) -> Generator[Dict[str, str]]: """ Sets up the required V1 hierarchy on the target device. Occurs in the presence of systemd. @@ -1414,7 +1435,7 @@ def _systemd_online_setup( """ with _request_delegation(target) as pid: - delegated_controllers = _CGroupV1Root._get_delegated_paths( + delegated_controllers: Dict[str, str] = _CGroupV1Root._get_delegated_paths( controllers=all_controllers, delegated_pid=pid, target=target, @@ -1428,7 +1449,7 @@ def _systemd_online_setup( @classmethod @contextmanager - def _mount_filesystem(cls, target: LinuxTarget, requested_controllers: Set[str]): + def _mount_filesystem(cls, target: LinuxTarget, requested_controllers: Set[str]) -> Generator[Dict[str, str]]: """ A context manager which Mounts/Sets-up a V1 split hierarchy on the target device, covering contexts where systemd is both present and absent. This context manager Mounts/Sets-up a split V1 hierarchy (if possible) @@ -1445,17 +1466,17 @@ def _mount_filesystem(cls, target: LinuxTarget, requested_controllers: Set[str]) :rtype: dict[str,str] """ - systemd_online = _is_systemd_online(target=target) - controllers = _get_cgroup_controllers(target=target) + systemd_online: bool = _is_systemd_online(target=target) + controllers: ControllerDict = _get_cgroup_controllers(target=target) if systemd_online: - cm = _CGroupV1Root._systemd_online_setup( + cm: _GeneratorContextManager[Dict[str, str]] = _CGroupV1Root._systemd_online_setup( target=target, requested_controllers=requested_controllers, all_controllers=controllers, ) - with cm as controllers: - yield controllers + with cm as controllers_temp: + yield controllers_temp else: cm = _CGroupV1Root._systemd_offline_mount( @@ -1463,8 +1484,8 @@ def _mount_filesystem(cls, target: LinuxTarget, requested_controllers: Set[str]) requested_controllers=requested_controllers, all_controllers=controllers, ) - with cm as controllers: - yield controllers + with cm as controllers_temp: + yield controllers_temp def __init__(self, mount_point: str, target: LinuxTarget): @@ -1510,14 +1531,14 @@ def __init__(self, name: str, is_threaded: bool): # Propagates Threaded Property to # sub-tree. - def make_threaded(grp): + def make_threaded(grp: '_TreeBase'): grp.is_threaded = True for child in grp._children_list: make_threaded(child) # Propagates the Threaded domain # to sub-tree. - def set_domain(grp): + def set_domain(grp: '_TreeBase'): grp.threaded_domain = domain for child in grp._children_list: set_domain(child) @@ -1525,14 +1546,18 @@ def set_domain(grp): if is_threaded: make_threaded(self) else: - domain = self - if any([child.is_threaded for child in self._children_list]): - for child in self._children_list: - make_threaded(child) - set_domain(child) + domain: '_TreeBase' = self + if self._children_list: + if any([child.is_threaded for child in self._children_list]): + for child in self._children_list: + make_threaded(child) + set_domain(child) @property - def is_threaded_domain(self): + def is_threaded_domain(self) -> bool: + """ + check if the is_threaded property is set in the domain + """ return ( True if any([child.is_threaded for child in self._children_list]) @@ -1542,7 +1567,10 @@ def is_threaded_domain(self): @property @memoized - def group_type(self): + def group_type(self) -> str: + """ + get the type of the group + """ if self.is_threaded_domain: return "threaded domain" elif self.is_threaded: @@ -1578,7 +1606,7 @@ def __str__(self, level=0): @property @abstractmethod - def _node_information(self): + def _node_information(self) -> str: """ Returns a formatted string displaying the information the :class:`_TreeBase` object represents. """ @@ -1586,7 +1614,7 @@ def _node_information(self): @property @abstractmethod - def _children_list(self): + def _children_list(self) -> List['_TreeBase']: """ Returns List[:class:`_TreeBase`]. """ @@ -1618,18 +1646,18 @@ class RequestTree(_TreeBase): def __init__( self, name: str, - children: Union[list, None] = None, - controllers: Union[Dict[str, Dict[str, Any]], None] = None, - threaded=False, + children: Union[list['RequestTree'], None] = None, + controllers: Optional[ControllerDict] = None, + threaded: bool = False, ): self.children = children or [] self.controllers = controllers or {} super().__init__(name=name, is_threaded=threaded) @property - def _node_information(self): + def _node_information(self) -> str: # Returns Requests Tree Node Information. - active_controllers = [ + active_controllers: List[str] = [ "({controller}) {config}".format( controller=controller, config=configuration ) @@ -1642,8 +1670,10 @@ def _node_information(self): @property @memoized - def _all_controllers(self): - # Returns a set of all the controllers that are active in that subtree, including its own. + def _all_controllers(self) -> Set[str]: + """ + Returns a set of all the controllers that are active in that subtree, including its own. + """ return set( itertools.chain( self.controllers.keys(), @@ -1654,8 +1684,10 @@ def _all_controllers(self): ) @property - def _subtree_controllers(self): - # Returns a set of all the controllers that are active in that subtree, excluding its own. + def _subtree_controllers(self) -> Set[str]: + """ + Returns a set of all the controllers that are active in that subtree, excluding its own. + """ return set( itertools.chain.from_iterable( map(lambda child: child._all_controllers, self.children) @@ -1663,11 +1695,11 @@ def _subtree_controllers(self): ) @property - def _children_list(self): + def _children_list(self) -> List['_TreeBase']: return list(self.children) @contextmanager - def setup_hierarchy(self, version: int, target: LinuxTarget): + def setup_hierarchy(self, version: int, target: LinuxTarget) -> Generator['ResponseTree']: """ A context manager which processes the user defined hierarchy and sets-up said hierarchy on the ``target`` device. Uses an internal exit stack to the handle the entering and safe exiting of the lower level @@ -1689,15 +1721,16 @@ def setup_hierarchy(self, version: int, target: LinuxTarget): """ with ExitStack() as exit_stack: + make_groups: Callable if version == 1: # Returns a {controller_name: controller_mount_point} dict - controller_paths = exit_stack.enter_context( + controller_paths: Dict[str, str] = exit_stack.enter_context( _CGroupV1Root._mount_filesystem( target=target, requested_controllers=self._all_controllers ) ) # Mounts the Roots Controller Parents. - root_parents = { + root_parents: Union[Dict[str, _CGroupV1], _CGroupV2] = { controller: _CGroupV1Root( mount_point=mount_path, target=target, @@ -1706,7 +1739,7 @@ def setup_hierarchy(self, version: int, target: LinuxTarget): if controller in self._all_controllers } - def make_groups(request: RequestTree, parents: Dict[str, _CGroupBase]): + def make_groups_v1(request: RequestTree, parents: Dict[str, _CGroupBase]): """ Defines and instantiates the low-level :class:`_CGroupV1` objects as per defined by the configuration of the ``request`` :class:`RequestTree` object. @@ -1733,7 +1766,7 @@ def make_groups(request: RequestTree, parents: Dict[str, _CGroupBase]): :rtype: tuple(Dict[str,:class:`_CGroupV1`], Dict[str,:class:`_CGroupV1`], Dict[str,:class:`_CGroupV1`]) """ - request_defined_cgroups = { + request_defined_cgroups: Dict[str, _CGroupV1] = { controller: _CGroupV1( name=request.name, parent_path=parents[controller].group_path, @@ -1745,9 +1778,11 @@ def make_groups(request: RequestTree, parents: Dict[str, _CGroupBase]): # Parent dict updated to include the newly created leaf CGroups. parents = {**parents, **request_defined_cgroups} - all_cgroups = parents + all_cgroups: Dict[str, _CGroupBase] = parents return (request_defined_cgroups, all_cgroups, parents) + make_groups = make_groups_v1 + elif version == 2: # Returns a string representing the root of the V2 hierarchy @@ -1767,7 +1802,7 @@ def make_groups(request: RequestTree, parents: Dict[str, _CGroupBase]): # root CGroup setup defined within the __enter__ method. exit_stack.enter_context(root_parents) - def make_groups(request: RequestTree, parent: _CGroupV2): + def make_groups_v2(request: RequestTree, parent: _CGroupV2): """ Defines and instantiates the low-level :class:`_CGroupV2` object as per defined by the configuration of the ``request`` :class:`RequestTree` object. The parents of said :class:`_CGroupV2` object @@ -1802,13 +1837,15 @@ def make_groups(request: RequestTree, parent: _CGroupV2): # Creates a mapping between the enabled controllers within this CGroup to the low-level # _CGroupV2 object - controllers_to_cgroup = dict.fromkeys( + controllers_to_cgroup: Dict[str, _CGroupV2] = dict.fromkeys( request.controllers, request_group ) # Creating 'parent' variable for readability’s sake. parent = request_group return (controllers_to_cgroup, controllers_to_cgroup, parent) + make_groups = make_groups_v2 + else: raise TargetStableError( "A {version} version hierarchy cannot be mounted. Ensure requested hierarchy version is 1 or 2.".format( @@ -1817,9 +1854,9 @@ def make_groups(request: RequestTree, parent: _CGroupV2): ) # Create the Response Tree from the Request Tree. - response = self._create_response(root_parents, make_groups=make_groups) + response: 'ResponseTree' = self._create_response(root_parents, make_groups=make_groups) # Returns a list of all the Low-level _CGroupBase objects the response object represents in the right order - groups = response._all_nodes + groups: List[_CGroupBase] = response._all_nodes # Remove duplicates while preserving order. groups = sorted(set(groups), key=groups.index) # Enter the context for each object @@ -1828,7 +1865,7 @@ def make_groups(request: RequestTree, parent: _CGroupV2): yield response - def _create_response(self, low_level_parent, make_groups): + def _create_response(self, low_level_parent: Union[Dict[str, _CGroupV1], _CGroupV2], make_groups) -> 'ResponseTree': """ Creates the :class:`ResponseTree` object tree, using the appropriately defined :meth:`make_group` callable (defined as a local function internally within :meth:`setup_hierarchy`) alongside the ``low_level_parent`` object to create the low-level CGroups a particular :class:`RequestTree` object represents. @@ -1911,7 +1948,7 @@ def __init__( super().__init__(name=name, is_threaded=is_threaded) @property - def _node_information(self): + def _node_information(self) -> str: # Returns a formatted string, displaying the enabled user-defined controllers and their paths # (alongside the type of CGroup the controller resides in). return ", ".join( @@ -1924,22 +1961,22 @@ def _node_information(self): ) @property - def _children_list(self): + def _children_list(self) -> List['_TreeBase']: # Children Objects are the values in our self.children dict. return list(self.children.values()) @property - def _all_nodes(self): + def _all_nodes(self) -> List[_CGroupBase]: return list( itertools.chain( self.low_levels.values(), itertools.chain.from_iterable( - map(lambda child: child._all_nodes, self.children.values()), + map(lambda child: cast(ResponseTree, child)._all_nodes, self.children.values()), ), ) ) - def add_process(self, pid: int): + def add_process(self, pid: int) -> None: """ Adds the process associated with ``pid`` to the low level CGroups this :class:`ResponseTree` object represents. @@ -1960,7 +1997,7 @@ def add_process(self, pid: int): ) ) - def add_thread(self, tid: int): + def add_thread(self, tid: int) -> None: """ Adds the thread associated with the ``tid`` to the low level CGroups this :class:`ResponseTree` object represents. diff --git a/devlib/module/cooling.py b/devlib/module/cooling.py index 413360a7c..39c43aeb1 100644 --- a/devlib/module/cooling.py +++ b/devlib/module/cooling.py @@ -1,4 +1,4 @@ -# Copyright 2014-2015 ARM Limited +# Copyright 2014-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,50 +16,68 @@ from devlib.module import Module from devlib.utils.serial_port import open_serial_connection +from typing import TYPE_CHECKING, cast +from pexpect import fdpexpect +if TYPE_CHECKING: + from devlib.target import Target class MbedFanActiveCoolingModule(Module): - - name = 'mbed-fan' - timeout = 30 + """ + Module to control active cooling using fan + """ + name: str = 'mbed-fan' + timeout: int = 30 @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: return True - def __init__(self, target, port='/dev/ttyACM0', baud=115200, fan_pin=0): + def __init__(self, target: 'Target', port: str = '/dev/ttyACM0', baud: int = 115200, fan_pin: int = 0): super(MbedFanActiveCoolingModule, self).__init__(target) self.port = port self.baud = baud self.fan_pin = fan_pin - def start(self): + def start(self) -> None: + """ + send motor start to fan + """ with open_serial_connection(timeout=self.timeout, port=self.port, baudrate=self.baud) as target: # pylint: disable=no-member - target.sendline('motor_{}_1'.format(self.fan_pin)) + cast(fdpexpect.fdspawn, target).sendline('motor_{}_1'.format(self.fan_pin)) - def stop(self): + def stop(self) -> None: + """ + send motor stop to fan + """ with open_serial_connection(timeout=self.timeout, port=self.port, baudrate=self.baud) as target: # pylint: disable=no-member - target.sendline('motor_{}_0'.format(self.fan_pin)) + cast(fdpexpect.fdspawn, target).sendline('motor_{}_0'.format(self.fan_pin)) class OdroidXU3ctiveCoolingModule(Module): - name = 'odroidxu3-fan' + name: str = 'odroidxu3-fan' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: return target.file_exists('/sys/devices/odroid_fan.15/fan_mode') - def start(self): + def start(self) -> None: + """ + start fan + """ self.target.write_value('/sys/devices/odroid_fan.15/fan_mode', 0, verify=False) self.target.write_value('/sys/devices/odroid_fan.15/pwm_duty', 255, verify=False) def stop(self): + """ + stop fan + """ self.target.write_value('/sys/devices/odroid_fan.15/fan_mode', 0, verify=False) self.target.write_value('/sys/devices/odroid_fan.15/pwm_duty', 1, verify=False) diff --git a/devlib/module/cpufreq.py b/devlib/module/cpufreq.py index 2640a9a8e..b8536b59c 100644 --- a/devlib/module/cpufreq.py +++ b/devlib/module/cpufreq.py @@ -1,4 +1,4 @@ -# Copyright 2014-2024 ARM Limited +# Copyright 2014-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,22 +17,38 @@ from devlib.exception import TargetStableError from devlib.utils.misc import memoized import devlib.utils.asyn as asyn - +from typing import (TYPE_CHECKING, Dict, List, Tuple, Union, + cast, Optional, AsyncGenerator, Any, Set, + Coroutine) +if TYPE_CHECKING: + from devlib.target import Target # a dict of governor name and a list of it tunables that can't be read -WRITE_ONLY_TUNABLES = { +WRITE_ONLY_TUNABLES: Dict[str, List[str]] = { 'interactive': ['boostpulse'] } class CpufreqModule(Module): - - name = 'cpufreq' + """ + ``cpufreq`` is the kernel subsystem for managing DVFS (Dynamic Voltage and + Frequency Scaling). It allows controlling frequency ranges and switching + policies (governors). The ``devlib`` module exposes the following interface + + .. note:: On ARM big.LITTLE systems, all cores on a cluster (usually all cores + of the same type) are in the same frequency domain, so setting + ``cpufreq`` state on one core on a cluster will affect all cores on + that cluster. Because of this, some devices only expose cpufreq sysfs + interface (which is what is used by the ``devlib`` module) on the + first cpu in a cluster. So to keep your scripts portable, always use + the fist (online) CPU in a cluster to set ``cpufreq`` state. + """ + name: str = 'cpufreq' @staticmethod @asyn.asyncf - async def probe(target): - paths = [ + async def probe(target: 'Target') -> bool: + paths_tmp: List[Tuple[bool, str]] = [ # x86 with Intel P-State driver (target.abi == 'x86_64', '/sys/devices/system/cpu/intel_pstate'), # Generic CPUFreq support (single policy) @@ -40,8 +56,8 @@ async def probe(target): # Generic CPUFreq support (per CPU policy) (True, '/sys/devices/system/cpu/cpu0/cpufreq'), ] - paths = [ - path[1] for path in paths + paths: List[str] = [ + path[1] for path in paths_tmp if path[0] ] @@ -52,38 +68,48 @@ async def probe(target): return any(exists.values()) - def __init__(self, target): + def __init__(self, target: 'Target'): super(CpufreqModule, self).__init__(target) - self._governor_tunables = {} + self._governor_tunables: Dict[str, Tuple[str, bool, List[str]]] = {} @asyn.asyncf @asyn.memoized_method - async def list_governors(self, cpu): - """Returns a list of governors supported by the cpu.""" + async def list_governors(self, cpu: Union[int, str]) -> List[str]: + """List cpufreq governors available for the specified cpu. Returns a list of + strings. + + :param cpu: The cpu; could be a numeric or the corresponding string (e.g. + ``1`` or ``"cpu1"``). + """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_available_governors'.format(cpu) - output = await self.target.read_value.asyn(sysfile) + sysfile: str = '/sys/devices/system/cpu/{}/cpufreq/scaling_available_governors'.format(cpu) + output: str = await self.target.read_value.asyn(sysfile) return output.strip().split() @asyn.asyncf - async def get_governor(self, cpu): - """Returns the governor currently set for the specified CPU.""" + async def get_governor(self, cpu: Union[int, str]) -> str: + """ + Returns the name of the currently set governor for the specified cpu. + + :param cpu: The cpu; could be a numeric or the corresponding string (e.g. + ``1`` or ``"cpu1"``). + """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_governor'.format(cpu) + sysfile: str = '/sys/devices/system/cpu/{}/cpufreq/scaling_governor'.format(cpu) return await self.target.read_value.asyn(sysfile) @asyn.asyncf - async def set_governor(self, cpu, governor, **kwargs): + async def set_governor(self, cpu: Union[int, str], governor: str, **kwargs) -> None: """ Set the governor for the specified CPU. See https://www.kernel.org/doc/Documentation/cpu-freq/governors.txt - :param cpu: The CPU for which the governor is to be set. This must be - the full name as it appears in sysfs, e.g. "cpu0". + :param cpu: The CPU for which the governor is to be set. It could be a numeric or the corresponding string (e.g. + ``1`` or ``"cpu1"``). :param governor: The name of the governor to be used. This must be - supported by the specific device. + supported by the specific device (as returned by ``list_governors()``. Additional keyword arguments can be used to specify governor tunables for governors that support them. @@ -98,7 +124,7 @@ async def set_governor(self, cpu, governor, **kwargs): """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - supported = await self.list_governors.asyn(cpu) + supported: List[str] = await self.list_governors.asyn(cpu) if governor not in supported: raise TargetStableError('Governor {} not supported for cpu {}'.format(governor, cpu)) sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_governor'.format(cpu) @@ -106,7 +132,7 @@ async def set_governor(self, cpu, governor, **kwargs): await self.set_governor_tunables.asyn(cpu, governor, **kwargs) @asyn.asynccontextmanager - async def use_governor(self, governor, cpus=None, **kwargs): + async def use_governor(self, governor: str, cpus: Optional[List[str]] = None, **kwargs) -> AsyncGenerator: """ Use a given governor, then restore previous governor(s) @@ -121,7 +147,7 @@ async def use_governor(self, governor, cpus=None, **kwargs): if not cpus: cpus = await self.target.list_online_cpus.asyn() - async def get_cpu_info(cpu): + async def get_cpu_info(cpu) -> List[Any]: return await self.target.async_manager.concurrently(( self.get_affected_cpus.asyn(cpu), self.get_governor.asyn(cpu), @@ -131,12 +157,12 @@ async def get_cpu_info(cpu): self.get_frequency.asyn(cpu), )) - cpus_infos = await self.target.async_manager.map_concurrently(get_cpu_info, cpus) + cpus_infos: Dict[int, List[Any]] = await self.target.async_manager.map_concurrently(get_cpu_info, cpus) # Setting a governor & tunables for a cpu will set them for all cpus in # the same cpufreq policy, so only manipulating one cpu per domain is # enough - domains = set( + domains: Set[Any] = set( info[0][0] for info in cpus_infos.values() ) @@ -149,7 +175,7 @@ async def get_cpu_info(cpu): try: yield finally: - async def set_per_cpu_tunables(cpu): + async def set_per_cpu_tunables(cpu: int) -> None: domain, prev_gov, tunables, freq = cpus_infos[cpu] # Per-cpu tunables are safe to set concurrently await self.set_governor_tunables.asyn(cpu, prev_gov, per_cpu=True, **tunables) @@ -157,7 +183,7 @@ async def set_per_cpu_tunables(cpu): if prev_gov == "userspace": await self.set_frequency.asyn(cpu, freq) - per_cpu_tunables = self.target.async_manager.concurrently( + per_cpu_tunables: Coroutine = self.target.async_manager.concurrently( set_per_cpu_tunables(cpu) for cpu in domains ) @@ -165,14 +191,14 @@ async def set_per_cpu_tunables(cpu): # Non-per-cpu tunables have to be set one after the other, for each # governor that we had to deal with. - global_tunables = { + global_tunables_dict: Dict[str, Tuple[int, Dict[str, List[str]]]] = { prev_gov: (cpu, tunables) for cpu, (domain, prev_gov, tunables, freq) in cpus_infos.items() } - global_tunables = self.target.async_manager.concurrently( + global_tunables: Coroutine = self.target.async_manager.concurrently( self.set_governor_tunables.asyn(cpu, gov, per_cpu=False, **tunables) - for gov, (cpu, tunables) in global_tunables.items() + for gov, (cpu, tunables) in global_tunables_dict.items() ) global_tunables.__qualname__ = 'CpufreqModule.use_governor..global_tunables' @@ -188,7 +214,11 @@ async def set_per_cpu_tunables(cpu): ) @asyn.asyncf - async def _list_governor_tunables(self, cpu, governor=None): + async def _list_governor_tunables(self, cpu: Union[int, str], + governor: Optional[str] = None) -> Tuple[str, bool, List[str]]: + """ + helper function for list_governor_tunables + """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) @@ -196,6 +226,8 @@ async def _list_governor_tunables(self, cpu, governor=None): governor = await self.get_governor.asyn(cpu) try: + if not governor: + raise TargetStableError return self._governor_tunables[governor] except KeyError: for per_cpu, path in ( @@ -204,7 +236,7 @@ async def _list_governor_tunables(self, cpu, governor=None): (False, '/sys/devices/system/cpu/cpufreq/{}'.format(governor)), ): try: - tunables = await self.target.list_directory.asyn(path) + tunables: List[str] = await self.target.list_directory.asyn(path) except TargetStableError: continue else: @@ -213,34 +245,49 @@ async def _list_governor_tunables(self, cpu, governor=None): per_cpu = False tunables = [] - data = (governor, per_cpu, tunables) - self._governor_tunables[governor] = data + data: Tuple[str, bool, List[str]] = (cast(str, governor), per_cpu, tunables) + if governor: + self._governor_tunables[governor] = data return data @asyn.asyncf - async def list_governor_tunables(self, cpu): - """Returns a list of tunables available for the governor on the specified CPU.""" + async def list_governor_tunables(self, cpu: Union[int, str]) -> Tuple[str, bool, List[str]]: + """ + List the tunables for the specified cpu's current governor. + + :param cpu: The cpu; could be a numeric or the corresponding string (e.g. + ``1`` or ``"cpu1"``). + """ _, _, tunables = await self._list_governor_tunables.asyn(cpu) return tunables @asyn.asyncf - async def get_governor_tunables(self, cpu): + async def get_governor_tunables(self, cpu: Union[int, str]) -> Dict[str, List[str]]: + """ + Return a dict with the values of the specified CPU's current governor. + + :param cpu: The cpu; could be a numeric or the corresponding string (e.g. + ``1`` or ``"cpu1"``). + """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) + governor: str + tunable_list: List[str] governor, _, tunable_list = await self._list_governor_tunables.asyn(cpu) - write_only = set(WRITE_ONLY_TUNABLES.get(governor, [])) + write_only: Set[str] = set(WRITE_ONLY_TUNABLES.get(governor, [])) tunable_list = [ tunable for tunable in tunable_list if tunable not in write_only ] - tunables = {} - async def get_tunable(tunable): + tunables: Dict[str, List[str]] = {} + + async def get_tunable(tunable: str) -> str: try: - path = '/sys/devices/system/cpu/{}/cpufreq/{}/{}'.format(cpu, governor, tunable) - x = await self.target.read_value.asyn(path) + path: str = '/sys/devices/system/cpu/{}/cpufreq/{}/{}'.format(cpu, governor, tunable) + x: str = await self.target.read_value.asyn(path) except TargetStableError: # May be an older kernel path = '/sys/devices/system/cpu/cpufreq/{}/{}'.format(governor, tunable) x = await self.target.read_value.asyn(path) @@ -250,7 +297,8 @@ async def get_tunable(tunable): return tunables @asyn.asyncf - async def set_governor_tunables(self, cpu, governor=None, per_cpu=None, **kwargs): + async def set_governor_tunables(self, cpu: Union[int, str], governor: Optional[str] = None, + per_cpu: Optional[bool] = None, **kwargs) -> None: """ Set tunables for the specified governor. Tunables should be specified as keyword arguments. Which tunables and values are valid depends on the @@ -276,6 +324,8 @@ async def set_governor_tunables(self, cpu, governor=None, per_cpu=None, **kwargs if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) + gov_per_cpu: bool + valid_tunables: List[str] governor, gov_per_cpu, valid_tunables = await self._list_governor_tunables.asyn(cpu, governor=governor) for tunable, value in kwargs.items(): if tunable in valid_tunables: @@ -283,34 +333,38 @@ async def set_governor_tunables(self, cpu, governor=None, per_cpu=None, **kwargs continue if gov_per_cpu: - path = '/sys/devices/system/cpu/{}/cpufreq/{}/{}'.format(cpu, governor, tunable) + path: str = '/sys/devices/system/cpu/{}/cpufreq/{}/{}'.format(cpu, governor, tunable) else: path = '/sys/devices/system/cpu/cpufreq/{}/{}'.format(governor, tunable) await self.target.write_value.asyn(path, value) else: - message = 'Unexpected tunable {} for governor {} on {}.\n'.format(tunable, governor, cpu) + message: str = 'Unexpected tunable {} for governor {} on {}.\n'.format(tunable, governor, cpu) message += 'Available tunables are: {}'.format(valid_tunables) raise TargetStableError(message) @asyn.asyncf @asyn.memoized_method - async def list_frequencies(self, cpu): - """Returns a sorted list of frequencies supported by the cpu or an empty list - if not could be found.""" + async def list_frequencies(self, cpu: Union[int, str]) -> List[int]: + """ + Returns a sorted list of frequencies supported by the cpu or an empty list + if not could be found. + :param cpu: The cpu; could be a numeric or the corresponding string (e.g. + ``1`` or ``"cpu1"``). + """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) try: - cmd = 'cat /sys/devices/system/cpu/{}/cpufreq/scaling_available_frequencies'.format(cpu) - output = await self.target.execute.asyn(cmd) - available_frequencies = list(map(int, output.strip().split())) # pylint: disable=E1103 + cmd: str = 'cat /sys/devices/system/cpu/{}/cpufreq/scaling_available_frequencies'.format(cpu) + output: str = await self.target.execute.asyn(cmd) + available_frequencies: List[int] = list(map(int, output.strip().split())) # pylint: disable=E1103 except TargetStableError: # On some devices scaling_frequencies is not generated. # http://adrynalyne-teachtofish.blogspot.co.uk/2011/11/how-to-enable-scalingavailablefrequenci.html # Fall back to parsing stats/time_in_state - path = '/sys/devices/system/cpu/{}/cpufreq/stats/time_in_state'.format(cpu) + path: str = '/sys/devices/system/cpu/{}/cpufreq/stats/time_in_state'.format(cpu) try: - out_iter = (await self.target.read_value.asyn(path)).split() + out_iter: List[str] = cast(str, (await self.target.read_value.asyn(path))).split() except TargetStableError: if not self.target.file_exists(path): # Probably intel_pstate. Can't get available freqs. @@ -321,7 +375,7 @@ async def list_frequencies(self, cpu): return sorted(available_frequencies) @memoized - def get_max_available_frequency(self, cpu): + def get_max_available_frequency(self, cpu: Union[str, int]) -> Optional[int]: """ Returns the maximum available frequency for a given core or None if could not be found. @@ -330,16 +384,16 @@ def get_max_available_frequency(self, cpu): return max(freqs) if freqs else None @memoized - def get_min_available_frequency(self, cpu): + def get_min_available_frequency(self, cpu: Union[str, int]) -> Optional[int]: """ Returns the minimum available frequency for a given core or None if could not be found. """ - freqs = self.list_frequencies(cpu) + freqs: List[int] = self.list_frequencies(cpu) return min(freqs) if freqs else None @asyn.asyncf - async def get_min_frequency(self, cpu): + async def get_min_frequency(self, cpu: Union[str, int]) -> int: """ Returns the min frequency currently set for the specified CPU. @@ -352,11 +406,11 @@ async def get_min_frequency(self, cpu): """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_min_freq'.format(cpu) + sysfile: str = '/sys/devices/system/cpu/{}/cpufreq/scaling_min_freq'.format(cpu) return await self.target.read_int.asyn(sysfile) @asyn.asyncf - async def set_min_frequency(self, cpu, frequency, exact=True): + async def set_min_frequency(self, cpu: Union[str, int], frequency: Union[int, str], exact: bool = True) -> None: """ Set's the minimum value for CPU frequency. Actual frequency will depend on the Governor used and may vary during execution. The value should be @@ -375,20 +429,20 @@ async def set_min_frequency(self, cpu, frequency, exact=True): """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - available_frequencies = await self.list_frequencies.asyn(cpu) + available_frequencies: List[int] = await self.list_frequencies.asyn(cpu) try: value = int(frequency) if exact and available_frequencies and value not in available_frequencies: raise TargetStableError('Can\'t set {} frequency to {}\nmust be in {}'.format(cpu, - value, - available_frequencies)) - sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_min_freq'.format(cpu) + value, + available_frequencies)) + sysfile: str = '/sys/devices/system/cpu/{}/cpufreq/scaling_min_freq'.format(cpu) await self.target.write_value.asyn(sysfile, value) except ValueError: raise ValueError('Frequency must be an integer; got: "{}"'.format(frequency)) @asyn.asyncf - async def get_frequency(self, cpu, cpuinfo=False): + async def get_frequency(self, cpu: Union[str, int], cpuinfo: bool = False) -> int: """ Returns the current frequency currently set for the specified CPU. @@ -405,18 +459,20 @@ async def get_frequency(self, cpu, cpuinfo=False): if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - sysfile = '/sys/devices/system/cpu/{}/cpufreq/{}'.format( - cpu, - 'cpuinfo_cur_freq' if cpuinfo else 'scaling_cur_freq') + sysfile: str = '/sys/devices/system/cpu/{}/cpufreq/{}'.format( + cpu, + 'cpuinfo_cur_freq' if cpuinfo else 'scaling_cur_freq') return await self.target.read_int.asyn(sysfile) @asyn.asyncf - async def set_frequency(self, cpu, frequency, exact=True): + async def set_frequency(self, cpu: Union[str, int], frequency: Union[int, str], exact: bool = True) -> None: """ - Set's the minimum value for CPU frequency. Actual frequency will + Sets the minimum value for CPU frequency. Actual frequency will depend on the Governor used and may vary during execution. The value should be either an int or a string representing an integer. + `set_frequency`` is only available if the current governor is ``userspace``. + If ``exact`` flag is set (the default), the Value must also be supported by the device. The available frequencies can be obtained by calling get_frequencies() or examining @@ -435,16 +491,16 @@ async def set_frequency(self, cpu, frequency, exact=True): try: value = int(frequency) if exact: - available_frequencies = await self.list_frequencies.asyn(cpu) + available_frequencies: List[int] = await self.list_frequencies.asyn(cpu) if available_frequencies and value not in available_frequencies: raise TargetStableError('Can\'t set {} frequency to {}\nmust be in {}'.format(cpu, - value, - available_frequencies)) + value, + available_frequencies)) if await self.get_governor.asyn(cpu) != 'userspace': raise TargetStableError('Can\'t set {} frequency; governor must be "userspace"'.format(cpu)) - sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_setspeed'.format(cpu) + sysfile: str = '/sys/devices/system/cpu/{}/cpufreq/scaling_setspeed'.format(cpu) await self.target.write_value.asyn(sysfile, value, verify=False) - cpuinfo = await self.get_frequency.asyn(cpu, cpuinfo=True) + cpuinfo: int = await self.get_frequency.asyn(cpu, cpuinfo=True) if cpuinfo != value: self.logger.warning( 'The cpufreq value has not been applied properly cpuinfo={} request={}'.format(cpuinfo, value)) @@ -452,7 +508,7 @@ async def set_frequency(self, cpu, frequency, exact=True): raise ValueError('Frequency must be an integer; got: "{}"'.format(frequency)) @asyn.asyncf - async def get_max_frequency(self, cpu): + async def get_max_frequency(self, cpu: Union[str, int]) -> int: """ Returns the max frequency currently set for the specified CPU. @@ -464,13 +520,14 @@ async def get_max_frequency(self, cpu): """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_max_freq'.format(cpu) + sysfile: str = '/sys/devices/system/cpu/{}/cpufreq/scaling_max_freq'.format(cpu) return await self.target.read_int.asyn(sysfile) @asyn.asyncf - async def set_max_frequency(self, cpu, frequency, exact=True): + async def set_max_frequency(self, cpu: Union[str, int], + frequency: Union[str, int], exact: bool = True) -> None: """ - Set's the minimum value for CPU frequency. Actual frequency will + Set's the maximum value for CPU frequency. Actual frequency will depend on the Governor used and may vary during execution. The value should be either an int or a string representing an integer. The Value must also be supported by the device. The available frequencies can be obtained by calling @@ -492,58 +549,58 @@ async def set_max_frequency(self, cpu, frequency, exact=True): value = int(frequency) if exact and available_frequencies and value not in available_frequencies: raise TargetStableError('Can\'t set {} frequency to {}\nmust be in {}'.format(cpu, - value, - available_frequencies)) - sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_max_freq'.format(cpu) + value, + available_frequencies)) + sysfile: str = '/sys/devices/system/cpu/{}/cpufreq/scaling_max_freq'.format(cpu) await self.target.write_value.asyn(sysfile, value) except ValueError: raise ValueError('Frequency must be an integer; got: "{}"'.format(frequency)) @asyn.asyncf - async def set_governor_for_cpus(self, cpus, governor, **kwargs): + async def set_governor_for_cpus(self, cpus: List[Union[str, int]], governor: str, **kwargs) -> None: """ Set the governor for the specified list of CPUs. See https://www.kernel.org/doc/Documentation/cpu-freq/governors.txt :param cpus: The list of CPU for which the governor is to be set. """ - await self.target.async_manager.map_concurrently( + await self.target.async_manager.concurrently( self.set_governor(cpu, governor, **kwargs) for cpu in sorted(set(cpus)) ) @asyn.asyncf - async def set_frequency_for_cpus(self, cpus, freq, exact=False): + async def set_frequency_for_cpus(self, cpus: List[Union[int, str]], freq: int, exact: bool = False) -> None: """ Set the frequency for the specified list of CPUs. See https://www.kernel.org/doc/Documentation/cpu-freq/governors.txt :param cpus: The list of CPU for which the frequency has to be set. """ - await self.target.async_manager.map_concurrently( + await self.target.async_manager.concurrently( self.set_frequency(cpu, freq, exact) for cpu in sorted(set(cpus)) ) @asyn.asyncf - async def set_all_frequencies(self, freq): + async def set_all_frequencies(self, freq: int) -> None: """ Set the specified (minimum) frequency for all the (online) CPUs """ # pylint: disable=protected-access return await self.target._execute_util.asyn( - 'cpufreq_set_all_frequencies {}'.format(freq), - as_root=True) + 'cpufreq_set_all_frequencies {}'.format(freq), + as_root=True) @asyn.asyncf - async def get_all_frequencies(self): + async def get_all_frequencies(self) -> Dict[str, str]: """ Get the current frequency for all the (online) CPUs """ # pylint: disable=protected-access - output = await self.target._execute_util.asyn( - 'cpufreq_get_all_frequencies', as_root=True) - frequencies = {} + output: str = await self.target._execute_util.asyn( + 'cpufreq_get_all_frequencies', as_root=True) + frequencies: Dict[str, str] = {} for x in output.splitlines(): kv = x.split(' ') if kv[0] == '': @@ -552,7 +609,7 @@ async def get_all_frequencies(self): return frequencies @asyn.asyncf - async def set_all_governors(self, governor): + async def set_all_governors(self, governor: str) -> None: """ Set the specified governor for all the (online) CPUs """ @@ -563,10 +620,10 @@ async def set_all_governors(self, governor): as_root=True) except TargetStableError as e: if ("echo: I/O error" in str(e) or - "write error: Invalid argument" in str(e)): + "write error: Invalid argument" in str(e)): - cpus_unsupported = [c for c in await self.target.list_online_cpus.asyn() - if governor not in await self.list_governors.asyn(c)] + cpus_unsupported: List[int] = [c for c in await self.target.list_online_cpus.asyn() + if governor not in await self.list_governors.asyn(c)] raise TargetStableError("Governor {} unsupported for CPUs {}".format( governor, cpus_unsupported)) else: @@ -579,7 +636,7 @@ async def get_all_governors(self): """ # pylint: disable=protected-access output = await self.target._execute_util.asyn( - 'cpufreq_get_all_governors', as_root=True) + 'cpufreq_get_all_governors', as_root=True) governors = {} for x in output.splitlines(): kv = x.split(' ') @@ -597,7 +654,7 @@ async def trace_frequencies(self): return await self.target._execute_util.asyn('cpufreq_trace_all_frequencies', as_root=True) @asyn.asyncf - async def get_affected_cpus(self, cpu): + async def get_affected_cpus(self, cpu: Union[str, int]) -> List[int]: """ Get the online CPUs that share a frequency domain with the given CPU """ @@ -611,7 +668,7 @@ async def get_affected_cpus(self, cpu): @asyn.asyncf @asyn.memoized_method - async def get_related_cpus(self, cpu): + async def get_related_cpus(self, cpu: Union[str, int]) -> List[int]: """ Get the CPUs that share a frequency domain with the given CPU """ @@ -620,11 +677,11 @@ async def get_related_cpus(self, cpu): sysfile = '/sys/devices/system/cpu/{}/cpufreq/related_cpus'.format(cpu) - return [int(c) for c in (await self.target.read_value.asyn(sysfile)).split()] + return [int(c) for c in cast(str, (await self.target.read_value.asyn(sysfile))).split()] @asyn.asyncf @asyn.memoized_method - async def get_driver(self, cpu): + async def get_driver(self, cpu: Union[str, int]) -> str: """ Get the name of the driver used by this cpufreq policy. """ @@ -633,16 +690,16 @@ async def get_driver(self, cpu): sysfile = '/sys/devices/system/cpu/{}/cpufreq/scaling_driver'.format(cpu) - return (await self.target.read_value.asyn(sysfile)).strip() + return cast(str, (await self.target.read_value.asyn(sysfile))).strip() @asyn.asyncf - async def iter_domains(self): + async def iter_domains(self) -> AsyncGenerator[Set[int], None]: """ Iterate over the frequency domains in the system """ cpus = set(range(self.target.number_of_cpus)) while cpus: cpu = next(iter(cpus)) # pylint: disable=stop-iteration-return - domain = await self.target.cpufreq.get_related_cpus.asyn(cpu) + domain: Set[int] = await cast(CpufreqModule, self.target.cpufreq).get_related_cpus.asyn(cpu) yield domain cpus = cpus.difference(domain) diff --git a/devlib/module/cpuidle.py b/devlib/module/cpuidle.py index a7d0fef64..bfed7d2c5 100644 --- a/devlib/module/cpuidle.py +++ b/devlib/module/cpuidle.py @@ -1,4 +1,4 @@ -# Copyright 2014-2018 ARM Limited +# Copyright 2014-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ # limitations under the License. # # pylint: disable=attribute-defined-outside-init -from past.builtins import basestring from operator import attrgetter from pprint import pformat @@ -23,24 +22,27 @@ from devlib.utils.types import integer, boolean from devlib.utils.misc import memoized import devlib.utils.asyn as asyn +from typing import Optional, TYPE_CHECKING, Union, Dict, List +if TYPE_CHECKING: + from devlib.target import Target, Node class CpuidleState(object): @property - def usage(self): + def usage(self) -> int: return integer(self.get('usage')) @property - def time(self): + def time(self) -> int: return integer(self.get('time')) @property - def is_enabled(self): + def is_enabled(self) -> bool: return not boolean(self.get('disable')) @property - def ordinal(self): + def ordinal(self) -> int: i = len(self.id) while self.id[i - 1].isdigit(): i -= 1 @@ -48,7 +50,8 @@ def ordinal(self): raise ValueError('invalid idle state name: "{}"'.format(self.id)) return int(self.id[i:]) - def __init__(self, target, index, path, name, desc, power, latency, residency): + def __init__(self, target: 'Target', index: int, path: str, name: str, + desc: str, power: int, latency: int, residency: Optional[int]): self.target = target self.index = index self.path = path @@ -57,31 +60,43 @@ def __init__(self, target, index, path, name, desc, power, latency, residency): self.power = power self.latency = latency self.residency = residency - self.id = self.target.path.basename(self.path) - self.cpu = self.target.path.basename(self.target.path.dirname(path)) + self.id: str = self.target.path.basename(self.path) if self.target.path else '' + self.cpu: str = self.target.path.basename(self.target.path.dirname(path)) if self.target.path else '' @asyn.asyncf - async def enable(self): + async def enable(self) -> None: + """ + enable idle state + """ await self.set.asyn('disable', 0) @asyn.asyncf - async def disable(self): + async def disable(self) -> None: + """ + disable idle state + """ await self.set.asyn('disable', 1) @asyn.asyncf - async def get(self, prop): - property_path = self.target.path.join(self.path, prop) + async def get(self, prop: str) -> str: + """ + get the property + """ + property_path = self.target.path.join(self.path, prop) if self.target.path else '' return await self.target.read_value.asyn(property_path) @asyn.asyncf - async def set(self, prop, value): - property_path = self.target.path.join(self.path, prop) + async def set(self, prop: str, value: str) -> None: + """ + set the property + """ + property_path = self.target.path.join(self.path, prop) if self.target.path else '' await self.target.write_value.asyn(property_path, value) def __eq__(self, other): if isinstance(other, CpuidleState): return (self.name == other.name) and (self.desc == other.desc) - elif isinstance(other, basestring): + elif isinstance(other, str): return (self.name == other) or (self.desc == other) else: return False @@ -96,19 +111,23 @@ def __str__(self): class Cpuidle(Module): - + """ + ``cpuidle`` is the kernel subsystem for managing CPU low power (idle) states. + """ name = 'cpuidle' root_path = '/sys/devices/system/cpu/cpuidle' @staticmethod @asyn.asyncf - async def probe(target): + async def probe(target: 'Target') -> bool: return await target.file_exists.asyn(Cpuidle.root_path) - def __init__(self, target): + def __init__(self, target: 'Target'): super(Cpuidle, self).__init__(target) - basepath = '/sys/devices/system/cpu/' + basepath: str = '/sys/devices/system/cpu/' + # FIXME - annotating the values_tree based on read_tree_values return type is causing errors due to recursive + # definition of the Node type. leaving it out for now values_tree = self.target.read_tree_values(basepath, depth=4, check_exit_code=False) self._states = { @@ -118,7 +137,7 @@ def __init__(self, target): self.target, # state_name is formatted as "state42" index=int(state_name[len('state'):]), - path=self.target.path.join(basepath, cpu_name, 'cpuidle', state_name), + path=self.target.path.join(basepath, cpu_name, 'cpuidle', state_name) if self.target.path else '', name=state_node['name'], desc=state_node['desc'], power=int(state_node['power']), @@ -137,12 +156,18 @@ def __init__(self, target): self.logger.debug('Adding cpuidle states:\n{}'.format(pformat(self._states))) - def get_states(self, cpu=0): + def get_states(self, cpu: Union[int, str] = 0) -> List[CpuidleState]: + """ + get the cpu idle states + """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) return self._states.get(cpu, []) - def get_state(self, state, cpu=0): + def get_state(self, state: Union[str, int], cpu: Union[str, int] = 0) -> CpuidleState: + """ + get the specific cpuidle state values + """ if isinstance(state, int): try: return self.get_states(cpu)[state] @@ -155,29 +180,41 @@ def get_state(self, state, cpu=0): raise ValueError('Cpuidle state {} does not exist'.format(state)) @asyn.asyncf - async def enable(self, state, cpu=0): + async def enable(self, state: Union[str, int], cpu: Union[str, int] = 0) -> None: + """ + enable the specific cpu idle state + """ await self.get_state(state, cpu).enable.asyn() @asyn.asyncf - async def disable(self, state, cpu=0): + async def disable(self, state: Union[str, int], cpu: Union[str, int] = 0) -> None: + """ + disable the specific cpu idle state + """ await self.get_state(state, cpu).disable.asyn() @asyn.asyncf - async def enable_all(self, cpu=0): + async def enable_all(self, cpu: Union[str, int] = 0) -> None: + """ + enable all the cpu idle states + """ await self.target.async_manager.concurrently( state.enable.asyn() for state in self.get_states(cpu) ) @asyn.asyncf - async def disable_all(self, cpu=0): + async def disable_all(self, cpu: Union[str, int] = 0) -> None: + """ + disable all cpu idle states + """ await self.target.async_manager.concurrently( state.disable.asyn() for state in self.get_states(cpu) ) @asyn.asyncf - async def perturb_cpus(self): + async def perturb_cpus(self) -> None: """ Momentarily wake each CPU. Ensures cpu_idle events in trace file. """ @@ -185,25 +222,30 @@ async def perturb_cpus(self): await self.target._execute_util.asyn('cpuidle_wake_all_cpus') @asyn.asyncf - async def get_driver(self): - return await self.target.read_value.asyn(self.target.path.join(self.root_path, 'current_driver')) + async def get_driver(self) -> Optional[str]: + """ + get the current driver of idle states + """ + if self.target.path: + return await self.target.read_value.asyn(self.target.path.join(self.root_path, 'current_driver')) + return None @memoized - def list_governors(self): + def list_governors(self) -> List[str]: """Returns a list of supported idle governors.""" - sysfile = self.target.path.join(self.root_path, 'available_governors') - output = self.target.read_value(sysfile) + sysfile: str = self.target.path.join(self.root_path, 'available_governors') if self.target.path else '' + output: str = self.target.read_value(sysfile) return output.strip().split() @asyn.asyncf - async def get_governor(self): + async def get_governor(self) -> str: """Returns the currently selected idle governor.""" - path = self.target.path.join(self.root_path, 'current_governor_ro') + path = self.target.path.join(self.root_path, 'current_governor_ro') if self.target.path else '' if not await self.target.file_exists.asyn(path): - path = self.target.path.join(self.root_path, 'current_governor') + path = self.target.path.join(self.root_path, 'current_governor') if self.target.path else '' return await self.target.read_value.asyn(path) - def set_governor(self, governor): + def set_governor(self, governor: str) -> None: """ Set the idle governor for the system. @@ -213,8 +255,8 @@ def set_governor(self, governor): :raises TargetStableError if governor is not supported by the CPU, or if, for some reason, the governor could not be set. """ - supported = self.list_governors() + supported: List[str] = self.list_governors() if governor not in supported: raise TargetStableError('Governor {} not supported'.format(governor)) - sysfile = self.target.path.join(self.root_path, 'current_governor') + sysfile: str = self.target.path.join(self.root_path, 'current_governor') if self.target.path else '' self.target.write_value(sysfile, governor) diff --git a/devlib/module/devfreq.py b/devlib/module/devfreq.py index 00c3154c8..0fd75e7ac 100644 --- a/devlib/module/devfreq.py +++ b/devlib/module/devfreq.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,14 +15,20 @@ from devlib.module import Module from devlib.exception import TargetStableError from devlib.utils.misc import memoized +from typing import TYPE_CHECKING, List, Union, Dict +if TYPE_CHECKING: + from devlib.target import Target -class DevfreqModule(Module): +class DevfreqModule(Module): + """ + The devfreq framework in Linux is used for dynamic voltage and frequency scaling (DVFS) of various devices. + """ name = 'devfreq' @staticmethod - def probe(target): - path = '/sys/class/devfreq/' + def probe(target: 'Target') -> bool: + path: str = '/sys/class/devfreq/' if not target.file_exists(path): return False @@ -33,26 +39,26 @@ def probe(target): return True @memoized - def list_devices(self): + def list_devices(self) -> List[str]: """Returns a list of devfreq devices supported by the target platform.""" sysfile = '/sys/class/devfreq/' return self.target.list_directory(sysfile) @memoized - def list_governors(self, device): + def list_governors(self, device: str) -> List[str]: """Returns a list of governors supported by the device.""" - sysfile = '/sys/class/devfreq/{}/available_governors'.format(device) - output = self.target.read_value(sysfile) + sysfile: str = '/sys/class/devfreq/{}/available_governors'.format(device) + output: str = self.target.read_value(sysfile) return output.strip().split() - def get_governor(self, device): + def get_governor(self, device: Union[str, int]) -> str: """Returns the governor currently set for the specified device.""" if isinstance(device, int): device = 'device{}'.format(device) sysfile = '/sys/class/devfreq/{}/governor'.format(device) return self.target.read_value(sysfile) - def set_governor(self, device, governor): + def set_governor(self, device: str, governor: str) -> None: """ Set the governor for the specified device. @@ -68,25 +74,25 @@ def set_governor(self, device, governor): for some reason, the governor could not be set. """ - supported = self.list_governors(device) + supported: List[str] = self.list_governors(device) if governor not in supported: raise TargetStableError('Governor {} not supported for device {}'.format(governor, device)) - sysfile = '/sys/class/devfreq/{}/governor'.format(device) + sysfile: str = '/sys/class/devfreq/{}/governor'.format(device) self.target.write_value(sysfile, governor) @memoized - def list_frequencies(self, device): + def list_frequencies(self, device: str) -> List[int]: """ Returns a list of frequencies supported by the device or an empty list if could not be found. """ - cmd = 'cat /sys/class/devfreq/{}/available_frequencies'.format(device) - output = self.target.execute(cmd) - available_frequencies = [int(freq) for freq in output.strip().split()] + cmd: str = 'cat /sys/class/devfreq/{}/available_frequencies'.format(device) + output: str = self.target.execute(cmd) + available_frequencies: List[int] = [int(freq) for freq in output.strip().split()] return available_frequencies - def get_min_frequency(self, device): + def get_min_frequency(self, device: str) -> int: """ Returns the min frequency currently set for the specified device. @@ -100,7 +106,7 @@ def get_min_frequency(self, device): sysfile = '/sys/class/devfreq/{}/min_freq'.format(device) return self.target.read_int(sysfile) - def set_min_frequency(self, device, frequency, exact=True): + def set_min_frequency(self, device: str, frequency: Union[int, str], exact: bool = True) -> None: """ Sets the minimum value for device frequency. Actual frequency will depend on the thermal governor used and may vary during execution. The @@ -117,19 +123,19 @@ def set_min_frequency(self, device, frequency, exact=True): :raises: ValueError if ``frequency`` is not an integer. """ - available_frequencies = self.list_frequencies(device) + available_frequencies: List[int] = self.list_frequencies(device) try: value = int(frequency) if exact and available_frequencies and value not in available_frequencies: raise TargetStableError('Can\'t set {} frequency to {}\nmust be in {}'.format(device, - value, - available_frequencies)) - sysfile = '/sys/class/devfreq/{}/min_freq'.format(device) + value, + available_frequencies)) + sysfile: str = '/sys/class/devfreq/{}/min_freq'.format(device) self.target.write_value(sysfile, value) except ValueError: raise ValueError('Frequency must be an integer; got: "{}"'.format(frequency)) - def get_frequency(self, device): + def get_frequency(self, device: str) -> int: """ Returns the current frequency currently set for the specified device. @@ -140,10 +146,10 @@ def get_frequency(self, device): :raises: TargetStableError if for some reason the frequency could not be read. """ - sysfile = '/sys/class/devfreq/{}/cur_freq'.format(device) + sysfile: str = '/sys/class/devfreq/{}/cur_freq'.format(device) return self.target.read_int(sysfile) - def get_max_frequency(self, device): + def get_max_frequency(self, device: str) -> int: """ Returns the max frequency currently set for the specified device. @@ -153,10 +159,10 @@ def get_max_frequency(self, device): :raises: TargetStableError if for some reason the frequency could not be read. """ - sysfile = '/sys/class/devfreq/{}/max_freq'.format(device) + sysfile: str = '/sys/class/devfreq/{}/max_freq'.format(device) return self.target.read_int(sysfile) - def set_max_frequency(self, device, frequency, exact=True): + def set_max_frequency(self, device: str, frequency: Union[int, str], exact: bool = True) -> None: """ Sets the maximum value for device frequency. Actual frequency will depend on the Governor used and may vary during execution. The value @@ -173,7 +179,7 @@ def set_max_frequency(self, device, frequency, exact=True): :raises: ValueError if ``frequency`` is not an integer. """ - available_frequencies = self.list_frequencies(device) + available_frequencies: List[int] = self.list_frequencies(device) try: value = int(frequency) except ValueError: @@ -181,12 +187,12 @@ def set_max_frequency(self, device, frequency, exact=True): if exact and value not in available_frequencies: raise TargetStableError('Can\'t set {} frequency to {}\nmust be in {}'.format(device, - value, - available_frequencies)) - sysfile = '/sys/class/devfreq/{}/max_freq'.format(device) + value, + available_frequencies)) + sysfile: str = '/sys/class/devfreq/{}/max_freq'.format(device) self.target.write_value(sysfile, value) - def set_governor_for_devices(self, devices, governor): + def set_governor_for_devices(self, devices: List[str], governor: str) -> None: """ Set the governor for the specified list of devices. @@ -195,7 +201,7 @@ def set_governor_for_devices(self, devices, governor): for device in devices: self.set_governor(device, governor) - def set_all_governors(self, governor): + def set_all_governors(self, governor: str) -> None: """ Set the specified governor for all the (available) devices """ @@ -204,22 +210,22 @@ def set_all_governors(self, governor): 'devfreq_set_all_governors {}'.format(governor), as_root=True) except TargetStableError as e: if ("echo: I/O error" in str(e) or - "write error: Invalid argument" in str(e)): + "write error: Invalid argument" in str(e)): - devs_unsupported = [d for d in self.target.list_devices() - if governor not in self.list_governors(d)] + devs_unsupported: List[str] = [d for d in self.list_devices() + if governor not in self.list_governors(d)] raise TargetStableError("Governor {} unsupported for devices {}".format( governor, devs_unsupported)) else: raise - def get_all_governors(self): + def get_all_governors(self) -> Dict[str, str]: """ Get the current governor for all the (online) CPUs """ - output = self.target._execute_util( # pylint: disable=protected-access - 'devfreq_get_all_governors', as_root=True) - governors = {} + output: str = self.target._execute_util( # pylint: disable=protected-access + 'devfreq_get_all_governors', as_root=True) + governors: Dict[str, str] = {} for x in output.splitlines(): kv = x.split(' ') if kv[0] == '': @@ -227,7 +233,7 @@ def get_all_governors(self): governors[kv[0]] = kv[1] return governors - def set_frequency_for_devices(self, devices, freq, exact=False): + def set_frequency_for_devices(self, devices: List[str], freq: Union[int, str], exact: bool = False) -> None: """ Set the frequency for the specified list of devices. @@ -237,21 +243,21 @@ def set_frequency_for_devices(self, devices, freq, exact=False): self.set_max_frequency(device, freq, exact) self.set_min_frequency(device, freq, exact) - def set_all_frequencies(self, freq): + def set_all_frequencies(self, freq: Union[int, str]) -> None: """ Set the specified (minimum) frequency for all the (available) devices """ return self.target._execute_util( # pylint: disable=protected-access - 'devfreq_set_all_frequencies {}'.format(freq), - as_root=True) + 'devfreq_set_all_frequencies {}'.format(freq), + as_root=True) - def get_all_frequencies(self): + def get_all_frequencies(self) -> Dict[str, str]: """ Get the current frequency for all the (available) devices """ - output = self.target._execute_util( # pylint: disable=protected-access - 'devfreq_get_all_frequencies', as_root=True) - frequencies = {} + output: str = self.target._execute_util( # pylint: disable=protected-access + 'devfreq_get_all_frequencies', as_root=True) + frequencies: Dict[str, str] = {} for x in output.splitlines(): kv = x.split(' ') if kv[0] == '': diff --git a/devlib/module/gpufreq.py b/devlib/module/gpufreq.py index 9f0a9529d..b8be14f4d 100644 --- a/devlib/module/gpufreq.py +++ b/devlib/module/gpufreq.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,41 +31,50 @@ from devlib.module import Module from devlib.exception import TargetStableError from devlib.utils.misc import memoized +from typing import TYPE_CHECKING, List +if TYPE_CHECKING: + from devlib.target import Target -class GpufreqModule(Module): +class GpufreqModule(Module): + """ + module that handles gpu frequency scaling + """ name = 'gpufreq' path = '' - def __init__(self, target): + def __init__(self, target: 'Target'): super(GpufreqModule, self).__init__(target) - frequencies_str = self.target.read_value("/sys/kernel/gpu/gpu_freq_table") - self.frequencies = list(map(int, frequencies_str.split(" "))) + frequencies_str: str = self.target.read_value("/sys/kernel/gpu/gpu_freq_table") + self.frequencies: List[int] = list(map(int, frequencies_str.split(" "))) self.frequencies.sort() - self.governors = self.target.read_value("/sys/kernel/gpu/gpu_available_governor").split(" ") + self.governors: List[str] = self.target.read_value("/sys/kernel/gpu/gpu_available_governor").split(" ") @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: # kgsl/Adreno - probe_path = '/sys/kernel/gpu/' + probe_path: str = '/sys/kernel/gpu/' if target.file_exists(probe_path): - model = target.read_value(probe_path + "gpu_model") + model: str = target.read_value(probe_path + "gpu_model") if re.search('adreno', model, re.IGNORECASE): return True return False - def set_governor(self, governor): + def set_governor(self, governor: str) -> None: + """ + set the governor to the gpu + """ if governor not in self.governors: raise TargetStableError('Governor {} not supported for gpu'.format(governor)) self.target.write_value("/sys/kernel/gpu/gpu_governor", governor) - def get_frequencies(self): + def get_frequencies(self) -> List[int]: """ Returns the list of frequencies that the GPU can have """ return self.frequencies - def get_current_frequency(self): + def get_current_frequency(self) -> int: """ Returns the current frequency currently set for the GPU. @@ -79,7 +88,7 @@ def get_current_frequency(self): return int(self.target.read_value("/sys/kernel/gpu/gpu_clock")) @memoized - def get_model_name(self): + def get_model_name(self) -> str: """ Returns the model name reported by the GPU. """ diff --git a/devlib/module/hotplug.py b/devlib/module/hotplug.py index 7d5ea5f64..731770bf6 100644 --- a/devlib/module/hotplug.py +++ b/devlib/module/hotplug.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,32 +15,48 @@ from devlib.module import Module from devlib.exception import TargetTransientError +from typing import TYPE_CHECKING, Dict, cast, Union, List +if TYPE_CHECKING: + from devlib.target import Target class HotplugModule(Module): - + """ + Kernel ``hotplug`` subsystem allows offlining ("removing") cores from the + system, and onlining them back in. The ``devlib`` module exposes a simple + interface to this subsystem + """ name = 'hotplug' base_path = '/sys/devices/system/cpu' @classmethod - def probe(cls, target): # pylint: disable=arguments-differ + def probe(cls, target: 'Target') -> bool: # pylint: disable=arguments-differ # If a system has just 1 CPU, it makes not sense to hotplug it. # If a system has more than 1 CPU, CPU0 could be configured to be not # hotpluggable. Thus, check for hotplug support by looking at CPU1 path = cls._cpu_path(target, 1) - return target.file_exists(path) and target.is_rooted + return cast(bool, target.file_exists(path) and target.is_rooted) @classmethod - def _cpu_path(cls, target, cpu): + def _cpu_path(cls, target: 'Target', cpu: Union[int, str]) -> str: + """ + get path to cpu online + """ if isinstance(cpu, int): cpu = 'cpu{}'.format(cpu) - return target.path.join(cls.base_path, cpu, 'online') + return target.path.join(cls.base_path, cpu, 'online') if target.path else '' - def list_hotpluggable_cpus(self): + def list_hotpluggable_cpus(self) -> List[int]: + """ + get the list of hotpluggable cpus + """ return [cpu for cpu in range(self.target.number_of_cpus) if self.target.file_exists(self._cpu_path(self.target, cpu))] - def online_all(self, verify=True): + def online_all(self, verify: bool = True) -> None: + """ + bring all cpus online + """ self.target._execute_util('hotplug_online_all', # pylint: disable=protected-access as_root=self.target.is_rooted) if verify: @@ -48,37 +64,60 @@ def online_all(self, verify=True): if offline: raise TargetTransientError('The following CPUs failed to come back online: {}'.format(offline)) - def online(self, *args): + def online(self, *args) -> None: + """ + bring online specific cpus + """ for cpu in args: self.hotplug(cpu, online=True) - def offline(self, *args): + def offline(self, *args) -> None: + """ + take specific cpus offline + """ for cpu in args: self.hotplug(cpu, online=False) - def hotplug(self, cpu, online): + def hotplug(self, cpu: Union[int, str], online: bool) -> None: + """ + bring cpus online or offline + """ path = self._cpu_path(self.target, cpu) if not self.target.file_exists(path): return value = 1 if online else 0 self.target.write_value(path, value) - def _get_path(self, path): + def _get_path(self, path: str) -> str: + """ + get path to cpu directory + """ return self.target.path.join(self.base_path, - path) + path) if self.target.path else '' - def fail(self, cpu, state): + def fail(self, cpu: Union[str, int], state: str) -> None: + """ + set fail status for cpu hotplug + """ path = self._get_path('cpu{}/hotplug/fail'.format(cpu)) return self.target.write_value(path, state) - def get_state(self, cpu): + def get_state(self, cpu: Union[int, str]) -> str: + """ + get the hotplug state of the cpu + """ path = self._get_path('cpu{}/hotplug/state'.format(cpu)) return self.target.read_value(path) - def get_states(self): - path = self._get_path('hotplug/states') - states_string = self.target.read_value(path) - return dict( - map(str.strip, string.split(':', 1)) - for string in states_string.strip().splitlines() - ) + def get_states(self) -> Dict[str, str]: + """ + get the possible values for hotplug states + """ + path: str = self._get_path('hotplug/states') + states_string: str = self.target.read_value(path) + return { + key.strip(): value.strip() + for line in states_string.strip().splitlines() + if ':' in line + for key, value in [line.split(':', 1)] + } diff --git a/devlib/module/hwmon.py b/devlib/module/hwmon.py index 3ecc55ca9..ccd32f2a7 100644 --- a/devlib/module/hwmon.py +++ b/devlib/module/hwmon.py @@ -1,4 +1,4 @@ -# Copyright 2015-2018 ARM Limited +# Copyright 2015-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,10 @@ from devlib import TargetStableError from devlib.module import Module from devlib.utils.types import integer +from typing import (TYPE_CHECKING, Set, Union, cast, DefaultDict, + Dict, List, Match, Optional) +if TYPE_CHECKING: + from devlib.target import Target HWMON_ROOT = '/sys/class/hwmon' @@ -25,39 +29,54 @@ class HwmonSensor(object): - - def __init__(self, device, path, kind, number): + """ + hardware monitoring sensor + """ + def __init__(self, device: 'HwmonDevice', path: str, + kind: str, number: int): self.device = device self.path = path self.kind = kind self.number = number - self.target = self.device.target - self.name = '{}/{}{}'.format(self.device.name, self.kind, self.number) + self.target: 'Target' = self.device.target + self.name: str = '{}/{}{}'.format(self.device.name, self.kind, self.number) self.label = self.name - self.items = set() + self.items: Set[str] = set() - def add(self, item): + def add(self, item: str) -> None: + """ + add item to items set + """ self.items.add(item) if item == 'label': - self.label = self.get('label') + self.label = cast(str, self.get('label')) - def get(self, item): + def get(self, item: str) -> Union[int, str]: + """ + get the value of the item + """ path = self.get_file(item) value = self.target.read_value(path) try: - return integer(value) + return integer(value) except (TypeError, ValueError): return value - def set(self, item, value): - path = self.get_file(item) + def set(self, item: str, value: Union[int, str]) -> None: + """ + set value to the item + """ + path: str = self.get_file(item) self.target.write_value(path, value) - def get_file(self, item): + def get_file(self, item: str) -> str: + """ + get file path + """ if item not in self.items: raise ValueError('item "{}" does not exist for {}'.format(item, self.name)) filename = '{}{}_{}'.format(self.kind, self.number, item) - return self.target.path.join(self.path, filename) + return self.target.path.join(self.path, filename) if self.target.path else '' def __str__(self): if self.name != self.label: @@ -70,34 +89,43 @@ def __str__(self): class HwmonDevice(object): - + """ + Hardware monitor device + """ @property - def sensors(self): - all_sensors = [] + def sensors(self) -> List[HwmonSensor]: + """ + get all the hardware monitoring sensors + """ + all_sensors: List[HwmonSensor] = [] for sensors_of_kind in self._sensors.values(): all_sensors.extend(list(sensors_of_kind.values())) return all_sensors - def __init__(self, target, path, name, fields): + def __init__(self, target: 'Target', path: str, name: str, fields: List[str]): self.target = target self.path = path self.name = name - self._sensors = defaultdict(dict) + self._sensors: DefaultDict[str, Dict[int, HwmonSensor]] = defaultdict(dict) path = self.path - if not path.endswith(self.target.path.sep): - path += self.target.path.sep - for entry in fields: - match = HWMON_FILE_REGEX.search(entry) - if match: - kind = match.group('kind') - number = int(match.group('number')) - item = match.group('item') - if number not in self._sensors[kind]: - sensor = HwmonSensor(self, self.path, kind, number) - self._sensors[kind][number] = sensor - self._sensors[kind][number].add(item) - - def get(self, kind, number=None): + if self.target.path: + if not path.endswith(self.target.path.sep): + path += self.target.path.sep + for entry in fields: + match: Optional[Match[str]] = HWMON_FILE_REGEX.search(entry) + if match: + kind: str = match.group('kind') + number: int = int(match.group('number')) + item: str = match.group('item') + if number not in self._sensors[kind]: + sensor = HwmonSensor(self, self.path, kind, number) + self._sensors[kind][number] = sensor + self._sensors[kind][number].add(item) + + def get(self, kind: str, number: Optional[int] = None) -> Union[List[HwmonSensor], HwmonSensor, None]: + """ + get the hardware monitor sensors of the specified kind + """ if number is None: return [s for _, s in sorted(self._sensors[kind].items(), key=lambda x: x[0])] @@ -111,11 +139,15 @@ def __str__(self): class HwmonModule(Module): - + """ + The hwmon (hardware monitoring) subsystem in Linux is used to monitor various hardware parameters + such as temperature, voltage, and fan speed. This subsystem provides a standardized interface for + accessing sensor data from different hardware components. + """ name = 'hwmon' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: try: target.list_directory(HWMON_ROOT, as_root=target.is_rooted) except TargetStableError: @@ -124,23 +156,29 @@ def probe(target): return True @property - def sensors(self): - all_sensors = [] + def sensors(self) -> List[HwmonSensor]: + """ + hardware monitoring sensors in all hardware monitoring devices + """ + all_sensors: List[HwmonSensor] = [] for device in self.devices: all_sensors.extend(device.sensors) return all_sensors - def __init__(self, target): + def __init__(self, target: 'Target'): super(HwmonModule, self).__init__(target) - self.root = HWMON_ROOT - self.devices = [] + self.root: str = HWMON_ROOT + self.devices: List[HwmonDevice] = [] self.scan() - def scan(self): + def scan(self) -> None: + """ + scan and add devices to the hardware mpnitor module + """ values_tree = self.target.read_tree_values(self.root, depth=3, tar=True) for entry_id, fields in values_tree.items(): - path = self.target.path.join(self.root, entry_id) - name = fields.pop('name', None) + path: str = self.target.path.join(self.root, entry_id) if self.target.path else '' + name: Optional[str] = fields.pop('name', None) if name is None: continue self.logger.debug('Adding device {}'.format(name)) diff --git a/devlib/module/sched.py b/devlib/module/sched.py index e1d526dfd..248fcda09 100644 --- a/devlib/module/sched.py +++ b/devlib/module/sched.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,12 +16,16 @@ import logging import re -from past.builtins import basestring - from devlib.module import Module from devlib.utils.misc import memoized from devlib.utils.types import boolean from devlib.exception import TargetStableError +from typing import (TYPE_CHECKING, cast, Match, Dict, + Any, List, Pattern, Union, Optional, + Tuple, Set) +if TYPE_CHECKING: + from devlib.target import Target + class SchedProcFSNode(object): """ @@ -49,30 +53,33 @@ class SchedProcFSNode(object): MC """ - _re_procfs_node = re.compile(r"(?P.*\D)(?P\d+)$") + _re_procfs_node: Pattern[str] = re.compile(r"(?P.*\D)(?P\d+)$") - PACKABLE_ENTRIES = [ + PACKABLE_ENTRIES: List[str] = [ "cpu", "domain", "group" ] @staticmethod - def _ends_with_digits(node): - if not isinstance(node, basestring): + def _ends_with_digits(node: str) -> bool: + """ + returns True if the node ends with digits + """ + if not isinstance(node, str): return False - return re.search(SchedProcFSNode._re_procfs_node, node) != None + return re.search(SchedProcFSNode._re_procfs_node, node) is not None @staticmethod - def _node_digits(node): + def _node_digits(node: str) -> int: """ :returns: The ending digits of the procfs node """ - return int(re.search(SchedProcFSNode._re_procfs_node, node).group("digits")) + return int(cast(Match, re.search(SchedProcFSNode._re_procfs_node, node)).group("digits")) @staticmethod - def _node_name(node): + def _node_name(node: str) -> str: """ :returns: The name of the procfs node """ @@ -83,7 +90,7 @@ def _node_name(node): return node @classmethod - def _packable(cls, node): + def _packable(cls, node: str) -> bool: """ :returns: Whether it makes sense to pack a node into a common entry """ @@ -91,14 +98,18 @@ def _packable(cls, node): SchedProcFSNode._node_name(node) in cls.PACKABLE_ENTRIES) @staticmethod - def _build_directory(node_name, node_data): + def _build_directory(node_name: str, + node_data: Any) -> Union['SchedDomain', 'SchedProcFSNode']: + """ + create a new sched domain or a new procfs node + """ if node_name.startswith("domain"): return SchedDomain(node_data) else: return SchedProcFSNode(node_data) @staticmethod - def _build_entry(node_data): + def _build_entry(node_data: Any) -> Union[int, Any]: value = node_data # Most nodes just contain numerical data, try to convert @@ -110,32 +121,33 @@ def _build_entry(node_data): return value @staticmethod - def _build_node(node_name, node_data): + def _build_node(node_name: str, node_data: Any) -> Union['SchedDomain', 'SchedProcFSNode', + int, Any]: if isinstance(node_data, dict): return SchedProcFSNode._build_directory(node_name, node_data) else: return SchedProcFSNode._build_entry(node_data) - def __getattr__(self, name): + def __getattr__(self, name: str): return self._dyn_attrs[name] - def __init__(self, nodes): + def __init__(self, nodes: Dict[str, 'SchedProcFSNode']): self.procfs = nodes # First, reduce the procs fields by packing them if possible # Find which entries can be packed into a common entry - packables = { - node : SchedProcFSNode._node_name(node) + "s" - for node in list(nodes.keys()) if SchedProcFSNode._packable(node) + packables: Dict[str, str] = { + node: SchedProcFSNode._node_name(node) + "s" + for node in list(cast(SchedProcFSNode, nodes).keys()) if SchedProcFSNode._packable(node) } - self._dyn_attrs = {} + self._dyn_attrs: Dict[str, Any] = {} for dest in set(packables.values()): self._dyn_attrs[dest] = {} # Pack common entries for key, dest in packables.items(): - i = SchedProcFSNode._node_digits(key) + i: int = SchedProcFSNode._node_digits(key) self._dyn_attrs[dest][i] = self._build_node(key, nodes[key]) # Build the other nodes @@ -153,13 +165,15 @@ class _SchedDomainFlag: exposed. """ - _INSTANCES = {} + _INSTANCES: Dict['_SchedDomainFlag', '_SchedDomainFlag'] = {} """ Dictionary storing the instances so that they can be compared with ``is`` operator. """ + name: str + _value: Optional[int] - def __new__(cls, name, value, doc=None): + def __new__(cls, name: str, value: Optional[int], doc: Optional[str] = None): self = super().__new__(cls) self.name = name self._value = value @@ -175,7 +189,7 @@ def __hash__(self): return hash((self.name, self._value)) @property - def value(self): + def value(self) -> Optional[int]: value = self._value if value is None: raise AttributeError('The kernel does not expose the sched domain flag values') @@ -183,14 +197,14 @@ def value(self): return value @staticmethod - def check_version(target, logger): + def check_version(target: 'Target', logger: logging.Logger) -> None: """ Check the target and see if its kernel version matches our view of the world """ - parts = target.kernel_version.parts + parts: Tuple[Optional[int], Optional[int], Optional[int]] = target.kernel_version.parts # Checked to be valid from v4.4 # Not saved as a class attribute else it'll be converted to an enum - ref_parts = (4, 4, 0) + ref_parts: Tuple[int, int, int] = (4, 4, 0) if parts < ref_parts: logger.warn( "Sched domain flags are defined for kernels v{} and up, " @@ -212,7 +226,7 @@ class _SchedDomainFlagMeta(type): backward compatibility. """ @property - def _flags(self): + def _flags(self) -> List[Any]: return [ attr for name, attr in self.__dict__.items() @@ -280,10 +294,10 @@ class SchedDomain(SchedProcFSNode): """ Represents a sched domain as seen through procfs """ - def __init__(self, nodes): + def __init__(self, nodes: Dict[str, SchedProcFSNode]): super().__init__(nodes) - flags = self.flags + flags: Union[Set[_SchedDomainFlag], str] = self.flags # Recent kernels now have a space-separated list of flags instead of a # packed bitfield if isinstance(flags, str): @@ -292,8 +306,8 @@ def __init__(self, nodes): for name in flags.split() } else: - def has_flag(flags, flag): - return flags & flag.value == flag.value + def has_flag(flags: Set[_SchedDomainFlag], flag: _SchedDomainFlag): + return any(f.value == flag.value for f in flags) flags = { flag @@ -303,69 +317,79 @@ def has_flag(flags, flag): self.flags = flags -def _select_path(target, paths, name): + +def _select_path(target: 'Target', paths: List[str], name: str) -> str: + """ + select existing file path + """ for p in paths: if target.file_exists(p): return p raise TargetStableError('No {} found. Tried: {}'.format(name, ', '.join(paths))) + class SchedProcFSData(SchedProcFSNode): """ Root class for creating & storing SchedProcFSNode instances """ - _read_depth = 6 + _read_depth: int = 6 @classmethod - def get_data_root(cls, target): + def get_data_root(cls, target: 'Target'): # Location differs depending on kernel version paths = ['/sys/kernel/debug/sched/domains/', '/proc/sys/kernel/sched_domain'] return _select_path(target, paths, "sched_domain debug directory") @staticmethod - def available(target): + def available(target: 'Target') -> bool: + """ + check availability of sched domains + """ try: path = SchedProcFSData.get_data_root(target) except TargetStableError: return False - cpus = target.list_directory(path, as_root=target.is_rooted) + cpus: List[str] = target.list_directory(path, as_root=target.is_rooted) if not cpus: return False # Even if we have a CPU entry, it can be empty (e.g. hotplugged out) # Make sure some data is there for cpu in cpus: - if target.file_exists(target.path.join(path, cpu, "domain0", "flags")): + if target.file_exists(target.path.join(path, cpu, "domain0", "flags") if target.path else ''): return True return False - def __init__(self, target, path=None): + def __init__(self, target: 'Target', path: Optional[str] = None): if path is None: path = SchedProcFSData.get_data_root(target) - procfs = target.read_tree_values(path, depth=self._read_depth) + procfs: Dict[str, 'SchedProcFSNode'] = target.read_tree_values(path, depth=self._read_depth) super(SchedProcFSData, self).__init__(procfs) class SchedModule(Module): + """ + scheduler module + """ + name: str = 'sched' - name = 'sched' - - cpu_sysfs_root = '/sys/devices/system/cpu' + cpu_sysfs_root: str = '/sys/devices/system/cpu' @staticmethod - def probe(target): - logger = logging.getLogger(SchedModule.name) + def probe(target: 'Target') -> bool: + logger: logging.Logger = logging.getLogger(SchedModule.name) SchedDomainFlag.check_version(target, logger) # It makes sense to load this module if at least one of those # functionalities is enabled - schedproc = SchedProcFSData.available(target) - debug = SchedModule.target_has_debug(target) - dmips = any([target.file_exists(SchedModule.cpu_dmips_capacity_path(target, cpu)) - for cpu in target.list_online_cpus()]) + schedproc: bool = SchedProcFSData.available(target) + debug: bool = SchedModule.target_has_debug(target) + dmips: bool = any([target.file_exists(SchedModule.cpu_dmips_capacity_path(target, cpu)) + for cpu in target.list_online_cpus()]) logger.info("Scheduler sched_domain procfs entries %s", "found" if schedproc else "not found") @@ -376,16 +400,17 @@ def probe(target): return schedproc or debug or dmips - def __init__(self, target): + def __init__(self, target: 'Target'): super().__init__(target) @classmethod - def get_sched_features_path(cls, target): + def get_sched_features_path(cls, target: 'Target') -> str: # Location differs depending on kernel version - paths = ['/sys/kernel/debug/sched/features', '/sys/kernel/debug/sched_features'] + paths: List[str] = ['/sys/kernel/debug/sched/features', '/sys/kernel/debug/sched_features'] return _select_path(target, paths, "sched_features file") - def get_kernel_attributes(self, matching=None, check_exit_code=True): + def get_kernel_attributes(self, matching: Optional[str] = None, + check_exit_code: bool = True) -> Dict[str, Union[int, bool]]: """ Get the value of scheduler attributes. @@ -406,21 +431,22 @@ def get_kernel_attributes(self, matching=None, check_exit_code=True): command = 'sched_get_kernel_attributes {}'.format( matching if matching else '' ) - output = self.target._execute_util(command, as_root=self.target.is_rooted, - check_exit_code=check_exit_code) - result = {} + output: str = self.target._execute_util(command, as_root=self.target.is_rooted, + check_exit_code=check_exit_code) + result: Dict[str, Union[int, bool]] = {} for entry in output.strip().split('\n'): if ':' not in entry: continue - path, value = entry.strip().split(':', 1) - if value in ['0', '1']: - value = bool(int(value)) - elif value.isdigit(): - value = int(value) + path, value_s = entry.strip().split(':', 1) + if value_s in ['0', '1']: + value: Union[int, bool] = bool(int(value_s)) + elif value_s.isdigit(): + value = int(value_s) result[path] = value return result - def set_kernel_attribute(self, attr, value, verify=True): + def set_kernel_attribute(self, attr: str, value: Union[bool, int, str], + verify: bool = True) -> None: """ Set the value of a scheduler attribute. @@ -434,11 +460,14 @@ def set_kernel_attribute(self, attr, value, verify=True): value = '1' if value else '0' elif isinstance(value, int): value = str(value) - path = '/proc/sys/kernel/sched_' + attr + path: str = '/proc/sys/kernel/sched_' + attr self.target.write_value(path, value, verify) @classmethod - def target_has_debug(cls, target): + def target_has_debug(cls, target: 'Target') -> bool: + """ + True if target has SCHED_DEBUG config set and has sched features + """ if target.config.get('SCHED_DEBUG') != 'y': return False @@ -448,23 +477,23 @@ def target_has_debug(cls, target): except TargetStableError: return False - def get_features(self): + def get_features(self) -> Dict[str, bool]: """ Get the status of each sched feature :returns: a dictionary of features and their "is enabled" status """ - feats = self.target.read_value(self.get_sched_features_path(self.target)) - features = {} + feats: str = self.target.read_value(self.get_sched_features_path(self.target)) + features: Dict[str, bool] = {} for feat in feats.split(): - value = True + value: bool = True if feat.startswith('NO'): feat = feat.replace('NO_', '', 1) value = False features[feat] = value return features - def set_feature(self, feature, enable, verify=True): + def set_feature(self, feature: str, enable: bool, verify: bool = True): """ Set the status of a specified scheduler feature @@ -475,63 +504,63 @@ def set_feature(self, feature, enable, verify=True): :raise RuntimeError: if the specified feature cannot be set """ feature = feature.upper() - feat_value = feature + feat_value: str = feature if not boolean(enable): feat_value = 'NO_' + feat_value self.target.write_value(self.get_sched_features_path(self.target), feat_value, verify=False) if not verify: return - msg = 'Failed to set {}, feature not supported?'.format(feat_value) - features = self.get_features() - feat_value = features.get(feature, not enable) - if feat_value != enable: + msg: str = 'Failed to set {}, feature not supported?'.format(feat_value) + features: Dict[str, bool] = self.get_features() + feat_ = features.get(feature, not enable) + if feat_ != enable: raise RuntimeError(msg) - def get_cpu_sd_info(self, cpu): + def get_cpu_sd_info(self, cpu: int) -> SchedProcFSData: """ :returns: An object view of the sched_domain debug directory of 'cpu' """ path = self.target.path.join( SchedProcFSData.get_data_root(self.target), "cpu{}".format(cpu) - ) + ) if self.target.path else '' return SchedProcFSData(self.target, path) - def get_sd_info(self): + def get_sd_info(self) -> SchedProcFSData: """ :returns: An object view of the entire sched_domain debug directory """ return SchedProcFSData(self.target) - def get_capacity(self, cpu): + def get_capacity(self, cpu: int) -> int: """ :returns: The capacity of 'cpu' """ return self.get_capacities()[cpu] @memoized - def has_em(self, cpu, sd=None): + def has_em(self, cpu: int, sd: Optional[SchedProcFSData] = None) -> bool: """ :returns: Whether energy model data is available for 'cpu' """ if not sd: sd = self.get_cpu_sd_info(cpu) - return sd.procfs["domain0"].get("group0", {}).get("energy", {}).get("cap_states") != None + return sd.procfs["domain0"].get("group0", {}).get("energy", {}).get("cap_states") is not None @classmethod - def cpu_dmips_capacity_path(cls, target, cpu): + def cpu_dmips_capacity_path(cls, target: 'Target', cpu: int): """ :returns: The target sysfs path where the dmips capacity data should be """ return target.path.join( cls.cpu_sysfs_root, - 'cpu{}/cpu_capacity'.format(cpu)) + 'cpu{}/cpu_capacity'.format(cpu)) if target.path else '' @memoized - def has_dmips_capacity(self, cpu): + def has_dmips_capacity(self, cpu: int) -> bool: """ :returns: Whether dmips capacity data is available for 'cpu' """ @@ -540,21 +569,21 @@ def has_dmips_capacity(self, cpu): ) @memoized - def get_em_capacity(self, cpu, sd=None): + def get_em_capacity(self, cpu: int, sd: Optional[SchedProcFSData] = None) -> int: """ :returns: The maximum capacity value exposed by the EAS energy model """ if not sd: sd = self.get_cpu_sd_info(cpu) - cap_states = sd.domains[0].groups[0].energy.cap_states - cap_states_list = cap_states.split('\t') - num_cap_states = sd.domains[0].groups[0].energy.nr_cap_states - max_cap_index = -1 * int(len(cap_states_list) / num_cap_states) + cap_states: str = sd.domains[0].groups[0].energy.cap_states + cap_states_list: List[str] = cap_states.split('\t') + num_cap_states: int = sd.domains[0].groups[0].energy.nr_cap_states + max_cap_index: int = -1 * int(len(cap_states_list) / num_cap_states) return int(cap_states_list[max_cap_index]) @memoized - def get_dmips_capacity(self, cpu): + def get_dmips_capacity(self, cpu: int) -> int: """ :returns: The capacity value generated from the capacity-dmips-mhz DT entry """ @@ -562,7 +591,7 @@ def get_dmips_capacity(self, cpu): self.cpu_dmips_capacity_path(self.target, cpu), int ) - def get_capacities(self, default=None): + def get_capacities(self, default: Optional[int] = None) -> Dict[int, int]: """ :param default: Default capacity value to find if no data is found in procfs @@ -572,41 +601,41 @@ def get_capacities(self, default=None): :raises RuntimeError: Raised when no capacity information is found and 'default' is None """ - cpus = self.target.list_online_cpus() + cpus: List[int] = self.target.list_online_cpus() - capacities = {} + capacities: Dict[int, int] = {} for cpu in cpus: if self.has_dmips_capacity(cpu): capacities[cpu] = self.get_dmips_capacity(cpu) - missing_cpus = set(cpus).difference(capacities.keys()) + missing_cpus: Set[int] = set(cpus).difference(capacities.keys()) if not missing_cpus: return capacities if not SchedProcFSData.available(self.target): - if default != None: - capacities.update({cpu : default for cpu in missing_cpus}) + if default is not None: + capacities.update({cpu: cast(int, default) for cpu in missing_cpus}) return capacities else: raise RuntimeError( 'No capacity data for cpus {}'.format(sorted(missing_cpus))) - sd_info = self.get_sd_info() + sd_info: SchedProcFSData = self.get_sd_info() for cpu in missing_cpus: if self.has_em(cpu, sd_info.cpus[cpu]): capacities[cpu] = self.get_em_capacity(cpu, sd_info.cpus[cpu]) else: - if default != None: - capacities[cpu] = default + if default is not None: + capacities[cpu] = cast(int, default) else: raise RuntimeError('No capacity data for cpu{}'.format(cpu)) return capacities @memoized - def get_hz(self): + def get_hz(self) -> int: """ :returns: The scheduler tick frequency on the target """ - return int(self.target.config.get('CONFIG_HZ', strict=True)) + return int(cast(str, self.target.config.get('CONFIG_HZ', strict=True))) diff --git a/devlib/module/thermal.py b/devlib/module/thermal.py index d23739ea2..1a288c5e4 100644 --- a/devlib/module/thermal.py +++ b/devlib/module/thermal.py @@ -1,4 +1,4 @@ -# Copyright 2015-2018 ARM Limited +# Copyright 2015-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,137 +18,174 @@ from devlib.module import Module from devlib.exception import TargetStableCalledProcessError +from typing import (TYPE_CHECKING, Dict, Match, Optional, + Tuple, List) +if TYPE_CHECKING: + from devlib.target import Target + class TripPoint(object): - def __init__(self, zone, _id): + """ + Trip points are predefined temperature thresholds within a thermal zone. When the temperature reaches these points, + specific actions are triggered to manage the system's thermal state. There are typically three types of trip points: + + Active Trip Points: Trigger active cooling mechanisms like fans when the temperature exceeds a certain threshold. + Passive Trip Points: Initiate passive cooling strategies, such as reducing the processor's clock speed, to lower the temperature. + Critical Trip Points: Indicate a critical temperature level that requires immediate action, such as shutting down the system to prevent damage + """ + def __init__(self, zone: 'ThermalZone', _id: str): self._id = _id self.zone = zone - self.temp_node = 'trip_point_' + _id + '_temp' - self.type_node = 'trip_point_' + _id + '_type' + self.temp_node: str = 'trip_point_' + _id + '_temp' + self.type_node: str = 'trip_point_' + _id + '_type' @property - def target(self): + def target(self) -> 'Target': + """ + target of the trip point + """ return self.zone.target @asyn.asyncf - async def get_temperature(self): + async def get_temperature(self) -> int: """Returns the currently configured temperature of the trip point""" - temp_file = self.target.path.join(self.zone.path, self.temp_node) + temp_file: str = self.target.path.join(self.zone.path, self.temp_node) if self.target.path else '' return await self.target.read_int.asyn(temp_file) @asyn.asyncf - async def set_temperature(self, temperature): - temp_file = self.target.path.join(self.zone.path, self.temp_node) + async def set_temperature(self, temperature: int) -> None: + """ + set temperature threshold for the trip point + """ + temp_file: str = self.target.path.join(self.zone.path, self.temp_node) if self.target.path else '' await self.target.write_value.asyn(temp_file, temperature) @asyn.asyncf - async def get_type(self): + async def get_type(self) -> str: """Returns the type of trip point""" - type_file = self.target.path.join(self.zone.path, self.type_node) + type_file: str = self.target.path.join(self.zone.path, self.type_node) if self.target.path else '' return await self.target.read_value.asyn(type_file) + class ThermalZone(object): - def __init__(self, target, root, _id): + """ + A thermal zone is a logical collection of interfaces to temperature sensors, trip points, + thermal property information, and thermal controls. These zones help manage the temperature + of various components within a system, such as CPUs, GPUs, and other hardware. + """ + def __init__(self, target: 'Target', root: str, _id: str): self.target = target self.name = 'thermal_zone' + _id - self.path = target.path.join(root, self.name) - self.trip_points = {} - self.type = self.target.read_value(self.target.path.join(self.path, 'type')) + self.path = target.path.join(root, self.name) if target.path else '' + self.trip_points: Dict[int, TripPoint] = {} + self.type: str = self.target.read_value(self.target.path.join(self.path, 'type') if self.target.path else '') for entry in self.target.list_directory(self.path, as_root=target.is_rooted): - re_match = re.match('^trip_point_([0-9]+)_temp', entry) + re_match: Optional[Match[str]] = re.match('^trip_point_([0-9]+)_temp', entry) if re_match is not None: self._add_trip_point(re_match.group(1)) - def _add_trip_point(self, _id): + def _add_trip_point(self, _id: str) -> None: + """ + add a trip point to the thermal zone + """ self.trip_points[int(_id)] = TripPoint(self, _id) @asyn.asyncf - async def is_enabled(self): + async def is_enabled(self) -> bool: """Returns a boolean representing the 'mode' of the thermal zone""" - value = await self.target.read_value.asyn(self.target.path.join(self.path, 'mode')) + value: str = await self.target.read_value.asyn(self.target.path.join(self.path, 'mode') if self.target.path else '') return value == 'enabled' @asyn.asyncf - async def set_enabled(self, enabled=True): + async def set_enabled(self, enabled: bool = True) -> None: + """ + enable or disable the thermal zone + """ value = 'enabled' if enabled else 'disabled' - await self.target.write_value.asyn(self.target.path.join(self.path, 'mode'), value) + await self.target.write_value.asyn(self.target.path.join(self.path, 'mode') if self.target.path else '', value) @asyn.asyncf - async def get_temperature(self): + async def get_temperature(self) -> int: """Returns the temperature of the thermal zone""" - sysfs_temperature_file = self.target.path.join(self.path, 'temp') + sysfs_temperature_file = self.target.path.join(self.path, 'temp') if self.target.path else '' return await self.target.read_int.asyn(sysfs_temperature_file) @asyn.asyncf - async def get_policy(self): + async def get_policy(self) -> str: """Returns the policy of the thermal zone""" - temp_file = self.target.path.join(self.path, 'policy') + temp_file = self.target.path.join(self.path, 'policy') if self.target.path else '' return await self.target.read_value.asyn(temp_file) @asyn.asyncf - async def set_policy(self, policy): + async def set_policy(self, policy: str) -> None: """ Sets the policy of the thermal zone :params policy: Thermal governor name :type policy: str """ - await self.target.write_value.asyn(self.target.path.join(self.path, 'policy'), policy) + await self.target.write_value.asyn(self.target.path.join(self.path, 'policy') if self.target.path else '', policy) @asyn.asyncf - async def get_offset(self): + async def get_offset(self) -> int: """Returns the temperature offset of the thermal zone""" - offset_file = self.target.path.join(self.path, 'offset') + offset_file: str = self.target.path.join(self.path, 'offset') if self.target.path else '' return await self.target.read_value.asyn(offset_file) @asyn.asyncf - async def set_offset(self, offset): + async def set_offset(self, offset: int) -> None: """ Sets the temperature offset in milli-degrees of the thermal zone :params offset: Temperature offset in milli-degrees - :type policy: int + :type offset: int """ - await self.target.write_value.asyn(self.target.path.join(self.path, 'offset'), policy) + await self.target.write_value.asyn(self.target.path.join(self.path, 'offset') if self.target.path else '', offset) @asyn.asyncf - async def set_emul_temp(self, offset): + async def set_emul_temp(self, offset: int) -> None: """ Sets the emulated temperature in milli-degrees of the thermal zone :params offset: Emulated temperature in milli-degrees - :type policy: int + :type offset: int """ - await self.target.write_value.asyn(self.target.path.join(self.path, 'emul_temp'), policy) + await self.target.write_value.asyn(self.target.path.join(self.path, 'emul_temp') if self.target.path else '', offset) @asyn.asyncf - async def get_available_policies(self): + async def get_available_policies(self) -> str: """Returns the policies available for the thermal zone""" - temp_file = self.target.path.join(self.path, 'available_policies') + temp_file: str = self.target.path.join(self.path, 'available_policies') if self.target.path else '' return await self.target.read_value.asyn(temp_file) + class ThermalModule(Module): + """ + The /sys/class/thermal directory in Linux provides a sysfs interface for thermal management. + This directory contains subdirectories and files that represent thermal zones and cooling devices, + allowing users and applications to monitor and manage system temperatures. + """ name = 'thermal' thermal_root = '/sys/class/thermal' @staticmethod - def probe(target): - + def probe(target: 'Target') -> bool: if target.file_exists(ThermalModule.thermal_root): return True + return False - def __init__(self, target): + def __init__(self, target: 'Target'): super(ThermalModule, self).__init__(target) - self.logger = logging.getLogger(self.name) + self.logger: logging.Logger = logging.getLogger(self.name) self.logger.debug('Initialized [%s] module', self.name) - self.zones = {} - self.cdevs = [] + self.zones: Dict[int, ThermalZone] = {} + self.cdevs: List = [] for entry in target.list_directory(self.thermal_root): - re_match = re.match('^(thermal_zone|cooling_device)([0-9]+)', entry) + re_match: Optional[Match[str]] = re.match('^(thermal_zone|cooling_device)([0-9]+)', entry) if not re_match: self.logger.warning('unknown thermal entry: %s', entry) continue @@ -159,16 +196,16 @@ def __init__(self, target): # TODO pass - def _add_thermal_zone(self, _id): + def _add_thermal_zone(self, _id: str) -> None: self.zones[int(_id)] = ThermalZone(self.target, self.thermal_root, _id) - def disable_all_zones(self): + def disable_all_zones(self) -> None: """Disables all the thermal zones in the target""" for zone in self.zones.values(): zone.set_enabled(False) @asyn.asyncf - async def get_all_temperatures(self, error='raise'): + async def get_all_temperatures(self, error: str = 'raise') -> Dict[str, int]: """ Returns dictionary with current reading of all thermal zones. @@ -178,10 +215,10 @@ async def get_all_temperatures(self, error='raise'): :returns: a dictionary in the form: {tz_type:temperature} """ - async def get_temperature_noexcep(item): + async def get_temperature_noexcep(item: Tuple[str, ThermalZone]) -> Optional[int]: tzid, tz = item try: - temperature = await tz.get_temperature.asyn() + temperature: int = await tz.get_temperature.asyn() except TargetStableCalledProcessError as e: if error == 'raise': raise e diff --git a/devlib/module/vexpress.py b/devlib/module/vexpress.py index c597747be..f44d99ab1 100644 --- a/devlib/module/vexpress.py +++ b/devlib/module/vexpress.py @@ -1,5 +1,5 @@ # -# Copyright 2015-2018 ARM Limited +# Copyright 2015-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,25 +25,35 @@ from devlib.utils.serial_port import open_serial_connection, pulse_dtr, write_characters from devlib.utils.uefi import UefiMenu, UefiConfig from devlib.utils.uboot import UbootMenu +from devlib.platform.arm import VersatileExpressPlatform +# pylint: disable=ungrouped-imports +try: + from pexpect import fdpexpect +# pexpect < 4.0.0 does not have fdpexpect module +except ImportError: + import fdpexpect # type:ignore +from typing import TYPE_CHECKING, cast, Optional, Dict, Union, Any +if TYPE_CHECKING: + from devlib.target import Target -OLD_AUTOSTART_MESSAGE = 'Press Enter to stop auto boot...' -AUTOSTART_MESSAGE = 'Hit any key to stop autoboot:' -POWERUP_MESSAGE = 'Powering up system...' -DEFAULT_MCC_PROMPT = 'Cmd>' +OLD_AUTOSTART_MESSAGE: str = 'Press Enter to stop auto boot...' +AUTOSTART_MESSAGE: str = 'Hit any key to stop autoboot:' +POWERUP_MESSAGE: str = 'Powering up system...' +DEFAULT_MCC_PROMPT: str = 'Cmd>' class VexpressDtrHardReset(HardRestModule): - name = 'vexpress-dtr' - stage = 'early' + name: str = 'vexpress-dtr' + stage: str = 'early' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: return True - def __init__(self, target, port='/dev/ttyS0', baudrate=115200, - mcc_prompt=DEFAULT_MCC_PROMPT, timeout=300): + def __init__(self, target: 'Target', port: str = '/dev/ttyS0', baudrate: int = 115200, + mcc_prompt: str = DEFAULT_MCC_PROMPT, timeout: int = 300): super(VexpressDtrHardReset, self).__init__(target) self.port = port self.baudrate = baudrate @@ -59,7 +69,7 @@ def __call__(self): with open_serial_connection(port=self.port, baudrate=self.baudrate, timeout=self.timeout, - init_dtr=0, + init_dtr=False, get_conn=True) as (_, conn): pulse_dtr(conn, state=True, duration=0.1) # TRM specifies a pulse of >=100ms @@ -70,13 +80,13 @@ class VexpressReboottxtHardReset(HardRestModule): stage = 'early' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: return True - def __init__(self, target, - port='/dev/ttyS0', baudrate=115200, - path='/media/VEMSD', - mcc_prompt=DEFAULT_MCC_PROMPT, timeout=30, short_delay=1): + def __init__(self, target: 'Target', + port: str = '/dev/ttyS0', baudrate: int = 115200, + path: str = '/media/VEMSD', + mcc_prompt: str = DEFAULT_MCC_PROMPT, timeout: int = 30, short_delay: int = 1): super(VexpressReboottxtHardReset, self).__init__(target) self.port = port self.baudrate = baudrate @@ -98,7 +108,7 @@ def __call__(self): with open_serial_connection(port=self.port, baudrate=self.baudrate, timeout=self.timeout, - init_dtr=0) as tty: + init_dtr=False) as tty: wait_for_vemsd(self.path, tty, self.mcc_prompt, self.short_delay) with open(self.filepath, 'w'): pass @@ -109,13 +119,13 @@ class VexpressBootModule(BootModule): stage = 'early' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: return True - def __init__(self, target, uefi_entry=None, - port='/dev/ttyS0', baudrate=115200, - mcc_prompt=DEFAULT_MCC_PROMPT, - timeout=120, short_delay=1): + def __init__(self, target: 'Target', uefi_entry: Optional[str] = None, + port: str = '/dev/ttyS0', baudrate: int = 115200, + mcc_prompt: str = DEFAULT_MCC_PROMPT, + timeout: int = 120, short_delay: int = 1): super(VexpressBootModule, self).__init__(target) self.port = port self.baudrate = baudrate @@ -128,18 +138,24 @@ def __call__(self): with open_serial_connection(port=self.port, baudrate=self.baudrate, timeout=self.timeout, - init_dtr=0) as tty: + init_dtr=False) as tty: self.get_through_early_boot(tty) self.perform_boot_sequence(tty) self.wait_for_shell_prompt(tty) - def perform_boot_sequence(self, tty): + def perform_boot_sequence(self, tty: fdpexpect.fdspawn) -> None: + """ + boot up the vexpress + """ raise NotImplementedError() - def get_through_early_boot(self, tty): + def get_through_early_boot(self, tty: fdpexpect.fdspawn) -> None: + """ + do the things necessary during early boot + """ self.logger.debug('Establishing initial state...') tty.sendline('') - i = tty.expect([AUTOSTART_MESSAGE, OLD_AUTOSTART_MESSAGE, POWERUP_MESSAGE, self.mcc_prompt]) + i: int = tty.expect([AUTOSTART_MESSAGE, OLD_AUTOSTART_MESSAGE, POWERUP_MESSAGE, self.mcc_prompt]) if i == 3: self.logger.debug('Saw MCC prompt.') time.sleep(self.short_delay) @@ -154,13 +170,13 @@ def get_through_early_boot(self, tty): tty.sendline('reboot') tty.sendline('reset') - def get_uefi_menu(self, tty): + def get_uefi_menu(self, tty: fdpexpect.fdspawn) -> UefiMenu: menu = UefiMenu(tty) self.logger.debug('Waiting for UEFI menu...') menu.wait(timeout=self.timeout) return menu - def wait_for_shell_prompt(self, tty): + def wait_for_shell_prompt(self, tty: fdpexpect.fdspawn) -> None: self.logger.debug('Waiting for the shell prompt.') tty.expect(self.target.shell_prompt, timeout=self.timeout) # This delay is needed to allow the platform some time to finish @@ -171,17 +187,17 @@ def wait_for_shell_prompt(self, tty): class VexpressUefiBoot(VexpressBootModule): - name = 'vexpress-uefi' + name: str = 'vexpress-uefi' - def __init__(self, target, uefi_entry, - image, fdt, bootargs, initrd, + def __init__(self, target: 'Target', uefi_entry: Optional[str], + image: str, fdt: str, bootargs: str, initrd: str, *args, **kwargs): - super(VexpressUefiBoot, self).__init__(target, uefi_entry=uefi_entry, + super(VexpressUefiBoot, self).__init__(target, uefi_entry, *args, **kwargs) - self.uefi_config = self._create_config(image, fdt, bootargs, initrd) + self.uefi_config: UefiConfig = self._create_config(image, fdt, bootargs, initrd) - def perform_boot_sequence(self, tty): - menu = self.get_uefi_menu(tty) + def perform_boot_sequence(self, tty: fdpexpect.fdspawn) -> None: + menu: UefiMenu = self.get_uefi_menu(tty) try: menu.select(self.uefi_entry) except LookupError: @@ -190,8 +206,8 @@ def perform_boot_sequence(self, tty): menu.create_entry(self.uefi_entry, self.uefi_config) menu.select(self.uefi_entry) - def _create_config(self, image, fdt, bootargs, initrd): # pylint: disable=R0201 - config_dict = { + def _create_config(self, image: str, fdt: str, bootargs: str, initrd: str): # pylint: disable=R0201 + config_dict: Dict[str, Union[str, bool]] = { 'image_name': image, 'image_args': bootargs, 'initrd': initrd, @@ -208,21 +224,21 @@ def _create_config(self, image, fdt, bootargs, initrd): # pylint: disable=R0201 class VexpressUefiShellBoot(VexpressBootModule): - name = 'vexpress-uefi-shell' + name: str = 'vexpress-uefi-shell' # pylint: disable=keyword-arg-before-vararg - def __init__(self, target, uefi_entry='^Shell$', - efi_shell_prompt='Shell>', - image='kernel', bootargs=None, + def __init__(self, target: 'Target', uefi_entry: Optional[str] = '^Shell$', + efi_shell_prompt: str = 'Shell>', + image: str = 'kernel', bootargs: Optional[str] = None, *args, **kwargs): - super(VexpressUefiShellBoot, self).__init__(target, uefi_entry=uefi_entry, + super(VexpressUefiShellBoot, self).__init__(target, uefi_entry, *args, **kwargs) self.efi_shell_prompt = efi_shell_prompt self.image = image self.bootargs = bootargs - def perform_boot_sequence(self, tty): - menu = self.get_uefi_menu(tty) + def perform_boot_sequence(self, tty: fdpexpect.fdspawn) -> None: + menu: UefiMenu = self.get_uefi_menu(tty) try: menu.select(self.uefi_entry) except LookupError: @@ -239,15 +255,15 @@ def perform_boot_sequence(self, tty): class VexpressUBoot(VexpressBootModule): - name = 'vexpress-u-boot' + name: str = 'vexpress-u-boot' # pylint: disable=keyword-arg-before-vararg - def __init__(self, target, env=None, + def __init__(self, target: 'Target', env: Optional[Dict] = None, *args, **kwargs): super(VexpressUBoot, self).__init__(target, *args, **kwargs) self.env = env - def perform_boot_sequence(self, tty): + def perform_boot_sequence(self, tty: fdpexpect.fdspawn) -> None: if self.env is None: return # Will boot automatically @@ -261,13 +277,13 @@ def perform_boot_sequence(self, tty): class VexpressBootmon(VexpressBootModule): - name = 'vexpress-bootmon' + name: str = 'vexpress-bootmon' # pylint: disable=keyword-arg-before-vararg - def __init__(self, target, - image, fdt, initrd, bootargs, - uses_bootscript=False, - bootmon_prompt='>', + def __init__(self, target: 'Target', + image: str, fdt: str, initrd: str, bootargs: str, + uses_bootscript: bool = False, + bootmon_prompt: str = '>', *args, **kwargs): super(VexpressBootmon, self).__init__(target, *args, **kwargs) self.image = image @@ -277,7 +293,7 @@ def __init__(self, target, self.uses_bootscript = uses_bootscript self.bootmon_prompt = bootmon_prompt - def perform_boot_sequence(self, tty): + def perform_boot_sequence(self, tty: fdpexpect.fdspawn) -> None: if self.uses_bootscript: return # Will boot automatically @@ -286,7 +302,7 @@ def perform_boot_sequence(self, tty): with open_serial_connection(port=self.port, baudrate=self.baudrate, timeout=self.timeout, - init_dtr=0) as tty_conn: + init_dtr=False) as tty_conn: write_characters(tty_conn, 'fl linux fdt {}'.format(self.fdt)) write_characters(tty_conn, 'fl linux initrd {}'.format(self.initrd)) write_characters(tty_conn, 'fl linux boot {} {}'.format(self.image, @@ -295,8 +311,8 @@ def perform_boot_sequence(self, tty): class VersatileExpressFlashModule(FlashModule): - name = 'vexpress-vemsd' - description = """ + name: str = 'vexpress-vemsd' + description: str = """ Enables flashing of kernels and firmware to ARM Versatile Express devices. This modules enables flashing of image bundles or individual images to ARM @@ -311,31 +327,34 @@ class VersatileExpressFlashModule(FlashModule): """ - stage = 'early' + stage: str = 'early' @staticmethod - def probe(target): + def probe(target: 'Target') -> bool: if not target.has('hard_reset'): return False return True - def __init__(self, target, vemsd_mount, mcc_prompt=DEFAULT_MCC_PROMPT, timeout=30, short_delay=1): + def __init__(self, target: 'Target', vemsd_mount: str, + mcc_prompt: str = DEFAULT_MCC_PROMPT, timeout: int = 30, short_delay: int = 1): super(VersatileExpressFlashModule, self).__init__(target) self.vemsd_mount = vemsd_mount self.mcc_prompt = mcc_prompt self.timeout = timeout self.short_delay = short_delay - def __call__(self, image_bundle=None, images=None, bootargs=None, connect=True): - self.target.hard_reset() - with open_serial_connection(port=self.target.platform.serial_port, - baudrate=self.target.platform.baudrate, + def __call__(self, image_bundle: Optional[str] = None, + images: Optional[Dict[str, str]] = None, + bootargs: Any = None, connect: bool = True): + cast(HardRestModule, self.target.hard_reset)() + with open_serial_connection(port=cast(VersatileExpressPlatform, self.target.platform).serial_port, + baudrate=cast(VersatileExpressPlatform, self.target.platform).baudrate, timeout=self.timeout, - init_dtr=0) as tty: + init_dtr=False) as tty: # pylint: disable=no-member - i = tty.expect([self.mcc_prompt, AUTOSTART_MESSAGE, OLD_AUTOSTART_MESSAGE]) + i: int = cast(fdpexpect.fdspawn, tty).expect([self.mcc_prompt, AUTOSTART_MESSAGE, OLD_AUTOSTART_MESSAGE]) if i: - tty.sendline('') # pylint: disable=no-member + cast(fdpexpect.fdspawn, tty).sendline('') # pylint: disable=no-member wait_for_vemsd(self.vemsd_mount, tty, self.mcc_prompt, self.short_delay) try: if image_bundle: @@ -344,20 +363,20 @@ def __call__(self, image_bundle=None, images=None, bootargs=None, connect=True): self._overlay_images(images) os.system('sync') except (IOError, OSError) as e: - msg = 'Could not deploy images to {}; got: {}' + msg: str = 'Could not deploy images to {}; got: {}' raise TargetStableError(msg.format(self.vemsd_mount, e)) - self.target.boot() + cast(BootModule, self.target.boot)() if connect: self.target.connect(timeout=30) - def _deploy_image_bundle(self, bundle): + def _deploy_image_bundle(self, bundle: str) -> None: self.logger.debug('Validating {}'.format(bundle)) validate_image_bundle(bundle) self.logger.debug('Extracting {} into {}...'.format(bundle, self.vemsd_mount)) with tarfile.open(bundle) as tar: safe_extract(tar, self.vemsd_mount) - def _overlay_images(self, images): + def _overlay_images(self, images: Dict[str, str]): for dest, src in images.items(): dest = os.path.join(self.vemsd_mount, dest) self.logger.debug('Copying {} to {}'.format(src, dest)) @@ -366,7 +385,7 @@ def _overlay_images(self, images): # utility functions -def validate_image_bundle(bundle): +def validate_image_bundle(bundle: str) -> None: if not tarfile.is_tarfile(bundle): raise HostError('Image bundle {} does not appear to be a valid TAR file.'.format(bundle)) with tarfile.open(bundle) as tar: @@ -380,9 +399,11 @@ def validate_image_bundle(bundle): raise HostError(msg.format(bundle)) -def wait_for_vemsd(vemsd_mount, tty, mcc_prompt=DEFAULT_MCC_PROMPT, short_delay=1, retries=3): - attempts = 1 + retries - path = os.path.join(vemsd_mount, 'config.txt') +def wait_for_vemsd(vemsd_mount: str, tty: fdpexpect.fdspawn, + mcc_prompt: str = DEFAULT_MCC_PROMPT, short_delay: int = 1, + retries: int = 3) -> None: + attempts: int = 1 + retries + path: str = os.path.join(vemsd_mount, 'config.txt') if os.path.exists(path): return for _ in range(attempts): diff --git a/devlib/platform/__init__.py b/devlib/platform/__init__.py index 205b5c624..94ffadcc2 100644 --- a/devlib/platform/__init__.py +++ b/devlib/platform/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,11 @@ # import logging +from typing import Optional, List, TYPE_CHECKING, cast, Dict +from devlib.module import Module +if TYPE_CHECKING: + from devlib.target import Target, AndroidTarget + from devlib.utils.types import caseless_string BIG_CPUS = ['A15', 'A57', 'A72', 'A73'] @@ -22,34 +27,37 @@ class Platform(object): @property - def number_of_clusters(self): + def number_of_clusters(self) -> int: return len(set(self.core_clusters)) def __init__(self, - name=None, - core_names=None, - core_clusters=None, - big_core=None, - model=None, - modules=None, + name: Optional[str] = None, + core_names: Optional[List['caseless_string']] = None, + core_clusters: Optional[List[int]] = None, + big_core: Optional[str] = None, + model: Optional[str] = None, + modules: Optional[List[Dict[str, Dict]]] = None, ): self.name = name self.core_names = core_names or [] self.core_clusters = core_clusters or [] self.big_core = big_core - self.little_core = None + self.little_core: Optional[caseless_string] = None self.model = model self.modules = modules or [] self.logger = logging.getLogger(self.name) if not self.core_clusters and self.core_names: self._set_core_clusters_from_core_names() - def init_target_connection(self, target): + def init_target_connection(self, target: 'Target') -> None: + """ + do platform specific initialization for the connection + """ # May be ovewritten by subclasses to provide target-specific # connection initialisation. pass - def update_from_target(self, target): + def update_from_target(self, target: 'Target') -> None: if not self.core_names: self.core_names = target.cpuinfo.cpu_names self._set_core_clusters_from_core_names() @@ -63,25 +71,28 @@ def update_from_target(self, target): self.name = self.model self._validate() - def setup(self, target): + def setup(self, target: 'Target') -> None: + """ + Platform specific setup + """ # May be overwritten by subclasses to provide platform-specific # setup procedures. pass - def _set_core_clusters_from_core_names(self): + def _set_core_clusters_from_core_names(self) -> None: self.core_clusters = [] - clusters = [] + clusters: List[str] = [] for cn in self.core_names: if cn not in clusters: clusters.append(cn) self.core_clusters.append(clusters.index(cn)) - def _set_model_from_target(self, target): + def _set_model_from_target(self, target: 'Target'): if target.os == 'android': try: - self.model = target.getprop(prop='ro.product.device') + self.model = cast('AndroidTarget', target).getprop(prop='ro.product.device') except KeyError: - self.model = target.getprop('ro.product.model') + self.model = cast('AndroidTarget', target).getprop('ro.product.model') elif target.file_exists("/proc/device-tree/model"): # There is currently no better way to do this cross platform. # ARM does not have dmidecode @@ -95,21 +106,21 @@ def _set_model_from_target(self, target): except Exception: # pylint: disable=broad-except pass # this is best-effort - def _identify_big_core(self): + def _identify_big_core(self) -> 'caseless_string': for core in self.core_names: if core.upper() in BIG_CPUS: return core big_idx = self.core_clusters.index(max(self.core_clusters)) return self.core_names[big_idx] - def _validate(self): + def _validate(self) -> None: if len(self.core_names) != len(self.core_clusters): raise ValueError('core_names and core_clusters are of different lengths.') if self.big_core and self.number_of_clusters != 2: raise ValueError('attempting to set big_core on non-big.LITTLE device. ' '(number of clusters is not 2)') if self.big_core and self.big_core not in self.core_names: - message = 'Invalid big_core value "{}"; must be in [{}]' + message: str = 'Invalid big_core value "{}"; must be in [{}]' raise ValueError(message.format(self.big_core, ', '.join(set(self.core_names)))) if self.big_core: diff --git a/devlib/platform/arm.py b/devlib/platform/arm.py index 6499ec88e..e97bcf31e 100644 --- a/devlib/platform/arm.py +++ b/devlib/platform/arm.py @@ -1,4 +1,4 @@ -# Copyright 2015-2024 ARM Limited +# Copyright 2015-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,38 +25,53 @@ from devlib.utils.csvutil import csvreader, csvwriter from devlib.utils.serial_port import open_serial_connection +# pylint: disable=ungrouped-imports +try: + from pexpect import fdpexpect +# pexpect < 4.0.0 does not have fdpexpect module +except ImportError: + import fdpexpect # type:ignore + +from typing import (cast, TYPE_CHECKING, Match, Optional, + List, Dict, OrderedDict) +from devlib.utils.types import caseless_string +from devlib.utils.annotation_helpers import AdbUserConnectionSettings +from signal import Signals +if TYPE_CHECKING: + from devlib.target import Target + class VersatileExpressPlatform(Platform): - def __init__(self, name, # pylint: disable=too-many-locals + def __init__(self, name: str, # pylint: disable=too-many-locals - core_names=None, - core_clusters=None, - big_core=None, - model=None, - modules=None, + core_names: Optional[List[caseless_string]] = None, + core_clusters: Optional[List[int]] = None, + big_core: Optional[str] = None, + model: Optional[str] = None, + modules: Optional[List[Dict[str, Dict]]] = None, # serial settings - serial_port='/dev/ttyS0', - baudrate=115200, + serial_port: str = '/dev/ttyS0', + baudrate: int = 115200, # VExpress MicroSD mount point - vemsd_mount=None, + vemsd_mount: Optional[str] = None, # supported: dtr, reboottxt - hard_reset_method=None, + hard_reset_method: Optional[str] = None, # supported: uefi, uefi-shell, u-boot, bootmon - bootloader=None, + bootloader: Optional[str] = None, # supported: vemsd - flash_method='vemsd', + flash_method: str = 'vemsd', - image=None, - fdt=None, - initrd=None, - bootargs=None, + image: Optional[str] = None, + fdt: Optional[str] = None, + initrd: Optional[str] = None, + bootargs: Optional[str] = None, - uefi_entry=None, # only used if bootloader is "uefi" - ready_timeout=60, + uefi_entry: Optional[str] = None, # only used if bootloader is "uefi" + ready_timeout: int = 60, ): super(VersatileExpressPlatform, self).__init__(name, core_names, @@ -73,56 +88,56 @@ def __init__(self, name, # pylint: disable=too-many-locals self.bootargs = bootargs self.uefi_entry = uefi_entry self.ready_timeout = ready_timeout - self.bootloader = None - self.hard_reset_method = None + self.bootloader: Optional[str] = None + self.hard_reset_method: Optional[str] = None self._set_bootloader(bootloader) self._set_hard_reset_method(hard_reset_method) self._set_flash_method(flash_method) - def init_target_connection(self, target): + def init_target_connection(self, target: 'Target') -> None: if target.os == 'android': self._init_android_target(target) else: self._init_linux_target(target) - def _init_android_target(self, target): + def _init_android_target(self, target: 'Target') -> None: if target.connection_settings.get('device') is None: addr = self._get_target_ip_address(target) - target.connection_settings['device'] = addr + ':5555' + cast(AdbUserConnectionSettings, target.connection_settings)['device'] = addr + ':5555' - def _init_linux_target(self, target): + def _init_linux_target(self, target: 'Target') -> None: if target.connection_settings.get('host') is None: addr = self._get_target_ip_address(target) target.connection_settings['host'] = addr # pylint: disable=no-member - def _get_target_ip_address(self, target): + def _get_target_ip_address(self, target: 'Target') -> str: with open_serial_connection(port=self.serial_port, baudrate=self.baudrate, timeout=30, - init_dtr=0) as tty: - tty.sendline('su') # this is, apprently, required to query network device - # info by name on recent Juno builds... + init_dtr=False) as tty: + cast(fdpexpect.fdspawn, tty).sendline('su') # this is, apprently, required to query network device + # info by name on recent Juno builds... self.logger.debug('Waiting for the shell prompt.') - tty.expect(target.shell_prompt) + cast(fdpexpect.fdspawn, tty).expect(target.shell_prompt) self.logger.debug('Waiting for IP address...') - wait_start_time = time.time() + wait_start_time: float = time.time() try: while True: - tty.sendline('ip addr list eth0') + cast(fdpexpect.fdspawn, tty).sendline('ip addr list eth0') time.sleep(1) try: - tty.expect(r'inet ([1-9]\d*.\d+.\d+.\d+)', timeout=10) - return tty.match.group(1).decode('utf-8') + cast(fdpexpect.fdspawn, tty).expect(r'inet ([1-9]\d*.\d+.\d+.\d+)', timeout=10) + return cast(Match[bytes], cast(fdpexpect.fdspawn, tty).match).group(1).decode('utf-8') except pexpect.TIMEOUT: pass # We have our own timeout -- see below. if (time.time() - wait_start_time) > self.ready_timeout: raise TargetTransientError('Could not acquire IP address.') finally: - tty.sendline('exit') # exit shell created by "su" call at the start + cast(fdpexpect.fdspawn, tty).sendline('exit') # exit shell created by "su" call at the start - def _set_hard_reset_method(self, hard_reset_method): + def _set_hard_reset_method(self, hard_reset_method: Optional[str]) -> None: if hard_reset_method == 'dtr': self.modules.append({'vexpress-dtr': {'port': self.serial_port, 'baudrate': self.baudrate, @@ -135,7 +150,7 @@ def _set_hard_reset_method(self, hard_reset_method): else: ValueError('Invalid hard_reset_method: {}'.format(hard_reset_method)) - def _set_bootloader(self, bootloader): + def _set_bootloader(self, bootloader: Optional[str]) -> None: self.bootloader = bootloader if self.bootloader == 'uefi': self.modules.append({'vexpress-uefi': {'port': self.serial_port, @@ -152,7 +167,7 @@ def _set_bootloader(self, bootloader): 'bootargs': self.bootargs, }}) elif self.bootloader == 'u-boot': - uboot_env = None + uboot_env: Optional[Dict[str, str]] = None if self.bootargs: uboot_env = {'bootargs': self.bootargs} self.modules.append({'vexpress-u-boot': {'port': self.serial_port, @@ -170,7 +185,7 @@ def _set_bootloader(self, bootloader): else: ValueError('Invalid hard_reset_method: {}'.format(bootloader)) - def _set_flash_method(self, flash_method): + def _set_flash_method(self, flash_method: str) -> None: if flash_method == 'vemsd': self.modules.append({'vexpress-vemsd': {'vemsd_mount': self.vemsd_mount}}) else: @@ -180,10 +195,10 @@ def _set_flash_method(self, flash_method): class Juno(VersatileExpressPlatform): def __init__(self, - vemsd_mount='/media/JUNO', - baudrate=115200, - bootloader='u-boot', - hard_reset_method='dtr', + vemsd_mount: str = '/media/JUNO', + baudrate: int = 115200, + bootloader: str = 'u-boot', + hard_reset_method: str = 'dtr', **kwargs ): super(Juno, self).__init__('juno', @@ -197,10 +212,10 @@ def __init__(self, class TC2(VersatileExpressPlatform): def __init__(self, - vemsd_mount='/media/VEMSD', - baudrate=38400, - bootloader='bootmon', - hard_reset_method='reboottxt', + vemsd_mount: str = '/media/VEMSD', + baudrate: int = 38400, + bootloader: str = 'bootmon', + hard_reset_method: str = 'reboottxt', **kwargs ): super(TC2, self).__init__('tc2', @@ -213,10 +228,10 @@ def __init__(self, class JunoEnergyInstrument(Instrument): - binname = 'readenergy' - mode = CONTINUOUS | INSTANTANEOUS + binname: str = 'readenergy' + mode: int = CONTINUOUS | INSTANTANEOUS - _channels = [ + _channels: List[InstrumentChannel] = [ InstrumentChannel('sys', 'current'), InstrumentChannel('a57', 'current'), InstrumentChannel('a53', 'current'), @@ -235,45 +250,47 @@ class JunoEnergyInstrument(Instrument): InstrumentChannel('gpu', 'energy'), ] - def __init__(self, target): + def __init__(self, target: 'Target'): super(JunoEnergyInstrument, self).__init__(target) - self.on_target_file = None - self.command = None - self.binary = self.target.bin(self.binname) + self.on_target_file: Optional[str] = None + self.command: Optional[str] = None + self.binary: str = self.target.bin(self.binname) for chan in self._channels: - self.channels[chan.name] = chan - self.on_target_file = self.target.tempfile('energy', '.csv') - self.sample_rate_hz = 10 # DEFAULT_PERIOD is 100[ms] in readenergy.c + self.channels[cast(str, chan.name)] = chan + self.on_target_file = cast(Target, self.target).tempfile('energy', '.csv') + self.sample_rate_hz: int = 10 # DEFAULT_PERIOD is 100[ms] in readenergy.c self.command = '{} -o {}'.format(self.binary, self.on_target_file) - self.command2 = '{}'.format(self.binary) + self.command2: str = '{}'.format(self.binary) - def setup(self): # pylint: disable=arguments-differ - self.binary = self.target.install(os.path.join(PACKAGE_BIN_DIRECTORY, - self.target.abi, self.binname)) + def setup(self) -> None: # pylint: disable=arguments-differ + self.binary = cast(Target, self.target).install(os.path.join(PACKAGE_BIN_DIRECTORY, + self.target.abi or '', self.binname)) self.command = '{} -o {}'.format(self.binary, self.on_target_file) self.command2 = '{}'.format(self.binary) - def reset(self, sites=None, kinds=None, channels=None): + def reset(self, sites: Optional[List[str]] = None, + kinds: Optional[List[str]] = None, + channels: Optional[OrderedDict[str, InstrumentChannel]] = None): super(JunoEnergyInstrument, self).reset(sites, kinds, channels) - self.target.killall(self.binname, as_root=True) + cast(Target, self.target).killall(self.binname, as_root=True) - def start(self): - self.target.kick_off(self.command, as_root=True) + def start(self) -> None: + cast(Target, self.target).kick_off(self.command, as_root=True) - def stop(self): - self.target.killall(self.binname, signal='TERM', as_root=True) + def stop(self) -> None: + cast(Target, self.target).killall(self.binname, signal=cast(Signals, 'TERM'), as_root=True) # pylint: disable=arguments-differ - def get_data(self, output_file): - temp_file = tempfile.mktemp() - self.target.pull(self.on_target_file, temp_file) - self.target.remove(self.on_target_file) + def get_data(self, output_file: str) -> MeasurementsCsv: + temp_file: str = tempfile.mktemp() + cast(Target, self.target).pull(self.on_target_file, temp_file) + cast(Target, self.target).remove(self.on_target_file) with csvreader(temp_file) as reader: headings = next(reader) # Figure out which columns from the collected csv we actually want - select_columns = [] + select_columns: List[int] = [] for chan in self.active_channels: try: select_columns.append(headings.index(chan.name)) @@ -281,22 +298,22 @@ def get_data(self, output_file): raise HostError('Channel "{}" is not in {}'.format(chan.name, temp_file)) with csvwriter(output_file) as writer: - write_headings = ['{}_{}'.format(c.site, c.kind) - for c in self.active_channels] + write_headings: List[str] = ['{}_{}'.format(c.site, c.kind) + for c in self.active_channels] writer.writerow(write_headings) for row in reader: - write_row = [row[c] for c in select_columns] + write_row: List[str] = [row[c] for c in select_columns] writer.writerow(write_row) return MeasurementsCsv(output_file, self.active_channels, sample_rate_hz=10) - def take_measurement(self): - result = [] + def take_measurement(self) -> List[Measurement]: + result: List[Measurement] = [] output = self.target.execute(self.command2).split() with csvreader(output) as reader: headings = next(reader) values = next(reader) for chan in self.active_channels: value = values[headings.index(chan.name)] - result.append(Measurement(value, chan)) + result.append(Measurement(cast(float, value), chan)) return result diff --git a/devlib/target.py b/devlib/target.py index 44e7fc685..3847f6d62 100644 --- a/devlib/target.py +++ b/devlib/target.py @@ -1,4 +1,4 @@ -# Copyright 2024 ARM Limited +# Copyright 2024-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # +""" +Target module for devlib. +This module defines the Target class and supporting functionality. +""" import atexit import asyncio -from contextlib import contextmanager import io import base64 import functools @@ -24,6 +27,7 @@ import os from operator import itemgetter import re +import sys import time import logging import posixpath @@ -37,55 +41,76 @@ import inspect import itertools from collections import namedtuple, defaultdict -from past.builtins import long -from past.types import basestring from numbers import Number from shlex import quote from weakref import WeakMethod try: from collections.abc import Mapping except ImportError: - from collections import Mapping + from collections import Mapping # type: ignore from enum import Enum from concurrent.futures import ThreadPoolExecutor from devlib.host import LocalConnection, PACKAGE_BIN_DIRECTORY -from devlib.module import get_module, Module +from devlib.module import get_module, Module, HardRestModule, BootModule from devlib.platform import Platform from devlib.exception import (DevlibTransientError, TargetStableError, TargetNotRespondingError, TimeoutError, TargetTransientError, KernelConfigKeyError, - TargetError, HostError, TargetCalledProcessError) + TargetError, HostError, TargetCalledProcessError, + DevlibError) from devlib.utils.ssh import SshConnection -from devlib.utils.android import AdbConnection, AndroidProperties, LogcatMonitor, adb_command, INTENT_FLAGS -from devlib.utils.misc import memoized, isiterable, convert_new_lines, groupby_value -from devlib.utils.misc import commonprefix, merge_lists -from devlib.utils.misc import ABI_MAP, get_cpu_name, ranges_to_list -from devlib.utils.misc import batch_contextmanager, tls_property, _BoundTLSProperty, nullcontext -from devlib.utils.misc import safe_extract -from devlib.utils.types import integer, boolean, bitmask, identifier, caseless_string, bytes_regex +from devlib.utils.android import (AdbConnection, AndroidProperties, + LogcatMonitor, adb_command, INTENT_FLAGS) +from devlib.utils.misc import (memoized, isiterable, convert_new_lines, + groupby_value, commonprefix, ABI_MAP, get_cpu_name, + ranges_to_list, batch_contextmanager, tls_property, + _BoundTLSProperty, nullcontext, safe_extract) +from devlib.utils.types import (integer, boolean, bitmask, identifier, + caseless_string, bytes_regex) import devlib.utils.asyn as asyn - - -FSTAB_ENTRY_REGEX = re.compile(r'(\S+) on (.+) type (\S+) \((\S+)\)') -ANDROID_SCREEN_STATE_REGEX = re.compile('(?:mPowerState|mScreenOn|mWakefulness|Display Power: state)=([0-9]+|true|false|ON|OFF|DOZE|Dozing|Asleep|Awake)', - re.IGNORECASE) -ANDROID_SCREEN_RESOLUTION_REGEX = re.compile(r'cur=(?P\d+)x(?P\d+)') -ANDROID_SCREEN_ROTATION_REGEX = re.compile(r'orientation=(?P[0-3])') -DEFAULT_SHELL_PROMPT = re.compile(r'^.*(shell|root|juno)@?.*:[/~]\S* *[#$] ', - re.MULTILINE) -KVERSION_REGEX = re.compile( +from devlib.utils.annotation_helpers import (SshUserConnectionSettings, UserConnectionSettings, + AdbUserConnectionSettings, SupportedConnections, + SubprocessCommand, BackgroundCommand) +from typing import (List, Set, Dict, Union, Optional, Callable, TypeVar, + Any, cast, TYPE_CHECKING, AsyncGenerator, Type, Pattern, + Tuple, Iterator, AsyncContextManager, Iterable, + Mapping as Maptype) +from types import ModuleType +from typing_extensions import Literal +import signal +if TYPE_CHECKING: + from devlib.connection import ConnectionBase + from devlib.utils.misc import InitCheckpointMeta + from devlib.utils.asyn import AsyncManager, _AsyncPolymorphicFunction + from asyncio import AbstractEventLoop + from contextlib import _GeneratorContextManager + from re import Match + from xml.dom.minidom import Document + + +FSTAB_ENTRY_REGEX: Pattern[str] = re.compile(r'(\S+) on (.+) type (\S+) \((\S+)\)') +ANDROID_SCREEN_STATE_REGEX: Pattern[str] = re.compile('(?:mPowerState|mScreenOn|mWakefulness|Display Power: state)=([0-9]+|true|false|ON|OFF|DOZE|Dozing|Asleep|Awake)', + re.IGNORECASE) +ANDROID_SCREEN_RESOLUTION_REGEX: Pattern[str] = re.compile(r'cur=(?P\d+)x(?P\d+)') +ANDROID_SCREEN_ROTATION_REGEX: Pattern[str] = re.compile(r'orientation=(?P[0-3])') +DEFAULT_SHELL_PROMPT: Pattern[str] = re.compile(r'^.*(shell|root|juno)@?.*:[/~]\S* *[#$] ', + re.MULTILINE) +KVERSION_REGEX: Pattern[str] = re.compile( r'(?P\d+)(\.(?P\d+)(\.(?P\d+))?(-rc(?P\d+))?)?(-android(?P[0-9]+))?(-(?P\d+)-g(?P[0-9a-fA-F]{7,}))?(-ab(?P[0-9]+))?' ) -GOOGLE_DNS_SERVER_ADDRESS = '8.8.8.8' +GOOGLE_DNS_SERVER_ADDRESS: str = '8.8.8.8' installed_package_info = namedtuple('installed_package_info', 'apk_path package') +T = TypeVar('T', bound=Callable[..., Any]) + -def call_conn(f): +# FIXME - need to annotate to indicate the self argument needs to have a conn object of ConnectionBase type. +def call_conn(f: T) -> T: """ Decorator to be used on all :class:`devlib.target.Target` methods that directly use a method of ``self.conn``. @@ -96,13 +121,20 @@ def call_conn(f): ``__del__``, which could be executed by the garbage collector, interrupting another call to a method of the connection instance. + :param f: Method to decorate. + :type f: T + + :returns: The wrapped method that automatically creates and releases + a new connection if reentered. + :rtype: T + .. note:: This decorator could be applied directly to all methods with a metaclass or ``__init_subclass__`` but it could create issues when passing target methods as callbacks to connections' methods. """ @functools.wraps(f) - def wrapper(self, *args, **kwargs): + def wrapper(self, *args: Any, **kwargs: Any) -> Any: conn = self.conn reentered = conn.is_in_use disconnect = False @@ -132,97 +164,391 @@ def wrapper(self, *args, **kwargs): with self._lock: self._unused_conns.add(conn) - return wrapper + return cast(T, wrapper) class Target(object): + """ + An abstract base class defining the interface for a devlib target device. + + :param connection_settings: Connection parameters for the target + (e.g., SSH, ADB) in a dictionary. + :type connection_settings: UserConnectionSettings, optional + :param platform: A platform object describing architecture, ABI, kernel, + etc. If ``None``, platform info may be inferred or left unspecified. + :type platform: Platform, optional + :param working_directory: A writable directory on the target for devlib's + temporary files or scripts. If ``None``, a default path is used. + :type working_directory: str, optional + :param executables_directory: A directory on the target for storing + executables installed by devlib. If ``None``, a default path may be used. + :type executables_directory: str, optional + :param connect: If ``True``, attempt to connect to the device immediately, + else call :meth:`connect` manually. + :type connect: bool + :param modules: Dict mapping module names to their parameters. Additional + devlib modules to load on initialization. + :type modules: dict, optional + :param load_default_modules: If ``True``, load the modules specified in + :attr:`default_modules`. + :type load_default_modules: bool + :param shell_prompt: Compiled regex matching the target’s shell prompt. + :type shell_prompt: Pattern[str] + :param conn_cls: A reference to the Connection class to be used. + :type conn_cls: InitCheckpointMeta, optional + :param is_container: If ``True``, indicates the target is a container + rather than a physical or virtual machine. + :type is_container: bool + :param max_async: Number of asynchronous operations supported. Affects the + creation of parallel connections. + :type max_async: int + + :raises: Various :class:`devlib.exception` types if connection fails. + + .. note:: + Subclasses must implement device-specific methods (e.g., for Android vs. Linux or + specialized boards). The default implementation here may be incomplete. + """ + path: Optional[ModuleType] = None + os: Optional[str] = None + system_id: Optional[str] = None - path = None - os = None - system_id = None + default_modules: List[Type[Module]] = [] - default_modules = [] + def __init__(self, + connection_settings: Optional[UserConnectionSettings] = None, + platform: Optional[Platform] = None, + working_directory: Optional[str] = None, + executables_directory: Optional[str] = None, + connect: bool = True, + modules: Optional[Dict[str, Dict[str, Type[Module]]]] = None, + load_default_modules: bool = True, + shell_prompt: Pattern[str] = DEFAULT_SHELL_PROMPT, + conn_cls: Optional['InitCheckpointMeta'] = None, + is_container: bool = False, + max_async: int = 50, + ): + """ + Initialize a new Target instance and optionally connect to it. + """ + self._lock = threading.RLock() + self._async_pool: Optional[ThreadPoolExecutor] = None + self._async_pool_size: Optional[int] = None + self._unused_conns: Set[ConnectionBase] = set() + + self._is_rooted: Optional[bool] = None + self.connection_settings: UserConnectionSettings = connection_settings or {} + # Set self.platform: either it's given directly (by platform argument) + # or it's given in the connection_settings argument + # If neither, create default Platform() + if platform is None: + self.platform = self.connection_settings.get('platform', Platform()) + else: + self.platform = platform + # Check if the user hasn't given two different platforms + if connection_settings and ('platform' in self.connection_settings) and ('platform' in connection_settings): + if connection_settings['platform'] is not platform: + raise TargetStableError('Platform specified in connection_settings ' + '({}) differs from that directly passed ' + '({})!)' + .format(connection_settings['platform'], + self.platform)) + self.connection_settings['platform'] = self.platform + self.working_directory = working_directory + self.executables_directory = executables_directory + self.load_default_modules = load_default_modules + self.shell_prompt: Pattern[str] = bytes_regex(shell_prompt) + self.conn_cls = conn_cls + self.is_container = is_container + self.logger = logging.getLogger(self.__class__.__name__) + self._installed_binaries: Dict[str, str] = {} + self._installed_modules: Dict[str, Module] = {} + self._shutils: Optional[str] = None + self._file_transfer_cache: Optional[str] = None + self._max_async = max_async + self.busybox: Optional[str] = None + +# FIXME - find annotation for spec + def normalize_mod_spec(spec) -> Tuple[str, Dict[str, Type[Module]]]: + if isinstance(spec, str): + return (spec, {}) + else: + [(name, params)] = spec.items() + return (name, params) + + normalized_modules: List[Tuple[str, Dict[str, Type[Module]]]] = sorted( + map( + normalize_mod_spec, + itertools.chain( + self.default_modules if load_default_modules else [], + modules or [], + self.platform.modules or [], + ) + ), + key=itemgetter(0), + ) + + # Ensure that we did not ask for the same module but different + # configurations. Empty configurations are ignored, so any + # user-provided conf will win against an empty conf. + def elect(name: str, specs: List[Tuple[str, Dict[str, Type[Module]]]]) -> Tuple[str, Dict[str, Type[Module]]]: + specs = list(specs) + + confs = set( + tuple(sorted(params.items())) + for _, params in specs + if params + ) + if len(confs) > 1: + raise ValueError(f'Attempted to load the module "{name}" with multiple different configuration') + else: + if any( + params is None + for _, params in specs + ): + params = {} + else: + params = dict(confs.pop()) if confs else {} + + return (name, params) + + modules = dict(itertools.starmap( + elect, + itertools.groupby(normalized_modules, key=itemgetter(0)) + )) + + def get_kind(name: str) -> str: + return get_module(name).kind or '' + + def kind_conflict(kind: str, names: List[str]): + if kind: + raise ValueError(f'Cannot enable multiple modules sharing the same kind "{kind}": {sorted(names)}') + + list(itertools.starmap( + kind_conflict, + itertools.groupby( + sorted( + modules.keys(), + key=get_kind + ), + key=get_kind + ) + )) + self._modules = modules + + atexit.register( + WeakMethod(self.disconnect, atexit.unregister) + ) + + self._update_modules('early') + if connect: + self.connect(max_async=max_async) @property - def core_names(self): - return self.platform.core_names + def core_names(self) -> Union[List[caseless_string], List[str]]: + """ + A list of CPU core names in the order they appear + registered with the OS. If they are not specified, + they will be queried at run time. + + :return: CPU core names in order (e.g. ["A53", "A53", "A72", "A72"]). + :rtype: list + """ + if self.platform: + return self.platform.core_names + raise ValueError("No Platform set for this target, cannot access core_names") @property - def core_clusters(self): - return self.platform.core_clusters + def core_clusters(self) -> List[int]: + """ + A list with cluster ids of each core (starting with + 0). If this is not specified, clusters will be + inferred from core names (cores with the same name are + assumed to be in a cluster). + + :return: A list of integer cluster IDs for each core. + :rtype: list + """ + if self.platform: + return self.platform.core_clusters + raise ValueError("No Platform set for this target cannot access core_clusters") @property - def big_core(self): - return self.platform.big_core + def big_core(self) -> Optional[str]: + """ + The name of the big core in a big.LITTLE system. If this is + not specified it will be inferred (on systems with exactly + two clusters). + + :return: Big core name, or None if not defined. + :rtype: str or None + """ + if self.platform: + return self.platform.big_core + raise ValueError("No Platform set for this target cannot access big_core") @property - def little_core(self): - return self.platform.little_core + def little_core(self) -> Optional[str]: + """ + The name of the little core in a big.LITTLE system. If this is + not specified it will be inferred (on systems with exactly + two clusters). + + :return: Little core name, or None if not defined. + :rtype: str or None + """ + if self.platform: + return self.platform.little_core + raise ValueError("No Platform set for this target cannot access little_core") @property - def is_connected(self): + def is_connected(self) -> bool: + """ + Indicates whether there is an active connection to the target. + + :return: True if connected, else False. + :rtype: bool + """ return self.conn is not None @property - def connected_as_root(self): - return self.conn and self.conn.connected_as_root + def connected_as_root(self) -> Optional[bool]: + """ + Indicates whether the connection user on the target is root (uid=0). + + :return: True if root, False otherwise, or None if unknown. + :rtype: bool or None + """ + if self.conn: + if self.conn.connected_as_root: + return True + return False @property - def is_rooted(self): + def is_rooted(self) -> Optional[bool]: + """ + Indicates whether superuser privileges (root or sudo) are available. + + :return: True if superuser privileges are accessible, False if not, + or None if undetermined. + :rtype: bool or None + """ if self._is_rooted is None: try: self.execute('ls /', timeout=5, as_root=True) self._is_rooted = True - except(TargetError, TimeoutError): + except (TargetError, TimeoutError): self._is_rooted = False return self._is_rooted or self.connected_as_root @property @memoized - def needs_su(self): + def needs_su(self) -> Optional[bool]: + """ + Whether the current user must escalate privileges to run root commands. + + :return: True if the device is rooted but not connected as root. + :rtype: bool + """ return not self.connected_as_root and self.is_rooted @property @memoized - def kernel_version(self): - return KernelVersion(self.execute('{} uname -r -v'.format(quote(self.busybox))).strip()) + def kernel_version(self) -> 'KernelVersion': + """ + The kernel version from ``uname -r -v``, wrapped in a KernelVersion object. + + :raises ValueError: If busybox is unavailable for executing the uname command. + :return: Kernel version details. + :rtype: KernelVersion + """ + if self.busybox: + return KernelVersion(self.execute('{} uname -r -v'.format(quote(self.busybox))).strip()) + raise ValueError("busybox not set. Cannot get kernel version") @property - def hostid(self): + def hostid(self) -> int: + """ + A numeric ID representing the system's host identity. + + :return: The hostid as an integer (parsed from hex). + :rtype: int + """ return int(self.execute('{} hostid'.format(self.busybox)).strip(), 16) @property - def hostname(self): + def hostname(self) -> str: + """ + System hostname from ``hostname`` or ``uname -n``. + + :return: Hostname of the target. + :rtype: str + """ return self.execute('{} hostname'.format(self.busybox)).strip() @property - def os_version(self): # pylint: disable=no-self-use + def os_version(self) -> Dict[str, str]: # pylint: disable=no-self-use + """ + A mapping of OS version info. Empty by default; child classes may override. + + :return: OS version details. + :rtype: dict + """ return {} @property - def model(self): + def model(self) -> Optional[str]: + """ + Hardware model name, if any. + + :return: Model name, or None if not defined. + :rtype: str or None + """ return self.platform.model @property - def abi(self): # pylint: disable=no-self-use + def abi(self) -> Optional[str]: # pylint: disable=no-self-use + """ + The primary application binary interface (ABI) of this target. + + :return: ABI name (e.g. "armeabi-v7a"), or None if unknown. + :rtype: str or None + """ return None @property - def supported_abi(self): + def supported_abi(self) -> List[Optional[str]]: + """ + A list of all supported ABIs. + + :return: List of ABI strings. + :rtype: list + """ return [self.abi] @property @memoized - def cpuinfo(self): + def cpuinfo(self) -> 'Cpuinfo': + """ + Parsed data from ``/proc/cpuinfo``. + + :return: A :class:`Cpuinfo` instance with CPU details. + """ return Cpuinfo(self.execute('cat /proc/cpuinfo')) @property @memoized - def number_of_cpus(self): - num_cpus = 0 + def number_of_cpus(self) -> int: + """ + Count of CPU cores, determined by listing ``/sys/devices/system/cpu/cpu*``. + + :return: Number of CPU cores. + :rtype: int + """ + num_cpus: int = 0 corere = re.compile(r'^\s*cpu\d+\s*$') - output = self.execute('ls /sys/devices/system/cpu', as_root=self.is_rooted) + output: str = self.execute('ls /sys/devices/system/cpu', as_root=self.is_rooted) for entry in output.split(): if corere.match(entry): num_cpus += 1 @@ -230,15 +556,24 @@ def number_of_cpus(self): @property @memoized - def number_of_nodes(self): - cmd = 'cd /sys/devices/system/node && {busybox} find . -maxdepth 1'.format(busybox=quote(self.busybox)) + def number_of_nodes(self) -> int: + """ + Number of NUMA nodes detected by enumerating ``/sys/devices/system/node``. + + :return: NUMA node count, or 1 if unavailable. + :rtype: int + """ + if self.busybox: + cmd = 'cd /sys/devices/system/node && {busybox} find . -maxdepth 1'.format(busybox=quote(self.busybox)) + else: + raise ValueError('busybox not set. cannot form cmd') try: - output = self.execute(cmd, as_root=self.is_rooted) + output: str = self.execute(cmd, as_root=self.is_rooted) except TargetStableError: return 1 else: nodere = re.compile(r'^\./node\d+\s*$') - num_nodes = 0 + num_nodes: int = 0 for entry in output.splitlines(): if nodere.match(entry): num_nodes += 1 @@ -246,17 +581,30 @@ def number_of_nodes(self): @property @memoized - def list_nodes_cpus(self): - nodes_cpus = [] + def list_nodes_cpus(self) -> List[int]: + """ + Aggregated list of CPU IDs across all NUMA nodes. + + :return: A list of CPU IDs from each detected node. + :rtype: list + """ + nodes_cpus: List[int] = [] for node in range(self.number_of_nodes): - path = self.path.join('/sys/devices/system/node/node{}/cpulist'.format(node)) - output = self.read_value(path) - nodes_cpus.append(ranges_to_list(output)) + if self.path: + path: str = self.path.join('/sys/devices/system/node/node{}/cpulist'.format(node)) + output: str = self.read_value(path) + if output: + nodes_cpus.extend(ranges_to_list(output)) return nodes_cpus @property @memoized - def config(self): + def config(self) -> 'KernelConfig': + """ + Parsed kernel config from ``/proc/config.gz`` or ``/boot/config-*``. + + :return: A :class:`KernelConfig` instance. + """ try: return KernelConfig(self.execute('zcat /proc/config.gz')) except TargetStableError: @@ -269,29 +617,67 @@ def config(self): @property @memoized - def user(self): + def user(self) -> str: + """ + The username for the active shell on the target. + + :return: Username (e.g., "root" or "shell"). + :rtype: str + """ return self.getenv('USER') @property @memoized - def page_size_kb(self): + def page_size_kb(self) -> int: + """ + Page size in kilobytes, derived from ``/proc/self/smaps``. + + :return: Page size in KiB, or 0 if unknown. + :rtype: int + """ cmd = "cat /proc/self/smaps | {0} grep KernelPageSize | {0} head -n 1 | {0} awk '{{ print $2 }}'" return int(self.execute(cmd.format(self.busybox)) or 0) @property - def shutils(self): + def shutils(self) -> Optional[str]: + """ + Path to shell utilities (if installed by devlib). Internal usage. + + :return: The path or None if uninitialized. + :rtype: str or None + """ if self._shutils is None: self._setup_shutils() return self._shutils - def is_running(self, comm): + def is_running(self, comm: str) -> bool: + """ + Check if a process with the specified name/command is running on the target. + + :param comm: The process name to search for. + :type comm: str + :return: True if a matching process is found, else False. + :rtype: bool + """ cmd_ps = f'''{self.busybox} ps -A -T -o stat,comm''' cmd_awk = f'''{self.busybox} awk 'BEGIN{{found=0}} {{state=$1; $1=""; if ($state != "Z" && $0 == " {comm}") {{found=1}}}} END {{print found}}' ''' - result = self.execute(f"{cmd_ps} | {cmd_awk}") + result: str = self.execute(f"{cmd_ps} | {cmd_awk}") return bool(int(result)) @tls_property - def _conn(self): + def _conn(self) -> 'ConnectionBase': + """ + The underlying connection object. This will be ``None`` if an active + connection does not exist (e.g. if ``connect=False`` as passed on + initialization and :meth:`connect()` has not been called). + + :returns: The thread-local :class:`ConnectionBase` instance. + :rtype: ConnectionBase + + .. note:: a :class:`~devlib.target.Target` will automatically create a + connection per thread. This will always be set to the connection + for the current thread. + """ try: with self._lock: return self._unused_conns.pop() @@ -299,147 +685,32 @@ def _conn(self): return self.get_connection() # Add a basic property that does not require calling to get the value - conn = _conn.basic_property + conn: SupportedConnections = cast(SupportedConnections, _conn.basic_property) @tls_property - def _async_manager(self): + def _async_manager(self) -> 'AsyncManager': + """ + Thread-local property that holds an async manager for concurrency tasks. + + :return: Async manager instance for the current thread. + :rtype: devlib.utils.asyn.AsyncManager + """ return asyn.AsyncManager() # Add a basic property that does not require calling to get the value - async_manager = _async_manager.basic_property + async_manager: 'AsyncManager' = cast('AsyncManager', _async_manager.basic_property) - def __init__(self, - connection_settings=None, - platform=None, - working_directory=None, - executables_directory=None, - connect=True, - modules=None, - load_default_modules=True, - shell_prompt=DEFAULT_SHELL_PROMPT, - conn_cls=None, - is_container=False, - max_async=50, - ): - - self._lock = threading.RLock() - self._async_pool = None - self._async_pool_size = None - self._unused_conns = set() - - self._is_rooted = None - self.connection_settings = connection_settings or {} - # Set self.platform: either it's given directly (by platform argument) - # or it's given in the connection_settings argument - # If neither, create default Platform() - if platform is None: - self.platform = self.connection_settings.get('platform', Platform()) - else: - self.platform = platform - # Check if the user hasn't given two different platforms - if 'platform' in self.connection_settings: - if connection_settings['platform'] is not platform: - raise TargetStableError('Platform specified in connection_settings ' - '({}) differs from that directly passed ' - '({})!)' - .format(connection_settings['platform'], - self.platform)) - self.connection_settings['platform'] = self.platform - self.working_directory = working_directory - self.executables_directory = executables_directory - self.load_default_modules = load_default_modules - self.shell_prompt = bytes_regex(shell_prompt) - self.conn_cls = conn_cls - self.is_container = is_container - self.logger = logging.getLogger(self.__class__.__name__) - self._installed_binaries = {} - self._installed_modules = {} - self._cache = {} - self._shutils = None - self._file_transfer_cache = None - self._max_async = max_async - self.busybox = None - - def normalize_mod_spec(spec): - if isinstance(spec, str): - return (spec, {}) - else: - [(name, params)] = spec.items() - return (name, params) - - modules = sorted( - map( - normalize_mod_spec, - itertools.chain( - self.default_modules if load_default_modules else [], - modules or [], - self.platform.modules or [], - ) - ), - key=itemgetter(0), - ) - - # Ensure that we did not ask for the same module but different - # configurations. Empty configurations are ignored, so any - # user-provided conf will win against an empty conf. - def elect(name, specs): - specs = list(specs) - - confs = set( - tuple(sorted(params.items())) - for _, params in specs - if params - ) - if len(confs) > 1: - raise ValueError(f'Attempted to load the module "{name}" with multiple different configuration') - else: - if any( - params is None - for _, params in specs - ): - params = None - else: - params = dict(confs.pop()) if confs else {} - - return (name, params) - - modules = dict(itertools.starmap( - elect, - itertools.groupby(modules, key=itemgetter(0)) - )) - - def get_kind(name): - return get_module(name).kind or '' - - def kind_conflict(kind, names): - if kind: - raise ValueError(f'Cannot enable multiple modules sharing the same kind "{kind}": {sorted(names)}') - - list(itertools.starmap( - kind_conflict, - itertools.groupby( - sorted( - modules.keys(), - key=get_kind - ), - key=get_kind - ) - )) - self._modules = modules - - atexit.register( - WeakMethod(self.disconnect, atexit.unregister) - ) - - self._update_modules('early') - if connect: - self.connect(max_async=max_async) + def __getstate__(self) -> Dict[str, Any]: + """ + For pickling: exclude thread-local objects from the state. - def __getstate__(self): + :return: A dictionary representing the object's state. + :rtype: dict + """ # tls_property will recreate the underlying value automatically upon # access and is typically used for dynamic content that cannot be # pickled or should not transmitted to another thread. - ignored = { + ignored: set[str] = { k for k, v in inspect.getmembers(self.__class__) if isinstance(v, _BoundTLSProperty) @@ -455,7 +726,13 @@ def __getstate__(self): if k not in ignored } - def __setstate__(self, dct): + def __setstate__(self, dct: Dict[str, Any]) -> None: + """ + Restores the object's state after unpickling, reinitializing ephemeral objects. + + :param dct: The saved state dictionary. + :type dct: dict + """ self.__dict__ = dct pool_size = self._async_pool_size if pool_size is None: @@ -468,7 +745,21 @@ def __setstate__(self, dct): # connection and initialization @asyn.asyncf - async def connect(self, timeout=None, check_boot_completed=True, max_async=None): + async def connect(self, timeout: Optional[int] = None, + check_boot_completed: Optional[bool] = True, + max_async: Optional[int] = None) -> None: + """ + Connect to the target (e.g., via SSH or another transport). + + :param timeout: Timeout (in seconds) for connecting. + :type timeout: int, optional + :param check_boot_completed: If ``True``, verify the target has booted. + :type check_boot_completed: bool, optional + :param max_async: The number of parallel async connections to allow. + :type max_async: int, optional + + :raises TargetError: If the device fails to connect within the specified time. + """ self.platform.init_target_connection(self) # Forcefully set the thread-local value for the connection, with the # timeout we want @@ -477,18 +768,32 @@ async def connect(self, timeout=None, check_boot_completed=True, max_async=None) self.wait_boot_complete(timeout) self.check_connection() self._resolve_paths() - self.execute('mkdir -p {}'.format(quote(self.working_directory))) - self.execute('mkdir -p {}'.format(quote(self.executables_directory))) - self.busybox = self.install(os.path.join(PACKAGE_BIN_DIRECTORY, self.abi, 'busybox'), timeout=30) + # FIXME - should we raise exception if self.working_directory, executables_directory or abi is None? + if self.working_directory: + self.execute('mkdir -p {}'.format(quote(self.working_directory))) + if self.executables_directory: + self.execute('mkdir -p {}'.format(quote(self.executables_directory))) + if self.abi: + self.busybox = self.install(os.path.join(PACKAGE_BIN_DIRECTORY, self.abi, 'busybox'), timeout=30) self.conn.busybox = self.busybox self._detect_max_async(max_async or self._max_async) self.platform.update_from_target(self) self._update_modules('connected') - def _detect_max_async(self, max_async): + def _detect_max_async(self, max_async: int) -> None: + """ + Attempt to detect the maximum number of parallel asynchronous + commands the target can handle by opening multiple connections. + + :param max_async: Upper bound for parallel async connections. + :type max_async: int + """ self.logger.debug('Detecting max number of async commands ...') - def make_conn(_): + def make_conn(_) -> Optional[SupportedConnections]: + """ + create a connection to target to execute a command + """ try: conn = self.get_connection() except Exception: @@ -520,23 +825,27 @@ def make_conn(_): finally: logging.disable(logging.NOTSET) - conns = {conn for conn in conns if conn is not None} + resultconns = {conn for conn in conns if conn is not None} # Keep the connection so it can be reused by future threads - self._unused_conns.update(conns) - max_conns = len(conns) + self._unused_conns.update(resultconns) + max_conns = len(resultconns) self.logger.debug(f'Detected max number of async commands: {max_conns}') self._async_pool_size = max_conns self._async_pool = ThreadPoolExecutor(max_conns) @asyn.asyncf - async def check_connection(self): + async def check_connection(self) -> None: """ - Check that the connection works without obvious issues. + Perform a quick command to verify the target's shell is responsive. + + :raises TargetStableError: If the shell is present but not functioning + correctly (e.g., output on stderr). + :raises TargetNotRespondingError: If the target is unresponsive. """ - async def check(**kwargs): - out = await self.execute.asyn('true', **kwargs) + async def check(*, as_root: Union[Literal[False], Literal[True], str] = False) -> None: + out = await self.execute.asyn('true', as_root=as_root) if out: raise TargetStableError('The shell seems to not be functional and adds content to stderr: {!r}'.format(out)) @@ -547,9 +856,13 @@ async def check(**kwargs): if self.is_rooted: await check(as_root=True) - def disconnect(self): + def disconnect(self) -> None: + """ + Close all active connections to the target and terminate any + connection threads or asynchronous operations. + """ with self._lock: - thread_conns = self._conn.get_all_values() + thread_conns: Set[SupportedConnections] = self._conn.get_all_values() # Now that we have all the connection objects, we simply reset the # TLS property so that the connections we obtained will not be # reused anywhere. @@ -567,30 +880,89 @@ def disconnect(self): pool.__exit__(None, None, None) def __enter__(self): + """ + Context manager entrypoint. Returns self. + """ return self def __exit__(self, *args, **kwargs): + """ + Context manager exitpoint. Automatically disconnects from the device. + """ self.disconnect() async def __aenter__(self): + """ + Async context manager entry. + """ return self.__enter__() async def __aexit__(self, *args, **kwargs): + """ + Async context manager exit. + """ return self.__exit__(*args, **kwargs) - def get_connection(self, timeout=None): + def get_connection(self, timeout: Optional[int] = None) -> SupportedConnections: + """ + Get an additional connection to the target. A connection can be used to + execute one blocking command at time. This will return a connection that can + be used to interact with a target in parallel while a blocking operation is + being executed. + + This should *not* be used to establish an initial connection; use + :meth:`connect()` instead. + + :param timeout: Timeout (in seconds) for establishing the connection. + :type timeout: int, optional + :returns: A new connection object to be used by the caller. + :rtype: SupportedConnections + :raises ValueError: If no connection class (`conn_cls`) is set. + + .. note:: :class:`~devlib.target.Target` will automatically create a connection + per thread, so you don't normally need to use this explicitly in + threaded code. This is generally useful if you want to perform a + blocking operation (e.g. using :class:`background()`) while at the same + time doing something else in the same host-side thread. + """ if self.conn_cls is None: raise ValueError('Connection class not specified on Target creation.') - conn = self.conn_cls(timeout=timeout, **self.connection_settings) # pylint: disable=not-callable + conn: SupportedConnections = self.conn_cls(timeout=timeout, **self.connection_settings) # pylint: disable=not-callable # This allows forwarding the detected busybox for connections created in new threads. conn.busybox = self.busybox return conn - def wait_boot_complete(self, timeout=10): + def wait_boot_complete(self, timeout: Optional[int] = 10) -> None: + """ + Wait for the device to boot. Must be overridden by derived classes + if the device needs a specific boot-completion check. + + :param timeout: How long to wait for the device to finish booting. + :type timeout: int, optional + :raises NotImplementedError: If not implemented in child classes. + """ raise NotImplementedError() @asyn.asyncf - async def setup(self, executables=None): + async def setup(self, executables: Optional[List[str]] = None) -> None: + """ + This will perform an initial one-time set up of a device for devlib + interaction. This involves deployment of tools relied on the + :class:`~devlib.target.Target`, creation of working locations on the device, + etc. + + Usually, it is enough to call this method once per new device, as its effects + will persist across reboots. However, it is safe to call this method multiple + times. It may therefore be a good practice to always call it once at the + beginning of a script to ensure that subsequent interactions will succeed. + + Optionally, this may also be used to deploy additional tools to the device + by specifying a list of binaries to install in the ``executables`` parameter. + + :param executables: Optional list of host-side binaries to install + on the target during setup. + :type executables: list(str), optional + """ await self._setup_shutils.asyn() for host_exe in (executables or []): # pylint: disable=superfluous-parens @@ -601,14 +973,29 @@ async def setup(self, executables=None): # Initialize modules which requires Busybox (e.g. shutil dependent tasks) self._update_modules('setup') + if self._file_transfer_cache: + await self.execute.asyn('mkdir -p {}'.format(quote(self._file_transfer_cache))) - await self.execute.asyn('mkdir -p {}'.format(quote(self._file_transfer_cache))) - - def reboot(self, hard=False, connect=True, timeout=180): + def reboot(self, hard: bool = False, connect: bool = True, timeout: int = 180) -> None: + """ + Reboot the target. Optionally performs a hard reset if supported + by a :class:`HardRestModule`. + + :param hard: If ``True``, use a hard reset. + :type hard: bool + :param connect: If ``True``, reconnect after reboot finishes. + :type connect: bool + :param timeout: Timeout in seconds for reconnection. + :type timeout: int + + :raises TargetStableError: If hard reset is requested but not supported. + :raises TargetTransientError: If the target is not currently connected + and a soft reset is requested. + """ if hard: if not self.has('hard_reset'): raise TargetStableError('Hard reset not supported for this target.') - self.hard_reset() # pylint: disable=no-member + cast(HardRestModule, self.hard_reset)() # pylint: disable=no-member else: if not self.is_connected: message = 'Cannot reboot target because it is disconnected. ' +\ @@ -624,7 +1011,7 @@ def reboot(self, hard=False, connect=True, timeout=180): time.sleep(reset_delay) timeout = max(timeout - reset_delay, 10) if self.has('boot'): - self.boot() # pylint: disable=no-member + cast(BootModule, self.boot)() # pylint: disable=no-member self.conn.connected_as_root = None if connect: self.connect(timeout=timeout) @@ -632,32 +1019,34 @@ def reboot(self, hard=False, connect=True, timeout=180): # file transfer @asyn.asynccontextmanager - async def _xfer_cache_path(self, name): + async def _xfer_cache_path(self, name: str) -> AsyncGenerator[str, None]: """ Context manager to provide a unique path in the transfer cache with the basename of the given name. """ # Use a UUID to avoid race conditions on the target side xfer_uuid = uuid.uuid4().hex - folder = self.path.join(self._file_transfer_cache, xfer_uuid) - # Make sure basename will work on folders too - name = os.path.normpath(name) - # Ensure the name is relative so that os.path.join() will actually - # join the paths rather than ignoring the first one. - name = './{}'.format(os.path.basename(name)) - - check_rm = False - try: - await self.makedirs.asyn(folder) - # Don't check the exit code as the folder might not even exist - # before this point, if creating it failed - check_rm = True - yield self.path.join(folder, name) - finally: - await self.execute.asyn('rm -rf -- {}'.format(quote(folder)), check_exit_code=check_rm) + if self.path: + folder = self.path.join(self._file_transfer_cache, xfer_uuid) + # Make sure basename will work on folders too + name = os.path.normpath(name) + # Ensure the name is relative so that os.path.join() will actually + # join the paths rather than ignoring the first one. + name = './{}'.format(os.path.basename(name)) + + check_rm = False + try: + await self.makedirs.asyn(folder) + # Don't check the exit code as the folder might not even exist + # before this point, if creating it failed + check_rm = True + yield self.path.join(folder, name) + finally: + await self.execute.asyn('rm -rf -- {}'.format(quote(folder)), check_exit_code=check_rm) @asyn.asyncf - async def _prepare_xfer(self, action, sources, dest, pattern=None, as_root=False): + async def _prepare_xfer(self, action: str, sources: List[str], dest: str, + pattern: Optional[str] = None, as_root: bool = False) -> Dict[Tuple[str, ...], str]: """ Check the sanity of sources and destination and prepare the ground for transfering multiple sources. @@ -665,16 +1054,17 @@ async def _prepare_xfer(self, action, sources, dest, pattern=None, as_root=False once = functools.lru_cache(maxsize=None) - _target_cache = {} - def target_paths_kind(paths, as_root=False): - def process(x): + _target_cache: Dict[str, Optional[str]] = {} + + def target_paths_kind(paths: List[str], as_root: bool = False) -> List[Optional[str]]: + def process(x: str) -> Optional[str]: x = x.strip() if x == 'notexist': return None else: return x - _paths = [ + _paths: List[str] = [ path for path in paths if path not in _target_cache @@ -686,7 +1076,7 @@ def process(x): ) for path in _paths ) - res = self.execute(cmd, as_root=as_root) + res: str = self.execute(cmd, as_root=as_root) _target_cache.update(zip(_paths, map(process, res.split()))) return [ @@ -694,9 +1084,10 @@ def process(x): for path in paths ] - _host_cache = {} - def host_paths_kind(paths, as_root=False): - def path_kind(path): + _host_cache: Dict[str, Optional[str]] = {} + + def host_paths_kind(paths: List[str], as_root: bool = False) -> List[Optional[str]]: + def path_kind(path: str) -> Optional[str]: if os.path.isdir(path): return 'dir' elif os.path.exists(path): @@ -719,11 +1110,12 @@ def path_kind(path): # use SFTP for these operations, which should be cheaper than # Target.execute() if action == 'push': - src_excep = HostError + src_excep: Type[DevlibError] = HostError src_path_kind = host_paths_kind _dst_mkdir = once(self.makedirs) - dst_path_join = self.path.join + if self.path: + dst_path_join = self.path.join dst_paths_kind = target_paths_kind dst_remove_file = once(functools.partial(self.remove, as_root=as_root)) elif action == 'pull': @@ -738,12 +1130,12 @@ def path_kind(path): raise ValueError('Unknown action "{}"'.format(action)) # Handle the case where path is None - def dst_mkdir(path): + def dst_mkdir(path: Optional[str]) -> None: if path: _dst_mkdir(path) - def rewrite_dst(src, dst): - new_dst = dst_path_join(dst, os.path.basename(src)) + def rewrite_dst(src: str, dst: str) -> str: + new_dst: str = dst_path_join(dst, os.path.basename(src)) src_kind, = src_path_kind([src], as_root) # Batch both checks to avoid a costly extra execute() @@ -800,14 +1192,43 @@ def rewrite_dst(src, dst): @asyn.asyncf @call_conn - async def push(self, source, dest, as_root=False, timeout=None, globbing=False): # pylint: disable=arguments-differ + async def push(self, source: str, dest: str, as_root: bool = False, + timeout: Optional[int] = None, globbing: bool = False) -> None: # pylint: disable=arguments-differ + """ + Transfer a file from the host machine to the target device. + + If transfer polling is supported (ADB connections and SSH connections), + ``poll_transfers`` is set in the connection, and a timeout is not specified, + the push will be polled for activity. Inactive transfers will be + cancelled. (See :ref:`connection-types` for more information on polling). + + :param source: path on the host + :type source: str + :param dest: path on the target + :type destination: str + :param as_root: whether root is required. Defaults to false. + :type as_root: bool + :param timeout: timeout (in seconds) for the transfer; if the transfer does + not complete within this period, an exception will be raised. Leave unset + to utilise transfer polling if enabled. + :type timeout: int + :param globbing: If ``True``, the ``source`` is interpreted as a globbing + pattern instead of being take as-is. If the pattern has multiple + matches, ``dest`` must be a folder (or will be created as such if it + does not exists yet). + :type globbing: bool + + :raises TargetStableError: If any failure occurs in copying + (e.g., insufficient permissions). + + """ source = str(source) dest = str(dest) - sources = glob.glob(source) if globbing else [source] - mapping = await self._prepare_xfer.asyn('push', sources, dest, pattern=source if globbing else None, as_root=as_root) + sources: List[str] = glob.glob(source) if globbing else [source] + mapping: Dict[Tuple[str, ...], str] = await self._prepare_xfer.asyn('push', sources, dest, pattern=source if globbing else None, as_root=as_root) - def do_push(sources, dest): + def do_push(sources: Tuple[str, ...], dest: str) -> None: for src in sources: self.async_manager.track_access( asyn.PathAccess(namespace='host', path=src, mode='r') @@ -815,20 +1236,20 @@ def do_push(sources, dest): self.async_manager.track_access( asyn.PathAccess(namespace='target', path=dest, mode='w') ) - return self.conn.push(sources, dest, timeout=timeout) + self.conn.push(sources, dest, timeout=timeout) if as_root: - for sources, dest in mapping.items(): - for source in sources: + for sources_map, dest_map in mapping.items(): + for source in sources_map: async with self._xfer_cache_path(source) as device_tempfile: - do_push([source], device_tempfile) - await self.execute.asyn("mv -f -- {} {}".format(quote(device_tempfile), quote(dest)), as_root=True) + do_push(tuple([source]), device_tempfile) + await self.execute.asyn("mv -f -- {} {}".format(quote(device_tempfile), quote(dest_map)), as_root=True) else: - for sources, dest in mapping.items(): - do_push(sources, dest) + for sources_map, dest_map in mapping.items(): + do_push(sources_map, dest_map) @asyn.asyncf - async def _expand_glob(self, pattern, **kwargs): + async def _expand_glob(self, pattern: str, **kwargs: Dict[str, bool]) -> Optional[List[str]]: """ Expand the given path globbing pattern on the target using the shell globbing. @@ -855,27 +1276,63 @@ async def _expand_glob(self, pattern, **kwargs): pattern = pattern.replace(c, '\\' + c) cmd = "exec printf '%s\n' {}".format(pattern) - # Make sure to use the same shell everywhere for the path globbing, - # ensuring consistent results no matter what is the default platform - # shell - cmd = '{} sh -c {} 2>/dev/null'.format(quote(self.busybox), quote(cmd)) - # On some shells, match failure will make the command "return" a - # non-zero code, even though the command was not actually called - result = await self.execute.asyn(cmd, strip_colors=False, check_exit_code=False, **kwargs) - paths = result.splitlines() - if not paths: - raise TargetStableError('No file matching: {}'.format(pattern)) - - return paths + if self.busybox: + # Make sure to use the same shell everywhere for the path globbing, + # ensuring consistent results no matter what is the default platform + # shell + cmd = '{} sh -c {} 2>/dev/null'.format(quote(self.busybox), quote(cmd)) + # On some shells, match failure will make the command "return" a + # non-zero code, even though the command was not actually called + result: str = await self.execute.asyn(cmd, strip_colors=False, check_exit_code=False, **kwargs) + paths: List[str] = result.splitlines() + if not paths: + raise TargetStableError('No file matching: {}'.format(pattern)) + + return paths + return None @asyn.asyncf @call_conn - async def pull(self, source, dest, as_root=False, timeout=None, globbing=False, via_temp=False): # pylint: disable=arguments-differ + async def pull(self, source: str, dest: str, as_root: bool = False, + timeout: Optional[int] = None, globbing: bool = False, + via_temp: bool = False) -> None: # pylint: disable=arguments-differ + """ + Transfer a file from the target device to the host machine. + + If transfer polling is supported (ADB connections and SSH connections), + ``poll_transfers`` is set in the connection, and a timeout is not specified, + the pull will be polled for activity. Inactive transfers will be + cancelled. (See :ref:`connection-types` for more information on polling). + + :param source: path on the target + :type source: str + :param dest: path on the host + :type dest: str + :param as_root: whether root is required. Defaults to false. + :type as_root: bool + :param timeout: timeout (in seconds) for the transfer; if the transfer does + not complete within this period, an exception will be raised. + :type timeout: int, optional + :param globbing: If ``True``, the ``source`` is interpreted as a globbing + pattern instead of being take as-is. If the pattern has multiple + matches, ``dest`` must be a folder (or will be created as such if it + does not exists yet). + :type globbing: bool + :param via_temp: If ``True``, copy the file first to a temporary location on + the target, and then pull it. This can avoid issues some filesystems, + notably paramiko + OpenSSH combination having performance issues when + pulling big files from sysfs. + :type via_temp: bool + + :raises TargetStableError: If a transfer error occurs. + """ source = str(source) dest = str(dest) if globbing: - sources = await self._expand_glob.asyn(source, as_root=as_root) + sources: Optional[List[str]] = await self._expand_glob.asyn(source, as_root=as_root) + if sources is None: + sources = [source] else: sources = [source] @@ -883,9 +1340,9 @@ async def pull(self, source, dest, as_root=False, timeout=None, globbing=False, # so use a temporary copy instead. via_temp |= as_root - mapping = await self._prepare_xfer.asyn('pull', sources, dest, pattern=source if globbing else None, as_root=as_root) + mapping: Dict[Tuple[str, ...], str] = await self._prepare_xfer.asyn('pull', sources, dest, pattern=source if globbing else None, as_root=as_root) - def do_pull(sources, dest): + def do_pull(sources: Tuple[str, ...], dest: str) -> None: for src in sources: self.async_manager.track_access( asyn.PathAccess(namespace='target', path=src, mode='r') @@ -896,46 +1353,50 @@ def do_pull(sources, dest): self.conn.pull(sources, dest, timeout=timeout) if via_temp: - for sources, dest in mapping.items(): - for source in sources: + for sources_map, dest_map in mapping.items(): + for source in sources_map: async with self._xfer_cache_path(source) as device_tempfile: await self.execute.asyn("cp -r -- {} {}".format(quote(source), quote(device_tempfile)), as_root=as_root) await self.execute.asyn("{} chmod 0644 -- {}".format(self.busybox, quote(device_tempfile)), as_root=as_root) - do_pull([device_tempfile], dest) + do_pull(tuple([device_tempfile]), dest_map) else: - for sources, dest in mapping.items(): - do_pull(sources, dest) + for sources_map, dest_map in mapping.items(): + do_pull(sources_map, dest_map) @asyn.asyncf - async def get_directory(self, source_dir, dest, as_root=False): + async def get_directory(self, source_dir: str, dest: str, + as_root: bool = False) -> None: """ Pull a directory from the device, after compressing dir """ - # Create all file names - tar_file_name = source_dir.lstrip(self.path.sep).replace(self.path.sep, '.') - # Host location of dir - outdir = os.path.join(dest, tar_file_name) - # Host location of archive - tar_file_name = '{}.tar'.format(tar_file_name) - tmpfile = os.path.join(dest, tar_file_name) + if self.path: + # Create all file names + tar_file_name: str = source_dir.lstrip(self.path.sep).replace(self.path.sep, '.') + # Host location of dir + outdir: str = os.path.join(dest, tar_file_name) + # Host location of archive + tar_file_name = '{}.tar'.format(tar_file_name) + tmpfile: str = os.path.join(dest, tar_file_name) # If root is required, use tmp location for tar creation. - tar_file_cm = self._xfer_cache_path if as_root else nullcontext + tar_file_cm: Union[Callable[[str], AsyncContextManager[str]], Callable[[str], nullcontext]] = self._xfer_cache_path if as_root else nullcontext # Does the folder exist? await self.execute.asyn('ls -la {}'.format(quote(source_dir)), as_root=as_root) - async with tar_file_cm(tar_file_name) as tar_file_name: + async with tar_file_cm(tar_file_name) as tar_file: # Try compressing the folder try: - await self.execute.asyn('{} tar -cvf {} {}'.format( - quote(self.busybox), quote(tar_file_name), quote(source_dir) - ), as_root=as_root) + # FIXME - should we raise an error in the else case here when busybox or tar_file is None + if self.busybox and tar_file: + await self.execute.asyn('{} tar -cvf {} {}'.format( + quote(self.busybox), quote(tar_file), quote(source_dir) + ), as_root=as_root) except TargetStableError: - self.logger.debug('Failed to run tar command on target! ' \ - 'Not pulling directory {}'.format(source_dir)) + self.logger.debug('Failed to run tar command on target! ' + 'Not pulling directory {}'.format(source_dir)) # Pull the file if not os.path.exists(dest): os.mkdir(dest) - await self.pull.asyn(tar_file_name, tmpfile) + await self.pull.asyn(tar_file, tmpfile) # Decompress with tarfile.open(tmpfile, 'r') as f: safe_extract(f, outdir) @@ -943,16 +1404,27 @@ async def get_directory(self, source_dir, dest, as_root=False): # execution - def _prepare_cmd(self, command, force_locale): + def _prepare_cmd(self, command: SubprocessCommand, force_locale: str) -> SubprocessCommand: + """ + Internal helper to prepend environment settings (e.g., PATH, locale) + to a command string before execution. + + :param command: The command to execute. + :type command: SubprocessCommand or str + :param force_locale: The locale to enforce (e.g. 'C') or None for none. + :type force_locale: str + :return: The updated command string with environment preparation. + :rtype: str + """ # Force the locale if necessary for more predictable output if force_locale: # Use an explicit export so that the command is allowed to be any # shell statement, rather than just a command invocation - command = 'export LC_ALL={} && {}'.format(quote(force_locale), command) + command = 'export LC_ALL={} && {}'.format(quote(force_locale), cast(str, command)) # Ensure to use deployed command when availables if self.executables_directory: - command = "export PATH={}:$PATH && {}".format(quote(self.executables_directory), command) + command = "export PATH={}:$PATH && {}".format(quote(self.executables_directory), cast(str, command)) return command @@ -961,18 +1433,33 @@ class _BrokenConnection(Exception): @asyn.asyncf @call_conn - async def _execute_async(self, *args, **kwargs): + async def _execute_async(self, *args: Any, **kwargs: Any) -> str: + """ + Internal asynchronous handler for command execution. + + This is typically invoked by the asynchronous version of :meth:`execute`. + It may create a background thread or use an existing thread pool + to run the blocking command. + + :param args: Positional arguments forwarded to the blocking command. + :type args: Any + :param kwargs: Keyword arguments forwarded to the blocking command. + :type kwargs: Any + :return: The stdout of the command executed. + :rtype: str + :raises DevlibError: If any error occurs during command execution. + """ execute = functools.partial( self._execute, *args, **kwargs ) - pool = self._async_pool + pool: Optional[ThreadPoolExecutor] = self._async_pool if pool is None: return execute() else: - def thread_f(): + def thread_f() -> str: # If we cannot successfully connect from the thread, it might # mean that something external opened a connection on the # target, so we just revert to the blocking path. @@ -983,21 +1470,52 @@ def thread_f(): else: return execute() - loop = asyncio.get_running_loop() + loop: AbstractEventLoop = asyncio.get_running_loop() try: return await loop.run_in_executor(pool, thread_f) except self._BrokenConnection: return execute() @call_conn - def _execute(self, command, timeout=None, check_exit_code=True, - as_root=False, strip_colors=True, will_succeed=False, - force_locale='C'): + def _execute(self, command: SubprocessCommand, timeout: Optional[int] = None, check_exit_code: bool = True, + as_root: bool = False, strip_colors: bool = True, will_succeed: bool = False, + force_locale: str = 'C') -> str: + """ + Internal blocking command executor. Actual synchronous logic is placed here, + usually invoked by :meth:`execute`. + + :param command: The command to be executed. + :type command: str or SubprocessCommand + :param timeout: Timeout (in seconds) for the execution of the command. If + specified, an exception will be raised if execution does not complete + with the specified period. + :type timeout: int + :param check_exit_code: If ``True`` (the default) the exit code (on target) + from execution of the command will be checked, and an exception will be + raised if it is not ``0``. + :type check_exit_code: bool + :param as_root: The command will be executed as root. This will fail on + unrooted targets. + :type as_root: bool + :param strip_colours: The command output will have colour encodings and + most ANSI escape sequences striped out before returning. + :type strip_colors: bool + :param will_succeed: The command is assumed to always succeed, unless there is + an issue in the environment like the loss of network connectivity. That + will make the method always raise an instance of a subclass of + :class:`DevlibTransientError` when the command fails, instead of a + :class:`DevlibStableError`. + :type will_succeed: bool + :param force_locale: Prepend ``LC_ALL=`` in front of the + command to get predictable output that can be more safely parsed. + If ``None``, no locale is prepended. + :type force_locale: str + """ command = self._prepare_cmd(command, force_locale) return self.conn.execute(command, timeout=timeout, - check_exit_code=check_exit_code, as_root=as_root, - strip_colors=strip_colors, will_succeed=will_succeed) + check_exit_code=check_exit_code, as_root=as_root, + strip_colors=strip_colors, will_succeed=will_succeed) execute = asyn._AsyncPolymorphicFunction( asyn=_execute_async.asyn, @@ -1005,38 +1523,77 @@ def _execute(self, command, timeout=None, check_exit_code=True, ) @call_conn - def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False, - force_locale='C', timeout=None): - conn = self.conn + def background(self, command: SubprocessCommand, stdout: int = subprocess.PIPE, + stderr: int = subprocess.PIPE, as_root: Optional[bool] = False, + force_locale: str = 'C', timeout: Optional[int] = None) -> BackgroundCommand: + """ + Execute the command on the target, invoking it via subprocess on the host. + This will return :class:`subprocess.Popen` instance for the command. + + :param command: The command to be executed. + :type command: SubprocessCommand or str + :param stdout: By default, standard output will be piped from the subprocess; + this may be used to redirect it to an alternative file handle. + :type stdout: int + :param stderr: By default, standard error will be piped from the subprocess; + this may be used to redirect it to an alternative file handle. + :type stderr: int + :param as_root: The command will be executed as root. This will fail on + unrooted targets. + :type as_root: bool + :param force_locale: Prepend ``LC_ALL=`` in front of the + command to get predictable output that can be more safely parsed. + If ``None``, no locale is prepended. + :type force_locale: str + :param timeout: Timeout (in seconds) for the execution of the command. When + the timeout expires, :meth:`BackgroundCommand.cancel` is executed to + terminate the command. + :type timeout: int, optional + + :return: A handle to the background command. + :rtype: BackgroundCommand + + .. note:: This **will block the connection** until the command completes. + """ command = self._prepare_cmd(command, force_locale) - bg_cmd = self.conn.background(command, stdout, stderr, as_root) + bg_cmd: BackgroundCommand = self.conn.background(command, stdout, stderr, as_root) if timeout is not None: timer = threading.Timer(timeout, function=bg_cmd.cancel) timer.daemon = True timer.start() return bg_cmd - def invoke(self, binary, args=None, in_directory=None, on_cpus=None, - redirect_stderr=False, as_root=False, timeout=30): + def invoke(self, binary: str, args: Optional[Union[str, Iterable[str]]] = None, in_directory: Optional[str] = None, + on_cpus: Optional[Union[int, List[int], str]] = None, redirect_stderr: bool = False, as_root: bool = False, + timeout: Optional[int] = 30) -> str: """ Executes the specified binary under the specified conditions. - :binary: binary to execute. Must be present and executable on the device. - :args: arguments to be passed to the binary. The can be either a list or + :param binary: binary to execute. Must be present and executable on the device. + :type binary: str + :param args: arguments to be passed to the binary. The can be either a list or a string. - :in_directory: execute the binary in the specified directory. This must + :type args: str or list, optional + :param in_directory: execute the binary in the specified directory. This must be an absolute path. - :on_cpus: taskset the binary to these CPUs. This may be a single ``int`` (in which + :type in_directory: str, optional + :param on_cpus: taskset the binary to these CPUs. This may be a single ``int`` (in which case, it will be interpreted as the mask), a list of ``ints``, in which case this will be interpreted as the list of cpus, or string, which will be interpreted as a comma-separated list of cpu ranges, e.g. ``"0,4-7"``. - :as_root: Specify whether the command should be run as root - :timeout: If the invocation does not terminate within this number of seconds, + :type on_cpus: int or list or str, optional + :param redirect_stderr: redirect stderr to stdout + :type redirect_stderr: bool + :param as_root: Specify whether the command should be run as root + :type as_root: bool + :param timeout: If the invocation does not terminate within this number of seconds, a ``TimeoutError`` exception will be raised. Set to ``None`` if the invocation should not timeout. + :type timeout: int, optional - :returns: output of command. + :return: The captured output of the command. + :rtype: str """ command = binary if args: @@ -1044,33 +1601,42 @@ def invoke(self, binary, args=None, in_directory=None, on_cpus=None, args = ' '.join(args) command = '{} {}'.format(command, args) if on_cpus: - on_cpus = bitmask(on_cpus) - command = '{} taskset 0x{:x} {}'.format(quote(self.busybox), on_cpus, command) + on_cpus_bitmask = bitmask(on_cpus) + if self.busybox: + command = '{} taskset 0x{:x} {}'.format(quote(self.busybox), on_cpus_bitmask, command) if in_directory: command = 'cd {} && {}'.format(quote(in_directory), command) if redirect_stderr: command = '{} 2>&1'.format(command) return self.execute(command, as_root=as_root, timeout=timeout) - def background_invoke(self, binary, args=None, in_directory=None, - on_cpus=None, as_root=False): + def background_invoke(self, binary: str, args: Optional[Union[str, Iterable[str]]] = None, in_directory: Optional[str] = None, + on_cpus: Optional[Union[int, List[int], str]] = None, as_root: bool = False) -> BackgroundCommand: """ - Executes the specified binary as a background task under the - specified conditions. + Runs the specified binary as a background task, possibly pinned to CPUs or + launched in a certain directory. - :binary: binary to execute. Must be present and executable on the device. - :args: arguments to be passed to the binary. The can be either a list or + :param binary: binary to execute. Must be present and executable on the device. + :type binary: str + :param args: arguments to be passed to the binary. The can be either a list or a string. - :in_directory: execute the binary in the specified directory. This must + :type args: str or list of str, optional + :param in_directory: execute the binary in the specified directory. This must be an absolute path. - :on_cpus: taskset the binary to these CPUs. This may be a single ``int`` (in which + :type in_directory: str, optional + :param on_cpus: taskset the binary to these CPUs. This may be a single ``int`` (in which case, it will be interpreted as the mask), a list of ``ints``, in which case this will be interpreted as the list of cpus, or string, which will be interpreted as a comma-separated list of cpu ranges, e.g. ``"0,4-7"``. - :as_root: Specify whether the command should be run as root + :type on_cpus: int or list(int) or str, optional + :param as_root: Specify whether the command should be run as root + :type as_root: bool :returns: the subprocess instance handling that command + :rtype: BackgroundCommand + + :raises TargetError: If the binary does not exist or is not executable. """ command = binary if args: @@ -1078,66 +1644,162 @@ def background_invoke(self, binary, args=None, in_directory=None, args = ' '.join(args) command = '{} {}'.format(command, args) if on_cpus: - on_cpus = bitmask(on_cpus) - command = '{} taskset 0x{:x} {}'.format(quote(self.busybox), on_cpus, command) + on_cpus_bitmask = bitmask(on_cpus) + if self.busybox: + command = '{} taskset 0x{:x} {}'.format(quote(self.busybox), on_cpus_bitmask, command) + else: + raise TargetStableError("busybox not set. cannot execute command") if in_directory: command = 'cd {} && {}'.format(quote(in_directory), command) return self.background(command, as_root=as_root) @asyn.asyncf - async def kick_off(self, command, as_root=None): + async def kick_off(self, command: str, as_root: Optional[bool] = None) -> None: """ - Like execute() but returns immediately. Unlike background(), it will - not return any handle to the command being run. + Kick off the specified command on the target and return immediately. Unlike + ``background()`` this will not block the connection; on the other hand, there + is not way to know when the command finishes (apart from calling ``ps()``) + or to get its output (unless its redirected into a file that can be pulled + later as part of the command). + + :param command: The command to be executed. + :type command: str + :param as_root: The command will be executed as root. This will fail on + unrooted targets. + :type as_root: bool, optional + + :raises TargetError: If the command cannot be launched. """ - cmd = 'cd {wd} && {busybox} sh -c {cmd} >/dev/null 2>&1'.format( - wd=quote(self.working_directory), - busybox=quote(self.busybox), - cmd=quote(command) - ) + if self.working_directory and self.busybox: + cmd = 'cd {wd} && {busybox} sh -c {cmd} >/dev/null 2>&1'.format( + wd=quote(self.working_directory), + busybox=quote(self.busybox), + cmd=quote(command) + ) + else: + raise TargetStableError("working directory or busybox not set. cannot kick off command") self.background(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, as_root=as_root) - # sysfs interaction + R = TypeVar('R') @asyn.asyncf - async def read_value(self, path, kind=None): + async def read_value(self, path: str, kind: Optional[Callable[[str], R]] = None) -> Union[str, R]: + """ + Read the value from the specified path. This is primarily intended for + sysfs/procfs/debugfs etc. + + :param path: file to read + :type path: str + :param kind: Optionally, read value will be converted into the specified + kind (which should be a callable that takes exactly one parameter) + :type kind: callable, optional + + :return: The contents of the file, possibly parsed via ``kind``. + :raises TargetStableError: If the file does not exist or is unreadable. + """ self.async_manager.track_access( asyn.PathAccess(namespace='target', path=path, mode='r') ) - output = await self.execute.asyn('cat {}'.format(quote(path)), as_root=self.needs_su) # pylint: disable=E1103 + output: str = await self.execute.asyn('cat {}'.format(quote(path)), as_root=self.needs_su) # pylint: disable=E1103 output = output.strip() - if kind: - return kind(output) + if kind and callable(kind) and output: + try: + return kind(output) + except Exception as e: + raise ValueError(f"Error converting output using {kind}: {e}") else: return output @asyn.asyncf - async def read_int(self, path): + async def read_int(self, path: str) -> int: + """ + Equivalent to ``Target.read_value(path, kind=devlib.utils.types.integer)`` + + :param path: The file path to read. + :type path: str + :return: The integer value contained in the file. + :rtype: int + :raises ValueError: If the file contents cannot be parsed as an integer. + """ return await self.read_value.asyn(path, kind=integer) @asyn.asyncf - async def read_bool(self, path): + async def read_bool(self, path: str) -> bool: + """ + Equivalent to ``Target.read_value(path, kind=devlib.utils.types.boolean)`` + + :param path: File path to read. + :type path: str + :return: True or False, parsed from the file content. + :rtype: bool + :raises ValueError: If the file contents cannot be interpreted as a boolean. + """ return await self.read_value.asyn(path, kind=boolean) @asyn.asynccontextmanager - async def revertable_write_value(self, path, value, verify=True, as_root=True): - orig_value = self.read_value(path) + async def revertable_write_value(self, path: str, value: Any, verify: bool = True, as_root: bool = True) -> AsyncGenerator: + """ + Same as :meth:`Target.write_value`, but as a context manager that will write + back the previous value on exit. + + :param path: The file path to write to on the target. + :type path: str + :param value: The value to write, converted to a string. + :type value: any + :param verify: If True, read the file back to confirm the change. + :type verify: bool + :param as_root: If True, write as root. + :type as_root: bool + :yield: Allows running code in the context while the value is changed. + """ + orig_value: str = self.read_value(path) try: await self.write_value.asyn(path, value, verify=verify, as_root=as_root) yield finally: await self.write_value.asyn(path, orig_value, verify=verify, as_root=as_root) - def batch_revertable_write_value(self, kwargs_list): + def batch_revertable_write_value(self, kwargs_list: List[Dict[str, Any]]) -> '_GeneratorContextManager': + """ + Calls :meth:`Target.revertable_write_value` with all the keyword arguments + dictionary given in the list. This is a convenience method to update + multiple files at once, leaving them in their original state on exit. If one + write fails, all the already-performed writes will be reverted as well. + + :param kwargs_list: A list of dicts, each containing the kwargs for + :meth:`revertable_write_value`, e.g., {"path": , "value": , ...}. + :type kwargs_list: list of dict + :return: A context manager that applies all writes on entry, then reverts them. + :rtype: contextlib._GeneratorContextManager + """ return batch_contextmanager(self.revertable_write_value, kwargs_list) @asyn.asyncf - async def write_value(self, path, value, verify=True, as_root=True): + async def write_value(self, path: str, value: Any, verify: bool = True, as_root: bool = True) -> None: + """ + Write the value to the specified path on the target. This is primarily + intended for sysfs/procfs/debugfs etc. + + :param path: file to write into + :type path: str + :param value: value to be written + :type value: any + :param verify: If ``True`` (the default) the value will be read back after + it is written to make sure it has been written successfully. This due to + some sysfs entries silently failing to set the written value without + returning an error code. + :type verify: bool + :param as_root: specifies if writing requires being root. Its default value + is ``True``. + :type as_root: bool + + :raises TargetStableError: If the write or verification fails. + """ self.async_manager.track_access( asyn.PathAccess(namespace='target', path=path, mode='w') ) - value = str(value) + string_value = str(value) if verify: # Check in a loop for a while since updates to sysfs files can take @@ -1147,10 +1809,10 @@ async def write_value(self, path, value, verify=True, as_root=True): # request, such as hotplugging a CPU. cmd = ''' orig=$(cat {path} 2>/dev/null || printf "") -printf "%s" {value} > {path} || exit 10 -if [ {value} != "$orig" ]; then +printf "%s" {string_value} > {path} || exit 10 +if [ {string_value} != "$orig" ]; then trials=0 - while [ "$(cat {path} 2>/dev/null)" != {value} ]; do + while [ "$(cat {path} 2>/dev/null)" != {string_value} ]; do if [ $trials -ge 10 ]; then cat {path} exit 11 @@ -1161,24 +1823,26 @@ async def write_value(self, path, value, verify=True, as_root=True): fi ''' else: - cmd = '{busybox} printf "%s" {value} > {path}' - cmd = cmd.format(busybox=quote(self.busybox), path=quote(path), value=quote(value)) + cmd = '{busybox} printf "%s" {string_value} > {path}' + if self.busybox: + cmd = cmd.format(busybox=quote(self.busybox), path=quote(path), string_value=quote(string_value)) try: await self.execute.asyn(cmd, check_exit_code=True, as_root=as_root) except TargetCalledProcessError as e: if e.returncode == 10: - raise TargetStableError('Could not write "{value}" to {path}: {e.output}'.format( - value=value, path=path, e=e)) + raise TargetStableError('Could not write "{string_value}" to {path}: {e.output}'.format( + string_value=string_value, path=path, e=e)) elif verify and e.returncode == 11: out = e.output - message = 'Could not set the value of {} to "{}" (read "{}")'.format(path, value, out) + message = 'Could not set the value of {} to "{}" (read "{}")'.format(path, string_value, out) raise TargetStableError(message) else: raise @asyn.asynccontextmanager - async def make_temp(self, is_directory=True, directory='', prefix='devlib-test'): + async def make_temp(self, is_directory: Optional[bool] = True, directory: Optional[str] = '', + prefix: Optional[str] = 'devlib-test') -> AsyncGenerator: """ Creates temporary file/folder on target and deletes it once it's done. @@ -1196,20 +1860,29 @@ async def make_temp(self, is_directory=True, directory='', prefix='devlib-test') :rtype: str """ - directory = directory or self.working_directory - temp_obj = None + working_directory = directory or self.working_directory + if prefix: + _prefix = prefix + else: + _prefix = '' + temp_obj: Optional[str] = None try: - cmd = f'mktemp -p {quote(directory)} {quote(prefix)}-XXXXXX' - if is_directory: - cmd += ' -d' + if working_directory: + cmd = f'mktemp -p {quote(working_directory)} {quote(_prefix)}-XXXXXX' + if is_directory: + cmd += ' -d' - temp_obj = (await self.execute.asyn(cmd)).strip() - yield temp_obj + temp_obj = (await self.execute.asyn(cmd)).strip() + yield temp_obj finally: if temp_obj is not None: await self.remove.asyn(temp_obj) - def reset(self): + def reset(self) -> None: + """ + Soft reset the target. Typically, this means executing ``reboot`` on the + target. + """ try: self.execute('reboot', as_root=self.needs_su, timeout=2) except (TargetError, subprocess.CalledProcessError): @@ -1218,7 +1891,11 @@ def reset(self): self.conn.connected_as_root = None @call_conn - def check_responsive(self, explode=True): + def check_responsive(self, explode: bool = True) -> bool: + """ + Returns ``True`` if the target appears to be responsive and ``False`` + otherwise. + """ try: self.conn.execute('ls /', timeout=5) return True @@ -1229,51 +1906,109 @@ def check_responsive(self, explode=True): # process management - def kill(self, pid, signal=None, as_root=False): + def kill(self, pid: int, signal: Optional[signal.Signals] = None, as_root: Optional[bool] = False) -> None: + """ + Send a signal (default SIGTERM) to a process by PID. + + :param pid: The PID of the process to kill. + :type pid: int + :param signal: The signal to send (e.g., signal.SIGKILL). + :type signal: signal.Signals, optional + :param as_root: If True, run the kill command as root. + :type as_root: bool + """ signal_string = '-s {}'.format(signal) if signal else '' self.execute('{} kill {} {}'.format(self.busybox, signal_string, pid), as_root=as_root) - def killall(self, process_name, signal=None, as_root=False): + def killall(self, process_name: str, signal: Optional[signal.Signals] = None, + as_root: Optional[bool] = False) -> None: + """ + Send a signal to all processes matching the given name. + + :param process_name: Name of processes to kill. + :type process_name: str + :param signal: The signal to send. + :type signal: signal.Signals, optional + :param as_root: If True, run the kill command as root. + :type as_root: bool + """ for pid in self.get_pids_of(process_name): try: self.kill(pid, signal=signal, as_root=as_root) except TargetStableError: pass - def get_pids_of(self, process_name): + def get_pids_of(self, process_name: str) -> List[int]: + """ + Return a list of PIDs of all running instances of the specified process. + """ raise NotImplementedError() - def ps(self, **kwargs): + def ps(self, **kwargs: Dict[str, Any]) -> List['PsEntry']: + """ + Return a list of :class:`PsEntry` instances for all running processes on the + system. + """ raise NotImplementedError() # files @asyn.asyncf - async def makedirs(self, path, as_root=False): + async def makedirs(self, path: str, as_root: bool = False) -> None: + """ + Create a directory (and its parents if needed) on the target. + + :param path: Directory path to create. + :type path: str + :param as_root: If True, create as root. + :type as_root: bool + """ await self.execute.asyn('mkdir -p {}'.format(quote(path)), as_root=as_root) @asyn.asyncf - async def file_exists(self, filepath): + async def file_exists(self, filepath: str) -> bool: + """ + Check if a file or directory exists at the specified path. + + :param filepath: The target path to check. + :type filepath: str + :return: True if the path exists on the target, else False. + :rtype: bool + """ command = 'if [ -e {} ]; then echo 1; else echo 0; fi' - output = await self.execute.asyn(command.format(quote(filepath)), as_root=self.is_rooted) + output: str = await self.execute.asyn(command.format(quote(filepath)), as_root=self.is_rooted) return boolean(output.strip()) @asyn.asyncf - async def directory_exists(self, filepath): + async def directory_exists(self, filepath: str) -> bool: + """ + Check if the path on the target is an existing directory. + + :param filepath: The path to check. + :type filepath: str + :return: True if a directory exists at the path, else False. + :rtype: bool + """ output = await self.execute.asyn('if [ -d {} ]; then echo 1; else echo 0; fi'.format(quote(filepath))) # output from ssh my contain part of the expression in the buffer, # split out everything except the last word. return boolean(output.split()[-1]) # pylint: disable=maybe-no-member @asyn.asyncf - async def list_file_systems(self): - output = await self.execute.asyn('mount') - fstab = [] + async def list_file_systems(self) -> List['FstabEntry']: + """ + Return a list of currently mounted file systems, parsed into FstabEntry objects. + + :return: A list of file system entries describing mount points. + :rtype: list of FstabEntry + """ + output: str = await self.execute.asyn('mount') + fstab: List['FstabEntry'] = [] for line in output.split('\n'): line = line.strip() if not line: continue - match = FSTAB_ENTRY_REGEX.search(line) + match: Optional[Match[str]] = FSTAB_ENTRY_REGEX.search(line) if match: fstab.append(FstabEntry(match.group(1), match.group(2), match.group(3), match.group(4), @@ -1283,20 +2018,64 @@ async def list_file_systems(self): return fstab @asyn.asyncf - async def list_directory(self, path, as_root=False): + async def list_directory(self, path: str, as_root: bool = False) -> List[str]: + """ + Internal method that returns the contents of a directory. Called by + :meth:`list_directory`. + + :param path: Directory path to list. + :type path: str + :param as_root: If True, list as root. + :type as_root: bool + :return: A list of filenames within the directory. + :rtype: list of str + :raises NotImplementedError: If not implemented in a subclass. + """ self.async_manager.track_access( asyn.PathAccess(namespace='target', path=path, mode='r') ) return await self._list_directory(path, as_root=as_root) - def _list_directory(self, path, as_root=False): + async def _list_directory(self, path: str, as_root: bool = False) -> List[str]: + """ + List the contents of the specified directory. Optionally run as root. + + :param path: Directory path to list. + :type path: str + :param as_root: If True, run the directory listing as root. + :type as_root: bool + :return: Names of entries in the directory. + :rtype: list of str + :raises TargetStableError: If the path is not a directory or is unreadable. + """ raise NotImplementedError() - def get_workpath(self, name): - return self.path.join(self.working_directory, name) + def get_workpath(self, name: str) -> Optional[str]: + """ + Join a name with :attr:`working_directory` on the target, returning + an absolute path for convenience. + + :param name: The filename to append to the working directory. + :type name: str + :return: The combined absolute path, or None if no working directory is set. + :rtype: str or None + """ + if self.path: + return self.path.join(self.working_directory, name) + return None @asyn.asyncf - async def tempfile(self, prefix='', suffix=''): + async def tempfile(self, prefix: Optional[str] = '', suffix: Optional[str] = '') -> Optional[str]: + """ + Generate a unique path for a temporary file in the :attr:`working_directory`. + + :param prefix: An optional prefix for the file name. + :type prefix: str + :param suffix: An optional suffix (e.g. ".txt"). + :type suffix: str + :return: The full path to the file, which does not yet exist. + :rtype: str or None + """ name = '{prefix}_{uuid}_{suffix}'.format( prefix=prefix, uuid=uuid.uuid4().hex, @@ -1309,92 +2088,233 @@ async def tempfile(self, prefix='', suffix=''): return path @asyn.asyncf - async def remove(self, path, as_root=False): + async def remove(self, path: str, as_root=False) -> None: + """ + Remove a file or directory on the target. + + :param path: Path to remove. + :type path: str + :param as_root: If True, remove as root. + :type as_root: bool + """ await self.execute.asyn('rm -rf -- {}'.format(quote(path)), as_root=as_root) # misc @asyn.asyncf - async def read_sysctl(self, parameter): + async def read_sysctl(self, parameter: str) -> Optional[str]: """ - Returns the value of the given sysctl parameter as a string. + Read the specified sysctl parameter. Equivalent to reading the file under + ``/proc/sys/...``. + + :param parameter: The sysctl name, e.g. "kernel.sched_latency_ns". + :type parameter: str + :return: The value of the sysctl parameter, or None if not found. + :rtype: str or None + :raises ValueError: If the sysctl parameter doesn't exist. """ - path = self.path.join('/', 'proc', 'sys', *parameter.split('.')) - try: - return await self.read_value.asyn(path) - except FileNotFoundError as e: - raise ValueError(f'systcl parameter {parameter} was not found: {e}') + if self.path: + path: str = self.path.join('/', 'proc', 'sys', *parameter.split('.')) + try: + return await self.read_value.asyn(path) + except FileNotFoundError as e: + raise ValueError(f'systcl parameter {parameter} was not found: {e}') + return None - def core_cpus(self, core): + def core_cpus(self, core: str) -> List[int]: + """ + Return numeric CPU IDs corresponding to the given core name. + + :param core: The name of the CPU core (e.g., "A53"). + :type core: str + :return: List of CPU indices that match the given name. + :rtype: list of int + """ return [i for i, c in enumerate(self.core_names) if c == core] @asyn.asyncf - async def list_online_cpus(self, core=None): - path = self.path.join('/sys/devices/system/cpu/online') - output = await self.read_value.asyn(path) - all_online = ranges_to_list(output) - if core: - cpus = self.core_cpus(core) - if not cpus: - raise ValueError(core) - return [o for o in all_online if o in cpus] - else: - return all_online + async def list_online_cpus(self, core: Optional[str] = None) -> Optional[List[int]]: + """ + Return a list of online CPU IDs. If a core name is provided, restricts + to CPUs that match that name. + + :param core: Optional name of the CPU core (e.g., "A53") to filter results. + :type core: str, optional + :return: Online CPU IDs. + :rtype: list of int + :raises ValueError: If the specified core name is invalid. + """ + if self.path: + path: str = self.path.join('/sys/devices/system/cpu/online') + output: str = await self.read_value.asyn(path) + all_online: List[int] = ranges_to_list(output) + if core: + cpus: List[int] = self.core_cpus(core) + if not cpus: + raise ValueError(core) + return [o for o in all_online if o in cpus] + else: + return all_online + return None @asyn.asyncf - async def list_offline_cpus(self): - online = await self.list_online_cpus.asyn() + async def list_offline_cpus(self) -> List[int]: + """ + Return a list of offline CPU IDs, i.e., those not present in + :meth:`list_online_cpus`. + + :return: Offline CPU IDs. + :rtype: list of int + """ + online: List[int] = await self.list_online_cpus.asyn() return [c for c in range(self.number_of_cpus) if c not in online] @asyn.asyncf - async def getenv(self, variable): - var = await self.execute.asyn('printf "%s" ${}'.format(variable)) + async def getenv(self, variable: str) -> str: + """ + Return the value of the specified environment variable on the device + """ + var: str = await self.execute.asyn('printf "%s" ${}'.format(variable)) return var.rstrip('\r\n') - def capture_screen(self, filepath): + def capture_screen(self, filepath: str) -> None: + """ + Take a screenshot on the device and save it to the specified file on the + host. This may not be supported by the target. You can optionally insert a + ``{ts}`` tag into the file name, in which case it will be substituted with + on-target timestamp of the screen shot in ISO8601 format. + + :param filepath: Path on the host where screenshot is stored. + :type filepath: str + :raises NotImplementedError: If screenshot capture is not implemented. + """ raise NotImplementedError() - def install(self, filepath, timeout=None, with_name=None): + @asyn.asyncf + def install(self, filepath: str, timeout: Optional[int] = None, with_name: Optional[str] = None) -> str: + """ + Install an executable from the host to the target. If `with_name` is given, + the file is renamed on the target. + + :param filepath: Path on the host to the executable. + :type filepath: str + :param timeout: Timeout in seconds for the installation. + :type timeout: int, optional + :param with_name: If provided, rename the installed file on the target. + :type with_name: str, optional + :return: The path to the installed binary on the target. + :rtype: str + :raises NotImplementedError: If not implemented in a subclass. + """ raise NotImplementedError() - def uninstall(self, name): + def uninstall(self, name: str) -> None: + """ + Uninstall a previously installed executable. + + :param name: Name of the executable to remove. + :type name: str + :raises NotImplementedError: If not implemented in a subclass. + """ raise NotImplementedError() @asyn.asyncf - async def get_installed(self, name, search_system_binaries=True): + async def get_installed(self, name: str, search_system_binaries: bool = True) -> Optional[str]: + """ + Return the absolute path of an installed executable with the given name, + or None if not found. + + :param name: The name of the binary. + :type name: str + :param search_system_binaries: If True, also search the system PATH. + :type search_system_binaries: bool + :return: Full path to the binary on the target, or None if not found. + :rtype: str or None + """ # Check user installed binaries first if self.file_exists(self.executables_directory): - if name in (await self.list_directory.asyn(self.executables_directory)): + if name in (await self.list_directory.asyn(self.executables_directory)) and self.path: return self.path.join(self.executables_directory, name) # Fall back to binaries in PATH if search_system_binaries: - PATH = await self.getenv.asyn('PATH') - for path in PATH.split(self.path.pathsep): - try: - if name in (await self.list_directory.asyn(path)): - return self.path.join(path, name) - except TargetStableError: - pass # directory does not exist or no executable permissions + PATH: str = await self.getenv.asyn('PATH') + if self.path: + for path in PATH.split(self.path.pathsep): + try: + if name in (await self.list_directory.asyn(path)): + return self.path.join(path, name) + except TargetStableError: + pass # directory does not exist or no executable permissions + return None - which = get_installed + which: '_AsyncPolymorphicFunction' = get_installed @asyn.asyncf - async def install_if_needed(self, host_path, search_system_binaries=True, timeout=None): + async def install_if_needed(self, host_path: str, search_system_binaries: bool = True, + timeout: Optional[int] = None) -> str: + """ + Check whether an executable with the name of ``host_path`` is already installed + on the target. If it is not installed, install it from the specified path. + + :param host_path: The path to the executable on the host system. + :type host_path: str + :param search_system_binaries: If ``True``, also search the device's system PATH + for the binary before installing. If ``False``, only check user-installed + binaries. + :type search_system_binaries: bool + :param timeout: Maximum time in seconds to wait for installation to complete. + If ``None``, a default (implementation-defined) timeout is used. + :type timeout: int, optional + :return: The absolute path of the binary on the target after ensuring it is installed. + :rtype: str - binary_path = await self.get_installed.asyn(os.path.split(host_path)[1], - search_system_binaries=search_system_binaries) + :raises TargetError: If the target is disconnected. + :raises TargetStableError: If installation fails or times out (depending on implementation). + """ + binary_path: str = await self.get_installed.asyn(os.path.split(host_path)[1], + search_system_binaries=search_system_binaries) if not binary_path: binary_path = await self.install.asyn(host_path, timeout=timeout) return binary_path @asyn.asyncf - async def is_installed(self, name): + async def is_installed(self, name: str) -> bool: + """ + Determine whether an executable with the specified name is installed on the target. + + :param name: Name of the executable (e.g. "perf"). + :type name: str + :return: ``True`` if the executable is found, otherwise ``False``. + :rtype: bool + + :raises TargetError: If the target is not currently connected. + """ return bool(await self.get_installed.asyn(name)) - def bin(self, name): + def bin(self, name: str) -> str: + """ + Retrieve the installed path to the specified binary on the target. + + :param name: Name of the binary whose path is requested. + :type name: str + :return: The path to the binary if installed and recorded by devlib, + otherwise returns ``name`` unmodified. + :rtype: str + """ return self._installed_binaries.get(name, name) - def has(self, modname): + def has(self, modname: str) -> bool: + """ + Check whether the specified module or feature is present on the target. + + :param modname: Module name to look up. + :type modname: str + :return: ``True`` if the module is present and loadable, otherwise ``False``. + :rtype: bool + + :raises Exception: If an unexpected error occurs while querying the module. + (Can be replaced with a more specific exception if desired.) + """ modname = identifier(modname) try: self._get_module(modname, log=False) @@ -1404,9 +2324,16 @@ def has(self, modname): return True @asyn.asyncf - async def lsmod(self): - lines = (await self.execute.asyn('lsmod')).splitlines() - entries = [] + async def lsmod(self) -> List['LsmodEntry']: + """ + Run the ``lsmod`` command on the target and return the result as a list + of :class:`LsmodEntry` namedtuples. + + :return: A list of loaded kernel modules, each represented by an LsmodEntry object. + :rtype: list[LsmodEntry] + """ + lines: str = (await self.execute.asyn('lsmod')).splitlines() + entries: List['LsmodEntry'] = [] for line in lines[1:]: # first line is the header if not line.strip(): continue @@ -1419,13 +2346,21 @@ async def lsmod(self): return entries @asyn.asyncf - async def insmod(self, path): - target_path = self.get_workpath(os.path.basename(path)) + async def insmod(self, path: str) -> None: + """ + Insert a kernel module onto the target via ``insmod``. + + :param path: The path on the *host* system to the kernel module file (.ko). + :type path: str + :raises TargetStableError: If the module cannot be inserted (e.g., missing dependencies). + """ + target_path: Optional[str] = self.get_workpath(os.path.basename(path)) await self.push.asyn(path, target_path) - await self.execute.asyn('insmod {}'.format(quote(target_path)), as_root=True) + if target_path: + await self.execute.asyn('insmod {}'.format(quote(target_path)), as_root=True) @asyn.asyncf - async def extract(self, path, dest=None): + async def extract(self, path: str, dest: Optional[str] = None) -> Optional[str]: """ Extract the specified on-target file. The extraction method to be used (unzip, gunzip, bunzip2, or tar) will be based on the file's extension. @@ -1442,39 +2377,87 @@ async def extract(self, path, dest=None): (``dest`` if it was specified otherwise, the directory that contained the archive). + :param path: The on-target path of the archive or compressed file. + :type path: str + :param dest: An optional directory path on the target where the contents + should be extracted. The directory must already exist. + :type dest: str, optional + :return: Path to the extracted files. + * If a multi-file archive, returns the directory containing those files. + * If a single-file compression (e.g., .gz, .bz2), returns the path to + the decompressed file. + * If extraction fails or is unknown format, ``None`` might be returned + (depending on your usage). + :rtype: str or None + + :raises ValueError: If the file’s format is unrecognized. + :raises TargetStableError: If extraction fails on the target. """ for ending in ['.tar.gz', '.tar.bz', '.tar.bz2', '.tgz', '.tbz', '.tbz2']: if path.endswith(ending): return await self._extract_archive(path, 'tar xf {} -C {}', dest) - ext = self.path.splitext(path)[1] - if ext in ['.bz', '.bz2']: - return await self._extract_file(path, 'bunzip2 -f {}', dest) - elif ext == '.gz': - return await self._extract_file(path, 'gunzip -f {}', dest) - elif ext == '.zip': - return await self._extract_archive(path, 'unzip {} -d {}', dest) - else: - raise ValueError('Unknown compression format: {}'.format(ext)) + if self.path: + ext: str = self.path.splitext(path)[1] + if ext in ['.bz', '.bz2']: + return await self._extract_file(path, 'bunzip2 -f {}', dest) + elif ext == '.gz': + return await self._extract_file(path, 'gunzip -f {}', dest) + elif ext == '.zip': + return await self._extract_archive(path, 'unzip {} -d {}', dest) + else: + raise ValueError('Unknown compression format: {}'.format(ext)) + return None @asyn.asyncf - async def sleep(self, duration): + async def sleep(self, duration: int) -> None: + """ + Invoke a ``sleep`` command on the target to pause for the specified duration. + + :param duration: The time in seconds the target should sleep. + :type duration: int + :raises TimeoutError: If the sleep operation times out (rare, but can be forced). + """ timeout = duration + 10 await self.execute.asyn('sleep {}'.format(duration), timeout=timeout) @asyn.asyncf - async def read_tree_tar_flat(self, path, depth=1, check_exit_code=True, - decode_unicode=True, strip_null_chars=True): + async def read_tree_tar_flat(self, path: str, depth: int = 1, check_exit_code: bool = True, + decode_unicode: bool = True, strip_null_chars: bool = True) -> Dict[str, str]: + """ + Recursively read file nodes within a tar archive stored on the target, up to + a given ``depth``. The archive is temporarily extracted in memory, and the + contents are returned in a flat dictionary mapping each file path to its content. + + :param path: Path to the tar archive on the target. + :type path: str + :param depth: Maximum directory depth to traverse within the archive. + :type depth: int + :param check_exit_code: If ``True``, raise an error if the helper command exits non-zero. + :type check_exit_code: bool + :param decode_unicode: If ``True``, attempt to decode each file’s content as UTF-8. + :type decode_unicode: bool + :param strip_null_chars: If ``True``, strip out any null characters (``\\x00``) from + decoded text. + :type strip_null_chars: bool + :return: A dictionary mapping file paths (within the archive) to their textual content. + :rtype: dict(str, str) + + :raises TargetStableError: If the helper command fails or returns unexpected data. + :raises UnicodeDecodeError: If a file's content cannot be decoded when + ``decode_unicode=True``. + """ self.async_manager.track_access( asyn.PathAccess(namespace='target', path=path, mode='r') ) - command = 'read_tree_tgz_b64 {} {} {}'.format(quote(path), depth, - quote(self.working_directory)) - output = await self._execute_util.asyn(command, as_root=self.is_rooted, - check_exit_code=check_exit_code) + if path and self.working_directory: + command = 'read_tree_tgz_b64 {} {} {}'.format(quote(path), depth, + quote(self.working_directory)) + output: str = await self._execute_util.asyn(command, as_root=self.is_rooted, + check_exit_code=check_exit_code) - result = {} + result: Dict[str, str] = {} # Unpack the archive in memory tar_gz = base64.b64decode(output) @@ -1493,25 +2476,40 @@ async def read_tree_tar_flat(self, path, depth=1, check_exit_code=True, content = content_f.read() if decode_unicode: try: - content = content.decode('utf-8').strip() + content_str = content.decode('utf-8').strip() if strip_null_chars: - content = content.replace('\x00', '').strip() + content_str = content_str.replace('\x00', '').strip() except UnicodeDecodeError: - content = '' - - name = self.path.join(path, member.name) - result[name] = content + content_str = '' + if self.path: + name: str = self.path.join(path, member.name) + result[name] = content_str return result @asyn.asyncf - async def read_tree_values_flat(self, path, depth=1, check_exit_code=True): + async def read_tree_values_flat(self, path: str, depth: int = 1, check_exit_code: bool = True) -> Dict[str, str]: + """ + Recursively read file nodes under a given directory (e.g., sysfs) on the target, + up to the specified depth, returning a flat dictionary of file paths to contents. + + :param path: The on-target directory path to read from. + :type path: str + :param depth: Maximum directory depth to traverse. + :type depth: int + :param check_exit_code: If ``True``, raises an error if the helper command fails. + :type check_exit_code: bool + :return: A dict mapping each discovered file path to the file's textual content. + :rtype: dict(str, str) + + :raises TargetStableError: If the read-tree helper command fails or no content is returned. + """ self.async_manager.track_access( asyn.PathAccess(namespace='target', path=path, mode='r') ) - command = 'read_tree_values {} {}'.format(quote(path), depth) - output = await self._execute_util.asyn(command, as_root=self.is_rooted, - check_exit_code=check_exit_code) + command: str = 'read_tree_values {} {}'.format(quote(path), depth) + output: str = await self._execute_util.asyn(command, as_root=self.is_rooted, + check_exit_code=check_exit_code) accumulator = defaultdict(list) for entry in output.strip().split('\n'): @@ -1520,35 +2518,52 @@ async def read_tree_values_flat(self, path, depth=1, check_exit_code=True): path, value = entry.strip().split(':', 1) accumulator[path].append(value) - result = {k: '\n'.join(v).strip() for k, v in accumulator.items()} + result: Dict[str, str] = {k: '\n'.join(v).strip() for k, v in accumulator.items()} return result @asyn.asyncf - async def read_tree_values(self, path, depth=1, dictcls=dict, - check_exit_code=True, tar=False, decode_unicode=True, - strip_null_chars=True): + async def read_tree_values(self, path: str, depth: int = 1, dictcls: Type[Dict] = dict, + check_exit_code: bool = True, tar: bool = False, decode_unicode: bool = True, + strip_null_chars: bool = True) -> Union[str, Dict[str, 'Node']]: """ - Reads the content of all files under a given tree - - :path: path to the tree - :depth: maximum tree depth to read - :dictcls: type of the dict used to store the results - :check_exit_code: raise an exception if the shutil command fails - :tar: fetch the entire tree using tar rather than just the value (more - robust but slower in some use-cases) - :decode_unicode: decode the content of tar-ed files as utf-8 - :strip_null_chars: remove '\x00' chars from the content of utf-8 - decoded files - - :returns: a tree-like dict with the content of files as leafs + Recursively read all file nodes under a given directory or tar archive on the target, + building a **tree-like** structure up to the given depth. + + :param path: On-target path to read. May be a directory path or a tar file path + if ``tar=True``. + :type path: str + :param depth: Maximum directory depth to traverse. + :type depth: int + :param dictcls: The dictionary class to use for constructing the tree + (defaults to the built-in :class:`dict`). + :type dictcls: Type[dict] + :param check_exit_code: If ``True``, raises an error if the internal helper command fails. + :type check_exit_code: bool + :param tar: If ``True``, treat ``path`` as a tar archive and read it. If ``False``, + read from a normal directory hierarchy. + :type tar: bool + :param decode_unicode: If ``True``, decode file contents (in tar mode) as UTF-8. + :type decode_unicode: bool + :param strip_null_chars: If ``True``, strip out any null characters (``\\x00``) from + decoded text. + :type strip_null_chars: bool + :return: A hierarchical dictionary (or specialized mapping) containing sub-directories + and files as nested keys, or a string in some edge cases (depending on usage). + :rtype: str or Dict[str, Node] + + :raises TargetStableError: If the read-tree operation fails. + :raises UnicodeDecodeError: If a file content cannot be decoded. """ if not tar: - value_map = await self.read_tree_values_flat.asyn(path, depth, check_exit_code) + value_map: Dict[str, str] = await self.read_tree_values_flat.asyn(path, depth, check_exit_code) else: value_map = await self.read_tree_tar_flat.asyn(path, depth, check_exit_code, - decode_unicode, - strip_null_chars) - return _build_path_tree(value_map, path, self.path.sep, dictcls) + decode_unicode, + strip_null_chars) + if self.path: + return _build_path_tree(value_map, path, self.path.sep, dictcls) + else: + return {} def install_module(self, mod, **params): mod = get_module(mod) @@ -1562,71 +2577,182 @@ def install_module(self, mod, **params): # internal methods @asyn.asyncf - async def _setup_shutils(self): - shutils_ifile = os.path.join(PACKAGE_BIN_DIRECTORY, 'scripts', 'shutils.in') + async def _setup_shutils(self) -> None: + """ + Install and prepare the ``shutils`` script on the target. This script provides + shell utility functions that may be invoked by other devlib features. + + :raises TargetStableError: + If ``busybox`` is not installed or if pushing/installing ``shutils`` fails. + :raises IOError: + If reading the local script file fails on the host system. + """ + shutils_ifile: str = os.path.join(PACKAGE_BIN_DIRECTORY, 'scripts', 'shutils.in') with open(shutils_ifile) as fh: - lines = fh.readlines() + lines: List[str] = fh.readlines() with tempfile.TemporaryDirectory() as folder: - shutils_ofile = os.path.join(folder, 'shutils') + shutils_ofile: str = os.path.join(folder, 'shutils') with open(shutils_ofile, 'w') as ofile: - for line in lines: - line = line.replace("__DEVLIB_BUSYBOX__", self.busybox) - ofile.write(line) + if self.busybox: + for line in lines: + line = line.replace("__DEVLIB_BUSYBOX__", self.busybox) + ofile.write(line) self._shutils = await self.install.asyn(shutils_ofile) @asyn.asyncf @call_conn - async def _execute_util(self, command, timeout=None, check_exit_code=True, as_root=False): - command = '{} sh {} {}'.format(quote(self.busybox), quote(self.shutils), command) - return await self.execute.asyn( - command, - timeout=timeout, - check_exit_code=check_exit_code, - as_root=as_root - ) + async def _execute_util(self, command: SubprocessCommand, timeout: Optional[int] = None, + check_exit_code: bool = True, as_root: bool = False) -> Optional[str]: + """ + Execute a shell utility command via the ``shutils`` script on the target. + This typically prepends the busybox and shutils script calls before your + specified command. + + :param command: The command (or SubprocessCommand) string to run. + :type command: str or SubprocessCommand + :param timeout: Maximum number of seconds to allow for completion. If None, + an implementation-defined default is used. + :type timeout: int, optional + :param check_exit_code: If True, raise an error when the return code is non-zero. + :type check_exit_code: bool + :param as_root: If True, attempt to run with root privileges (e.g., ``su`` + or ``sudo``). + :type as_root: bool + :return: The command’s output on success, or ``None`` if busybox/shutils is + unavailable. + :rtype: str or None + + :raises TargetStableError: If the script is not present or the command fails + with a non-zero code (while ``check_exit_code=True``). + :raises TimeoutError: If the command runs longer than the specified timeout. + """ + if self.busybox and self.shutils: + command_str = '{} sh {} {}'.format(quote(self.busybox), quote(self.shutils), cast(str, command)) + return await self.execute.asyn( + command_str, + timeout=timeout, + check_exit_code=check_exit_code, + as_root=as_root + ) + return None - async def _extract_archive(self, path, cmd, dest=None): + async def _extract_archive(self, path: str, cmd: str, dest: Optional[str] = None) -> Optional[str]: + """ + extract files of type - + '.tar.gz', '.tar.bz', '.tar.bz2', '.tgz', '.tbz', '.tbz2' + + :param path: On-target path of the compressed archive (e.g., .tar.gz). + :type path: str + :param cmd: A template string for the extraction command (e.g., 'tar xf {} -C {}'). + :type cmd: str + :param dest: Optional path to a destination directory on the target + where files are extracted. If not specified, extraction occurs in + the same directory as ``path``. + :type dest: str, optional + :return: The directory or file path where the archive's contents were extracted, + or None if ``busybox`` or other prerequisites are missing. + :rtype: str or None + + :raises TargetStableError: If extraction fails or the file/directory cannot be written. + """ cmd = '{} ' + cmd # busybox if dest: - extracted = dest + extracted: Optional[str] = dest else: - extracted = self.path.dirname(path) - cmdtext = cmd.format(quote(self.busybox), quote(path), quote(extracted)) - await self.execute.asyn(cmdtext) + if self.path: + extracted = self.path.dirname(path) + if self.busybox and extracted: + cmdtext = cmd.format(quote(self.busybox), quote(path), quote(extracted)) + await self.execute.asyn(cmdtext) return extracted - async def _extract_file(self, path, cmd, dest=None): + async def _extract_file(self, path: str, cmd: str, dest: Optional[str] = None) -> Optional[str]: + """ + Decompress a single file on the target (e.g., .gz, .bz2). + + :param path: On-target path of the compressed file. + :type path: str + :param cmd: The decompression command format string (e.g., 'gunzip -f {}'). + :type cmd: str + :param dest: Optional directory path on the target where the decompressed file + should be moved. If omitted, the file remains in its original directory + (with the extension removed). + :type dest: str, optional + :return: The path to the decompressed file after extraction, or None if + prerequisites are missing. + :rtype: str or None + + :raises TargetStableError: If decompression fails or the file/directory is unwritable. + """ cmd = '{} ' + cmd # busybox - cmdtext = cmd.format(quote(self.busybox), quote(path)) - await self.execute.asyn(cmdtext) - extracted = self.path.splitext(path)[0] - if dest: - await self.execute.asyn('mv -f {} {}'.format(quote(extracted), quote(dest))) - if dest.endswith('/'): - extracted = self.path.join(dest, self.path.basename(extracted)) - else: - extracted = dest - return extracted + if self.busybox and self.path: + cmdtext: str = cmd.format(quote(self.busybox), quote(path)) + await self.execute.asyn(cmdtext) + extracted: Optional[str] = self.path.splitext(path)[0] + if dest and extracted: + await self.execute.asyn('mv -f {} {}'.format(quote(extracted), quote(dest))) + if dest.endswith('/'): + extracted = self.path.join(dest, self.path.basename(extracted)) + else: + extracted = dest + return extracted + return None - def _install_module(self, mod, params, log=True): - mod = get_module(mod) - name = mod.name - if params is None or self._modules.get(name, {}) is None: - raise TargetStableError(f'Could not load module "{name}" as it has been explicilty disabled') - else: - try: - return mod.install(self, **params) - except Exception as e: - if log: - self.logger.error(f'Module "{name}" failed to install on target: {e}') - raise + def _install_module(self, mod: Union[str, Type[Module]], + params: Dict[str, Type[Module]], log: bool = True) -> Optional[Module]: + """ + Installs a devlib module onto the target post-setup. + + :param mod: Either the module's name (string) or a Module type object. + :type mod: str or Type[Module] + :param params: A dictionary of parameters for initializing the module. + :type params: dict + :param log: If True, logs errors if installation fails. + :type log: bool + :return: The instantiated Module object if installation succeeds, otherwise None. + :rtype: Module or None + + :raises TargetStableError: If the module has been explicitly disabled or if + initialization fails irrecoverably. + :raises Exception: If any other unexpected error occurs. + """ + module = get_module(mod) + name = module.name + if name: + if params is None or self._modules.get(name, {}) is None: + raise TargetStableError(f'Could not load module "{name}" as it has been explicilty disabled') + else: + try: + return module.install(self, **params) + except Exception as e: + if log: + self.logger.error(f'Module "{name}" failed to install on target: {e}') + raise + raise TargetStableError('Failed to install module as module name is not present') @property - def modules(self): + def modules(self) -> List[str]: + """ + A list of module names registered on this target, regardless of which + have been installed. + + :return: Sorted list of module names. + :rtype: list of str + """ return sorted(self._modules.keys()) - def _update_modules(self, stage): - to_install = [ + def _update_modules(self, stage: str) -> None: + """ + Load or install modules that match the specified stage (e.g., "early", + "connected", or "setup"). + + :param stage: The stage name used for grouping when modules should be installed. + :type stage: str + + :raises Exception: If a module fails installation or is not supported + by the target (caught and logged internally). + """ + to_install: List[Tuple[Type[Module], Dict[str, Type[Module]]]] = [ (mod, params) for mod, params in ( (get_module(name), params) @@ -1638,10 +2764,20 @@ def _update_modules(self, stage): try: self._install_module(mod, params) except Exception as e: - mod_name = mod.name self.logger.warning(f'Module {mod.name} is not supported by the target: {e}') - def _get_module(self, modname, log=True): + def _get_module(self, modname: str, log: bool = True) -> Module: + """ + Retrieve or install a module by name. If not already installed, this + attempts to install it first. + + :param modname: The name or attribute of the module to retrieve. + :type modname: str + :param log: If True, logs errors if installation fails. + :type log: bool + :return: The installed module object, if successful. + :raises AttributeError: If the module or attribute cannot be found or installed. + """ try: return self._installed_modules[modname] except KeyError: @@ -1655,12 +2791,12 @@ def _get_module(self, modname, log=True): except ValueError: for _mod, _params in self._modules.items(): try: - _mod = get_module(_mod) + _module = get_module(_mod) except ValueError: pass else: - if _mod.attr_name == modname: - mod = _mod + if _module.attr_name == modname: + mod = _module params = _params break else: @@ -1668,12 +2804,23 @@ def _get_module(self, modname, log=True): f"'{self.__class__.__name__}' object has no attribute '{modname}'" ) else: - params = self._modules.get(mod.name, {}) + if mod.name: + params = self._modules.get(mod.name, {}) self._install_module(mod, params, log=log) return self.__getattr__(modname) - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Module: + """ + Fallback attribute accessor, invoked if a normal attribute or method + is not found. This checks for a corresponding installed or installable + module whose name matches ``attr``. + + :param attr: The module name or attribute to fetch. + :type attr: str + :return: The installed module if found/installed, otherwise raises AttributeError. + :raises AttributeError: If the module does not exist or cannot be installed. + """ # When unpickled, objects will have an empty dict so fail early if attr.startswith('__') and attr.endswith('__'): raise AttributeError(attr) @@ -1685,11 +2832,29 @@ def __getattr__(self, attr): # work as expected raise AttributeError(str(e)) - def _resolve_paths(self): + def _resolve_paths(self) -> None: + """ + Perform final path resolutions, such as setting the target's working directory, + file transfer cache, or executables directory. + + :raises NotImplementedError: If the target subclass has not overridden this method. + """ raise NotImplementedError() @asyn.asyncf - async def is_network_connected(self): + async def is_network_connected(self) -> bool: + """ + Check if the target has basic network/internet connectivity by using + ``ping`` to reach a known IP (e.g., 8.8.8.8). + + :return: True if the network appears to be reachable; False otherwise. + :rtype: bool + + :raises TargetStableError: If the network is known to be unreachable or if + the shell command reports a fatal error. + :raises TimeoutError: If repeatedly pinging does not respond within + the default or user-defined time. + """ self.logger.debug('Checking for internet connectivity...') timeout_s = 5 @@ -1707,7 +2872,7 @@ async def is_network_connected(self): await self.execute.asyn(command) return True except TargetStableError as e: - err = str(e).lower() + err: str = str(e).lower() if '100% packet loss' in err: # We sent a packet but got no response. # Try again - we don't want this to fail just because of a @@ -1730,14 +2895,33 @@ async def is_network_connected(self): class LinuxTarget(Target): + """ + A specialized :class:`Target` subclass for devices or systems running Linux. + Adapts path handling to ``posixpath`` and includes additional helpers for + Linux-specific commands or filesystems. + + :ivar path: Set to ``posixpath``. + :vartype path: ModuleType + :ivar os: ``"linux"`` + :vartype os: str + """ - path = posixpath + path: ModuleType = posixpath os = 'linux' @property @memoized - def abi(self): - value = self.execute('uname -m').strip() + def abi(self) -> str: + """ + Determine the Application Binary Interface (ABI) of the device by + interpreting the output of ``uname -m`` and mapping it to known + architecture strings in ``ABI_MAP``. + + :return: The ABI string (e.g., "arm64" or "x86_64"). If unmapped, + returns the exact output of ``uname -m``. + :rtype: str + """ + value: str = self.execute('uname -m').strip() for abi, architectures in ABI_MAP.items(): if value in architectures: result = abi @@ -1748,33 +2932,48 @@ def abi(self): @property @memoized - def os_version(self): - os_version = {} + def os_version(self) -> Dict[str, str]: + """ + Gather Linux distribution or version info by scanning files in ``/etc/`` + that end with ``-release`` or ``-version``. + + :return: A dictionary mapping the filename (e.g. "os-release") to + its contents as a single line. + :rtype: dict + """ + os_version: Dict[str, str] = {} command = 'ls /etc/*-release /etc*-version /etc/*_release /etc/*_version 2>/dev/null' - version_files = self.execute(command, check_exit_code=False).strip().split() + version_files: List[str] = self.execute(command, check_exit_code=False).strip().split() for vf in version_files: - name = self.path.basename(vf) - output = self.read_value(vf) + name: str = self.path.basename(vf) + output: str = self.read_value(vf) os_version[name] = convert_new_lines(output.strip()).replace('\n', ' ') return os_version @property @memoized - def system_id(self): + def system_id(self) -> str: + """ + Retrieve a Linux-specific system ID by invoking + a specialized utility command on the target. + + :return: A string uniquely identifying the Linux system. + :rtype: str + """ return self._execute_util('get_linux_system_id').strip() def __init__(self, - connection_settings=None, - platform=None, - working_directory=None, - executables_directory=None, - connect=True, - modules=None, - load_default_modules=True, - shell_prompt=DEFAULT_SHELL_PROMPT, - conn_cls=SshConnection, - is_container=False, - max_async=50, + connection_settings: Optional[UserConnectionSettings] = None, + platform: Optional[Platform] = None, + working_directory: Optional[str] = None, + executables_directory: Optional[str] = None, + connect: bool = True, + modules: Optional[Dict[str, Dict[str, Type[Module]]]] = None, + load_default_modules: bool = True, + shell_prompt: Pattern[str] = DEFAULT_SHELL_PROMPT, + conn_cls: 'InitCheckpointMeta' = SshConnection, + is_container: bool = False, + max_async: int = 50, ): super(LinuxTarget, self).__init__(connection_settings=connection_settings, platform=platform, @@ -1788,23 +2987,41 @@ def __init__(self, is_container=is_container, max_async=max_async) - def wait_boot_complete(self, timeout=10): + def wait_boot_complete(self, timeout: Optional[int] = 10) -> None: + """ + wait for target to boot up + """ pass @asyn.asyncf - async def get_pids_of(self, process_name): - """Returns a list of PIDs of all processes with the specified name.""" + async def get_pids_of(self, process_name) -> List[int]: + """ + Return a list of PIDs of all running processes matching the given name. + + :param process_name: Name of the process to look up. + :type process_name: str + :return: List of matching PIDs. + :rtype: list of int + :raises NotImplementedError: If not overridden by child classes. + """ # result should be a column of PIDs with the first row as "PID" header - result = await self.execute.asyn('ps -C {} -o pid'.format(quote(process_name)), # NOQA - check_exit_code=False) - result = result.strip().split() + result_temp:str = await self.execute.asyn('ps -C {} -o pid'.format(quote(process_name)), # NOQA + check_exit_code=False) + result: List[str] = result_temp.strip().split() if len(result) >= 2: # at least one row besides the header return list(map(int, result[1:])) else: return [] @asyn.asyncf - async def ps(self, threads=False, **kwargs): + async def ps(self, threads: bool = False, **kwargs: Dict[str, Any]) -> List['PsEntry']: + """ + Return a list of PsEntry objects for each process on the system. + + :return: A list of processes. + :rtype: list of PsEntry + :raises NotImplementedError: If not overridden. + """ ps_flags = '-eo' if threads: ps_flags = '-eLo' @@ -1812,84 +3029,163 @@ async def ps(self, threads=False, **kwargs): out = await self.execute.asyn(command) - result = [] - lines = convert_new_lines(out).splitlines() + result: List['PsEntry'] = [] + lines: List[str] = convert_new_lines(out).splitlines() # Skip header for line in lines[1:]: - parts = re.split(r'\s+', line, maxsplit=9) + parts: List[str] = re.split(r'\s+', line, maxsplit=9) if parts: result.append(PsEntry(*(parts[0:1] + list(map(int, parts[1:6])) + parts[6:]))) if not kwargs: return result else: - filtered_result = [] + filtered_result: List['PsEntry'] = [] for entry in result: if all(getattr(entry, k) == v for k, v in kwargs.items()): filtered_result.append(entry) return filtered_result - async def _list_directory(self, path, as_root=False): + async def _list_directory(self, path: str, as_root: bool = False) -> List[str]: + """ + target specific implementation of list_directory + """ contents = await self.execute.asyn('ls -1 {}'.format(quote(path)), as_root=as_root) return [x.strip() for x in contents.split('\n') if x.strip()] @asyn.asyncf - async def install(self, filepath, timeout=None, with_name=None): # pylint: disable=W0221 - destpath = self.path.join(self.executables_directory, - with_name and with_name or self.path.basename(filepath)) + async def install(self, filepath: str, timeout: Optional[int] = None, + with_name: Optional[str] = None) -> str: # pylint: disable=W0221 + """ + Install an executable on the device. + + :param filepath: path to the executable on the host + :param timeout: Optional timeout (in seconds) for the installation + :param with_name: This may be used to rename the executable on the target + """ + destpath: str = self.path.join(self.executables_directory, + with_name and with_name or self.path.basename(filepath)) await self.push.asyn(filepath, destpath, timeout=timeout) await self.execute.asyn('chmod a+x {}'.format(quote(destpath)), timeout=timeout) self._installed_binaries[self.path.basename(destpath)] = destpath return destpath @asyn.asyncf - async def uninstall(self, name): - path = self.path.join(self.executables_directory, name) + async def uninstall(self, name: str) -> None: + """ + Uninstall the specified executable from the target + """ + path: str = self.path.join(self.executables_directory, name) await self.remove.asyn(path) @asyn.asyncf - async def capture_screen(self, filepath): + async def capture_screen(self, filepath: str) -> None: + """ + Take a screenshot on the device and save it to the specified file on the + host. This may not be supported by the target. You can optionally insert a + ``{ts}`` tag into the file name, in which case it will be substituted with + on-target timestamp of the screen shot in ISO8601 format. + """ if not (await self.is_installed.asyn('scrot')): self.logger.debug('Could not take screenshot as scrot is not installed.') return try: - tmpfile = await self.tempfile.asyn() + tmpfile: str = await self.tempfile.asyn() cmd = 'DISPLAY=:0.0 scrot {} && {} date -u -Iseconds' - ts = (await self.execute.asyn(cmd.format(quote(tmpfile), quote(self.busybox)))).strip() - filepath = filepath.format(ts=ts) - await self.pull.asyn(tmpfile, filepath) - await self.remove.asyn(tmpfile) + if self.busybox: + ts: str = (await self.execute.asyn(cmd.format(quote(tmpfile), quote(self.busybox)))).strip() + filepath = filepath.format(ts=ts) + await self.pull.asyn(tmpfile, filepath) + await self.remove.asyn(tmpfile) + else: + raise TargetStableError("busybox is not present") except TargetStableError as e: - if "Can't open X dispay." not in e.message: + if isinstance(e.message, str) and "Can't open X dispay." not in e.message: raise e - message = e.message.split('OUTPUT:', 1)[1].strip() # pylint: disable=no-member - self.logger.debug('Could not take screenshot: {}'.format(message)) + if isinstance(e.message, str): + message = e.message.split('OUTPUT:', 1)[1].strip() # pylint: disable=no-member + self.logger.debug('Could not take screenshot: {}'.format(message)) - def _resolve_paths(self): + def _resolve_paths(self) -> None: + """ + set paths for working directory, file transfer cache and executables directory + """ if self.working_directory is None: - self.working_directory = self.path.join(self.execute("pwd").strip(), 'devlib-target') - self._file_transfer_cache = self.path.join(self.working_directory, '.file-cache') + self.working_directory: str = self.path.join(self.execute("pwd").strip(), 'devlib-target') + self._file_transfer_cache: str = self.path.join(self.working_directory, '.file-cache') if self.executables_directory is None: - self.executables_directory = self.path.join(self.working_directory, 'bin') + self.executables_directory: str = self.path.join(self.working_directory, 'bin') class AndroidTarget(Target): - + """ + A specialized :class:`Target` subclass for devices running Android. This + provides additional Android-specific features like property retrieval + (``getprop``), APK installation, ADB connection management, screen controls, + input injection, and more. + + :param connection_settings: Parameters for connecting to the device + (e.g., ADB serial or host/port). + :type connection_settings: AdbUserConnectionSettings or None + :param platform: A ``Platform`` object describing hardware aspects. If None, + a generic or default platform is used. + :type platform: Platform, optional + :param working_directory: A directory on the device for devlib to store + temporary files. Defaults to a subfolder of external storage. + :type working_directory: str, optional + :param executables_directory: A directory on the device where devlib + installs binaries. Defaults to ``/data/local/tmp/bin``. + :type executables_directory: str, optional + :param connect: If True, automatically connect to the device upon instantiation. + Otherwise, call :meth:`connect`. + :type connect: bool + :param modules: Additional modules to load (name -> parameters). + :type modules: dict, optional + :param load_default_modules: If True, load all modules in :attr:`default_modules`. + :type load_default_modules: bool + :param shell_prompt: Regex matching the interactive shell prompt, if used. + :type shell_prompt: re.Pattern + :param conn_cls: The connection class, typically :class:`AdbConnection`. + :type conn_cls: ConnectionBase, optional + :param package_data_directory: Location where installed packages store data. + Defaults to ``"/data/data"``. + :type package_data_directory: str + :param is_container: If True, indicates the device is actually a container environment. + :type is_container: bool + :param max_async: Maximum number of asynchronous operations to allow in parallel. + :type max_async: int + """ path = posixpath os = 'android' ls_command = '' @property @memoized - def abi(self): + def abi(self) -> str: + """ + Return the main ABI (CPU architecture) by reading ``ro.product.cpu.abi`` + from the device properties. + + :return: E.g. "arm64" or "armeabi-v7a" for an Android device. + :rtype: str + """ return self.getprop()['ro.product.cpu.abi'].split('-')[0] @property @memoized - def supported_abi(self): - props = self.getprop() - result = [props['ro.product.cpu.abi']] + def supported_abi(self) -> List[Optional[str]]: + """ + List all supported ABIs found in Android system properties. Combines + values from ``ro.product.cpu.abi``, ``ro.product.cpu.abi2``, + and ``ro.product.cpu.abilist``. + + :return: A list of ABI strings (some might be mapped to devlib’s known + architecture list). + :rtype: list + """ + props: Dict[str, str] = self.getprop() + result: List[str] = [props['ro.product.cpu.abi']] if 'ro.product.cpu.abi2' in props: result.append(props['ro.product.cpu.abi2']) if 'ro.product.cpu.abilist' in props: @@ -1897,7 +3193,7 @@ def supported_abi(self): if abi not in result: result.append(abi) - mapped_result = [] + mapped_result: List[Optional[str]] = [] for supported_abi in result: for abi, architectures in ABI_MAP.items(): found = False @@ -1911,29 +3207,63 @@ def supported_abi(self): @property @memoized - def os_version(self): - os_version = {} + def os_version(self) -> Dict[str, str]: + """ + Read and parse Android build version info from properties whose keys + start with ``ro.build.version``. + + :return: Dictionary mapping the last component of each key + (e.g., "incremental", "release") to its string value. + :rtype: dict + """ + os_version: Dict[str, str] = {} for k, v in self.getprop().iteritems(): if k.startswith('ro.build.version'): - part = k.split('.')[-1] + part: str = k.split('.')[-1] os_version[part] = v return os_version @property - def adb_name(self): + def adb_name(self) -> Optional[str]: + """ + The ADB device name or serial number for the connected Android device. + + :return: + - The string serial/ID if connected via ADB (e.g. ``"0123456789ABCDEF"``). + - ``None`` if unavailable or a different connection type is used (e.g. SSH). + :rtype: str or None + """ return getattr(self.conn, 'device', None) @property - def adb_server(self): + def adb_server(self) -> Optional[str]: + """ + The hostname or IP address of the ADB server, if using a remote ADB + connection. + + :return: + - The ADB server address (e.g. ``"127.0.0.1"``). + - ``None`` if not applicable (local ADB or a non-ADB connection). + :rtype: str or None + """ return getattr(self.conn, 'adb_server', None) @property - def adb_port(self): + def adb_port(self) -> Optional[int]: + """ + The TCP port on which the ADB server is listening, if using a remote ADB + connection. + + :return: + - An integer port number (e.g. 5037). + - ``None`` if not applicable or unknown. + :rtype: int or None + """ return getattr(self.conn, 'adb_port', None) @property @memoized - def android_id(self): + def android_id(self) -> str: """ Get the device's ANDROID_ID. Which is @@ -1943,30 +3273,74 @@ def android_id(self): .. note:: This will get reset on userdata erasure. + :return: The ANDROID_ID in hexadecimal form. + :rtype: str + """ + # FIXME - would it be better to just do 'settings get secure android_id' ? when trying to execute the content command, + # getting some access issues with settings output = self.execute('content query --uri content://settings/secure --projection value --where "name=\'android_id\'"').strip() return output.split('value=')[-1] @property @memoized - def system_id(self): + def system_id(self) -> str: + """ + Obtain a unique Android system identifier by using a device utility + (e.g., 'get_android_system_id' in shutils). + + :return: A device-specific ID string. + :rtype: str + """ return self._execute_util('get_android_system_id').strip() @property @memoized - def external_storage(self): + def external_storage(self) -> str: + """ + The path to the device's external storage directory (often ``/sdcard`` or + ``/storage/emulated/0``). + + :return: + A filesystem path pointing to the shared/SD card area on the Android device. + :rtype: str + :raises TargetStableError: + If the environment variable ``EXTERNAL_STORAGE`` is unset or an error + occurs reading it. + """ return self.execute('echo $EXTERNAL_STORAGE').strip() @property @memoized - def external_storage_app_dir(self): - return self.path.join(self.external_storage, 'Android', 'data') + def external_storage_app_dir(self) -> Optional[str]: + """ + The application-specific directory within external storage + (commonly ``/sdcard/Android/data``). + + :return: + The path to the app-specific directory under external storage, or + ``None`` if not determinable (e.g. no external storage). + :rtype: str or None + """ + if self.path: + return self.path.join(self.external_storage, 'Android', 'data') + return None @property @memoized - def screen_resolution(self): - output = self.execute('dumpsys window displays') - match = ANDROID_SCREEN_RESOLUTION_REGEX.search(output) + def screen_resolution(self) -> Tuple[int, int]: + """ + The current display resolution (width, height), read from ``dumpsys window displays``. + + :return: + A tuple ``(width, height)`` of the device’s screen resolution in pixels. + :rtype: tuple(int, int) + + :raises TargetStableError: + If the resolution cannot be parsed from ``dumpsys`` output. + """ + output: str = self.execute('dumpsys window displays') + match: Optional[Match[str]] = ANDROID_SCREEN_RESOLUTION_REGEX.search(output) if match: return (int(match.group('width')), int(match.group('height'))) @@ -1974,19 +3348,23 @@ def screen_resolution(self): return (0, 0) def __init__(self, - connection_settings=None, - platform=None, - working_directory=None, - executables_directory=None, - connect=True, - modules=None, - load_default_modules=True, - shell_prompt=DEFAULT_SHELL_PROMPT, - conn_cls=AdbConnection, - package_data_directory="/data/data", - is_container=False, - max_async=50, + connection_settings: Optional[UserConnectionSettings] = None, + platform: Optional[Platform] = None, + working_directory: Optional[str] = None, + executables_directory: Optional[str] = None, + connect: bool = True, + modules: Optional[Dict[str, Dict[str, Type[Module]]]] = None, + load_default_modules: bool = True, + shell_prompt: Pattern[str] = DEFAULT_SHELL_PROMPT, + conn_cls: 'InitCheckpointMeta' = AdbConnection, + package_data_directory: str = "/data/data", + is_container: bool = False, + max_async: int = 50, ): + """ + Initialize an AndroidTarget instance and optionally connect to the + device via ADB. + """ super(AndroidTarget, self).__init__(connection_settings=connection_settings, platform=platform, working_directory=working_directory, @@ -2001,10 +3379,17 @@ def __init__(self, self.package_data_directory = package_data_directory self._init_logcat_lock() - def _init_logcat_lock(self): + def _init_logcat_lock(self) -> None: + """ + Initialize a lock used for serializing logcat clearing operations. + This prevents overlapping ``logcat -c`` calls from multiple threads. + """ self.clear_logcat_lock = threading.Lock() - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: + """ + Extend the base pickling to skip the `clear_logcat_lock`. + """ dct = super().__getstate__() return { k: v @@ -2012,35 +3397,69 @@ def __getstate__(self): if k not in ('clear_logcat_lock',) } - def __setstate__(self, dct): + def __setstate__(self, dct: Dict[str, Any]) -> None: + """ + Restore post-pickle state, reinitializing the logcat lock. + """ super().__setstate__(dct) self._init_logcat_lock() @asyn.asyncf async def reset(self, fastboot=False): # pylint: disable=arguments-differ + """ + Soft reset (reboot) the device. If ``fastboot=True``, attempt to reboot + into fastboot mode. + + :param fastboot: If True, reboot into fastboot instead of normal reboot. + :type fastboot: bool + :raises DevlibTransientError: If "reboot" command fails or times out. + """ try: await self.execute.asyn('reboot {}'.format(fastboot and 'fastboot' or ''), - as_root=self.needs_su, timeout=2) + as_root=self.needs_su, timeout=2) except (DevlibTransientError, subprocess.CalledProcessError): # on some targets "reboot" doesn't return gracefully pass self.conn.connected_as_root = None @asyn.asyncf - async def wait_boot_complete(self, timeout=10): - start = time.time() - boot_completed = boolean(await self.getprop.asyn('sys.boot_completed')) - while not boot_completed and timeout >= time.time() - start: - time.sleep(5) - boot_completed = boolean(await self.getprop.asyn('sys.boot_completed')) - if not boot_completed: - # Raise a TargetStableError as this usually happens because of - # an issue with Android more than a timeout that is too small. - raise TargetStableError('Connected but Android did not fully boot.') + async def wait_boot_complete(self, timeout: Optional[int] = 10) -> None: + """ + Wait for Android to finish booting, typically by polling ``sys.boot_completed`` + property. + + :param timeout: Seconds to wait. If the property isn't set by this time, raise. + :type timeout: int or None + :raises TargetStableError: If the device remains un-booted after `timeout` seconds. + """ + start: float = time.time() + boot_completed: bool = boolean(await self.getprop.asyn('sys.boot_completed')) + if timeout: + while not boot_completed and timeout >= time.time() - start: + time.sleep(5) + boot_completed = boolean(await self.getprop.asyn('sys.boot_completed')) + if not boot_completed: + # Raise a TargetStableError as this usually happens because of + # an issue with Android more than a timeout that is too small. + raise TargetStableError('Connected but Android did not fully boot.') @asyn.asyncf - async def connect(self, timeout=30, check_boot_completed=True, max_async=None): # pylint: disable=arguments-differ - device = self.connection_settings.get('device') + async def connect(self, timeout: Optional[int] = 30, + check_boot_completed: Optional[bool] = True, + max_async: Optional[int] = None) -> None: # pylint: disable=arguments-differ + """ + Establish a connection to the target. It is usually not necessary to call + this explicitly, as a connection gets automatically established on + instantiation. + + :param timeout: Time in seconds before giving up on connection attempts. + :type timeout: int, optional + :param check_boot_completed: Whether to call :meth:`wait_boot_complete`. + :type check_boot_completed: bool + :param max_async: Override the default concurrency limit if provided. + :type max_async: int, optional + :raises TargetError: If the device fails to connect. + """ await super(AndroidTarget, self).connect.asyn( timeout=timeout, check_boot_completed=check_boot_completed, @@ -2048,7 +3467,12 @@ async def connect(self, timeout=30, check_boot_completed=True, max_async=None): ) @asyn.asyncf - async def __setup_list_directory(self): + async def __setup_list_directory(self) -> None: + """ + One-time setup to determine if the device supports ``ls -1``. On older + Android versions, the ``-1`` flag might not be available, so fallback + to plain ``ls``. + """ # In at least Linaro Android 16.09 (which was their first Android 7 release) and maybe # AOSP 7.0 as well, the ls command was changed. # Previous versions default to a single column listing, which is nice and easy to parse. @@ -2061,30 +3485,72 @@ async def __setup_list_directory(self): except TargetStableError: self.ls_command = 'ls' - async def _list_directory(self, path, as_root=False): + async def _list_directory(self, path: str, as_root: bool = False) -> List[str]: + """ + Implementation of :meth:`list_directory` for Android. Uses an ls command + that might be adjusted depending on OS version. + + :param path: Directory path on the device. + :type path: str + :param as_root: If True, escalate privileges for listing. + :type as_root: bool + :return: A list of file/directory names in the specified path. + :rtype: list of str + :raises TargetStableError: If the directory doesn't exist or can't be listed. + """ if self.ls_command == '': await self.__setup_list_directory.asyn() contents = await self.execute.asyn('{} {}'.format(self.ls_command, quote(path)), as_root=as_root) return [x.strip() for x in contents.split('\n') if x.strip()] @asyn.asyncf - async def install(self, filepath, timeout=None, with_name=None): # pylint: disable=W0221 - ext = os.path.splitext(filepath)[1].lower() + async def install(self, filepath: str, timeout: Optional[int] = None, + with_name: Optional[str] = None) -> str: # pylint: disable=W0221 + """ + Install a file (APK or binary) onto the Android device. If the file is an APK, + use :meth:`install_apk`; otherwise, use :meth:`install_executable`. + + :param filepath: Path on the host to the file (APK or binary). + :type filepath: str + :param timeout: Optional time in seconds to allow the install. + :type timeout: int, optional + :param with_name: If installing a binary, rename it on the device. Ignored for APKs. + :type with_name: str, optional + :return: The path or package installed on the device. + :rtype: str + :raises TargetStableError: If the file extension is unsupported or installation fails. + """ + ext: str = os.path.splitext(filepath)[1].lower() if ext == '.apk': return await self.install_apk.asyn(filepath, timeout) else: return await self.install_executable.asyn(filepath, with_name, timeout) @asyn.asyncf - async def uninstall(self, name): + async def uninstall(self, name: str) -> None: + """ + Uninstall either a package (if installed as an APK) or an executable from + the device. + + :param name: The package name or binary name to remove. + :type name: str + """ if await self.package_is_installed.asyn(name): await self.uninstall_package.asyn(name) else: await self.uninstall_executable.asyn(name) @asyn.asyncf - async def get_pids_of(self, process_name): - result = [] + async def get_pids_of(self, process_name: str) -> List[int]: + """ + Return a list of process IDs (PIDs) for any processes matching ``process_name``. + + :param process_name: The substring or name to search for in the command name. + :type process_name: str + :return: List of integer PIDs matching the name. + :rtype: list of int + """ + result: List[int] = [] search_term = process_name[-15:] for entry in await self.ps.asyn(): if search_term in entry.name: @@ -2092,7 +3558,19 @@ async def get_pids_of(self, process_name): return result @asyn.asyncf - async def ps(self, threads=False, **kwargs): + async def ps(self, threads: bool = False, **kwargs: Dict[str, Any]) -> List['PsEntry']: + """ + Return a list of process entries on the device (like ``ps`` output), + optionally including thread info if ``threads=True``. + + :param threads: If True, use ``ps -AT`` to include threads. + :type threads: bool + :param kwargs: Key/value filters to match against the returned attributes + (like user, name, etc.). + :return: A list of PsEntry objects matching the filter. + :rtype: list of PsEntry + :raises TargetStableError: If the command fails or ps output is malformed. + """ maxsplit = 9 if threads else 8 command = 'ps' if threads: @@ -2100,13 +3578,13 @@ async def ps(self, threads=False, **kwargs): lines = iter(convert_new_lines(await self.execute.asyn(command)).split('\n')) next(lines) # header - result = [] + result: List['PsEntry'] = [] for line in lines: - parts = line.split(None, maxsplit) + parts: List[str] = line.split(None, maxsplit) if not parts: continue - wchan_missing = False + wchan_missing: bool = False if len(parts) == maxsplit: wchan_missing = True @@ -2121,30 +3599,61 @@ async def ps(self, threads=False, **kwargs): if not kwargs: return result else: - filtered_result = [] + filtered_result: List['PsEntry'] = [] for entry in result: if all(getattr(entry, k) == v for k, v in kwargs.items()): filtered_result.append(entry) return filtered_result @asyn.asyncf - async def capture_screen(self, filepath): - on_device_file = self.path.join(self.working_directory, 'screen_capture.png') - cmd = 'screencap -p {} && {} date -u -Iseconds' - ts = (await self.execute.asyn(cmd.format(quote(on_device_file), quote(self.busybox)))).strip() - filepath = filepath.format(ts=ts) - await self.pull.asyn(on_device_file, filepath) - await self.remove.asyn(on_device_file) + async def capture_screen(self, filepath: str) -> None: + """ + Take a screenshot on the device and save it to the specified file on the + host. This may not be supported by the target. You can optionally insert a + ``{ts}`` tag into the file name, in which case it will be substituted with + on-target timestamp of the screen shot in ISO8601 format. + + :param filepath: The host file path to store the screenshot. E.g. + ``"my_screenshot_{ts}.png"`` + :type filepath: str + :raises TargetStableError: If the device lacks a necessary screenshot tool (e.g. screencap). + """ + if self.path and self.working_directory: + on_device_file: str = self.path.join(self.working_directory, 'screen_capture.png') + cmd = 'screencap -p {} && {} date -u -Iseconds' + if self.busybox: + ts = (await self.execute.asyn(cmd.format(quote(on_device_file), quote(self.busybox)))).strip() + filepath = filepath.format(ts=ts) + await self.pull.asyn(on_device_file, filepath) + await self.remove.asyn(on_device_file) # Android-specific @asyn.asyncf - async def input_tap(self, x, y): + async def input_tap(self, x: int, y: int) -> None: + """ + Simulate a tap/click event at (x, y) on the device screen. + + :param x: The horizontal coordinate (pixels). + :type x: int + :param y: The vertical coordinate (pixels). + :type y: int + :raises TargetStableError: If the ``input`` command is not found or fails. + """ command = 'input tap {} {}' await self.execute.asyn(command.format(x, y)) @asyn.asyncf - async def input_tap_pct(self, x, y): + async def input_tap_pct(self, x: int, y: int): + """ + Simulate a tap event using percentage-based coordinates, relative + to the device screen size. + + :param x: Horizontal position as a percentage of screen width (0 to 100). + :type x: int + :param y: Vertical position as a percentage of screen height (0 to 100). + :type y: int + """ width, height = self.screen_resolution x = (x * width) // 100 @@ -2153,19 +3662,35 @@ async def input_tap_pct(self, x, y): await self.input_tap.asyn(x, y) @asyn.asyncf - async def input_swipe(self, x1, y1, x2, y2): + async def input_swipe(self, x1: int, y1: int, x2: int, y2: int) -> None: """ - Issue a swipe on the screen from (x1, y1) to (x2, y2) - Uses absolute screen positions + Issue a swipe gesture from (x1, y1) to (x2, y2), using absolute pixel coordinates. + + :param x1: Start X coordinate in pixels. + :type x1: int + :param y1: Start Y coordinate in pixels. + :type y1: int + :param x2: End X coordinate in pixels. + :type x2: int + :param y2: End Y coordinate in pixels. + :type y2: int """ command = 'input swipe {} {} {} {}' await self.execute.asyn(command.format(x1, y1, x2, y2)) @asyn.asyncf - async def input_swipe_pct(self, x1, y1, x2, y2): + async def input_swipe_pct(self, x1: int, y1: int, x2: int, y2: int) -> None: """ - Issue a swipe on the screen from (x1, y1) to (x2, y2) - Uses percent-based positions + Issue a swipe gesture from (x1, y1) to (x2, y2) using percentage-based coordinates. + + :param x1: Horizontal start percentage (0-100). + :type x1: int + :param y1: Vertical start percentage (0-100). + :type y1: int + :param x2: Horizontal end percentage (0-100). + :type x2: int + :param y2: Vertical end percentage (0-100). + :type y2: int """ width, height = self.screen_resolution @@ -2177,7 +3702,15 @@ async def input_swipe_pct(self, x1, y1, x2, y2): await self.input_swipe.asyn(x1, y1, x2, y2) @asyn.asyncf - async def swipe_to_unlock(self, direction="diagonal"): + async def swipe_to_unlock(self, direction: str = "diagonal") -> None: + """ + Attempt to swipe the lock screen open. Common directions are ``"horizontal"``, + ``"vertical"``, or ``"diagonal"``. + + :param direction: The direction to swipe; defaults to diagonal for maximum coverage. + :type direction: str + :raises TargetStableError: If the direction is invalid or the swipe fails. + """ width, height = self.screen_resolution if direction == "diagonal": start = 100 @@ -2190,21 +3723,41 @@ async def swipe_to_unlock(self, direction="diagonal"): stop = width - start await self.input_swipe.asyn(start, swipe_height, stop, swipe_height) elif direction == "vertical": - swipe_middle = width / 2 + swipe_middle = width // 2 swipe_height = height * 2 // 3 await self.input_swipe.asyn(swipe_middle, swipe_height, swipe_middle, 0) else: raise TargetStableError("Invalid swipe direction: {}".format(direction)) @asyn.asyncf - async def getprop(self, prop=None): + async def getprop(self, prop: Optional[str] = None) -> Optional[Union[str, AndroidProperties]]: + """ + Fetch properties from Android's ``getprop``. If ``prop`` is given, + return just that property’s value; otherwise return a dictionary-like + :class:`AndroidProperties`. + + :param prop: A specific property key to retrieve (e.g. "ro.build.version.sdk"). + :type prop: str, optional + :return: + - If ``prop`` is None, a dictionary-like object mapping all property keys to values. + - If ``prop`` is non-empty, the string value of that specific property. + :rtype: AndroidProperties or str + """ props = AndroidProperties(await self.execute.asyn('getprop')) if prop: return props[prop] return props @asyn.asyncf - async def capture_ui_hierarchy(self, filepath): + async def capture_ui_hierarchy(self, filepath: str) -> None: + """ + Capture the current UI hierarchy via ``uiautomator dump``, pull it to + the host, and optionally format it with pretty XML. + + :param filepath: The host file path to save the UI hierarchy XML. + :type filepath: str + :raises TargetStableError: If the device cannot produce a dump or fails to store it. + """ on_target_file = self.get_workpath('screen_capture.xml') try: await self.execute.asyn('uiautomator dump {}'.format(on_target_file)) @@ -2212,26 +3765,52 @@ async def capture_ui_hierarchy(self, filepath): finally: await self.remove.asyn(on_target_file) - parsed_xml = xml.dom.minidom.parse(filepath) + parsed_xml: Document = xml.dom.minidom.parse(filepath) with open(filepath, 'w') as f: f.write(parsed_xml.toprettyxml()) @asyn.asyncf - async def is_installed(self, name): + async def is_installed(self, name: str) -> bool: + """ + Returns ``True`` if an executable with the specified name is installed on the + target and ``False`` other wise. + """ return (await super(AndroidTarget, self).is_installed.asyn(name)) or (await self.package_is_installed.asyn(name)) @asyn.asyncf - async def package_is_installed(self, package_name): + async def package_is_installed(self, package_name: str) -> bool: + """ + Check if the given package name is installed on the device. + + :param package_name: Name of the Android package (e.g. "com.example.myapp"). + :type package_name: str + :return: True if installed, False otherwise. + :rtype: bool + """ return package_name in (await self.list_packages.asyn()) @asyn.asyncf - async def list_packages(self): - output = await self.execute.asyn('pm list packages') + async def list_packages(self) -> List[str]: + """ + Return a list of installed package names on the device (via ``pm list packages``). + + :return: A list of package identifiers. + :rtype: list of str + """ + output: str = await self.execute.asyn('pm list packages') output = output.replace('package:', '') return output.split() @asyn.asyncf - async def get_package_version(self, package): + async def get_package_version(self, package: str) -> Optional[str]: + """ + Obtain the versionName for a given package by parsing ``dumpsys package``. + + :param package: The package name (e.g. "com.example.myapp"). + :type package: str + :return: The versionName string if found, otherwise None. + :rtype: str or None + """ output = await self.execute.asyn('dumpsys package {}'.format(quote(package))) for line in convert_new_lines(output).split('\n'): if 'versionName' in line: @@ -2239,27 +3818,59 @@ async def get_package_version(self, package): return None @asyn.asyncf - async def get_package_info(self, package): - output = await self.execute.asyn('pm list packages -f {}'.format(quote(package))) + async def get_package_info(self, package: str) -> Optional['installed_package_info']: + """ + Return a tuple (apk_path, package_name) for the installed package, or None if not found. + + :param package: The package identifier (e.g. "com.example.myapp"). + :type package: str + :return: A namedtuple with fields (apk_path, package), or None. + :rtype: installed_package_info or None + """ + output: str = await self.execute.asyn('pm list packages -f {}'.format(quote(package))) for entry in output.strip().split('\n'): rest, entry_package = entry.rsplit('=', 1) if entry_package != package: continue _, apk_path = rest.split(':') return installed_package_info(apk_path, entry_package) + return None @asyn.asyncf - async def get_sdk_version(self): + async def get_sdk_version(self) -> Optional[int]: + """ + Return the integer value of ``ro.build.version.sdk`` if parseable; None if not. + + :return: e.g. 29 for Android 10, or None on error. + :rtype: int or None + """ try: return int(await self.getprop.asyn('ro.build.version.sdk')) except (ValueError, TypeError): return None @asyn.asyncf - async def install_apk(self, filepath, timeout=None, replace=False, allow_downgrade=False): # pylint: disable=W0221 - ext = os.path.splitext(filepath)[1].lower() + async def install_apk(self, filepath: str, timeout: Optional[int] = None, replace: Optional[bool] = False, + allow_downgrade: Optional[bool] = False) -> Optional[str]: # pylint: disable=W0221 + """ + Install an APK onto the device. If the device is connected via AdbConnection, + use an ADB install command. Otherwise, push it and run 'pm install'. + + :param filepath: The path to the APK on the host. + :type filepath: str + :param timeout: The time in seconds to wait for installation. + :type timeout: int, optional + :param replace: If True, pass -r to 'pm install' or `adb install`. + :type replace: bool + :param allow_downgrade: If True, allow installing an older version over a newer one. + :type allow_downgrade: bool + :return: The output from the install command, or None if something unexpected occurs. + :rtype: str or None + :raises TargetStableError: If the file is not an APK or installation fails. + """ + ext: str = os.path.splitext(filepath)[1].lower() if ext == '.apk': - flags = [] + flags: List[str] = [] if replace: flags.append('-r') # Replace existing APK if allow_downgrade: @@ -2273,80 +3884,132 @@ async def install_apk(self, filepath, timeout=None, replace=False, allow_downgra timeout=timeout, adb_server=self.adb_server, adb_port=self.adb_port) else: - dev_path = self.get_workpath(filepath.rsplit(os.path.sep, 1)[-1]) + dev_path: Optional[str] = self.get_workpath(filepath.rsplit(os.path.sep, 1)[-1]) await self.push.asyn(quote(filepath), dev_path, timeout=timeout) - result = await self.execute.asyn("pm install {} {}".format(' '.join(flags), quote(dev_path)), timeout=timeout) - await self.remove.asyn(dev_path) - return result + if dev_path: + result: str = await self.execute.asyn("pm install {} {}".format(' '.join(flags), quote(dev_path)), timeout=timeout) + await self.remove.asyn(dev_path) + return result + else: + raise TargetStableError('Can\'t install. could not get dev path') else: raise TargetStableError('Can\'t install {}: unsupported format.'.format(filepath)) @asyn.asyncf - async def grant_package_permission(self, package, permission): + async def grant_package_permission(self, package: str, permission: str) -> None: + """ + Run `pm grant `. Ignores some errors if the permission + cannot be granted. This is typically used for runtime permissions on modern Android. + + :param package: The target package. + :type package: str + :param permission: The permission string to grant (e.g. "android.permission.READ_LOGS"). + :type permission: str + :raises TargetStableError: If some unexpected error occurs that is not a known ignorable case. + """ try: return await self.execute.asyn('pm grant {} {}'.format(quote(package), quote(permission))) except TargetStableError as e: - if 'is not a changeable permission type' in e.message: - pass # Ignore if unchangeable - elif 'Unknown permission' in e.message: - pass # Ignore if unknown - elif 'has not requested permission' in e.message: - pass # Ignore if not requested - elif 'Operation not allowed' in e.message: - pass # Ignore if not allowed - elif 'is managed by role' in e.message: - pass # Ignore if cannot be granted + if isinstance(e.message, str): + if 'is not a changeable permission type' in e.message: + pass # Ignore if unchangeable + elif 'Unknown permission' in e.message: + pass # Ignore if unknown + elif 'has not requested permission' in e.message: + pass # Ignore if not requested + elif 'Operation not allowed' in e.message: + pass # Ignore if not allowed + elif 'is managed by role' in e.message: + pass # Ignore if cannot be granted + else: + raise else: raise @asyn.asyncf - async def refresh_files(self, file_list): + async def refresh_files(self, file_list: List[str]) -> None: """ - Depending on the android version and root status, determine the - appropriate method of forcing a re-index of the mediaserver cache for a given - list of files. + Attempt to force a re-index of the device media scanner for the given files. + On newer Android (7+), if not rooted, we fallback to scanning each file individually. + + :param file_list: A list of file paths on the device that may need indexing (e.g. new media). + :type file_list: list of str """ - if self.is_rooted or (await self.get_sdk_version.asyn()) < 24: # MM and below - common_path = commonprefix(file_list, sep=self.path.sep) + if self.path and (self.is_rooted or (await self.get_sdk_version.asyn()) < 24): # MM and below + common_path: str = commonprefix(file_list, sep=self.path.sep) await self.broadcast_media_mounted.asyn(common_path, self.is_rooted) else: for f in file_list: await self.broadcast_media_scan_file.asyn(f) @asyn.asyncf - async def broadcast_media_scan_file(self, filepath): + async def broadcast_media_scan_file(self, filepath: str) -> None: """ - Force a re-index of the mediaserver cache for the specified file. + Send a broadcast intent to the Android media scanner for a single file path. + + :param filepath: File path on the device to be scanned by mediaserver. + :type filepath: str """ command = 'am broadcast -a android.intent.action.MEDIA_SCANNER_SCAN_FILE -d {}' await self.execute.asyn(command.format(quote('file://' + filepath))) @asyn.asyncf - async def broadcast_media_mounted(self, dirpath, as_root=False): + async def broadcast_media_mounted(self, dirpath: str, as_root: bool = False) -> None: """ - Force a re-index of the mediaserver cache for the specified directory. + Broadcast that media at a directory path is newly mounted, prompting scanning + of its contents. + + :param dirpath: Directory path on the device. + :type dirpath: str + :param as_root: If True, escalate privileges for the broadcast command. + :type as_root: bool """ command = 'am broadcast -a android.intent.action.MEDIA_MOUNTED -d {} '\ '-n com.android.providers.media/.MediaScannerReceiver' - await self.execute.asyn(command.format(quote('file://'+dirpath)), as_root=as_root) + await self.execute.asyn(command.format(quote('file://' + dirpath)), as_root=as_root) @asyn.asyncf - async def install_executable(self, filepath, with_name=None, timeout=None): + async def install_executable(self, filepath: str, with_name: Optional[str] = None, + timeout: Optional[int] = None) -> Optional[str]: + """ + Install a single executable (non-APK) onto the device. Typically places + it in :attr:`executables_directory`, making it executable with chmod. + + :param filepath: The path on the host to the binary. + :type filepath: str + :param with_name: Optional name to rename the binary on the device. + :type with_name: str, optional + :param timeout: Time in seconds to allow the push & setup. + :type timeout: int, optional + :return: Path to the installed binary on the device, or None on failure. + :rtype: str or None + :raises TargetStableError: If the push or setup steps fail. + """ self._ensure_executables_directory_is_writable() - executable_name = with_name or os.path.basename(filepath) - on_device_file = self.path.join(self.working_directory, executable_name) - on_device_executable = self.path.join(self.executables_directory, executable_name) - await self.push.asyn(filepath, on_device_file, timeout=timeout) - if on_device_file != on_device_executable: - await self.execute.asyn('cp -f -- {} {}'.format(quote(on_device_file), quote(on_device_executable)), - as_root=self.needs_su, timeout=timeout) - await self.remove.asyn(on_device_file, as_root=self.needs_su) - await self.execute.asyn("chmod 0777 {}".format(quote(on_device_executable)), as_root=self.needs_su) - self._installed_binaries[executable_name] = on_device_executable - return on_device_executable - - @asyn.asyncf - async def uninstall_package(self, package): + executable_name: str = with_name or os.path.basename(filepath) + if self.path: + on_device_file: str = self.path.join(self.working_directory, executable_name) + on_device_executable: str = self.path.join(self.executables_directory, executable_name) + await self.push.asyn(filepath, on_device_file, timeout=timeout) + if on_device_file != on_device_executable: + await self.execute.asyn('cp -f -- {} {}'.format(quote(on_device_file), quote(on_device_executable)), + as_root=self.needs_su, timeout=timeout) + await self.remove.asyn(on_device_file, as_root=self.needs_su) + await self.execute.asyn("chmod 0777 {}".format(quote(on_device_executable)), as_root=self.needs_su) + self._installed_binaries[executable_name] = on_device_executable + return on_device_executable + else: + raise TargetStableError('path is not assigned') + + @asyn.asyncf + async def uninstall_package(self, package: str) -> None: + """ + Uninstall an Android package by name (using ``adb uninstall`` or + ``pm uninstall``). + + :param package: The package name to remove. + :type package: str + """ if isinstance(self.conn, AdbConnection): adb_command(self.adb_name, "uninstall {}".format(quote(package)), timeout=30, adb_server=self.adb_server, adb_port=self.adb_port) @@ -2354,14 +4017,36 @@ async def uninstall_package(self, package): await self.execute.asyn("pm uninstall {}".format(quote(package)), timeout=30) @asyn.asyncf - async def uninstall_executable(self, executable_name): - on_device_executable = self.path.join(self.executables_directory, executable_name) - self._ensure_executables_directory_is_writable() - await self.remove.asyn(on_device_executable, as_root=self.needs_su) + async def uninstall_executable(self, executable_name: str) -> None: + """ + Remove an installed executable from :attr:`executables_directory`. + + :param executable_name: The name of the binary to remove. + :type executable_name: str + """ + if self.path: + on_device_executable = self.path.join(self.executables_directory, executable_name) + self._ensure_executables_directory_is_writable() + await self.remove.asyn(on_device_executable, as_root=self.needs_su) @asyn.asyncf - async def dump_logcat(self, filepath, filter=None, logcat_format=None, append=False, - timeout=60): # pylint: disable=redefined-builtin + async def dump_logcat(self, filepath: str, filter: Optional[str] = None, + logcat_format: Optional[str] = None, + append: bool = False, timeout: int = 60) -> None: # pylint: disable=redefined-builtin + """ + Collect logcat output from the device and save it to ``filepath`` on the host. + + :param filepath: The file on the host to store the log output. + :type filepath: str + :param filter: If provided, a filter specifying which tags to match (e.g. '-s MyTag'). + :type filter: str, optional + :param logcat_format: Logcat format (e.g., 'threadtime'), if any. + :type logcat_format: str, optional + :param append: If True, append to the host file instead of overwriting. + :type append: bool + :param timeout: How many seconds to allow for reading the log. + :type timeout: int + """ op = '>>' if append else '>' filtstr = ' -s {}'.format(quote(filter)) if filter else '' formatstr = ' -v {}'.format(quote(logcat_format)) if logcat_format else '' @@ -2372,13 +4057,17 @@ async def dump_logcat(self, filepath, filter=None, logcat_format=None, append=Fa adb_port=self.adb_port) else: dev_path = self.get_workpath('logcat') - command = 'logcat {} {} {}'.format(logcat_opts, op, quote(dev_path)) - await self.execute.asyn(command, timeout=timeout) - await self.pull.asyn(dev_path, filepath) - await self.remove.asyn(dev_path) + if dev_path: + command = 'logcat {} {} {}'.format(logcat_opts, op, quote(dev_path)) + await self.execute.asyn(command, timeout=timeout) + await self.pull.asyn(dev_path, filepath) + await self.remove.asyn(dev_path) @asyn.asyncf - async def clear_logcat(self): + async def clear_logcat(self) -> None: + """ + Clear the device's logcat (``logcat -c``). Uses a lock to avoid concurrency issues. + """ locked = self.clear_logcat_lock.acquire(blocking=False) if locked: try: @@ -2390,26 +4079,65 @@ async def clear_logcat(self): finally: self.clear_logcat_lock.release() - def get_logcat_monitor(self, regexps=None): + def get_logcat_monitor(self, regexps: Optional[List[str]] = None) -> LogcatMonitor: + """ + Create a :class:`LogcatMonitor` object for capturing logcat output from the device. + + :param regexps: An optional list of uncompiled regex strings to filter log entries. + :type regexps: list of str + :return: A new LogcatMonitor instance referencing this AndroidTarget. + :rtype: LogcatMonitor + """ return LogcatMonitor(self, regexps) @call_conn - def wait_for_device(self, timeout=30): - self.conn.wait_for_device() + def wait_for_device(self, timeout: int = 30) -> None: + """ + Instruct ADB to wait until the device is present (``adb wait-for-device``). + + :param timeout: Seconds to wait before failing. + :type timeout: int + :raises TargetStableError: If waiting times out or if the connection is not ADB. + """ + if isinstance(self.conn, AdbConnection): + self.conn.wait_for_device() @call_conn - def reboot_bootloader(self, timeout=30): - self.conn.reboot_bootloader() + def reboot_bootloader(self, timeout: int = 30) -> None: + """ + Reboot the device into fastboot/bootloader mode. + + :param timeout: Time in seconds to allow for device to transition. + :type timeout: int + :raises TargetStableError: If not using ADB or the command fails. + """ + if isinstance(self.conn, AdbConnection): + self.conn.reboot_bootloader() @asyn.asyncf - async def is_screen_locked(self): + async def is_screen_locked(self) -> bool: + """ + Determine if the lock screen is active (e.g., phone is locked). + + :return: True if the screen is locked, False otherwise. + :rtype: bool + """ screen_state = await self.execute.asyn('dumpsys window') return 'mDreamingLockscreen=true' in screen_state @asyn.asyncf - async def is_screen_on(self): - output = await self.execute.asyn('dumpsys power') - match = ANDROID_SCREEN_STATE_REGEX.search(output) + async def is_screen_on(self) -> bool: + """ + Check if the device screen is currently on. + + :return: + - True if the screen is on or in certain "doze" states. + - False if the screen is off or fully asleep. + :rtype: bool + :raises TargetStableError: If unable to parse display power state. + """ + output: str = await self.execute.asyn('dumpsys power') + match: Optional[Match[str]] = ANDROID_SCREEN_STATE_REGEX.search(output) if match: if 'DOZE' in match.group(1).upper(): return True @@ -2424,19 +4152,55 @@ async def is_screen_on(self): raise TargetStableError('Could not establish screen state.') @asyn.asyncf - async def ensure_screen_is_on(self, verify=True): + async def ensure_screen_is_on(self, verify: bool = True) -> None: + """ + If the screen is off, press the power button (keyevent 26) to wake it. + Optionally verify the screen is on afterwards. + + :param verify: If True, raise an error if the screen doesn't turn on. + :type verify: bool + :raises TargetStableError: If the screen is still off after the attempt. + """ if not await self.is_screen_on.asyn(): + # The adb shell input keyevent 26 command is used to + # simulate pressing the power button on an Android device. self.execute('input keyevent 26') if verify and not await self.is_screen_on.asyn(): raise TargetStableError('Display cannot be turned on.') @asyn.asyncf - async def ensure_screen_is_on_and_stays(self, verify=True, mode=7): + async def ensure_screen_is_on_and_stays(self, verify: bool = True, mode: int = 7) -> None: + """ + Calls ``AndroidTarget.ensure_screen_is_on(verify)`` then additionally + sets the screen stay on mode to ``mode``. + mode options - + 0: Never stay on while plugged in. + 1: Stay on while plugged into an AC charger. + 2: Stay on while plugged into a USB charger. + 4: Stay on while on a wireless charger. + You can combine these values using bitwise OR. + For example, 3 (1 | 2) will stay on while plugged into either an AC or USB charger + + :param verify: If True, check that the screen does come on. + :type verify: bool + :param mode: A bitwise combination of (1 for AC, 2 for USB, 4 for wireless). + :type mode: int + """ await self.ensure_screen_is_on.asyn(verify=verify) await self.set_stay_on_mode.asyn(mode) @asyn.asyncf - async def ensure_screen_is_off(self, verify=True): + async def ensure_screen_is_off(self, verify: bool = True) -> None: + """ + Checks if the devices screen is on and if so turns it off. + If ``verify`` is set to ``True`` then a ``TargetStableError`` + will be raise if the display cannot be turned off. E.g. if + always on mode is enabled. + + :param verify: Raise an error if the screen remains on afterwards. + :type verify: bool + :raises TargetStableError: If the display remains on due to always-on or lock states. + """ # Allow 2 attempts to help with cases of ambient display modes # where the first attempt will switch the display fully on. for _ in range(2): @@ -2444,21 +4208,41 @@ async def ensure_screen_is_off(self, verify=True): await self.execute.asyn('input keyevent 26') time.sleep(0.5) if verify and await self.is_screen_on.asyn(): - msg = 'Display cannot be turned off. Is always on display enabled?' - raise TargetStableError(msg) + msg: str = 'Display cannot be turned off. Is always on display enabled?' + raise TargetStableError(msg) @asyn.asyncf - async def set_auto_brightness(self, auto_brightness): + async def set_auto_brightness(self, auto_brightness: bool) -> None: + """ + Enable or disable automatic screen brightness. + + :param auto_brightness: True to enable auto-brightness, False to disable. + :type auto_brightness: bool + """ cmd = 'settings put system screen_brightness_mode {}' await self.execute.asyn(cmd.format(int(boolean(auto_brightness)))) @asyn.asyncf - async def get_auto_brightness(self): + async def get_auto_brightness(self) -> bool: + """ + Check if auto-brightness is enabled. + + :return: True if auto-brightness is on, False otherwise. + :rtype: bool + """ cmd = 'settings get system screen_brightness_mode' return boolean((await self.execute.asyn(cmd)).strip()) @asyn.asyncf - async def set_brightness(self, value): + async def set_brightness(self, value: int) -> None: + """ + Manually set screen brightness to an integer between 0 and 255. + This also disables auto-brightness first. + + :param value: Desired brightness level (0-255). + :type value: int + :raises ValueError: If the given value is outside [0..255]. + """ if not 0 <= value <= 255: msg = 'Invalid brightness "{}"; Must be between 0 and 255' raise ValueError(msg.format(value)) @@ -2467,69 +4251,148 @@ async def set_brightness(self, value): await self.execute.asyn(cmd.format(int(value))) @asyn.asyncf - async def get_brightness(self): + async def get_brightness(self) -> int: + """ + Return the current screen brightness (0..255). + + :return: The brightness setting. + :rtype: int + """ cmd = 'settings get system screen_brightness' return integer((await self.execute.asyn(cmd)).strip()) @asyn.asyncf - async def set_screen_timeout(self, timeout_ms): + async def set_screen_timeout(self, timeout_ms: int) -> None: + """ + Set the screen-off timeout in milliseconds. + + :param timeout_ms: Number of ms before the screen turns off when idle. + :type timeout_ms: int + """ cmd = 'settings put system screen_off_timeout {}' await self.execute.asyn(cmd.format(int(timeout_ms))) @asyn.asyncf - async def get_screen_timeout(self): + async def get_screen_timeout(self) -> int: + """ + Get the screen-off timeout (ms). + + :return: Milliseconds before screen turns off. + :rtype: int + """ cmd = 'settings get system screen_off_timeout' return int((await self.execute.asyn(cmd)).strip()) @asyn.asyncf - async def get_airplane_mode(self): + async def get_airplane_mode(self) -> bool: + """ + Check if airplane mode is active (global setting). + + .. note:: Requires the device to be rooted if the device is running Android 7+. + + :return: True if airplane mode is on, otherwise False. + :rtype: bool + """ cmd = 'settings get global airplane_mode_on' return boolean((await self.execute.asyn(cmd)).strip()) @asyn.asyncf - async def get_stay_on_mode(self): + async def get_stay_on_mode(self) -> int: + """ + Returns an integer between ``0`` and ``7`` representing the current + stay-on mode of the device. + 0: Never stay on while plugged in. + 1: Stay on while plugged into an AC charger. + 2: Stay on while plugged into a USB charger. + 4: Stay on while on a wireless charger. + Combinations of these values can be used (e.g., 3 for both AC and USB chargers) + + :return: The integer bitmask (0..7). + :rtype: int + """ cmd = 'settings get global stay_on_while_plugged_in' return int((await self.execute.asyn(cmd)).strip()) @asyn.asyncf - async def set_airplane_mode(self, mode): - root_required = await self.get_sdk_version.asyn() > 23 + async def set_airplane_mode(self, mode: bool) -> None: + """ + Enable or disable airplane mode. On Android 7+, requires root. + + :param mode: True to enable airplane mode, False to disable. + :type mode: bool + :raises TargetStableError: If root is required but the device is not rooted. + """ + root_required: bool = await self.get_sdk_version.asyn() > 23 if root_required and not self.is_rooted: raise TargetStableError('Root is required to toggle airplane mode on Android 7+') - mode = int(boolean(mode)) + modeint = int(boolean(mode)) cmd = 'settings put global airplane_mode_on {}' - await self.execute.asyn(cmd.format(mode)) + await self.execute.asyn(cmd.format(modeint)) await self.execute.asyn('am broadcast -a android.intent.action.AIRPLANE_MODE ' - '--ez state {}'.format(mode), as_root=root_required) + '--ez state {}'.format(mode), as_root=root_required) @asyn.asyncf - async def get_auto_rotation(self): + async def get_auto_rotation(self) -> bool: + """ + Check if auto-rotation is enabled (system setting). + + :return: True if accelerometer-based rotation is enabled, False otherwise. + :rtype: bool + """ cmd = 'settings get system accelerometer_rotation' return boolean((await self.execute.asyn(cmd)).strip()) @asyn.asyncf - async def set_auto_rotation(self, autorotate): + async def set_auto_rotation(self, autorotate: bool) -> None: + """ + Enable or disable auto-rotation of the screen. + + :param autorotate: True to enable, False to disable. + :type autorotate: bool + """ cmd = 'settings put system accelerometer_rotation {}' await self.execute.asyn(cmd.format(int(boolean(autorotate)))) @asyn.asyncf - async def set_natural_rotation(self): + async def set_natural_rotation(self) -> None: + """ + Sets the screen orientation of the device to its natural (0 degrees) + orientation. + """ await self.set_rotation.asyn(0) @asyn.asyncf - async def set_left_rotation(self): + async def set_left_rotation(self) -> None: + """ + Sets the screen orientation of the device to 90 degrees. + """ await self.set_rotation.asyn(1) @asyn.asyncf - async def set_inverted_rotation(self): + async def set_inverted_rotation(self) -> None: + """ + Sets the screen orientation of the device to its inverted (180 degrees) + orientation. + """ await self.set_rotation.asyn(2) @asyn.asyncf - async def set_right_rotation(self): + async def set_right_rotation(self) -> None: + """ + Sets the screen orientation of the device to 270 degrees. + """ await self.set_rotation.asyn(3) @asyn.asyncf - async def get_rotation(self): + async def get_rotation(self) -> Optional[int]: + """ + Returns an integer value representing the orientation of the devices + screen. ``0`` : Natural, ``1`` : Rotated Left, ``2`` : Inverted + and ``3`` : Rotated Right. + + :return: The rotation value or None if not found. + :rtype: int or None + """ output = await self.execute.asyn('dumpsys input') match = ANDROID_SCREEN_ROTATION_REGEX.search(output) if match: @@ -2538,7 +4401,16 @@ async def get_rotation(self): return None @asyn.asyncf - async def set_rotation(self, rotation): + async def set_rotation(self, rotation: int) -> None: + """ + Specify an integer representing the desired screen rotation with the + following mappings: Natural: ``0``, Rotated Left: ``1``, Inverted : ``2`` + and Rotated Right : ``3``. + + :param rotation: Integer in [0..3]. + :type rotation: int + :raises ValueError: If rotation is not within [0..3]. + """ if not 0 <= rotation <= 3: raise ValueError('Rotation value must be between 0 and 3') await self.set_auto_rotation.asyn(False) @@ -2546,54 +4418,87 @@ async def set_rotation(self, rotation): await self.execute.asyn(cmd.format(rotation)) @asyn.asyncf - async def set_stay_on_never(self): + async def set_stay_on_never(self) -> None: + """ + Sets the stay-on mode to ``0``, where the screen will turn off + as standard after the timeout. + """ await self.set_stay_on_mode.asyn(0) @asyn.asyncf - async def set_stay_on_while_powered(self): + async def set_stay_on_while_powered(self) -> None: + """ + Sets the stay-on mode to ``7``, where the screen will stay on + while the device is charging + """ await self.set_stay_on_mode.asyn(7) @asyn.asyncf - async def set_stay_on_mode(self, mode): + async def set_stay_on_mode(self, mode: int) -> None: + """ + 0: Never stay on while plugged in. + 1: Stay on while plugged into an AC charger. + 2: Stay on while plugged into a USB charger. + 4: Stay on while on a wireless charger. + You can combine these values using bitwise OR. + For example, 3 (1 | 2) will stay on while plugged into either an AC or USB charger + + :param mode: Value in [0..7]. + :type mode: int + :raises ValueError: If outside [0..7]. + """ if not 0 <= mode <= 7: raise ValueError('Screen stay on mode must be between 0 and 7') cmd = 'settings put global stay_on_while_plugged_in {}' await self.execute.asyn(cmd.format(mode)) @asyn.asyncf - async def open_url(self, url, force_new=False): + async def open_url(self, url: str, force_new: bool = False) -> None: """ - Start a view activity by specifying an URL + Launch an intent to view a given URL, optionally forcing a new task in + the activity stack. - :param url: URL of the item to display + :param url: URL to open (e.g. "https://www.example.com"). :type url: str - - :param force_new: Force the viewing application to be relaunched - if it is already running + :param force_new: If True, use flags to clear the existing activity stack, + forcing a fresh activity. :type force_new: bool """ cmd = 'am start -a android.intent.action.VIEW -d {}' if force_new: - cmd = cmd + ' -f {}'.format(INTENT_FLAGS['ACTIVITY_NEW_TASK'] | - INTENT_FLAGS['ACTIVITY_CLEAR_TASK']) + cmd = cmd + ' -f {}'.format(INTENT_FLAGS['ACTIVITY_NEW_TASK'] | INTENT_FLAGS['ACTIVITY_CLEAR_TASK']) await self.execute.asyn(cmd.format(quote(url))) @asyn.asyncf - async def homescreen(self): + async def homescreen(self) -> None: + """ + Return to the home screen by launching the MAIN/HOME intent. + """ await self.execute.asyn('am start -a android.intent.action.MAIN -c android.intent.category.HOME') - def _resolve_paths(self): - if self.working_directory is None: - self.working_directory = self.path.join(self.external_storage, 'devlib-target') - self._file_transfer_cache = self.path.join(self.working_directory, '.file-cache') + def _resolve_paths(self) -> None: + """ + Finalize the paths for working directory, executables directory, etc. + If not user-defined, default them to a known location on the device. + """ + if (self.working_directory is None) and self.path: + self.working_directory: str = self.path.join(self.external_storage, 'devlib-target') + if self.path: + self._file_transfer_cache: str = self.path.join(self.working_directory, '.file-cache') if self.executables_directory is None: - self.executables_directory = '/data/local/tmp/bin' + self.executables_directory: str = '/data/local/tmp/bin' @asyn.asyncf - async def _ensure_executables_directory_is_writable(self): - matched = [] + async def _ensure_executables_directory_is_writable(self) -> None: + """ + Check if the executables directory is on a writable mount. If not, attempt + to remount it read/write as root. + + :raises TargetStableError: If the directory cannot be remounted or found in fstab. + """ + matched: List['FstabEntry'] = [] for entry in await self.list_file_systems.asyn(): if self.executables_directory.rstrip('/').startswith(entry.mount_point): matched.append(entry) @@ -2601,8 +4506,8 @@ async def _ensure_executables_directory_is_writable(self): entry = sorted(matched, key=lambda x: len(x.mount_point))[-1] if 'rw' not in entry.options: await self.execute.asyn('mount -o rw,remount {} {}'.format(quote(entry.device), - quote(entry.mount_point)), - as_root=True) + quote(entry.mount_point)), + as_root=True) else: message = 'Could not find mount point for executables directory {}' raise TargetStableError(message.format(self.executables_directory)) @@ -2610,87 +4515,131 @@ async def _ensure_executables_directory_is_writable(self): _charging_enabled_path = '/sys/class/power_supply/battery/charging_enabled' @property - def charging_enabled(self): + def charging_enabled(self) -> Optional[bool]: """ Whether drawing power to charge the battery is enabled Not all devices have the ability to enable/disable battery charging (e.g. because they don't have a battery). In that case, ``charging_enabled`` is None. + + :return: + - True if charging is enabled + - False if disabled + - None if the sysfs entry is absent + :rtype: bool or None """ if not self.file_exists(self._charging_enabled_path): return None return self.read_bool(self._charging_enabled_path) @charging_enabled.setter - def charging_enabled(self, enabled): + def charging_enabled(self, enabled: bool) -> None: """ Enable/disable drawing power to charge the battery Not all devices have this facility. In that case, do nothing. + + :param enabled: True to enable charging, False to disable. + :type enabled: bool """ if not self.file_exists(self._charging_enabled_path): return self.write_value(self._charging_enabled_path, int(bool(enabled))) + FstabEntry = namedtuple('FstabEntry', ['device', 'mount_point', 'fs_type', 'options', 'dump_freq', 'pass_num']) PsEntry = namedtuple('PsEntry', 'user pid tid ppid vsize rss wchan pc state name') LsmodEntry = namedtuple('LsmodEntry', ['name', 'size', 'use_count', 'used_by']) class Cpuinfo(object): - + """ + Represents the parsed contents of ``/proc/cpuinfo`` on the target. + + :param sections: A list of dictionaries, where each dictionary represents a + block of lines corresponding to a CPU. Key-value pairs correspond to + lines like ``CPU part: 0xd03`` or ``model name: Cortex-A53``. + :type sections: list of dict + :param text: The full text of the original ``/proc/cpuinfo`` content. + :type text: str + """ @property @memoized - def architecture(self): - for section in self.sections: - if 'CPU architecture' in section: - return section['CPU architecture'] - if 'architecture' in section: - return section['architecture'] + def architecture(self) -> Optional[str]: + """ + architecture as per cpuinfo + """ + if self.sections: + for section in self.sections: + if 'CPU architecture' in section: + return section['CPU architecture'] + if 'architecture' in section: + return section['architecture'] + return None @property @memoized - def cpu_names(self): - cpu_names = [] - global_name = None - for section in self.sections: - if 'processor' in section: - if 'CPU part' in section: - cpu_names.append(_get_part_name(section)) - elif 'model name' in section: - cpu_names.append(_get_model_name(section)) - else: - cpu_names.append(None) - elif 'CPU part' in section: - global_name = _get_part_name(section) + def cpu_names(self) -> List[caseless_string]: + """ + A list of CPU names derived from fields like ``CPU part`` or ``model name``. + If found globally, that name is reused for each CPU. If found per-CPU, + you get multiple entries. + + :return: List of CPU names, one per processor entry. + :rtype: list of caseless_string + """ + cpu_names: List[Optional[str]] = [] + global_name: Optional[str] = None + if self.sections: + for section in self.sections: + if 'processor' in section: + if 'CPU part' in section: + cpu_names.append(_get_part_name(section)) + elif 'model name' in section: + cpu_names.append(_get_model_name(section)) + else: + cpu_names.append(None) + elif 'CPU part' in section: + global_name = _get_part_name(section) return [caseless_string(c or global_name) for c in cpu_names] - def __init__(self, text): - self.sections = None - self.text = None + def __init__(self, text: str): + self.sections: List[Dict[str, str]] = [] + self.text = '' self.parse(text) @memoized - def get_cpu_features(self, cpuid=0): - global_features = [] - for section in self.sections: - if 'processor' in section: - if int(section.get('processor')) != cpuid: - continue - if 'Features' in section: - return section.get('Features').split() + def get_cpu_features(self, cpuid: int = 0) -> List[str]: + """ + get the Features field of the specified cpu + """ + global_features: List[str] = [] + if self.sections: + for section in self.sections: + if 'processor' in section: + if int(section.get('processor') or -1) != cpuid: + continue + if 'Features' in section: + return section.get('Features', '').split() + elif 'flags' in section: + return section.get('flags', '').split() + elif 'Features' in section: + global_features = section.get('Features', '').split() elif 'flags' in section: - return section.get('flags').split() - elif 'Features' in section: - global_features = section.get('Features').split() - elif 'flags' in section: - global_features = section.get('flags').split() + global_features = section.get('flags', '').split() return global_features - def parse(self, text): + def parse(self, text: str) -> None: + """ + Parse the provided ``/proc/cpuinfo`` text, splitting it into separate + sections for each CPU. + + :param text: The full multiline content of /proc/cpuinfo. + :type text: str + """ self.sections = [] - current_section = {} + current_section: Dict[str, str] = {} self.text = text.strip() for line in self.text.split('\n'): line = line.strip() @@ -2742,11 +4691,11 @@ class KernelVersion(object): lexicographically comparing kernel versions. :type parts: tuple(int) """ - def __init__(self, version_string): + def __init__(self, version_string: str): if ' #' in version_string: release, version = version_string.split(' #') - self.release = release - self.version = version + self.release: str = release + self.version: str = version elif version_string.startswith('#'): self.release = '' self.version = version_string @@ -2754,15 +4703,15 @@ def __init__(self, version_string): self.release = version_string self.version = '' - self.version_number = None - self.major = None - self.minor = None - self.sha1 = None - self.rc = None - self.commits = None - self.gki_abi = None - self.android_version = None - match = KVERSION_REGEX.match(version_string) + self.version_number: Optional[int] = None + self.major: Optional[int] = None + self.minor: Optional[int] = None + self.sha1: Optional[str] = None + self.rc: Optional[int] = None + self.commits: Optional[int] = None + self.gki_abi: Optional[str] = None + self.android_version: Optional[int] = None + match: Optional[Match[str]] = KVERSION_REGEX.match(version_string) if match: groups = match.groupdict() self.version_number = int(groups['version']) @@ -2780,7 +4729,7 @@ def __init__(self, version_string): if groups['android_version'] is not None: self.android_version = int(match.group('android_version')) - self.parts = (self.version_number, self.major, self.minor) + self.parts: Tuple[Optional[int], Optional[int], Optional[int]] = (self.version_number, self.major, self.minor) def __str__(self): return '{} {}'.format(self.release, self.version) @@ -2788,68 +4737,141 @@ def __str__(self): __repr__ = __str__ -class HexInt(long): +class HexInt(int): """ - Subclass of :class:`int` that uses hexadecimal formatting by default. + An int subclass that is displayed in hexadecimal form. + + Example usage: + + .. code-block:: python + + val = HexInt('FF') # Parse hex string as int + print(val) # Prints: 0xff + print(int(val)) # Prints: 255 """ - def __new__(cls, val=0, base=16): + def __new__(cls, val: Union[str, int, bytearray] = 0, base=16): + """ + Construct a HexInt object, interpreting ``val`` as a base-16 value + unless it's already a number or bytearray. + + :param val: The initial value. If str, is parsed as base-16 by default; + if int or bytearray, used directly. + :type val: str or int or bytearray + :param base: Numerical base (defaults to 16). + :type base: int + :raises TypeError: If ``val`` is not a supported type (str, int, or bytearray). + """ super_new = super(HexInt, cls).__new__ if isinstance(val, Number): return super_new(cls, val) + elif isinstance(val, bytearray): + val = int.from_bytes(val, byteorder=sys.byteorder) + return super(HexInt, cls).__new__(cls, val) + elif isinstance(val, str): + return super(HexInt, cls).__new__(cls, int(val, base)) else: - return super_new(cls, val, base=base) + raise TypeError("Unsupported type for HexInt") def __str__(self): + """ + Return a hexadecimal string representation of the integer, stripping + any trailing ``L`` in Python 2.x. + """ return hex(self).strip('L') class KernelConfigTristate(Enum): + """ + Represents a kernel config option that may be ``y``, ``n``, or ``m``. + Commonly seen in kernel ``.config`` files as: + + - ``CONFIG_FOO=y`` + - ``CONFIG_BAR=n`` + - ``CONFIG_BAZ=m`` + + Enum members: + * ``YES`` -> 'y' + * ``NO`` -> 'n' + * ``MODULE`` -> 'm' + """ YES = 'y' NO = 'n' MODULE = 'm' def __bool__(self): """ - Allow using this enum to represent bool Kconfig type, although it is - technically different from tristate. + Allow usage in boolean contexts: + + * True if the config is 'y' or 'm' + * False if the config is 'n' """ return self in (self.YES, self.MODULE) def __nonzero__(self): """ - For Python 2.x compatibility. + Python 2.x compatibility for boolean evaluation. """ return self.__bool__() @classmethod - def from_str(cls, str_): + def from_str(cls, str_: str) -> 'KernelConfigTristate': + """ + Convert a kernel config string ('y', 'n', or 'm') to the corresponding + enum member. + + :param str_: The single-character string from kernel config. + :type str_: str + :return: The enum member that matches the provided string. + :rtype: KernelConfigTristate + :raises ValueError: If the string is not 'y', 'n', or 'm'. + """ for state in cls: if state.value == str_: return state raise ValueError('No kernel config tristate value matches "{}"'.format(str_)) -class TypedKernelConfig(Mapping): +class TypedKernelConfig(Mapping): # type: ignore """ - Mapping-like typed version of :class:`KernelConfig`. + A mapping-like object representing typed kernel config parameters. Keys are + canonicalized config names (e.g. "CONFIG_FOO"), and values may be strings, ints, + :class:`HexInt`, or :class:`KernelConfigTristate`. + + :param not_set_regex: A regex that matches lines in the form ``# CONFIG_ABC is not set``. + :type not_set_regex: Pattern - Values are either :class:`str`, :class:`int`, - :class:`KernelConfigTristate`, or :class:`HexInt`. ``hex`` Kconfig type is - mapped to :class:`HexInt` and ``bool`` to :class:`KernelConfigTristate`. + :param mapping: An optional initial mapping of config keys to string values. + Typically set by parsing a kernel .config file or /proc/config.gz content. + :type mapping: Mapping or None """ not_set_regex = re.compile(r'# (\S+) is not set') @staticmethod - def get_config_name(name): + def get_config_name(name: str) -> str: + """ + Ensure the config name starts with 'CONFIG_', returning + the canonical form. + + :param name: A raw config key name (e.g. 'ABC'). + :type name: str + :return: The canonical name (e.g. 'CONFIG_ABC'). + :rtype: str + """ name = name.upper() if not name.startswith('CONFIG_'): name = 'CONFIG_' + name return name - def __init__(self, mapping=None): + def __init__(self, mapping: Optional[Maptype] = None): + """ + Initialize a typed kernel config from an existing dictionary or None. + + :param mapping: Existing config data (raw strings), keyed by config name. + :type mapping: Mapping, optional + """ mapping = mapping if mapping is not None else {} - self._config = { + self._config: Dict[str, str] = { # Ensure we use the canonical name of the config keys for internal # representation self.get_config_name(k): v @@ -2857,34 +4879,48 @@ def __init__(self, mapping=None): } @classmethod - def from_str(cls, text): + def from_str(cls, text: str) -> 'TypedKernelConfig': """ - Build a :class:`TypedKernelConfig` out of the string content of a - Kconfig file. + Build a typed config by parsing raw text of a kernel config file. + + :param text: Contents of the kernel config, including lines such as + ``CONFIG_ABC=y`` or ``# CONFIG_DEF is not set``. + :type text: str + :return: A :class:`TypedKernelConfig` reflecting typed config values. + :rtype: TypedKernelConfig """ return cls(cls._parse_text(text)) @staticmethod - def _val_to_str(val): + def _val_to_str(val: Optional[Union[KernelConfigTristate, str]]) -> str: "Convert back values to Kconfig-style string value" # Special case the gracefully handle the output of get() if val is None: - return None + return "" elif isinstance(val, KernelConfigTristate): return val.value - elif isinstance(val, basestring): + elif isinstance(val, str): return '"{}"'.format(val.strip('"')) else: return str(val) def __str__(self): + """ + Convert the typed config back to a kernel config-style string, e.g. + "CONFIG_FOO=y\nCONFIG_BAR=\"value\"\n..." + + :return: A multi-line string representation of the typed config. + :rtype: str + """ return '\n'.join( '{}={}'.format(k, self._val_to_str(v)) for k, v in self.items() ) @staticmethod - def _parse_val(k, v): + def _parse_val(k: str, v: Union[str, int, HexInt, + KernelConfigTristate]) -> Optional[Union[KernelConfigTristate, + HexInt, int, str]]: """ Parse a value of types handled by Kconfig: * string @@ -2896,43 +4932,54 @@ def _parse_val(k, v): Since bool cannot be distinguished from tristate, tristate is always used. :meth:`KernelConfigTristate.__bool__` will allow using it as a bool though, so it should not impact user code. + + :param k: The config key name (not used heavily). + :type k: str + :param v: The raw string or typed object. + :type v: str or int or KernelConfigTristate + :return: The typed version of the value. """ if not v: return None - # Handle "string" type - if v.startswith('"'): - # Strip enclosing " - return v[1:-1] + if isinstance(v, str): + # Handle "string" type + if v.startswith('"'): + # Strip enclosing " + return v[1:-1] - else: - try: - # Handles "bool" and "tristate" types - return KernelConfigTristate.from_str(v) - except ValueError: - pass + else: + try: + # Handles "bool" and "tristate" types + return KernelConfigTristate.from_str(v) + except ValueError: + pass - try: - # Handles "int" type - return int(v) - except ValueError: - pass + try: + # Handles "int" type + return int(v) + except ValueError: + pass - try: - # Handles "hex" type - return HexInt(v) - except ValueError: - pass + try: + # Handles "hex" type + return HexInt(v) + except ValueError: + pass - # If no type could be parsed - raise ValueError('Could not parse Kconfig key: {}={}'.format( + # If no type could be parsed + raise ValueError('Could not parse Kconfig key: {}={}'.format( k, v ), k, v - ) + ) + return None @classmethod - def _parse_text(cls, text): - config = {} + def _parse_text(cls, text: str) -> Dict[str, Optional[Union[KernelConfigTristate, HexInt, int, str]]]: + """ + parse the kernel config text and create a dictionary of the configs + """ + config: Dict[str, Optional[Union[KernelConfigTristate, HexInt, int, str]]] = {} for line in text.splitlines(): line = line.strip() @@ -2943,19 +4990,19 @@ def _parse_text(cls, text): if line.startswith('#'): match = cls.not_set_regex.search(line) if match: - value = 'n' - name = match.group(1) + value: str = 'n' + name: str = match.group(1) else: continue else: name, value = line.split('=', 1) name = cls.get_config_name(name.strip()) - value = cls._parse_val(name, value.strip()) - config[name] = value + parsed_value: Optional[Union[KernelConfigTristate, HexInt, int, str]] = cls._parse_val(name, value.strip()) + config[name] = parsed_value return config - def __getitem__(self, name): + def __getitem__(self, name: str) -> str: name = self.get_config_name(name) try: return self._config[name] @@ -2971,27 +5018,43 @@ def __iter__(self): def __len__(self): return len(self._config) +# FIXME - annotating name as str gives some type errors as Mapping superclass expects object def __contains__(self, name): name = self.get_config_name(name) return name in self._config - def like(self, name): + def like(self, name: str) -> Dict[str, str]: + """ + Return a dictionary of key-value pairs where the keys match the given regular expression pattern. + """ regex = re.compile(name, re.I) return { k: v for k, v in self.items() if regex.search(k) } - def is_enabled(self, name): + def is_enabled(self, name: str) -> bool: + """ + true if the config is enabled in kernel + """ return self.get(name) is KernelConfigTristate.YES - def is_module(self, name): + def is_module(self, name: str) -> bool: + """ + true if the config is of Module type + """ return self.get(name) is KernelConfigTristate.MODULE - def is_not_set(self, name): + def is_not_set(self, name: str) -> bool: + """ + true if the config is not enabled + """ return self.get(name) is KernelConfigTristate.NO - def has(self, name): + def has(self, name: str) -> bool: + """ + true if the config is either enabled or it is a module + """ return self.is_enabled(name) or self.is_module(name) @@ -3002,10 +5065,10 @@ class KernelConfig(object): This class does not provide a Mapping API and only return string values. """ @staticmethod - def get_config_name(name): + def get_config_name(name: str) -> str: return TypedKernelConfig.get_config_name(name) - def __init__(self, text): + def __init__(self, text: str): # Expose typed_config as a non-private attribute, so that user code # needing it can get it from any existing producer of KernelConfig. self.typed_config = TypedKernelConfig.from_str(text) @@ -3017,54 +5080,115 @@ def __bool__(self): not_set_regex = TypedKernelConfig.not_set_regex - def iteritems(self): + def iteritems(self) -> Iterator[Tuple[str, str]]: + """ + Iterate over the items in the typed configuration, converting each value to a string. + """ for k, v in self.typed_config.items(): yield (k, self.typed_config._val_to_str(v)) items = iteritems - def get(self, name, strict=False): + def get(self, name: str, strict: bool = False) -> Optional[str]: + """ + Retrieve a value from the typed configuration and convert it to a string. + """ if strict: - val = self.typed_config[name] + val: Optional[str] = self.typed_config[name] else: val = self.typed_config.get(name) return self.typed_config._val_to_str(val) - def like(self, name): + def like(self, name: str) -> Dict[str, str]: + """ + Return a dictionary of key-value pairs where the keys match the given regular expression pattern. + """ return { k: self.typed_config._val_to_str(v) for k, v in self.typed_config.like(name).items() } - def is_enabled(self, name): + def is_enabled(self, name: str) -> bool: + """ + true if the config is enabled in kernel + """ return self.typed_config.is_enabled(name) - def is_module(self, name): + def is_module(self, name: str) -> bool: + """ + true if the config is of Module type + """ return self.typed_config.is_module(name) - def is_not_set(self, name): + def is_not_set(self, name: str) -> bool: + """ + true if the config is not enabled + """ return self.typed_config.is_not_set(name) - def has(self, name): + def has(self, name: str) -> bool: + """ + true if the config is either enabled or it is a module + """ return self.typed_config.has(name) class LocalLinuxTarget(LinuxTarget): + """ + A specialized :class:`Target` subclass representing the local Linux system + (i.e., no remote connection needed). In many respects, this parallels + :class:`LinuxTarget`, but uses :class:`LocalConnection` under the hood. + + :param connection_settings: Dictionary specifying local connection options + (often unused or minimal). + :type connection_settings: dict, optional + :param platform: A ``Platform`` object if you want to specify architecture, + kernel version, etc. If None, a default is inferred from the host system. + :type platform: Platform, optional + :param working_directory: A writable directory on the local machine for devlibs + temporary operations. If None, a subfolder of /tmp or similar is often used. + :type working_directory: str, optional + :param executables_directory: Directory for installing binaries from devlib, + if needed. + :type executables_directory: str, optional + :param connect: Whether to connect (initialize local environment) immediately. + :type connect: bool + :param modules: Additional devlib modules to load at construction time. + :type modules: dict, optional + :param load_default_modules: If True, also load modules listed in + :attr:`default_modules`. + :type load_default_modules: bool + :param shell_prompt: Regex matching the local shell prompt (usually not used + since local commands are run directly). + :type shell_prompt: re.Pattern + :param conn_cls: Connection class to use, typically :class:`LocalConnection`. + :type conn_cls: Type[LocalConnection] + :param is_container: If True, indicates we’re running in a container environment + rather than the full host OS. + :type is_container: bool + :param max_async: Maximum concurrent asynchronous commands allowed. + :type max_async: int + + """ def __init__(self, - connection_settings=None, - platform=None, - working_directory=None, - executables_directory=None, - connect=True, - modules=None, - load_default_modules=True, - shell_prompt=DEFAULT_SHELL_PROMPT, - conn_cls=LocalConnection, - is_container=False, - max_async=50, + connection_settings: Optional[UserConnectionSettings] = None, + platform: Optional[Platform] = None, + working_directory: Optional[str] = None, + executables_directory: Optional[str] = None, + connect: bool = True, + modules: Optional[Dict[str, Dict[str, Type[Module]]]] = None, + load_default_modules: bool = True, + shell_prompt: Pattern[str] = DEFAULT_SHELL_PROMPT, + conn_cls: 'InitCheckpointMeta' = LocalConnection, + is_container: bool = False, + max_async: int = 50, ): + """ + Initialize a LocalLinuxTarget, representing the local machine as the devlib + target. Optionally connect and load modules immediately. + """ super(LocalLinuxTarget, self).__init__(connection_settings=connection_settings, platform=platform, working_directory=working_directory, @@ -3077,141 +5201,188 @@ def __init__(self, is_container=is_container, max_async=max_async) - def _resolve_paths(self): + def _resolve_paths(self) -> None: + """ + Resolve or finalize local working directories/executables directories. + By default, uses a subfolder of /tmp if none is set. + """ if self.working_directory is None: - self.working_directory = '/tmp/devlib-target' - self._file_transfer_cache = self.path.join(self.working_directory, '.file-cache') + self.working_directory: str = '/tmp/devlib-target' + self._file_transfer_cache: str = self.path.join(self.working_directory, '.file-cache') if self.executables_directory is None: - self.executables_directory = '/tmp/devlib-target/bin' + self.executables_directory: str = '/tmp/devlib-target/bin' -def _get_model_name(section): - name_string = section['model name'] - parts = name_string.split('@')[0].strip().split() +def _get_model_name(section: Dict[str, str]) -> str: + """ + get model name from section of cpu info + """ + name_string: str = section['model name'] + parts: List[str] = name_string.split('@')[0].strip().split() return ' '.join([p for p in parts if '(' not in p and p != 'CPU']) -def _get_part_name(section): - implementer = section.get('CPU implementer', '0x0') - part = section['CPU part'] - variant = section.get('CPU variant', '0x0') +def _get_part_name(section: Dict[str, str]) -> str: + """ + get part name from cpu info + """ + implementer: str = section.get('CPU implementer', '0x0') + part: str = section['CPU part'] + variant: str = section.get('CPU variant', '0x0') name = get_cpu_name(*list(map(integer, [implementer, part, variant]))) if name is None: name = f'{implementer}/{part}/{variant}' return name -def _build_path_tree(path_map, basepath, sep=os.path.sep, dictcls=dict): +Node = Union[str, Dict[str, 'Node']] + + +def _build_path_tree(path_map: Dict[str, str], basepath: str, + sep: str = os.path.sep, dictcls=dict) -> Union[str, Dict[str, 'Node']]: """ Convert a flat mapping of paths to values into a nested structure of - dict-line object (``dict``'s by default), mirroring the directory hierarchy + dict-like object (``dict``'s by default), mirroring the directory hierarchy represented by the paths relative to ``basepath``. """ - def process_node(node, path, value): + def process_node(node: 'Node', path: str, value: str): parts = path.split(sep, 1) - if len(parts) == 1: # leaf + if len(parts) == 1 and not isinstance(node, str): # leaf node[parts[0]] = value else: # branch - if parts[0] not in node: - node[parts[0]] = dictcls() - process_node(node[parts[0]], parts[1], value) + if not isinstance(node, str): + if parts[0] not in node: + node[parts[0]] = dictcls() + process_node(node[parts[0]], parts[1], value) - relpath_map = {os.path.relpath(p, basepath): v - for p, v in path_map.items()} + relpath_map: Dict[str, str] = {os.path.relpath(p, basepath): v + for p, v in path_map.items()} if len(relpath_map) == 1 and list(relpath_map.keys())[0] == '.': - result = list(relpath_map.values())[0] + result: Union[str, Dict[str, Any]] = list(relpath_map.values())[0] else: result = dictcls() for path, value in relpath_map.items(): - process_node(result, path, value) + if not isinstance(result, str): + process_node(result, path, value) return result class ChromeOsTarget(LinuxTarget): """ - Class for interacting with ChromeOS targets. + :class:`ChromeOsTarget` is a subclass of :class:`LinuxTarget` with + additional features specific to a device running ChromeOS for example, + if supported, its own android container which can be accessed via the + ``android_container`` attribute. When making calls to or accessing + properties and attributes of the ChromeOS target, by default they will + be applied to Linux target as this is where the majority of device + configuration will be performed and if not available, will fall back to + using the android container if available. This means that all the + available methods from + :class:`LinuxTarget` and :class:`AndroidTarget` are available for + :class:`ChromeOsTarget` if the device supports android otherwise only the + :class:`LinuxTarget` methods will be available. + + :param working_directory: This is the location of the working directory to + be used for the Linux target container. If not specified will default to + ``"/mnt/stateful_partition/devlib-target"``. + + :param android_working_directory: This is the location of the working + directory to be used for the android container. If not specified it will + use the working directory default for :class:`AndroidTarget.`. + + :param android_executables_directory: This is the location of the + executables directory to be used for the android container. If not + specified will default to a ``bin`` subdirectory in the + ``android_working_directory.`` + + :param package_data_directory: This is the location of the data stored + for installed Android packages on the device. """ - os = 'chromeos' + os: str = 'chromeos' # pylint: disable=too-many-locals,too-many-arguments def __init__(self, - connection_settings=None, - platform=None, - working_directory=None, - executables_directory=None, - android_working_directory=None, - android_executables_directory=None, - connect=True, - modules=None, - load_default_modules=True, - shell_prompt=DEFAULT_SHELL_PROMPT, - package_data_directory="/data/data", - is_container=False, - max_async=50, + connection_settings: Optional[UserConnectionSettings] = None, + platform: Optional[Platform] = None, + working_directory: Optional[str] = None, + executables_directory: Optional[str] = None, + android_working_directory: Optional[str] = None, + android_executables_directory: Optional[str] = None, + connect: bool = True, + modules: Optional[Dict[str, Dict[str, Type[Module]]]] = None, + load_default_modules: bool = True, + shell_prompt: Pattern[str] = DEFAULT_SHELL_PROMPT, + package_data_directory: str = "/data/data", + is_container: bool = False, + max_async: int = 50, ): + """ + Initialize a ChromeOsTarget for interacting with a device running Chrome OS + in developer mode (exposing SSH). + """ - self.supports_android = None - self.android_container = None + self.supports_android: Optional[bool] = None + self.android_container: Optional[AndroidTarget] = None # Pull out ssh connection settings - ssh_conn_params = ['host', 'username', 'password', 'keyfile', - 'port', 'timeout', 'sudo_cmd', - 'strict_host_check', 'use_scp', - 'total_transfer_timeout', 'poll_transfers', - 'start_transfer_poll_delay'] - self.ssh_connection_settings = {} - self.ssh_connection_settings.update( - (key, value) - for key, value in connection_settings.items() - if key in ssh_conn_params - ) + ssh_conn_params: List[str] = ['host', 'username', 'password', 'keyfile', + 'port', 'timeout', 'sudo_cmd', + 'strict_host_check', 'use_scp', + 'total_transfer_timeout', 'poll_transfers', + 'start_transfer_poll_delay'] + self.ssh_connection_settings: SshUserConnectionSettings = {} + if connection_settings: + update_dict = cast(SshUserConnectionSettings, + {key: value for key, value in connection_settings.items() if key in ssh_conn_params}) + self.ssh_connection_settings.update(update_dict) super().__init__(connection_settings=self.ssh_connection_settings, - platform=platform, - working_directory=working_directory, - executables_directory=executables_directory, - connect=False, - modules=modules, - load_default_modules=load_default_modules, - shell_prompt=shell_prompt, - conn_cls=SshConnection, - is_container=is_container, - max_async=max_async) + platform=platform, + working_directory=working_directory, + executables_directory=executables_directory, + connect=False, + modules=modules, + load_default_modules=load_default_modules, + shell_prompt=shell_prompt, + conn_cls=SshConnection, + is_container=is_container, + max_async=max_async) # We can't determine if the target supports android until connected to the linux host so # create unconditionally. # Pull out adb connection settings adb_conn_params = ['device', 'adb_server', 'adb_port', 'timeout'] - self.android_connection_settings = {} - self.android_connection_settings.update( - (key, value) - for key, value in connection_settings.items() - if key in adb_conn_params - ) - - # If adb device is not explicitly specified use same as ssh host - if not connection_settings.get('device', None): - self.android_connection_settings['device'] = connection_settings.get('host', None) - - self.android_container = AndroidTarget(connection_settings=self.android_connection_settings, - platform=platform, - working_directory=android_working_directory, - executables_directory=android_executables_directory, - connect=False, - load_default_modules=False, - shell_prompt=shell_prompt, - conn_cls=AdbConnection, - package_data_directory=package_data_directory, - is_container=True) - if connect: - self.connect() - - def __getattr__(self, attr): + self.android_connection_settings: AdbUserConnectionSettings = {} + if connection_settings: + update_dict_adb = cast(AdbUserConnectionSettings, + {key: value for key, value in connection_settings.items() if key in adb_conn_params}) + self.android_connection_settings.update(update_dict_adb) + + # If adb device is not explicitly specified use same as ssh host + if not connection_settings.get('device', None): + device = connection_settings.get('host', None) + if device: + self.android_connection_settings['device'] = device + + self.android_container = AndroidTarget(connection_settings=self.android_connection_settings, + platform=platform, + working_directory=android_working_directory, + executables_directory=android_executables_directory, + connect=False, + load_default_modules=False, + shell_prompt=shell_prompt, + conn_cls=AdbConnection, + package_data_directory=package_data_directory, + is_container=True) + if connect: + self.connect() + + def __getattr__(self, attr: str): """ By default use the linux target methods and attributes however, if not present, use android implementation if available. @@ -3224,7 +5395,7 @@ def __getattr__(self, attr): raise @asyn.asyncf - async def connect(self, timeout=30, check_boot_completed=True, max_async=None): + async def connect(self, timeout: int = 30, check_boot_completed: bool = True, max_async: Optional[int] = None) -> None: super().connect( timeout=timeout, check_boot_completed=check_boot_completed, @@ -3235,14 +5406,18 @@ async def connect(self, timeout=30, check_boot_completed=True, max_async=None): if self.supports_android is None: self.supports_android = self.directory_exists('/opt/google/containers/android/') - if self.supports_android: + if self.supports_android and self.android_container: self.android_container.connect(timeout) else: self.android_container = None - def _resolve_paths(self): + def _resolve_paths(self) -> None: + """ + Finalize any path logic specific to Chrome OS. Some directories + may be restricted or read-only, depending on dev mode settings. + """ if self.working_directory is None: - self.working_directory = '/mnt/stateful_partition/devlib-target' - self._file_transfer_cache = self.path.join(self.working_directory, '.file-cache') + self.working_directory: str = '/mnt/stateful_partition/devlib-target' + self._file_transfer_cache: str = self.path.join(self.working_directory, '.file-cache') if self.executables_directory is None: - self.executables_directory = self.path.join(self.working_directory, 'bin') + self.executables_directory: str = self.path.join(self.working_directory, 'bin') diff --git a/devlib/utils/android.py b/devlib/utils/android.py index c77a86446..a54f33eaa 100755 --- a/devlib/utils/android.py +++ b/devlib/utils/android.py @@ -1,4 +1,4 @@ -# Copyright 2013-2018 ARM Limited +# Copyright 2013-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ """ Utility functions for working with Android devices through adb. - """ # pylint: disable=E1103 import functools @@ -38,20 +37,38 @@ from lxml import etree from shlex import quote -from devlib.exception import TargetTransientError, TargetStableError, HostError, TargetTransientCalledProcessError, TargetStableCalledProcessError, AdbRootError +from devlib.exception import (TargetTransientError, TargetStableError, HostError, + TargetTransientCalledProcessError, TargetStableCalledProcessError, AdbRootError) from devlib.utils.misc import check_output, which, ABI_MAP, redirect_streams, get_subprocess -from devlib.connection import ConnectionBase, AdbBackgroundCommand, PopenBackgroundCommand, PopenTransferHandle - +from devlib.connection import (ConnectionBase, AdbBackgroundCommand, PopenBackgroundCommand, + PopenTransferHandle, TransferManager) + +from typing import (Optional, TYPE_CHECKING, cast, Tuple, Union, + List, DefaultDict, Pattern, Dict, Iterator, + Match, Callable, Generator) +from typing_extensions import Required, TypedDict, Literal +if TYPE_CHECKING: + from devlib.utils.annotation_helpers import SubprocessCommand + from threading import Lock + from lxml.etree import _ElementTree, _Element, XMLParser + from devlib.platform import Platform + from subprocess import Popen, CompletedProcess + from devlib.target import AndroidTarget + from io import TextIOWrapper + from tempfile import _TemporaryFileWrapper + from pexpect import spawn + +PartsType = Tuple[Union[str, Tuple[str, ...]], ...] logger = logging.getLogger('android') -MAX_ATTEMPTS = 5 -AM_START_ERROR = re.compile(r"Error: Activity.*") -AAPT_BADGING_OUTPUT = re.compile(r"no dump ((file)|(apk)) specified", re.IGNORECASE) +MAX_ATTEMPTS: int = 5 +AM_START_ERROR: Pattern[str] = re.compile(r"Error: Activity.*") +AAPT_BADGING_OUTPUT: Pattern[str] = re.compile(r"no dump ((file)|(apk)) specified", re.IGNORECASE) # See: # http://developer.android.com/guide/topics/manifest/uses-sdk-element.html#ApiLevels -ANDROID_VERSION_MAP = { +ANDROID_VERSION_MAP: Dict[int, str] = { 29: 'Q', 28: 'PIE', 27: 'OREO_MR1', @@ -84,96 +101,230 @@ } # See https://developer.android.com/reference/android/content/Intent.html#setFlags(int) -INTENT_FLAGS = { - 'ACTIVITY_NEW_TASK' : 0x10000000, - 'ACTIVITY_CLEAR_TASK' : 0x00008000 +INTENT_FLAGS: Dict[str, int] = { + 'ACTIVITY_NEW_TASK': 0x10000000, + 'ACTIVITY_CLEAR_TASK': 0x00008000 } + class AndroidProperties(object): + """ + Represents Android system properties as reported by the ``getprop`` command. + Allows easy retrieval of property values. - def __init__(self, text): - self._properties = {} + :param text: Full string output from ``adb shell getprop`` (or similar). + :type text: str + """ + def __init__(self, text: str): + self._properties: Dict[str, str] = {} self.parse(text) - def parse(self, text): + def parse(self, text: str) -> None: + """ + Parse the output text and update the internal property dictionary. + + :param text: String containing the property lines. + :type text: str + """ self._properties = dict(re.findall(r'\[(.*?)\]:\s+\[(.*?)\]', text)) - def iteritems(self): + def iteritems(self) -> Iterator[Tuple[str, str]]: + """ + Return an iterator of (property_key, property_value) pairs. + + :returns: An iterator of tuples like (key, value). + :rtype: Iterator[Tuple[str, str]] + """ return iter(self._properties.items()) def __iter__(self): + """ + Iterate over the property keys. + """ return iter(self._properties) - def __getattr__(self, name): + def __getattr__(self, name: str): + """ + Return a property value by attribute-style lookup. + Defaults to None if the property is missing. + """ return self._properties.get(name) __getitem__ = __getattr__ class AdbDevice(object): + """ + Represents a single device as seen by ``adb devices`` (usually a USB or IP + device). - def __init__(self, name, status): + :param name: The serial number or identifier of the device. + :type name: str + :param status: The device status, e.g. "device", "offline", or "unauthorized". + :type status: str + """ + def __init__(self, name: str, status: str): self.name = name self.status = status - # pylint: disable=undefined-variable - def __cmp__(self, other): + # replace __cmp__ of python 2 with explicit comparison methods + # of python 3 + def __lt__(self, other: Union['AdbDevice', str]) -> bool: + """ + Compare this device's name with another device or string for ordering. + """ if isinstance(other, AdbDevice): - return cmp(self.name, other.name) - else: - return cmp(self.name, other) + return self.name < other.name + return self.name < other - def __str__(self): + def __eq__(self, other: object) -> bool: + """ + Check if this device's name matches another device's name or a string. + """ + if isinstance(other, AdbDevice): + return self.name == other.name + return self.name == other + + def __le__(self, other: Union['AdbDevice', str]) -> bool: + """ + Test if this device's name is <= another device/string. + """ + return self < other or self == other + + def __gt__(self, other: Union['AdbDevice', str]) -> bool: + """ + Test if this device's name is > another device/string. + """ + return not self <= other + + def __ge__(self, other: Union['AdbDevice', str]) -> bool: + """ + Test if this device's name is >= another device/string. + """ + return not self < other + + def __ne__(self, other: object) -> bool: + """ + Invert the __eq__ comparison. + """ + return not self == other + + def __str__(self) -> str: + """ + Return a string representation of this device for debugging. + """ return 'AdbDevice({}, {})'.format(self.name, self.status) __repr__ = __str__ +class BuildToolsInfo(TypedDict, total=False): + """ + Typed dictionary capturing build tools info. + + :param build_tools: The path to the build-tools directory. + :type build_tools: Optional[str] + :param aapt: Path to the aapt or aapt2 binary. + :type aapt: Optional[str] + :param aapt_version: Integer 1 or 2 indicating which aapt is used. + :type aapt_version: Optional[int] + """ + build_tools: Required[Optional[str]] + aapt: Required[Optional[str]] + aapt_version: Required[Optional[int]] + + +class Android_Env_Type(TypedDict, total=False): + """ + Typed dictionary representing environment paths for Android tools. + + :param android_home: ANDROID_HOME path, if set. + :param platform_tools: Path to the 'platform-tools' directory containing adb/fastboot. + :param adb: Path to the 'adb' executable. + :param fastboot: Path to the 'fastboot' executable. + :param build_tools: Path to the 'build-tools' directory if available. + :param aapt: Path to aapt or aapt2, if found. + :param aapt_version: 1 or 2 indicating which aapt variant is used. + """ + android_home: Required[Optional[str]] + platform_tools: Required[str] + adb: Required[str] + fastboot: Required[str] + build_tools: Required[Optional[str]] + aapt: Required[Optional[str]] + aapt_version: Required[Optional[int]] + + +Android_Env_TypeKeys = Union[Literal['android_home'], + Literal['platform_tools'], + Literal['adb'], + Literal['fastboot'], + Literal['build_tools'], + Literal['aapt'], + Literal['aapt_version']] + + class ApkInfo(object): + """ + Extracts and stores metadata about an APK, including package name, version, + supported ABIs, permissions, etc. The parsing relies on the 'aapt' or 'aapt2' + command from Android build-tools. - version_regex = re.compile(r"name='(?P[^']+)' versionCode='(?P[^']+)' versionName='(?P[^']+)'") - name_regex = re.compile(r"name='(?P[^']+)'") - permission_regex = re.compile(r"name='(?P[^']+)'") - activity_regex = re.compile(r'\s*A:\s*android:name\(0x\d+\)=".(?P\w+)"') + :param path: Optional path to the APK file on the host. If provided, it is + immediately parsed. + :type path: str or None + """ + version_regex: Pattern[str] = re.compile(r"name='(?P[^']+)' versionCode='(?P[^']+)' versionName='(?P[^']+)'") + name_regex: Pattern[str] = re.compile(r"name='(?P[^']+)'") + permission_regex: Pattern[str] = re.compile(r"name='(?P[^']+)'") + activity_regex: Pattern[str] = re.compile(r'\s*A:\s*android:name\(0x\d+\)=".(?P\w+)"') - def __init__(self, path=None): + def __init__(self, path: Optional[str] = None): self.path = path - self.package = None - self.activity = None - self.label = None - self.version_name = None - self.version_code = None - self.native_code = None - self.permissions = [] - self._apk_path = None - self._activities = None - self._methods = None - self._aapt = _ANDROID_ENV.get_env('aapt') - self._aapt_version = _ANDROID_ENV.get_env('aapt_version') + self.package: Optional[str] = None + self.activity: Optional[str] = None + self.label: Optional[str] = None + self.version_name: Optional[str] = None + self.version_code: Optional[str] = None + self.native_code: Optional[List[str]] = None + self.permissions: List[str] = [] + self._apk_path: Optional[str] = None + self._activities: Optional[List[str]] = None + self._methods: Optional[List[Tuple[str, str]]] = None + self._aapt: str = cast(str, _ANDROID_ENV.get_env('aapt')) + self._aapt_version: int = cast(int, _ANDROID_ENV.get_env('aapt_version')) if path: self.parse(path) # pylint: disable=too-many-branches - def parse(self, apk_path): - output = self._run([self._aapt, 'dump', 'badging', apk_path]) + def parse(self, apk_path: str) -> None: + """ + Parse the given APK file with the aapt or aapt2 utility, retrieving + metadata such as package name, version, and permissions. + + :param apk_path: The path to the APK file on the host system. + :type apk_path: str + :raises HostError: If aapt fails to run or returns an error message. + """ + output: str = self._run([self._aapt, 'dump', 'badging', apk_path]) for line in output.split('\n'): if line.startswith('application-label:'): self.label = line.split(':')[1].strip().replace('\'', '') elif line.startswith('package:'): - match = self.version_regex.search(line) + match: Optional[Match[str]] = self.version_regex.search(line) if match: self.package = match.group('name') self.version_code = match.group('vcode') self.version_name = match.group('vname') elif line.startswith('launchable-activity:'): match = self.name_regex.search(line) - self.activity = match.group('name') + self.activity = match.group('name') if match else None elif line.startswith('native-code'): - apk_abis = [entry.strip() for entry in line.split(':')[1].split("'") if entry.strip()] - mapped_abis = [] + apk_abis: List[str] = [entry.strip() for entry in line.split(':')[1].split("'") if entry.strip()] + mapped_abis: List[str] = [] for apk_abi in apk_abis: - found = False + found: bool = False for abi, architectures in ABI_MAP.items(): if apk_abi in architectures: mapped_abis.append(abi) @@ -194,37 +345,51 @@ def parse(self, apk_path): self._methods = None @property - def activities(self): + def activities(self) -> List[str]: + """ + Return a list of activity names declared in this APK. + + :returns: A list of activity names found in AndroidManifest.xml. + :rtype: list of str + """ if self._activities is None: - cmd = [self._aapt, 'dump', 'xmltree', self._apk_path] + cmd: List[str] = [self._aapt, 'dump', 'xmltree', self._apk_path if self._apk_path else ''] if self._aapt_version == 2: cmd += ['--file'] cmd += ['AndroidManifest.xml'] - matched_activities = self.activity_regex.finditer(self._run(cmd)) + matched_activities: Iterator[Match[str]] = self.activity_regex.finditer(self._run(cmd)) self._activities = [m.group('name') for m in matched_activities] return self._activities @property - def methods(self): + def methods(self) -> Optional[List[Tuple[str, str]]]: + """ + Return a list of (method_name, class_name) pairs, if any can be extracted + by dexdump. If no classes.dex is found or an error occurs, returns an empty list. + + :returns: A list of (method_name, class_name) tuples, or None if not parsed yet. + :rtype: list of (str, str) or None + """ if self._methods is None: # Only try to extract once self._methods = [] with tempfile.TemporaryDirectory() as tmp_dir: - with zipfile.ZipFile(self._apk_path, 'r') as z: - try: - extracted = z.extract('classes.dex', tmp_dir) - except KeyError: - return [] - dexdump = os.path.join(os.path.dirname(self._aapt), 'dexdump') - command = [dexdump, '-l', 'xml', extracted] - dump = self._run(command) + if self._apk_path: + with zipfile.ZipFile(self._apk_path, 'r') as z: + try: + extracted: str = z.extract('classes.dex', tmp_dir) + except KeyError: + return [] + dexdump: str = os.path.join(os.path.dirname(self._aapt), 'dexdump') + command: List[str] = [dexdump, '-l', 'xml', extracted] + dump: str = self._run(command) # Dexdump from build tools v30.0.X does not seem to produce # valid xml from certain APKs so ignore errors and attempt to recover. - parser = etree.XMLParser(encoding='utf-8', recover=True) - xml_tree = etree.parse(StringIO(dump), parser) + parser: XMLParser = etree.XMLParser(encoding='utf-8', recover=True) + xml_tree: _ElementTree = etree.parse(StringIO(dump), parser) - package = [] + package: List[_Element] = [] for i in xml_tree.iter('package'): if i.attrib['name'] == self.package: package.append(i) @@ -235,11 +400,20 @@ def methods(self): for meth in klass.iter('method')]) return self._methods - def _run(self, command): + def _run(self, command: List[str]) -> str: + """ + Execute a local shell command (e.g., aapt) and return its output as a string. + + :param command: List of command arguments to run. + :type command: list of str + :returns: Combined stdout+stderr as a decoded string. + :rtype: str + :raises HostError: If the command fails or returns a nonzero exit code. + """ logger.debug(' '.join(command)) try: - output = subprocess.check_output(command, stderr=subprocess.STDOUT) - output = output.decode(sys.stdout.encoding or 'utf-8', 'replace') + output_tmp: bytes = subprocess.check_output(command, stderr=subprocess.STDOUT) + output: str = output_tmp.decode(sys.stdout.encoding or 'utf-8', 'replace') except subprocess.CalledProcessError as e: raise HostError('Error while running "{}":\n{}' .format(command, e.output)) @@ -247,46 +421,110 @@ def _run(self, command): class AdbConnection(ConnectionBase): - + """ + A connection to an android device via ``adb`` (Android Debug Bridge). + ``adb`` is part of the Android SDK (though stand-alone versions are also + available). + + :param device: The name of the adb device. This is usually a unique hex + string for USB-connected devices, or an ip address/port + combination. To see connected devices, you can run ``adb + devices`` on the host. + :type device: str or None + :param timeout: Connection timeout in seconds. If a connection to the device + is not established within this period, :class:`HostError` + is raised. + :type timeout: int or None + :param platform: An optional Platform object describing hardware aspects. + :type platform: Platform or None + :param adb_server: Allows specifying the address of the adb server to use. + :type adb_server: str or None + :param adb_port: If specified, connect to a custom adb server port. + :type adb_port: int or None + :param adb_as_root: Specify whether the adb server should be restarted in root mode. + :type adb_as_root: bool + :param connection_attempts: Specify how many connection attempts, 10 seconds + apart, should be attempted to connect to the device. + Defaults to 5. + :type connection_attempts: int + :param poll_transfers: Specify whether file transfers should be polled. Polling + monitors the progress of file transfers and periodically + checks whether they have stalled, attempting to cancel + the transfers prematurely if so. + :type poll_transfers: bool + :param start_transfer_poll_delay: If transfers are polled, specify the length of + time after a transfer has started before polling + should start. + :type start_transfer_poll_delay: int + :param total_transfer_timeout: If transfers are polled, specify the total amount of time + to elapse before the transfer is cancelled, regardless + of its activity. + :type total_transfer_timeout: int + :param transfer_poll_period: If transfers are polled, specify the period at which + the transfers are sampled for activity. Too small values + may cause the destination size to appear the same over + one or more sample periods, causing improper transfer + cancellation. + :type transfer_poll_period: int + + :raises AdbRootError: If root mode is requested but multiple connections are active or device does not allow it. + :raises HostError: If the device fails to connect or is invalid. + """ # maintains the count of parallel active connections to a device, so that # adb disconnect is not invoked untill all connections are closed - active_connections = (threading.Lock(), defaultdict(int)) + active_connections: Tuple['Lock', DefaultDict[str, int]] = (threading.Lock(), defaultdict(int)) # Track connected as root status per device - _connected_as_root = defaultdict(lambda: None) - default_timeout = 10 - ls_command = 'ls' - su_cmd = 'su -c {}' + _connected_as_root: DefaultDict[str, Optional[bool]] = defaultdict(lambda: None) + default_timeout: int = 10 + ls_command: str = 'ls' + su_cmd: str = 'su -c {}' @property - def name(self): + def name(self) -> str: + """ + :returns: The device serial number or IP:port used by this connection. + :rtype: str + """ return self.device @property - def connected_as_root(self): + def connected_as_root(self) -> Optional[bool]: + """ + Check if the current connection is effectively root on the device. + + :returns: True if root, False if not, or None if undetermined. + :rtype: bool or None + """ if self._connected_as_root[self.device] is None: result = self.execute('id') self._connected_as_root[self.device] = 'uid=0(' in result return self._connected_as_root[self.device] @connected_as_root.setter - def connected_as_root(self, state): + def connected_as_root(self, state: Optional[bool]) -> None: + """ + Manually set the known state of root usage on this device connection. + + :param state: True if connected as root, False if not, None to reset. + :type state: bool or None + """ self._connected_as_root[self.device] = state # pylint: disable=unused-argument def __init__( self, - device=None, - timeout=None, - platform=None, - adb_server=None, - adb_port=None, - adb_as_root=False, - connection_attempts=MAX_ATTEMPTS, - - poll_transfers=False, - start_transfer_poll_delay=30, - total_transfer_timeout=3600, - transfer_poll_period=30, + device: Optional[str] = None, + timeout: Optional[int] = None, + platform: Optional['Platform'] = None, + adb_server: Optional[str] = None, + adb_port: Optional[int] = None, + adb_as_root: bool = False, + connection_attempts: int = MAX_ATTEMPTS, + + poll_transfers: bool = False, + start_transfer_poll_delay: int = 30, + total_transfer_timeout: int = 3600, + transfer_poll_period: int = 30, ): super().__init__( poll_transfers=poll_transfers, @@ -323,25 +561,52 @@ def __init__( self._setup_ls() self._setup_su() - def push(self, sources, dest, timeout=None): + def push(self, sources: Tuple[str, ...], dest: str, + timeout: Optional[int] = None) -> None: + """ + Upload (push) one or more files/directories from the host to the device. + + :param sources: Paths on the host system to be pushed. + :type sources: tuple(str, ...) + :param dest: Target path on the device. If multiple sources, dest should be a dir. + :type dest: str + :param timeout: Max time in seconds for each file push. If exceeded, an error is raised. + :type timeout: int, optional + """ return self._push_pull('push', sources, dest, timeout) - def pull(self, sources, dest, timeout=None): + def pull(self, sources: Tuple[str, ...], dest: str, + timeout: Optional[int] = None) -> None: + """ + Download (pull) one or more files/directories from the device to the host. + + :param sources: Paths on the device to be pulled. + :type sources: tuple(str, ...) + :param dest: Destination path on the host. + :type dest: str + :param timeout: Max time in seconds for each file. If exceeded, an error is raised. + :type timeout: int, optional + """ return self._push_pull('pull', sources, dest, timeout) - def _push_pull(self, action, sources, dest, timeout): - sources = list(sources) - paths = sources + [dest] + def _push_pull(self, action: Union[Literal['push'], Literal['pull']], + sources: Tuple[str, ...], dest: str, timeout: Optional[int]) -> None: + """ + Internal helper that runs 'adb push' or 'adb pull' with optional timeouts + and transfer polling. + """ + sourcesList: List[str] = list(sources) + pathsList: List[str] = sourcesList + [dest] # Quote twice to avoid expansion by host shell, then ADB globbing - do_quote = lambda x: quote(glob.escape(x)) - paths = ' '.join(map(do_quote, paths)) + do_quote: Callable[[str], str] = lambda x: quote(glob.escape(x)) + paths: str = ' '.join(map(do_quote, pathsList)) command = "{} {}".format(action, paths) if timeout: adb_command(self.device, command, timeout=timeout, adb_server=self.adb_server, adb_port=self.adb_port) else: - bg_cmd = adb_command_background( + bg_cmd: PopenBackgroundCommand = adb_command_background( device=self.device, conn=self, command=command, @@ -355,12 +620,35 @@ def _push_pull(self, action, sources, dest, timeout): dest=dest, direction=action ) - with bg_cmd, self.transfer_manager.manage(sources, dest, action, handle): - bg_cmd.communicate() + if isinstance(self.transfer_manager, TransferManager): + with bg_cmd, self.transfer_manager.manage(sources, dest, action, handle): + bg_cmd.communicate() # pylint: disable=unused-argument - def execute(self, command, timeout=None, check_exit_code=False, - as_root=False, strip_colors=True, will_succeed=False): + def execute(self, command: 'SubprocessCommand', timeout: Optional[int] = None, + check_exit_code: bool = False, as_root: Optional[bool] = False, + strip_colors: bool = True, will_succeed: bool = False) -> str: + """ + Execute a command on the device via ``adb shell``. + + :param command: The command line to run (string or SubprocessCommand). + :type command: SubprocessCommand + :param timeout: Time in seconds before forcibly terminating the command. None for no limit. + :type timeout: int or None + :param check_exit_code: If True, raise an error if the command's exit code != 0. + :type check_exit_code: bool + :param as_root: If True, attempt to run it as root if available. + :type as_root: bool or None + :param strip_colors: If True, strip any ANSI colors (unused in this method). + :type strip_colors: bool + :param will_succeed: If True, treat an error as transient rather than stable. + :type will_succeed: bool + :returns: The command's output (combined stdout+stderr). + :rtype: str + :raises TargetTransientCalledProcessError: If the command fails but is flagged as transient. + :raises TargetStableCalledProcessError: If the command fails in a stable (non-transient) way. + :raises TargetStableError: If there's a stable device/command error. + """ if as_root and self.connected_as_root: as_root = False try: @@ -380,13 +668,41 @@ def execute(self, command, timeout=None, check_exit_code=False, else: raise - def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False): + def background(self, command: 'SubprocessCommand', stdout: int = subprocess.PIPE, + stderr: int = subprocess.PIPE, as_root: Optional[bool] = False) -> AdbBackgroundCommand: + """ + Launch a background command via adb shell and return a handle to manage it. + + :param command: The command to run on the device. + :type command: SubprocessCommand + :param stdout: File descriptor or special value (e.g., subprocess.PIPE) for stdout. + :type stdout: int + :param stderr: File descriptor or special value for stderr. + :type stderr: int + :param as_root: If True, attempt to run the command as root. + :type as_root: bool or None + :returns: A handle to the background command. + :rtype: AdbBackgroundCommand + + .. note:: This **will block the connection** until the command completes. + """ if as_root and self.connected_as_root: as_root = False - bg_cmd = self._background(command, stdout, stderr, as_root) + bg_cmd: AdbBackgroundCommand = self._background(command, stdout, stderr, as_root) return bg_cmd - def _background(self, command, stdout, stderr, as_root): + def _background(self, command: 'SubprocessCommand', stdout: int, + stderr: int, as_root: Optional[bool]) -> AdbBackgroundCommand: + """ + Helper method to run a background shell command via adb. + + :param command: Shell command to run. + :param stdout: Location for stdout writes. + :param stderr: Location for stderr writes. + :param as_root: If True, run as root if possible. + :returns: An AdbBackgroundCommand object. + :raises Exception: If PID detection fails or no valid device is set. + """ adb_shell, pid = adb_background_shell(self, command, stdout, stderr, as_root) bg_cmd = AdbBackgroundCommand( conn=self, @@ -396,7 +712,12 @@ def _background(self, command, stdout, stderr, as_root): ) return bg_cmd - def _close(self): + def _close(self) -> None: + """ + Close the connection to the device. The :class:`Connection` object should not + be used after this method is called. There is no way to reopen a previously + closed connection, a new connection object should be created instead. + """ lock, nr_active = AdbConnection.active_connections with lock: nr_active[self.device] -= 1 @@ -409,13 +730,25 @@ def _close(self): self.adb_root(enable=self._restore_to_adb_root) adb_disconnect(self.device, self.adb_server, self.adb_port) - def cancel_running_command(self): + def cancel_running_command(self) -> None: + """ + Cancel a running command (previously started with :func:`background`) and free up the connection. + It is valid to call this if the command has already terminated (or if no + command was issued), in which case this is a no-op. + """ # adbd multiplexes commands so that they don't interfer with each # other, so there is no need to explicitly cancel a running command # before the next one can be issued. pass def adb_root(self, enable=True): + """ + Enable or disable root mode for this device connection. + + :param enable: True to enable root, False to unroot. + :type enable: bool + :raises AdbRootError: If multiple connections are active or device disallows root. + """ self._adb_root(enable=enable) def _adb_root(self, enable): @@ -445,33 +778,50 @@ def is_rooted(out): AdbConnection._connected_as_root[self.device] = enable return was_rooted - def wait_for_device(self, timeout=30): + def wait_for_device(self, timeout: Optional[int] = 30) -> None: + """ + Block until the device is available for commands, up to a specified timeout. + + :param timeout: Time in seconds before giving up. + :type timeout: int or None + """ adb_command(self.device, 'wait-for-device', timeout, self.adb_server, self.adb_port) - def reboot_bootloader(self, timeout=30): + def reboot_bootloader(self, timeout: int = 30) -> None: + """ + Reboot the device into its bootloader (fastboot) mode. + + :param timeout: Seconds to wait for the reboot command to be accepted. + :type timeout: int + """ adb_command(self.device, 'reboot-bootloader', timeout, self.adb_server, self.adb_port) # Again, we need to handle boards where the default output format from ls is # single column *and* boards where the default output is multi-column. # We need to do this purely because the '-1' option causes errors on older # versions of the ls tool in Android pre-v7. - def _setup_ls(self): + def _setup_ls(self) -> None: + """ + Detect whether 'ls -1' is supported, falling back to plain 'ls' on older devices. + """ command = "shell '(ls -1); echo \"\n$?\"'" try: output = adb_command(self.device, command, timeout=self.timeout, adb_server=self.adb_server, adb_port=self.adb_port) except subprocess.CalledProcessError as e: raise HostError( - 'Failed to set up ls command on Android device. Output:\n' - + e.output) - lines = output.splitlines() - retval = lines[-1].strip() + 'Failed to set up ls command on Android device. Output:\n' + e.output) + lines: List[str] = output.splitlines() + retval: str = lines[-1].strip() if int(retval) == 0: self.ls_command = 'ls -1' else: self.ls_command = 'ls' logger.debug("ls command is set to {}".format(self.ls_command)) - def _setup_su(self): + def _setup_su(self) -> None: + """ + Attempt to confirm if 'su -c' is required or a simpler 'su' approach works. + """ # Already root, nothing to do if self.connected_as_root: return @@ -486,26 +836,59 @@ def _setup_su(self): logger.debug("su command is set to {}".format(quote(self.su_cmd))) -def fastboot_command(command, timeout=None, device=None): - target = '-s {}'.format(quote(device)) if device else '' - bin_ = _ANDROID_ENV.get_env('fastboot') - full_command = f'{bin} {target} {command}' +def fastboot_command(command: str, timeout: Optional[int] = None, + device: Optional[str] = None) -> str: + """ + Execute a fastboot command, optionally targeted at a specific device. + + :param command: The fastboot subcommand (e.g. 'devices', 'flash'). + :type command: str + :param timeout: Time in seconds before the command fails. + :type timeout: int or None + :param device: Fastboot device name. If None, assumes a single device or environment default. + :type device: str or None + :returns: Combined stdout+stderr output from the fastboot command. + :rtype: str + :raises HostError: If the command fails or returns an error. + """ + target: str = '-s {}'.format(quote(device)) if device else '' + bin_: str = cast(str, _ANDROID_ENV.get_env('fastboot')) + full_command: str = f'{bin_} {target} {command}' logger.debug(full_command) output, _ = check_output(full_command, timeout, shell=True) return output -def fastboot_flash_partition(partition, path_to_image): - command = 'flash {} {}'.format(quote(partition), quote(path_to_image)) +def fastboot_flash_partition(partition: str, path_to_image: str) -> None: + """ + Execute 'fastboot flash ' to flash a file + onto a specific partition of the device. + + :param partition: The device partition to flash (e.g. "boot", "system"). + :type partition: str + :param path_to_image: Full path to the image file on the host. + :type path_to_image: str + :raises HostError: If fastboot fails or device is not in fastboot mode. + """ + command: str = 'flash {} {}'.format(quote(partition), quote(path_to_image)) fastboot_command(command) -def adb_get_device(timeout=None, adb_server=None, adb_port=None): +def adb_get_device(timeout: Optional[int] = None, adb_server: Optional[str] = None, + adb_port: Optional[int] = None) -> str: """ - Returns the serial number of a connected android device. - - If there are more than one device connected to the machine, or it could not - find any device connected, :class:`devlib.exceptions.HostError` is raised. + Attempt to auto-detect a single connected device. If multiple or none are found, + raise an error. + + :param timeout: Maximum time to wait for device detection, or None for no limit. + :type timeout: int or None + :param adb_server: Optional custom server host. + :type adb_server: str or None + :param adb_port: Optional custom server port. + :type adb_port: int or None + :returns: The device serial number or IP:port if exactly one device is found. + :rtype: str + :raises HostError: If zero or more than one devices are connected. """ # TODO this is a hacky way to issue a adb command to all listed devices @@ -517,67 +900,110 @@ def adb_get_device(timeout=None, adb_server=None, adb_port=None): # a list of the devices sperated by new line # The last line is a blank new line. in otherwords, if there is a device found # then the output length is 2 + (1 for each device) - start = time.time() + start: float = time.time() while True: - output = adb_command(None, "devices", adb_server=adb_server, adb_port=adb_port).splitlines() # pylint: disable=E1103 - output_length = len(output) + output: List[str] = adb_command(None, "devices", adb_server=adb_server, adb_port=adb_port).splitlines() # pylint: disable=E1103 + output_length: int = len(output) if output_length == 3: # output[1] is the 2nd line in the output which has the device name # Splitting the line by '\t' gives a list of two indexes, which has # device serial in 0 number and device type in 1. return output[1].split('\t')[0] elif output_length > 3: - message = '{} Android devices found; either explicitly specify ' +\ - 'the device you want, or make sure only one is connected.' + message: str = '{} Android devices found; either explicitly specify ' +\ + 'the device you want, or make sure only one is connected.' raise HostError(message.format(output_length - 2)) else: - if timeout < time.time() - start: + if timeout is not None and timeout < time.time() - start: raise HostError('No device is connected and available') time.sleep(1) -def adb_connect(device, timeout=None, attempts=MAX_ATTEMPTS, adb_server=None, adb_port=None): - tries = 0 - output = None +def adb_connect(device: Optional[str], timeout: Optional[int] = None, + attempts: int = MAX_ATTEMPTS, adb_server: Optional[str] = None, + adb_port: Optional[int] = None) -> None: + """ + Connect to an ADB-over-IP device or ensure a USB device is listed. Re-tries + until success or attempts are exhausted. + + :param device: The device name, if "." in it, assumes IP-based device. + :type device: str or None + :param timeout: Time in seconds for each attempt before giving up. + :type timeout: int or None + :param attempts: Number of times to retry connecting 10 seconds apart. + :type attempts: int + :param adb_server: Optional ADB server host. + :type adb_server: str or None + :param adb_port: Optional ADB server port. + :type adb_port: int or None + :raises HostError: If connection fails after all attempts. + """ + tries: int = 0 + output: Optional[str] = None while tries <= attempts: tries += 1 if device: - if "." in device: # Connect is required only for ADB-over-IP + if "." in device: # Connect is required only for ADB-over-IP # ADB does not automatically remove a network device from it's # devices list when the connection is broken by the remote, so the # adb connection may have gone "stale", resulting in adb blocking # indefinitely when making calls to the device. To avoid this, # always disconnect first. adb_disconnect(device, adb_server, adb_port) - adb_cmd = get_adb_command(None, 'connect', adb_server, adb_port) - command = '{} {}'.format(adb_cmd, quote(device)) + adb_cmd: str = get_adb_command(None, 'connect', adb_server, adb_port) + command: str = '{} {}'.format(adb_cmd, quote(device)) logger.debug(command) output, _ = check_output(command, shell=True, timeout=timeout) if _ping(device, adb_server, adb_port): break time.sleep(10) else: # did not connect to the device - message = f'Could not connect to {device or "a device"} at {adb_server}:{adb_port}' + message: str = f'Could not connect to {device or "a device"} at {adb_server}:{adb_port}' if output: message += f'; got: {output}' raise HostError(message) -def adb_disconnect(device, adb_server=None, adb_port=None): +def adb_disconnect(device: Optional[str], adb_server: Optional[str] = None, + adb_port: Optional[int] = None) -> None: + """ + Issue an 'adb disconnect' for the specified device, if relevant. + + :param device: Device serial or IP:port. If None or no IP in the name, no action is taken. + :type device: str or None + :param adb_server: Custom ADB server host if used. + :type adb_server: str or None + :param adb_port: Custom ADB server port if used. + :type adb_port: int or None + """ if not device: return if ":" in device and device in adb_list_devices(adb_server, adb_port): - adb_cmd = get_adb_command(None, 'disconnect', adb_server, adb_port) - command = "{} {}".format(adb_cmd, device) + adb_cmd: str = get_adb_command(None, 'disconnect', adb_server, adb_port) + command: str = "{} {}".format(adb_cmd, device) logger.debug(command) - retval = subprocess.call(command, stdout=subprocess.DEVNULL, shell=True) + retval: int = subprocess.call(command, stdout=subprocess.DEVNULL, shell=True) if retval: raise TargetTransientError('"{}" returned {}'.format(command, retval)) -def _ping(device, adb_server=None, adb_port=None): - adb_cmd = get_adb_command(device, 'shell', adb_server, adb_port) - command = "{} {}".format(adb_cmd, quote('ls /data/local/tmp > /dev/null')) +def _ping(device: Optional[str], adb_server: Optional[str] = None, + adb_port: Optional[int] = None) -> bool: + """ + Ping the specified device by issuing a trivial command (ls /data/local/tmp). + If it fails, the device is presumably unreachable or offline. + + :param device: The device name or IP:port. + :type device: str or None + :param adb_server: ADB server host, if any. + :type adb_server: str or None + :param adb_port: ADB server port, if any. + :type adb_port: int or None + :returns: True if the device responded, otherwise False. + :rtype: bool + """ + adb_cmd: str = get_adb_command(device, 'shell', adb_server, adb_port) + command: str = "{} {}".format(adb_cmd, quote('ls /data/local/tmp > /dev/null')) logger.debug(command) try: subprocess.check_output(command, stderr=subprocess.STDOUT, shell=True) @@ -589,23 +1015,48 @@ def _ping(device, adb_server=None, adb_port=None): # pylint: disable=too-many-locals -def adb_shell(device, command, timeout=None, check_exit_code=False, - as_root=False, adb_server=None, adb_port=None, su_cmd='su -c {}'): # NOQA - +def adb_shell(device: str, command: 'SubprocessCommand', timeout: Optional[int] = None, + check_exit_code: bool = False, as_root: Optional[bool] = False, adb_server: Optional[str] = None, + adb_port:Optional[int]=None, su_cmd:str='su -c {}') -> str: # NOQA + """ + Run a command in 'adb shell' mode, capturing both stdout/stderr. Uses a technique + to capture the actual command's exit code so that we can detect non-zero exit + reliably on older ADB combos. + + :param device: The device serial or IP:port. + :type device: str + :param command: The command line to run inside 'adb shell'. + :type command: SubprocessCommand + :param timeout: Time in seconds to wait for the command, or None for no limit. + :type timeout: int or None + :param check_exit_code: If True, raise an error if the command exit code is nonzero. + :type check_exit_code: bool + :param as_root: If True, prepend an su command to run as root if supported. + :type as_root: bool or None + :param adb_server: Optional custom adb server IP/name. + :type adb_server: str or None + :param adb_port: Optional custom adb server port. + :type adb_port: int or None + :param su_cmd: Command template to wrap as root, e.g. 'su -c {}'. + :type su_cmd: str + :returns: The combined stdout from the command (minus the exit code). + :rtype: str + :raises TargetStableError: If there's an error with the command or exit code extraction fails. + """ # On older combinations of ADB/Android versions, the adb host command always # exits with 0 if it was able to run the command on the target, even if the # command failed (https://code.google.com/p/android/issues/detail?id=3254). # Homogenise this behaviour by running the command then echoing the exit # code of the executed command itself. - command = r'({}); echo "\n$?"'.format(command) + command = r'({}); echo "\n$?"'.format(cast(str, command)) command = su_cmd.format(quote(command)) if as_root else command command = ('shell', command) parts, env = _get_adb_parts(command, device, adb_server, adb_port, quote_adb=False) env = {**os.environ, **env} - logger.debug(' '.join(quote(part) for part in parts)) + logger.debug(' '.join(quote(cast(str, part)) for part in parts)) try: - raw_output, error = check_output(parts, timeout, shell=False, env=env) + raw_output, error = check_output(cast('SubprocessCommand', parts), timeout, shell=False, env=env) except subprocess.CalledProcessError as e: raise TargetStableError(str(e)) @@ -623,10 +1074,10 @@ def adb_shell(device, command, timeout=None, check_exit_code=False, exit_code = exit_code.strip() re_search = AM_START_ERROR.findall(output) if exit_code.isdigit(): - exit_code = int(exit_code) - if exit_code: + exit_code_i = int(exit_code) + if exit_code_i: raise subprocess.CalledProcessError( - exit_code, + exit_code_i, command, output, error, @@ -648,11 +1099,29 @@ def adb_shell(device, command, timeout=None, check_exit_code=False, return '\n'.join(x for x in (output, error) if x) -def adb_background_shell(conn, command, +def adb_background_shell(conn: AdbConnection, command: 'SubprocessCommand', stdout=subprocess.PIPE, stderr=subprocess.PIPE, - as_root=False): - """Runs the specified command in a subprocess, returning the the Popen object.""" + as_root: Optional[bool] = False) -> Tuple['Popen', int]: + """ + Run a command in the background on the device via ADB shell, returning a Popen + object and an integer PID. This approach uses SIGSTOP to freeze the shell + while the PID is identified. + + :param conn: The AdbConnection managing the device. + :type conn: AdbConnection + :param command: A shell command to run in the background. + :type command: SubprocessCommand + :param stdout: File descriptor for stdout, default is pipe. + :type stdout: int + :param stderr: File descriptor for stderr, default is pipe. + :type stderr: int + :param as_root: If True, attempt to run under su if root is available. + :type as_root: bool or None + :returns: A tuple of (popen_obj, pid). + :rtype: (subprocess.Popen, int) + :raises TargetTransientError: If the PID cannot be identified after retries. + """ device = conn.device adb_server = conn.adb_server adb_port = conn.adb_port @@ -661,12 +1130,12 @@ def adb_background_shell(conn, command, stdout, stderr, command = redirect_streams(stdout, stderr, command) if as_root: - command = f'{busybox} printf "%s" {quote(command)} | su' + command = f'{busybox} printf "%s" {quote(cast(str,command))} | su' - def with_uuid(cmd): + def with_uuid(cmd: str) -> Tuple[str, str]: # Attach a unique UUID to the command line so it can be looked for # without any ambiguity with ps - uuid_ = uuid.uuid4().hex + uuid_: str = uuid.uuid4().hex # Unset the var, since not all connection types set it. This will avoid # anyone depending on that value. cmd = f'DEVLIB_CMD_UUID={uuid_}; unset DEVLIB_CMD_UUID; {cmd}' @@ -676,16 +1145,16 @@ def with_uuid(cmd): return (uuid_, cmd) # Freeze the command with SIGSTOP to avoid racing with PID detection. - command = f"{busybox} kill -STOP $$ && exec {busybox} sh -c {quote(command)}" + command = f"{busybox} kill -STOP $$ && exec {busybox} sh -c {quote(cast(str,command))}" command_uuid, command = with_uuid(command) - adb_cmd = get_adb_command(device, 'shell', adb_server, adb_port) - full_command = f'{adb_cmd} {quote(command)}' + adb_cmd: str = get_adb_command(device, 'shell', adb_server, adb_port) + full_command: str = f'{adb_cmd} {quote(cast(str,command))}' logger.debug(full_command) - p = subprocess.Popen(full_command, stdout=stdout, stderr=stderr, stdin=subprocess.PIPE, shell=True) + p: 'Popen' = subprocess.Popen(full_command, stdout=stdout, stderr=stderr, stdin=subprocess.PIPE, shell=True) # Out of band PID lookup, to avoid conflicting needs with stdout redirection - grep_cmd = f'{busybox} grep {quote(command_uuid)}' + grep_cmd: str = f'{busybox} grep {quote(command_uuid)}' # Find the PID and release the blocked background command with SIGCONT. # We get multiple PIDs: # * One from the grep command itself, but we remove it with another grep command. @@ -694,15 +1163,15 @@ def with_uuid(cmd): # For each of the parent layer, we issue SIGCONT as it is harmless and # avoids having to rely on PID ordering (which could be misleading if PIDs # got recycled). - find_pid = f'''pids=$({busybox} ps -A -o pid,args | {grep_cmd} | {busybox} grep -v {quote(grep_cmd)} | {busybox} awk '{{print $1}}') && {busybox} printf "%s" "$pids" && {busybox} kill -CONT $pids''' + find_pid: str = f'''pids=$({busybox} ps -A -o pid,args | {grep_cmd} | {busybox} grep -v {quote(grep_cmd)} | {busybox} awk '{{print $1}}') && {busybox} printf "%s" "$pids" && {busybox} kill -CONT $pids''' - excep = None + excep: Optional[Exception] = None for _ in range(5): try: - pids = conn.execute(find_pid, as_root=as_root) + pids: str = conn.execute(find_pid, as_root=as_root) # We choose the highest PID as the "control" PID. It actually does not # really matter which one we pick, as they are all equivalent sh -c layers. - pid = max(map(int, pids.split())) + pid: int = max(map(int, pids.split())) except TargetStableError: raise except Exception as e: @@ -712,72 +1181,167 @@ def with_uuid(cmd): else: break else: - raise TargetTransientError(f'Could not detect PID of background command: {orig_command}') from excep + raise TargetTransientError(f'Could not detect PID of background command: {cast(str,orig_command)}') from excep return (p, pid) -def adb_kill_server(timeout=30, adb_server=None, adb_port=None): + +def adb_kill_server(timeout: Optional[int] = 30, adb_server: Optional[str] = None, + adb_port: Optional[int] = None) -> None: + """ + Issue 'adb kill-server' to forcibly shut down the local ADB server. + + :param timeout: Seconds to wait for the command. + :type timeout: int or None + :param adb_server: Optional custom server host. + :type adb_server: str or None + :param adb_port: Optional custom server port. + :type adb_port: int or None + """ adb_command(None, 'kill-server', timeout, adb_server, adb_port) -def adb_list_devices(adb_server=None, adb_port=None): - output = adb_command(None, 'devices', adb_server=adb_server, adb_port=adb_port) - devices = [] + +def adb_list_devices(adb_server: Optional[str] = None, adb_port: Optional[int] = None) -> List[AdbDevice]: + """ + List all devices known to ADB by running 'adb devices'. Each line is parsed + into an :class:`AdbDevice`. + + :param adb_server: Custom ADB server hostname. + :type adb_server: str or None + :param adb_port: Custom ADB server port. + :type adb_port: int or None + :returns: A list of AdbDevice objects describing connected devices. + :rtype: list of AdbDevice + """ + output: str = adb_command(None, 'devices', adb_server=adb_server, adb_port=adb_port) + devices: List[AdbDevice] = [] for line in output.splitlines(): - parts = [p.strip() for p in line.split()] + parts: List[str] = [p.strip() for p in line.split()] if len(parts) == 2: devices.append(AdbDevice(*parts)) return devices -def _get_adb_parts(command, device=None, adb_server=None, adb_port=None, quote_adb=True): +def _get_adb_parts(command: Union[Tuple[str], Tuple[str, str]], device: Optional[str] = None, + adb_server: Optional[str] = None, adb_port: Optional[int] = None, + quote_adb: bool = True) -> Tuple[PartsType, Dict[str, str]]: + """ + Build a tuple of adb command parts, plus environment variables. + + :param command: A tuple of command parts (like ('shell', 'ls')). + :param device: The device name or None if no device param used. + :param adb_server: Host/IP of custom adb server if set. + :param adb_port: Port of custom adb server if set. + :param quote_adb: Whether to quote the server/port args. + :returns: A tuple containing the command parts, plus a dict of env updates. + :rtype: (tuple, dict) + """ _quote = quote if quote_adb else lambda x: x - parts = ( - _ANDROID_ENV.get_env('adb'), + + parts: PartsType = ( + cast(str, _ANDROID_ENV.get_env('adb')), *(('-H', _quote(adb_server)) if adb_server is not None else ()), *(('-P', _quote(str(adb_port))) if adb_port is not None else ()), *(('-s', _quote(device)) if device is not None else ()), *command, ) - env = {'LC_ALL': 'C'} + env: Dict[str, str] = {'LC_ALL': 'C'} return (parts, env) -def get_adb_command(device, command, adb_server=None, adb_port=None): - parts, env = _get_adb_parts((command,), device, adb_server, adb_port, quote_adb=True) - env = [quote(f'{name}={val}') for name, val in sorted(env.items())] - parts = [*env, *parts] - return ' '.join(parts) +def get_adb_command(device: Optional[str], command: str, adb_server: Optional[str] = None, + adb_port: Optional[int] = None) -> str: + """ + Build a single-string 'adb' command that can be run in a host shell. + + :param device: The device serial or IP:port, or None to skip. + :type device: str or None + :param command: The subcommand, e.g. 'shell', 'push', etc. + :type command: str + :param adb_server: Optional custom server address. + :type adb_server: str or None + :param adb_port: Optional custom server port. + :type adb_port: int or None + :returns: A fully expanded command string including environment variables for LC_ALL. + :rtype: str + """ + partstemp, envtemp = _get_adb_parts((command,), device, adb_server, adb_port, quote_adb=True) + env: List[str] = [quote(f'{name}={val}') for name, val in sorted(envtemp.items())] + parts = [*env, *partstemp] + return ' '.join(cast(List[str], parts)) -def adb_command(device, command, timeout=None, adb_server=None, adb_port=None): - full_command = get_adb_command(device, command, adb_server, adb_port) +def adb_command(device: Optional[str], command: str, timeout: Optional[int] = None, + adb_server: Optional[str] = None, adb_port: Optional[int] = None) -> str: + """ + Build and run an 'adb' command synchronously, returning its combined output. + + :param device: Device name, or None if only one or no device is expected. + :type device: str or None + :param command: A subcommand or subcommand + arguments (e.g. 'push file /sdcard/'). + :type command: str + :param timeout: Seconds to wait for completion (None for no limit). + :type timeout: int or None + :param adb_server: Custom ADB server host if needed. + :type adb_server: str or None + :param adb_port: Custom ADB server port if needed. + :type adb_port: int or None + :returns: The command's output as a decoded string. + :rtype: str + :raises HostError: If the command fails or returns non-zero. + """ + full_command: str = get_adb_command(device, command, adb_server, adb_port) logger.debug(full_command) output, _ = check_output(full_command, timeout, shell=True) return output -def adb_command_background(device, conn, command, adb_server=None, adb_port=None): - full_command = get_adb_command(device, command, adb_server, adb_port) +def adb_command_background(device: Optional[str], conn: AdbConnection, command: str, + adb_server: Optional[str] = None, adb_port: Optional[int] = None) -> PopenBackgroundCommand: + """ + Build and run an 'adb' command in the background, returning a handle. + + :param device: The device serial or IP, or None if unspecified. + :type device: str or None + :param conn: The active AdbConnection instance. + :type conn: AdbConnection + :param command: The adb subcommand string. + :type command: str + :param adb_server: Custom adb server address if required. + :type adb_server: str or None + :param adb_port: Custom adb server port if required. + :type adb_port: int or None + :returns: A PopenBackgroundCommand object referencing the running 'adb' command. + :rtype: PopenBackgroundCommand + """ + full_command: str = get_adb_command(device, command, adb_server, adb_port) logger.debug(full_command) - popen = get_subprocess(full_command, shell=True) + popen: 'Popen' = get_subprocess(full_command, shell=True) cmd = PopenBackgroundCommand(conn=conn, popen=popen) return cmd -def grant_app_permissions(target, package): +def grant_app_permissions(target: 'AndroidTarget', package: str) -> None: """ - Grant an app all the permissions it may ask for + Grant all requested permissions to an installed app package by parsing the + 'dumpsys package' output. + + :param target: The Android target on which the package is installed. + :type target: AndroidTarget + :param package: The package name (e.g., "com.example.app"). + :type package: str + :raises TargetStableError: If permission granting fails or the package is invalid. """ - dumpsys = target.execute('dumpsys package {}'.format(package)) + dumpsys: str = target.execute('dumpsys package {}'.format(package)) - permissions = re.search( + permissions: Optional[Match[str]] = re.search( r'requested permissions:\s*(?P(android.permission.+\s*)+)', dumpsys ) if permissions is None: return - permissions = permissions.group('permissions').replace(" ", "").splitlines() + permissions_list: List[str] = permissions.group('permissions').replace(" ", "").splitlines() - for permission in permissions: + for permission in permissions_list: try: target.execute('pm grant {} {}'.format(package, permission)) except TargetStableError: @@ -789,10 +1353,19 @@ class _AndroidEnvironment: # Make the initialization lazy so that we don't trigger an exception if the # user imports the module (directly or indirectly) without actually using # anything from it + """ + Lazy-initialized environment data for Android tools (adb, aapt, etc.), + constructed from ANDROID_HOME or by scanning the system PATH. + """ @property @functools.lru_cache(maxsize=None) - def env(self): - android_home = os.getenv('ANDROID_HOME') + def env(self) -> Android_Env_Type: + """ + :returns: The discovered Android environment mapping with keys like 'adb', 'aapt', etc. + :rtype: Android_Env_Type + :raises HostError: If we cannot find a suitable ANDROID_HOME or 'adb' in PATH. + """ + android_home: Optional[str] = os.getenv('ANDROID_HOME') if android_home: env = self._from_android_home(android_home) else: @@ -800,52 +1373,91 @@ def env(self): return env - def get_env(self, name): + def get_env(self, name: Android_Env_TypeKeys) -> Optional[Union[str, int]]: + """ + Retrieve a specific environment field, such as 'adb', 'aapt', or 'build_tools'. + + :param name: Name of the environment key. + :type name: Android_Env_TypeKeys + :returns: The value if found, else None. + :rtype: str or int or None + """ return self.env[name] @classmethod - def _from_android_home(cls, android_home): + def _from_android_home(cls, android_home: str) -> Android_Env_Type: + """ + Build environment info from ANDROID_HOME. + + :param android_home: Path to Android SDK root. + :type android_home: str + :returns: Dictionary of environment settings. + :rtype: Android_Env_Type + """ logger.debug('Using ANDROID_HOME from the environment.') platform_tools = os.path.join(android_home, 'platform-tools') - return { + return cast(Android_Env_Type, { 'android_home': android_home, 'platform_tools': platform_tools, 'adb': os.path.join(platform_tools, 'adb'), 'fastboot': os.path.join(platform_tools, 'fastboot'), **cls._init_common(android_home) - } + }) @classmethod - def _from_adb(cls): + def _from_adb(cls) -> Android_Env_Type: + """ + Attempt to derive environment info by locating 'adb' on the system PATH. + + :returns: A dictionary of environment settings. + :rtype: Android_Env_Type + :raises HostError: If 'adb' is not found in PATH. + """ adb_path = which('adb') if adb_path: logger.debug('Discovering ANDROID_HOME from adb path.') platform_tools = os.path.dirname(adb_path) android_home = os.path.dirname(platform_tools) - return { + return cast(Android_Env_Type, { 'android_home': android_home, 'platform_tools': platform_tools, 'adb': adb_path, 'fastboot': which('fastboot'), **cls._init_common(android_home) - } + }) else: raise HostError('ANDROID_HOME is not set and adb is not in PATH. ' 'Have you installed Android SDK?') @classmethod - def _init_common(cls, android_home): + def _init_common(cls, android_home: str) -> BuildToolsInfo: + """ + Discover build tools, aapt, etc., from an Android SDK layout. + + :param android_home: Android SDK root path. + :type android_home: str + :returns: Partial dictionary with keys like 'build_tools', 'aapt', 'aapt_version'. + :rtype: BuildToolsInfo + """ logger.debug(f'ANDROID_HOME: {android_home}') build_tools = cls._discover_build_tools(android_home) - return { + return cast(BuildToolsInfo, { 'build_tools': build_tools, **cls._discover_aapt(build_tools) - } + }) @staticmethod - def _discover_build_tools(android_home): + def _discover_build_tools(android_home: str) -> Optional[str]: + """ + Attempt to locate the build-tools directory under android_home. + + :param android_home: Path to the SDK. + :type android_home: str + :returns: Path to build-tools if found, else None. + :rtype: str or None + """ build_tools = os.path.join(android_home, 'build-tools') if os.path.isdir(build_tools): return build_tools @@ -853,7 +1465,15 @@ def _discover_build_tools(android_home): return None @staticmethod - def _check_supported_aapt2(binary): + def _check_supported_aapt2(binary: str) -> bool: + """ + Check if a given 'aapt2' binary supports 'dump badging'. + + :param binary: Path to the aapt2 binary. + :type binary: str + :returns: True if the binary appears to support the 'badging' command, else False. + :rtype: bool + """ # At time of writing the version argument of aapt2 is not helpful as # the output is only a placeholder that does not distinguish between versions # with and without support for badging. Unfortunately aapt has been @@ -862,32 +1482,47 @@ def _check_supported_aapt2(binary): # Try to execute the badging command and check if we get an expected error # message as opposed to an unknown command error to determine if we have a # suitable version. - result = subprocess.run([str(binary), 'dump', 'badging'], stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, universal_newlines=True) + """ + check if aapt2 is supported + """ + result: 'CompletedProcess' = subprocess.run([str(binary), 'dump', 'badging'], + stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, + universal_newlines=True) supported = bool(AAPT_BADGING_OUTPUT.search(result.stderr)) - msg = 'Found a {} aapt2 binary at: {}' + msg: str = 'Found a {} aapt2 binary at: {}' logger.debug(msg.format('supported' if supported else 'unsupported', binary)) return supported @classmethod - def _discover_aapt(cls, build_tools): + def _discover_aapt(cls, build_tools: Optional[str]) -> Dict[str, Optional[Union[str, int]]]: + """ + Attempt to find 'aapt2' or 'aapt' in build-tools (or PATH fallback). + Prefers aapt2 if available. + + :param build_tools: Path to the build-tools directory or None if unknown. + :type build_tools: str or None + :returns: A dictionary with 'aapt' and 'aapt_version' keys. + :rtype: dict + :raises HostError: If neither aapt nor aapt2 is found. + """ if build_tools: - def find_aapt2(version): + def find_aapt2(version: str) -> Tuple[Optional[int], Optional[str]]: path = os.path.join(build_tools, version, 'aapt2') if os.path.isfile(path) and cls._check_supported_aapt2(path): return (2, path) else: return (None, None) - def find_aapt(version): - path = os.path.join(build_tools, version, 'aapt') + def find_aapt(version: str) -> Tuple[Optional[int], Optional[str]]: + path: str = os.path.join(build_tools, version, 'aapt') if os.path.isfile(path): return (1, path) else: return (None, None) - versions = os.listdir(build_tools) - found = ( + versions: List[str] = os.listdir(build_tools) + found: Generator[Tuple[str, Tuple[Optional[int], Optional[str]]]] = ( (version, finder(version)) for version in reversed(sorted(versions)) for finder in (find_aapt2, find_aapt) @@ -902,7 +1537,7 @@ def find_aapt(version): ) # Try detecting aapt2 and aapt from PATH - aapt2_path = which('aapt2') + aapt2_path: Optional[str] = which('aapt2') aapt_path = which('aapt') if aapt2_path and cls._check_supported_aapt2(aapt2_path): return dict( @@ -932,24 +1567,32 @@ class LogcatMonitor(object): """ @property - def logfile(self): + def logfile(self) -> Optional[Union['TextIOWrapper', '_TemporaryFileWrapper[str]']]: + """ + Return the file-like object that logcat is writing to, if any. + + :returns: The log file or None. + :rtype: file-like or None + """ return self._logfile - def __init__(self, target, regexps=None, logcat_format=None): + def __init__(self, target: 'AndroidTarget', regexps: Optional[List[str]] = None, + logcat_format: Optional[str] = None): super(LogcatMonitor, self).__init__() self.target = target self._regexps = regexps self._logcat_format = logcat_format - self._logcat = None - self._logfile = None + self._logcat: Optional[spawn] = None + self._logfile: Optional[Union['TextIOWrapper', '_TemporaryFileWrapper[str]']] = None - def start(self, outfile=None): + def start(self, outfile: Optional[str] = None) -> None: """ - Start logcat and begin monitoring + Begin capturing logcat output. If outfile is given, logcat lines are + appended there; otherwise, a temporary file is used. - :param outfile: Optional path to file to store all logcat entries - :type outfile: str + :param outfile: A path to a file on the host, or None for a temporary file. + :type outfile: str or None """ if outfile: self._logfile = open(outfile, 'w') @@ -958,11 +1601,11 @@ def start(self, outfile=None): self.target.clear_logcat() - logcat_cmd = 'logcat' + logcat_cmd: str = 'logcat' # Join all requested regexps with an 'or' if self._regexps: - regexp = '{}'.format('|'.join(self._regexps)) + regexp: str = '{}'.format('|'.join(self._regexps)) if len(self._regexps) > 1: regexp = '({})'.format(regexp) # Logcat on older version of android do not support the -e argument @@ -975,26 +1618,42 @@ def start(self, outfile=None): if self._logcat_format: logcat_cmd = "{} -v {}".format(logcat_cmd, quote(self._logcat_format)) - logcat_cmd = get_adb_command(self.target.conn.device, logcat_cmd, self.target.adb_server, self.target.adb_port) + logcat_cmd = get_adb_command(self.target.conn.device, + logcat_cmd, self.target.adb_server, + self.target.adb_port) if isinstance(self.target.conn, AdbConnection) else '' logger.debug('logcat command ="{}"'.format(logcat_cmd)) self._logcat = pexpect.spawn(logcat_cmd, logfile=self._logfile, encoding='utf-8') - def stop(self): + def stop(self) -> None: + """ + Stop capturing logcat and close the log file if applicable. + """ self.flush_log() - self._logcat.terminate() - self._logfile.close() + if self._logcat: + self._logcat.terminate() + if self._logfile: + self._logfile.close() - def get_log(self): + def get_log(self) -> List[str]: """ - Return the list of lines found by the monitor + Retrieve all captured lines from the log so far. + + :returns: A list of log lines from the log file. + :rtype: list of str """ self.flush_log() + if self._logfile: + with open(self._logfile.name) as fh: + return [line for line in fh] + else: + return [] - with open(self._logfile.name) as fh: - return [line for line in fh] - - def flush_log(self): + def flush_log(self) -> None: + """ + Force-read all pending data from the logcat pexpect spawn to ensure it's + written to the logfile. Prevents missed lines if pexpect hasn't pulled them yet. + """ # Unless we tell pexect to 'expect' something, it won't read from # logcat's buffer or write into our logfile. We'll need to force it to # read any pending logcat output. @@ -1004,7 +1663,9 @@ def flush_log(self): # This will read up to read_size bytes, but only those that are # already ready (i.e. it won't block). If there aren't any bytes # already available it raises pexpect.TIMEOUT. - buf = self._logcat.read_nonblocking(read_size, timeout=0) + buf: str = '' + if self._logcat: + buf = self._logcat.read_nonblocking(read_size, timeout=0) # We can't just keep calling read_nonblocking until we get a # pexpect.TIMEOUT (i.e. until we don't find any available @@ -1025,18 +1686,27 @@ def flush_log(self): # printed anything since pexpect last read from its buffer. break - def clear_log(self): - with open(self._logfile.name, 'w') as _: - pass + def clear_log(self) -> None: + """ + Erase current content of the log file so subsequent calls to get_log() + won't return older lines. + """ + if self._logfile: + with open(self._logfile.name, 'w') as _: + pass - def search(self, regexp): + def search(self, regexp: str) -> List[str]: """ - Search a line that matches a regexp in the logcat log - Return immediatly + Search the captured lines for matches of the given regexp. + + :param regexp: A regular expression pattern. + :type regexp: str + :returns: All matching lines found so far. + :rtype: list of str """ return [line for line in self.get_log() if re.match(regexp, line)] - def wait_for(self, regexp, timeout=30): + def wait_for(self, regexp: str, timeout: Optional[int] = 30) -> List[str]: """ Search a line that matches a regexp in the logcat log Wait for it to appear if it's not found @@ -1049,9 +1719,11 @@ def wait_for(self, regexp, timeout=30): :type timeout: number :returns: List of matched strings + :rtype: list of str + :raises RuntimeError: If the regex is not found within ``timeout`` seconds. """ - log = self.get_log() - res = [line for line in log if re.match(regexp, line)] + log: List[str] = self.get_log() + res: List[str] = [line for line in log if re.match(regexp, line)] # Found some matches, return them if res: @@ -1059,15 +1731,16 @@ def wait_for(self, regexp, timeout=30): # Store the number of lines we've searched already, so we don't have to # re-grep them after 'expect' returns - next_line_num = len(log) + next_line_num: int = len(log) try: - self._logcat.expect(regexp, timeout=timeout) + if self._logcat: + self._logcat.expect(regexp, timeout=timeout) except pexpect.TIMEOUT: raise RuntimeError('Logcat monitor timeout ({}s)'.format(timeout)) return [line for line in self.get_log()[next_line_num:] if re.match(regexp, line)] -_ANDROID_ENV = _AndroidEnvironment() +_ANDROID_ENV = _AndroidEnvironment() diff --git a/devlib/utils/annotation_helpers.py b/devlib/utils/annotation_helpers.py new file mode 100644 index 000000000..cee651a8c --- /dev/null +++ b/devlib/utils/annotation_helpers.py @@ -0,0 +1,72 @@ +# Copyright 2025 ARM Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Helpers to annotate the code + +""" +import sys +from typing import Union, Sequence, Optional +from typing_extensions import NotRequired, LiteralString, TYPE_CHECKING, TypedDict +if TYPE_CHECKING: + from _typeshed import StrPath, BytesPath + from devlib.platform import Platform + from devlib.utils.android import AdbConnection + from devlib.utils.ssh import SshConnection + from devlib.host import LocalConnection + from devlib.connection import PopenBackgroundCommand, AdbBackgroundCommand, ParamikoBackgroundCommand +else: + StrPath = str + BytesPath = bytes + + +import os +if sys.version_info >= (3, 9): + SubprocessCommand = Union[ + str, bytes, os.PathLike[str], os.PathLike[bytes], + Sequence[Union[str, bytes, os.PathLike[str], os.PathLike[bytes]]]] +else: + SubprocessCommand = Union[str, bytes, os.PathLike, + Sequence[Union[str, bytes, os.PathLike]]] + +BackgroundCommand = Union['AdbBackgroundCommand', 'ParamikoBackgroundCommand', 'PopenBackgroundCommand'] + +SupportedConnections = Union['LocalConnection', 'AdbConnection', 'SshConnection'] + + +class SshUserConnectionSettings(TypedDict, total=False): + username: NotRequired[str] + password: NotRequired[str] + keyfile: NotRequired[Optional[Union[LiteralString, StrPath, BytesPath]]] + host: NotRequired[str] + port: NotRequired[int] + timeout: NotRequired[float] + platform: NotRequired['Platform'] + sudo_cmd: NotRequired[str] + strict_host_check: NotRequired[bool] + use_scp: NotRequired[bool] + poll_transfers: NotRequired[bool] + start_transfer_poll_delay: NotRequired[int] + total_transfer_timeout: NotRequired[int] + transfer_poll_period: NotRequired[int] + + +class AdbUserConnectionSettings(SshUserConnectionSettings): + device: NotRequired[str] + adb_server: NotRequired[str] + adb_port: NotRequired[int] + + +UserConnectionSettings = Union[SshUserConnectionSettings, AdbUserConnectionSettings] diff --git a/devlib/utils/asyn.py b/devlib/utils/asyn.py index dd6d42d59..1df1559fb 100644 --- a/devlib/utils/asyn.py +++ b/devlib/utils/asyn.py @@ -1,4 +1,4 @@ -# Copyright 2013-2018 ARM Limited +# Copyright 2013-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,24 +30,49 @@ import inspect import sys import threading -from concurrent.futures import ThreadPoolExecutor -from weakref import WeakSet, WeakKeyDictionary +from concurrent.futures import ThreadPoolExecutor, Future +from weakref import WeakSet from greenlet import greenlet - - -def create_task(awaitable, name=None): +from typing import (AsyncGenerator, Any, Callable, TypeVar, Type, + Optional, Coroutine, Tuple, Dict, cast, Set, + List, Generator, Union, AsyncContextManager, + Iterable, Awaitable) +from asyncio import Task, AbstractEventLoop +from inspect import Signature, BoundArguments +from contextvars import Context +from queue import SimpleQueue +from threading import local + + +def create_task(awaitable: Awaitable, name: Optional[str] = None) -> Task: + """ + Create a new asyncio Task from an awaitable and set its name. + + :param awaitable: A coroutine or awaitable object to schedule. + :type awaitable: Coroutine + :param name: An optional name for the task. If None, attempts to use the awaitable’s __qualname__. + :type name: str or None + :returns: The created asyncio Task. + :rtype: Taskcreate a new task using asyncio create_task and assign the name provided to it + """ if isinstance(awaitable, asyncio.Task): - task = awaitable + task: Task = awaitable else: - task = asyncio.create_task(awaitable) + task = asyncio.create_task(cast(Coroutine, awaitable)) if name is None: name = getattr(awaitable, '__qualname__', None) - task.name = name + task.set_name(name) return task -def _close_loop(loop): +def _close_loop(loop: Optional[AbstractEventLoop]) -> None: + """ + Close an asyncio event loop after shutting down asynchronous generators and the default executor. + + :param loop: The event loop to close, or None. + :type loop: AbstractEventLoop or None + """ if loop is not None: try: loop.run_until_complete(loop.shutdown_asyncgens()) @@ -62,11 +87,21 @@ def _close_loop(loop): class AsyncManager: - def __init__(self): - self.task_tree = dict() - self.resources = dict() + """ + Manages asynchronous operations by tracking tasks and ensuring that concurrently + running asynchronous functions do not interfere with one another. + + This manager maintains a mapping of tasks to resources and allows running tasks + concurrently while checking for overlapping resource usage. + """ + def __init__(self) -> None: + """ + Initialize the AsyncManager with empty task trees and resource maps. + """ + self.task_tree: Dict[Task, Set[Task]] = dict() + self.resources: Dict[Task, Set['ConcurrentAccessBase']] = dict() - def track_access(self, access): + def track_access(self, access: 'ConcurrentAccessBase') -> None: """ Register the given ``access`` to have been handled by the current async task. @@ -78,39 +113,48 @@ def track_access(self, access): step on each other's toes. """ try: - task = asyncio.current_task() + task: Optional[Task] = asyncio.current_task() except RuntimeError: pass else: - self.resources.setdefault(task, set()).add(access) + if task: + self.resources.setdefault(task, set()).add(access) - async def concurrently(self, awaitables): + async def concurrently(self, awaitables: Iterable[Awaitable]) -> List[Any]: """ Await concurrently for the given awaitables, and cancel them as soon as one raises an exception. + + :param awaitables: An iterable of coroutine objects to run concurrently. + :type awaitables: Iterable[Coroutine] + :returns: A list with the results of the awaitables. + :rtype: list + :raises Exception: Propagates the first exception encountered, canceling the others. """ - awaitables = list(awaitables) + awaitables_list: List[Awaitable] = list(awaitables) # Avoid creating asyncio.Tasks when it's not necessary, as it will # disable a the blocking path optimization of Target._execute_async() # that uses blocking calls as long as there is only one asyncio.Task # running on the event loop. - if len(awaitables) == 1: - return [await awaitables[0]] + if len(awaitables_list) == 1: + return [await awaitables_list[0]] - tasks = list(map(create_task, awaitables)) + tasks: List[Task] = list(map(create_task, awaitables_list)) - current_task = asyncio.current_task() - task_tree = self.task_tree + current_task: Optional[Task] = asyncio.current_task() + task_tree: Dict[Task, Set[Task]] = self.task_tree try: - node = task_tree[current_task] + if current_task: + node: Set[Task] = task_tree[current_task] except KeyError: - is_root_task = True + is_root_task: bool = True node = set() else: is_root_task = False - task_tree[current_task] = node + if current_task: + task_tree[current_task] = node task_tree.update({ child: set() @@ -126,8 +170,12 @@ async def concurrently(self, awaitables): raise finally: - def get_children(task): - immediate_children = task_tree[task] + def get_children(task: Task) -> frozenset[Task]: + """ + get the children of the task and their children etc and return as a + single set + """ + immediate_children: Set[Task] = task_tree[task] return frozenset( itertools.chain( [task], @@ -140,7 +188,7 @@ def get_children(task): # Get the resources created during the execution of each subtask # (directly or indirectly) - resources = { + resources: Dict[Task, frozenset['ConcurrentAccessBase']] = { task: frozenset( itertools.chain.from_iterable( self.resources.get(child, []) @@ -153,20 +201,25 @@ def get_children(task): for res1, res2 in itertools.product(resources1, resources2): if issubclass(res2.__class__, res1.__class__) and res1.overlap_with(res2): raise RuntimeError( - 'Overlapping resources manipulated in concurrent async tasks: {} (task {}) and {} (task {})'.format(res1, task1.name, res2, task2.name) + 'Overlapping resources manipulated in concurrent async tasks: {} (task {}) and {} (task {})'.format(res1, task1.get_name(), res2, task2.get_name()) ) if is_root_task: self.resources.clear() task_tree.clear() - async def map_concurrently(self, f, keys): + async def map_concurrently(self, f: Callable, keys: Any) -> Dict: """ Similar to :meth:`concurrently`, but maps the given function ``f`` on the given ``keys``. + :param f: The function to apply to each key. + :type f: Callable + :param keys: An iterable of keys. + :type keys: Any :return: A dictionary with ``keys`` as keys, and function result as values. + :rtype: dict """ keys = list(keys) return dict(zip( @@ -175,13 +228,18 @@ async def map_concurrently(self, f, keys): )) -def compose(*coros): +def compose(*coros: Callable) -> Callable[..., Coroutine]: """ Compose coroutines, feeding the output of each as the input of the next one. ``await compose(f, g)(x)`` is equivalent to ``await f(await g(x))`` + :param coros: A variable number of coroutine functions. + :type coros: Callable + :returns: A callable that, when awaited, composes the coroutines in sequence. + :rtype: Callable[..., Coroutine] + .. note:: In Haskell, ``compose f g h`` would be equivalent to ``f <=< g <=< h`` """ async def f(*args, **kwargs): @@ -205,8 +263,13 @@ class _AsyncPolymorphicFunction: When called, the blocking synchronous operation is called. The ```asyn`` attribute gives access to the asynchronous version of the function, and all the other attribute access will be redirected to the async function. + + :param asyn: The asynchronous version of the function. + :type asyn: Callable + :param blocking: The synchronous (blocking) version of the function. + :type blocking: Callable """ - def __init__(self, asyn, blocking): + def __init__(self, asyn: Callable[..., Awaitable], blocking: Callable[..., Any]): self.asyn = asyn self.blocking = blocking functools.update_wrapper(self, asyn) @@ -240,36 +303,46 @@ class memoized_method: * non-async methods * method already decorated with :func:`devlib.asyn.asyncf`. + :param f: The method to memoize. + :type f: Callable + .. note:: This decorator does not rely on hacks to hash unhashable data. If such input is required, it will either have to be coerced to a hashable first (e.g. converting a list to a tuple), or the code of :func:`devlib.asyn.memoized_method` will have to be updated to do so. """ - def __init__(self, f): - memo = self - - sig = inspect.signature(f) - - def bind(self, *args, **kwargs): - bound = sig.bind(self, *args, **kwargs) + def __init__(self, f: Callable): + memo: 'memoized_method' = self + + sig: Signature = inspect.signature(f) + + def bind(self, *args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], + Tuple[Any, ...], + Dict[str, Any]]: + """ + bind arguments to function signature + """ + bound: BoundArguments = sig.bind(self, *args, **kwargs) bound.apply_defaults() key = (bound.args[1:], tuple(sorted(bound.kwargs.items()))) return (key, bound.args, bound.kwargs) - def get_cache(self): + def get_cache(self) -> Dict[Tuple[Any, ...], Any]: try: - cache = self.__dict__[memo.name] + cache: Dict[Tuple[Any, ...], Any] = self.__dict__[memo.name] except KeyError: cache = {} self.__dict__[memo.name] = cache return cache - if inspect.iscoroutinefunction(f): @functools.wraps(f) - async def wrapper(self, *args, **kwargs): - cache = get_cache(self) + async def async_wrapper(self, *args: Any, **kwargs: Any) -> Any: + """ + wrapper for async functions + """ + cache: Dict[Tuple[Any, ...], Any] = get_cache(self) key, args, kwargs = bind(self, *args, **kwargs) try: return cache[key] @@ -277,9 +350,13 @@ async def wrapper(self, *args, **kwargs): x = await f(*args, **kwargs) cache[key] = x return x + self.f: Callable[..., Coroutine] = async_wrapper else: @functools.wraps(f) - def wrapper(self, *args, **kwargs): + def sync_wrapper(self, *args: Any, **kwargs: Any) -> Any: + """ + wrapper for sync functions + """ cache = get_cache(self) key, args, kwargs = bind(self, *args, **kwargs) try: @@ -288,25 +365,24 @@ def wrapper(self, *args, **kwargs): x = f(*args, **kwargs) cache[key] = x return x + self.f = sync_wrapper - - self.f = wrapper self._name = f.__name__ @property - def name(self): + def name(self) -> str: return '__memoization_cache_of_' + self._name def __call__(self, *args, **kwargs): return self.f(*args, **kwargs) - def __get__(self, obj, owner=None): + def __get__(self, obj: Optional['memoized_method'], owner: Optional[Type['memoized_method']] = None) -> Any: return self.f.__get__(obj, owner) - def __set__(self, obj, value): + def __set__(self, obj: 'memoized_method', value: Any): raise RuntimeError("Cannot monkey-patch a memoized function") - def __set_name__(self, owner, name): + def __set_name__(self, owner: Type['memoized_method'], name: str): self._name = name @@ -325,22 +401,36 @@ def __init__(self, *args, **kwargs): self.gr_context = contextvars.copy_context() @classmethod - def from_coro(cls, coro): + def from_coro(cls, coro: Coroutine) -> '_Genlet': """ Create a :class:`_Genlet` from a given coroutine, treating it as a generator. + + :param coro: The coroutine to wrap. + :type coro: Coroutine + :returns: A _Genlet that wraps the coroutine. + :rtype: _Genlet """ - f = lambda value: self.consume_coro(coro, value) + def f(value: Any) -> Any: + return self.consume_coro(coro, value) self = cls(f) return self - def consume_coro(self, coro, value): + def consume_coro(self, coro: Coroutine, value: Any) -> Any: """ Send ``value`` to ``coro`` then consume the coroutine, passing all its yielded actions to the enclosing :class:`_Genlet`. This allows crossing blocking calls layers as if they were async calls with `await`. + + :param coro: The coroutine to consume. + :type coro: Coroutine + :param value: The initial value to send. + :type value: Any + :returns: The final value returned by the coroutine. + :rtype: Any + :raises StopIteration: When the coroutine is exhausted. """ - excep = None + excep: Optional[BaseException] = None while True: try: if excep is None: @@ -351,11 +441,11 @@ def consume_coro(self, coro, value): except StopIteration as e: return e.value else: - parent = self.parent + parent: Optional[greenlet] = self.parent # Switch back to the consumer that returns the values via # send() try: - value = parent.switch(future) + value = parent.switch(future) if parent else None except BaseException as e: excep = e value = None @@ -363,17 +453,31 @@ def consume_coro(self, coro, value): excep = None @classmethod - def get_enclosing(cls): + def get_enclosing(cls) -> Optional['_Genlet']: """ Get the immediately enclosing :class:`_Genlet` in the callstack or ``None``. + + :returns: The nearest _Genlet instance in the chain, or None if not found. + :rtype: _Genlet or None """ g = greenlet.getcurrent() while not (isinstance(g, cls) or g is None): g = g.parent return g - def _send_throw(self, value, excep): + def _send_throw(self, value: Optional['_Genlet'], excep: Optional[BaseException]) -> Any: + """ + helper function to do switch to another genlet or throw exception + + :param value: The value to send to the parent. + :type value: _Genlet or None + :param excep: The exception to throw in the parent, or None. + :type excep: BaseException or None + :returns: The result returned from the parent's switch. + :rtype: Any + :raises StopIteration: If the parent completes. + """ self.parent = greenlet.getcurrent() # Switch back to the function yielding values @@ -387,55 +491,86 @@ def _send_throw(self, value, excep): else: raise StopIteration(result) - def gen_send(self, x): + def gen_send(self, x: Optional['_Genlet']) -> Any: """ Similar to generators' ``send`` method. + + :param x: The value to send. + :type x: _Genlet or None + :returns: The value received from the parent. + :rtype: Any """ return self._send_throw(x, None) - def gen_throw(self, x): + def gen_throw(self, x: Optional[BaseException]): """ Similar to generators' ``throw`` method. + + :param x: The exception to throw. + :type x: BaseException or None + :returns: The value received from the parent after handling the exception. + :rtype: Any """ return self._send_throw(None, x) class _AwaitableGenlet: """ - Wrap a coroutine with a :class:`_Genlet` and wrap that to be awaitable. + Wraps a coroutine with a :class:`_Genlet` to allow it to be awaited using + the normal 'await' syntax. + + :param coro: The coroutine to wrap. + :type coro: Coroutine """ @classmethod - def wrap_coro(cls, coro): - async def coro_f(): + def wrap_coro(cls, coro: Coroutine) -> Coroutine: + """ + Wrap a coroutine inside an _AwaitableGenlet so that it becomes awaitable. + + :param coro: The coroutine to wrap. + :type coro: Coroutine + :returns: An awaitable version of the coroutine. + :rtype: Coroutine + """ + async def coro_f() -> Any: # Make sure every new task will be instrumented since a task cannot # yield futures on behalf of another task. If that were to happen, # the task B trying to do a nested yield would switch back to task # A, asking to yield on its behalf. Since the event loop would be # currently handling task B, nothing would handle task A trying to # yield on behalf of B, leading to a deadlock. - loop = asyncio.get_running_loop() + loop: AbstractEventLoop = asyncio.get_running_loop() _install_task_factory(loop) # Create a top-level _AwaitableGenlet that all nested runs will use # to yield their futures - _coro = cls(coro) + _coro: '_AwaitableGenlet' = cls(coro) return await _coro return coro_f() - def __init__(self, coro): + def __init__(self, coro: Coroutine): self._coro = coro - def __await__(self): - coro = self._coro - is_started = inspect.iscoroutine(coro) and coro.cr_running + def __await__(self) -> Generator: + """ + Make the _AwaitableGenlet awaitable. + + :returns: A generator that yields from the wrapped coroutine. + :rtype: Generator + """ + coro: Coroutine = self._coro + is_started: bool = inspect.iscoroutine(coro) and coro.cr_running - def genf(): + def genf() -> Generator: + """ + generator function + """ gen = _Genlet.from_coro(coro) - value = None - excep = None + value: Optional[_Genlet] = None + excep: Optional[BaseException] = None # The coroutine is already started, so we need to dispatch the # value from the upcoming send() to the gen without running @@ -468,25 +603,39 @@ def genf(): gen = genf() if is_started: # Start the generator so it waits at the first yield point - gen.gen_send(None) + cast(_Genlet, gen).gen_send(None) return gen -def _allow_nested_run(coro): +def _allow_nested_run(coro: Coroutine) -> Coroutine: + """ + If the current callstack does not have an enclosing _Genlet, wrap the coroutine + using _AwaitableGenlet; otherwise, return the coroutine unchanged. + + :param coro: The coroutine to potentially wrap. + :type coro: Coroutine + :returns: The original coroutine or a wrapped awaitable coroutine. + :rtype: Coroutine + """ if _Genlet.get_enclosing() is None: return _AwaitableGenlet.wrap_coro(coro) else: return coro -def allow_nested_run(coro): +def allow_nested_run(coro: Coroutine) -> Coroutine: """ Wrap the coroutine ``coro`` such that nested calls to :func:`run` will be - allowed. + allowed. This is useful when a coroutine needs to yield control to another layer. .. warning:: The coroutine needs to be consumed in the same OS thread it was created in. + + :param coro: The coroutine to wrap. + :type coro: Coroutine + :returns: A possibly wrapped coroutine that allows nested execution. + :rtype: Coroutine """ return _allow_nested_run(coro) @@ -503,7 +652,15 @@ def allow_nested_run(coro): ) -def _check_executor_alive(executor): +def _check_executor_alive(executor: ThreadPoolExecutor) -> bool: + """ + Check if the given ThreadPoolExecutor is still alive by submitting a no-op job. + + :param executor: The ThreadPoolExecutor to check. + :type executor: ThreadPoolExecutor + :returns: True if the executor accepts new jobs; False otherwise. + :rtype: bool + """ try: executor.submit(lambda: None) except RuntimeError: @@ -513,29 +670,38 @@ def _check_executor_alive(executor): _PATCHED_LOOP_LOCK = threading.Lock() -_PATCHED_LOOP = WeakSet() -def _install_task_factory(loop): +_PATCHED_LOOP: WeakSet = WeakSet() + + +def _install_task_factory(loop: AbstractEventLoop): """ Install a task factory on the given event ``loop`` so that top-level coroutines are wrapped using :func:`allow_nested_run`. This ensures that the nested :func:`run` infrastructure will be available. + + :param loop: The asyncio event loop on which to install the task factory. + :type loop: AbstractEventLoop """ - def install(loop): + def install(loop: AbstractEventLoop) -> None: + """ + install the task factory on the event loop + """ if sys.version_info >= (3, 11): - def default_factory(loop, coro, context=None): + def default_factory(loop: AbstractEventLoop, coro: Coroutine, context: Optional[Context] = None) -> Optional[Task]: return asyncio.Task(coro, loop=loop, context=context) else: - def default_factory(loop, coro, context=None): + def default_factory(loop: AbstractEventLoop, coro: Coroutine, context: Optional[Context] = None) -> Optional[Task]: return asyncio.Task(coro, loop=loop) make_task = loop.get_task_factory() or default_factory - def factory(loop, coro, context=None): + + def factory(loop: AbstractEventLoop, coro: Coroutine, context: Optional[Context] = None) -> Optional[Task]: # Make sure each Task will be able to yield on behalf of its nested # await beneath blocking layers coro = _AwaitableGenlet.wrap_coro(coro) - return make_task(loop, coro, context=context) + return cast(Callable, make_task)(loop, coro, context=context) - loop.set_task_factory(factory) + loop.set_task_factory(cast(Callable, factory)) with _PATCHED_LOOP_LOCK: if loop in _PATCHED_LOOP: @@ -545,13 +711,17 @@ def factory(loop, coro, context=None): _PATCHED_LOOP.add(loop) -def _set_current_context(ctx): +def _set_current_context(ctx: Optional[Context]) -> None: """ Get all the variable from the passed ``ctx`` and set them in the current context. + + :param ctx: A Context object containing variable values to set. + :type ctx: Context or None """ - for var, val in ctx.items(): - var.set(val) + if ctx: + for var, val in ctx.items(): + var.set(val) class _CoroRunner(abc.ABC): @@ -564,10 +734,25 @@ class _CoroRunner(abc.ABC): single event loop. """ @abc.abstractmethod - def _run(self, coro): + def _run(self, coro: Coroutine) -> Any: + """ + Execute the given coroutine using the runner's mechanism. + + :param coro: The coroutine to run. + :type coro: Coroutine + """ pass - def run(self, coro): + def run(self, coro: Coroutine) -> Any: + """ + Run the provided coroutine using the implemented runner. Raises an + assertion error if the coroutine is already running. + + :param coro: The coroutine to run. + :type coro: Coroutine + :returns: The result of the coroutine. + :rtype: Any + """ # Ensure we have a fresh coroutine. inspect.getcoroutinestate() does not # work on all objects that asyncio creates on some version of Python, such # as iterable_coroutine @@ -588,26 +773,43 @@ class _ThreadCoroRunner(_CoroRunner): Critically, this allows running multiple coroutines out of the same thread, which will be reserved until the runner ``__exit__`` method is called. + + :param future: A Future representing the thread running the coroutine loop. + :type future: Future + :param jobq: A SimpleQueue for scheduling coroutine jobs. + :type jobq: SimpleQueue[Optional[Tuple[Context, Coroutine]]] + :param resq: A SimpleQueue to collect results from executed coroutines. + :type resq: SimpleQueue[Tuple[Context, Optional[BaseException], Any]] """ - def __init__(self, future, jobq, resq): + def __init__(self, future: 'Future', jobq: 'SimpleQueue[Optional[Tuple[Context, Coroutine]]]', + resq: 'SimpleQueue[Tuple[Context, Optional[BaseException], Any]]'): self._future = future self._jobq = jobq self._resq = resq @staticmethod - def _thread_f(jobq, resq): - def handle_jobs(runner): + def _thread_f(jobq: 'SimpleQueue[Optional[Tuple[Context, Coroutine]]]', + resq: 'SimpleQueue[Tuple[Context, Optional[BaseException], Any]]') -> None: + """ + Thread function that continuously processes scheduled coroutine jobs. + + :param jobq: Queue of jobs. + :type jobq: SimpleQueue + :param resq: Queue to store results from the jobs. + :type resq: SimpleQueue + """ + def handle_jobs(runner: _LoopCoroRunner) -> None: while True: - job = jobq.get() + job: Optional[Tuple[Context, Coroutine]] = jobq.get() if job is None: return else: ctx, coro = job try: - value = ctx.run(runner.run, coro) + value: Any = ctx.run(runner.run, coro) except BaseException as e: value = None - excep = e + excep: Optional[BaseException] = e else: excep = None @@ -617,12 +819,21 @@ def handle_jobs(runner): handle_jobs(runner) @classmethod - def from_executor(cls, executor): - jobq = queue.SimpleQueue() - resq = queue.SimpleQueue() + def from_executor(cls, executor: ThreadPoolExecutor) -> '_ThreadCoroRunner': + """ + Create a _ThreadCoroRunner by submitting the thread function to an executor. + + :param executor: A ThreadPoolExecutor to run the coroutine loop. + :type executor: ThreadPoolExecutor + :returns: An instance of _ThreadCoroRunner. + :rtype: _ThreadCoroRunner + :raises RuntimeError: If the executor is not alive. + """ + jobq: SimpleQueue[Optional[Tuple[Context, Coroutine]]] = queue.SimpleQueue() + resq: SimpleQueue = queue.SimpleQueue() try: - future = executor.submit(cls._thread_f, jobq, resq) + future: Future = executor.submit(cls._thread_f, jobq, resq) except RuntimeError as e: if _check_executor_alive(executor): raise e @@ -635,7 +846,16 @@ def from_executor(cls, executor): future=future, ) - def _run(self, coro): + def _run(self, coro: Coroutine) -> Any: + """ + Schedule and run a coroutine in the separate thread, waiting for its result. + + :param coro: The coroutine to execute. + :type coro: Coroutine + :returns: The result from running the coroutine. + :rtype: Any + :raises Exception: Propagates any exception raised by the coroutine. + """ ctx = contextvars.copy_context() self._jobq.put((ctx, coro)) ctx, excep, value = self._resq.get() @@ -659,20 +879,32 @@ class _LoopCoroRunner(_CoroRunner): The passed event loop is assumed to not be running. If ``None`` is passed, a new event loop will be created in ``__enter__`` and closed in ``__exit__``. + + :param loop: An event loop to use; if None, a new one is created. + :type loop: AbstractEventLoop or None """ - def __init__(self, loop): + def __init__(self, loop: Optional[AbstractEventLoop]): self.loop = loop - self._owned = False + self._owned: bool = False - def _run(self, coro): + def _run(self, coro: Coroutine) -> Any: + """ + Run the given coroutine to completion on the event loop and return its result. + + :param coro: The coroutine to run. + :type coro: Coroutine + :returns: The result of the coroutine. + :rtype: Any + """ loop = self.loop # Back-propagate the contextvars that could have been modified by the # coroutine. This could be handled by asyncio.Runner().run(..., # context=...) or loop.create_task(..., context=...) but these APIs are # only available since Python 3.11 - ctx = None - async def capture_ctx(): + ctx: Optional[Context] = None + + async def capture_ctx() -> Any: nonlocal ctx try: return await _allow_nested_run(coro) @@ -680,12 +912,13 @@ async def capture_ctx(): ctx = contextvars.copy_context() try: - return loop.run_until_complete(capture_ctx()) + if loop: + return loop.run_until_complete(capture_ctx()) finally: _set_current_context(ctx) - def __enter__(self): - loop = self.loop + def __enter__(self) -> '_LoopCoroRunner': + loop: Optional[AbstractEventLoop] = self.loop if loop is None: owned = True loop = asyncio.new_event_loop() @@ -708,16 +941,37 @@ class _GenletCoroRunner(_CoroRunner): """ Run a coroutine assuming one of the parent coroutines was wrapped with :func:`allow_nested_run`. + + :param g: The enclosing _Genlet instance. + :type g: _Genlet """ - def __init__(self, g): + def __init__(self, g: _Genlet): self._g = g - def _run(self, coro): + def _run(self, coro: Coroutine) -> Any: + """ + Execute the coroutine by delegating to the enclosing _Genlet's consume_coro method. + + :param coro: The coroutine to run. + :type coro: Coroutine + :returns: The result of the coroutine. + :rtype: Any + """ return self._g.consume_coro(coro, None) -def _get_runner(): - executor = _CORO_THREAD_EXECUTOR +def _get_runner() -> Union[_GenletCoroRunner, + _LoopCoroRunner, + _ThreadCoroRunner]: + """ + Determine the appropriate coroutine runner based on the current context. + Returns a _GenletCoroRunner if an enclosing _Genlet is present, a _LoopCoroRunner + if an event loop exists (or can be created), or a _ThreadCoroRunner if an event loop is running. + + :returns: A coroutine runner appropriate for the current execution context. + :rtype: _GenletCoroRunner or _LoopCoroRunner or _ThreadCoroRunner + """ + executor: ThreadPoolExecutor = _CORO_THREAD_EXECUTOR g = _Genlet.get_enclosing() try: loop = asyncio.get_running_loop() @@ -748,7 +1002,7 @@ def _get_runner(): return _ThreadCoroRunner.from_executor(executor) -def run(coro): +def run(coro: Coroutine) -> Any: """ Similar to :func:`asyncio.run` but can be called while an event loop is running if a coroutine higher in the callstack has been wrapped using @@ -759,13 +1013,18 @@ def run(coro): be reflected in the context of the caller. This allows context variable updates to cross an arbitrary number of run layers, as if all those layers were just part of the same coroutine. + + :param coro: The coroutine to execute. + :type coro: Coroutine + :returns: The result of the coroutine. + :rtype: Any """ runner = _get_runner() with runner as runner: return runner.run(coro) -def asyncf(f): +def asyncf(f: Callable): """ Decorator used to turn a coroutine into a blocking function, with an optional asynchronous API. @@ -787,17 +1046,22 @@ async def foo(x): This allows the same implementation to be both used as blocking for ease of use and backward compatibility, or exposed as a corountine for callers that can deal with awaitables. + + :param f: The asynchronous function to decorate. + :type f: Callable + :returns: A callable that runs f synchronously, with an asynchronous version available as .asyn. + :rtype: Callable """ @functools.wraps(f) - def blocking(*args, **kwargs): + def blocking(*args, **kwargs) -> Any: # Since run() needs a corountine, make sure we provide one - async def wrapper(): + async def wrapper() -> Generator: x = f(*args, **kwargs) # Async generators have to be consumed and accumulated in a list # before crossing a blocking boundary. if inspect.isasyncgen(x): - def genf(): + def genf() -> Generator: asyncgen = x.__aiter__() while True: try: @@ -817,18 +1081,22 @@ def genf(): class _AsyncPolymorphicCMState: - def __init__(self): - self.nesting = 0 - self.runner = None + def __init__(self) -> None: + self.nesting: int = 0 + self.runner: Optional[Union[_GenletCoroRunner, + _LoopCoroRunner, + _ThreadCoroRunner]] = None - def _update_nesting(self, n): + def _update_nesting(self, n: int) -> bool: x = self.nesting assert x >= 0 x = x + n self.nesting = x return bool(x) - def _get_runner(self): + def _get_runner(self) -> Optional[Union[_GenletCoroRunner, + _LoopCoroRunner, + _ThreadCoroRunner]]: runner = self.runner if runner is None: assert not self.nesting @@ -837,8 +1105,8 @@ def _get_runner(self): self.runner = runner return runner - def _cleanup_runner(self, force=False): - def cleanup(): + def _cleanup_runner(self, force: bool = False) -> None: + def cleanup() -> None: self.runner = None if runner is not None: runner.__exit__(None, None, None) @@ -856,13 +1124,22 @@ class _AsyncPolymorphicCM: """ Wrap an async context manager such that it exposes a synchronous API as well for backward compatibility. + + :param async_cm: The asynchronous context manager to wrap. + :type async_cm: AsyncContextManager """ - def __init__(self, async_cm): + def __init__(self, async_cm: AsyncContextManager): self.cm = async_cm - self._state = threading.local() + self._state: local = threading.local() def _get_state(self): + """ + Retrieve or initialize the thread-local state for this context manager. + + :returns: The state object. + :rtype: _AsyncPolymorphicCMState + """ try: return self._state.x except AttributeError: @@ -870,7 +1147,10 @@ def _get_state(self): self._state.x = state return state - def _delete_state(self): + def _delete_state(self) -> None: + """ + Delete the thread-local state. + """ try: del self._state.x except AttributeError: @@ -883,33 +1163,39 @@ def __aexit__(self, *args, **kwargs): return self.cm.__aexit__(*args, **kwargs) @staticmethod - def _exit(state): + def _exit(state: _AsyncPolymorphicCMState) -> None: state._update_nesting(-1) state._cleanup_runner() - def __enter__(self, *args, **kwargs): - state = self._get_state() - runner = state._get_runner() + def __enter__(self, *args, **kwargs) -> Any: + state: _AsyncPolymorphicCMState = self._get_state() + runner: Optional[Union[_GenletCoroRunner, + _LoopCoroRunner, + _ThreadCoroRunner]] = state._get_runner() # Increase the nesting count _before_ we start running the # coroutine, in case it is a recursive context manager state._update_nesting(1) try: - coro = self.cm.__aenter__(*args, **kwargs) - return runner.run(coro) + coro: Coroutine = self.cm.__aenter__(*args, **kwargs) + if runner: + return runner.run(coro) except BaseException: self._exit(state) raise - def __exit__(self, *args, **kwargs): - coro = self.cm.__aexit__(*args, **kwargs) + def __exit__(self, *args, **kwargs) -> Any: + coro: Coroutine = self.cm.__aexit__(*args, **kwargs) - state = self._get_state() - runner = state._get_runner() + state: _AsyncPolymorphicCMState = self._get_state() + runner: Optional[Union[_GenletCoroRunner, + _LoopCoroRunner, + _ThreadCoroRunner]] = state._get_runner() try: - return runner.run(coro) + if runner: + return runner.run(coro) finally: self._exit(state) @@ -917,16 +1203,24 @@ def __del__(self): self._get_state()._cleanup_runner(force=True) -def asynccontextmanager(f): +T = TypeVar('T') + + +def asynccontextmanager(f: Callable[..., AsyncGenerator[T, None]]) -> Callable[..., _AsyncPolymorphicCM]: """ Same as :func:`contextlib.asynccontextmanager` except that it can also be used with a regular ``with`` statement for backward compatibility. + + :param f: A callable that returns an asynchronous generator. + :type f: Callable[..., AsyncGenerator[T, None]] + :returns: A context manager supporting both synchronous and asynchronous usage. + :rtype: Callable[..., _AsyncPolymorphicCM] """ - f = contextlib.asynccontextmanager(f) + f_int = contextlib.asynccontextmanager(f) - @functools.wraps(f) - def wrapper(*args, **kwargs): - cm = f(*args, **kwargs) + @functools.wraps(f_int) + def wrapper(*args: Any, **kwargs: Any) -> _AsyncPolymorphicCM: + cm = f_int(*args, **kwargs) return _AsyncPolymorphicCM(cm) return wrapper @@ -935,19 +1229,23 @@ def wrapper(*args, **kwargs): class ConcurrentAccessBase(abc.ABC): """ Abstract Base Class for resources tracked by :func:`concurrently`. + Subclasses must implement the method to determine if two resources overlap. """ @abc.abstractmethod - def overlap_with(self, other): + def overlap_with(self, other: 'ConcurrentAccessBase') -> bool: """ Return ``True`` if the resource overlaps with the given one. :param other: Resources that should not overlap with ``self``. :type other: devlib.utils.asym.ConcurrentAccessBase + :returns: True if the two resources overlap; False otherwise. + :rtype: bool .. note:: It is guaranteed that ``other`` will be a subclass of our class. """ + class PathAccess(ConcurrentAccessBase): """ Concurrent resource representing a file access. @@ -962,19 +1260,29 @@ class PathAccess(ConcurrentAccessBase): for writing. :type mode: str """ - def __init__(self, namespace, path, mode): + def __init__(self, namespace: str, path: str, mode: str): assert namespace in ('host', 'target') self.namespace = namespace assert mode in ('r', 'w') self.mode = mode self.path = os.path.abspath(path) if namespace == 'host' else os.path.normpath(path) - def overlap_with(self, other): + def overlap_with(self, other: ConcurrentAccessBase) -> bool: + """ + Check if this path access overlaps with another access, considering + namespace, mode, and filesystem hierarchy. + + :param other: Another resource access instance. + :type other: ConcurrentAccessBase + :returns: True if the two paths overlap (and one of the accesses is for writing), else False. + :rtype: bool + """ + other_internal = cast('PathAccess', other) path1 = pathlib.Path(self.path).resolve() - path2 = pathlib.Path(other.path).resolve() + path2 = pathlib.Path(other_internal.path).resolve() return ( - self.namespace == other.namespace and - 'w' in (self.mode, other.mode) and + self.namespace == other_internal.namespace and + 'w' in (self.mode, other_internal.mode) and ( path1 == path2 or path1 in path2.parents or @@ -983,6 +1291,12 @@ def overlap_with(self, other): ) def __str__(self): + """ + Return a string representation of the PathAccess, including the path and mode. + + :returns: A string describing the path access. + :rtype: str + """ mode = { 'r': 'read', 'w': 'write', diff --git a/devlib/utils/misc.py b/devlib/utils/misc.py index 1c49d0d0b..9d0e5303c 100644 --- a/devlib/utils/misc.py +++ b/devlib/utils/misc.py @@ -1,4 +1,4 @@ -# Copyright 2013-2024 ARM Limited +# Copyright 2013-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,6 +24,8 @@ from operator import itemgetter from weakref import WeakSet from ruamel.yaml import YAML +from ruamel.yaml.error import YAMLError, MarkedYAMLError +from devlib.utils.annotation_helpers import SubprocessCommand import ctypes import logging @@ -39,28 +41,34 @@ import warnings import wrapt - try: from contextlib import ExitStack except AttributeError: - from contextlib2 import ExitStack + from contextlib2 import ExitStack # type: ignore from shlex import quote -from past.builtins import basestring # pylint: disable=redefined-builtin from devlib.exception import HostError, TimeoutError +from typing import (Union, List, Optional, Tuple, Set, + Any, Callable, Dict, Generator, TYPE_CHECKING, + Type, cast, Pattern) +from typing_extensions import Literal +if TYPE_CHECKING: + from logging import Logger + from tarfile import TarFile, TarInfo + from devlib import Target # ABI --> architectures list -ABI_MAP = { +ABI_MAP: Dict[str, List[str]] = { 'armeabi': ['armeabi', 'armv7', 'armv7l', 'armv7el', 'armv7lh', 'armeabi-v7a'], 'arm64': ['arm64', 'armv8', 'arm64-v8a', 'aarch64'], } # Vendor ID --> CPU part ID --> CPU variant ID --> Core Name # None means variant is not used. -CPU_PART_MAP = { +CPU_PART_MAP: Dict[int, Dict[int, Dict[Optional[int], str]]] = { 0x41: { # ARM 0x926: {None: 'ARM926'}, 0x946: {None: 'ARM946'}, @@ -127,16 +135,34 @@ } -def get_cpu_name(implementer, part, variant): +def get_cpu_name(implementer: int, part: int, variant: int) -> Optional[str]: + """ + Retrieve the CPU name based on implementer, part, and variant IDs using the CPU_PART_MAP. + + :param implementer: The vendor identifier. + :type implementer: int + :param part: The CPU part identifier. + :type part: int + :param variant: The CPU variant identifier. + :type variant: int + :returns: The CPU name if found; otherwise, None. + :rtype: str or None + """ part_data = CPU_PART_MAP.get(implementer, {}).get(part, {}) if None in part_data: # variant does not determine core Name for this vendor - name = part_data[None] + name: Optional[str] = part_data[None] else: name = part_data.get(variant) return name -def preexec_function(): +def preexec_function() -> None: + """ + Set the process group ID for the current process so that a subprocess and all its children + can later be killed together. This function is Unix-specific. + + :raises OSError: If setting the process group fails. + """ # Change process group in case we have to kill the subprocess and all of # its children later. # TODO: this is Unix-specific; would be good to find an OS-agnostic way @@ -144,22 +170,57 @@ def preexec_function(): os.setpgrp() -check_output_logger = logging.getLogger('check_output') +check_output_logger: 'Logger' = logging.getLogger('check_output') + -def get_subprocess(command, **kwargs): +def get_subprocess(command: SubprocessCommand, **kwargs) -> subprocess.Popen: + """ + Launch a subprocess to run the specified command, overriding stdout to PIPE. + The process is set to a new process group via a preexec function. + + :param command: The command to execute. + :type command: SubprocessCommand + :param kwargs: Additional keyword arguments to pass to subprocess.Popen. + :raises ValueError: If 'stdout' is provided in kwargs. + :returns: A subprocess.Popen object running the command. + :rtype: subprocess.Popen + """ if 'stdout' in kwargs: raise ValueError('stdout argument not allowed, it will be overridden.') return subprocess.Popen(command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - stdin=subprocess.PIPE, - preexec_fn=preexec_function, - **kwargs) - + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + stdin=subprocess.PIPE, + preexec_fn=preexec_function, + **kwargs) + + +def check_subprocess_output( + process: subprocess.Popen, + timeout: Optional[float] = None, + ignore: Optional[Union[int, List[int], Literal['all']]] = None, + inputtext: Union[str, bytes, None] = None) -> Tuple[str, str]: + """ + Communicate with the given subprocess and return its decoded output and error streams. + This function handles timeouts and can ignore specified return codes. + + :param process: The subprocess.Popen instance to interact with. + :type process: subprocess.Popen + :param timeout: The maximum time in seconds to wait for the process to complete. + :type timeout: int or None + :param ignore: A return code (or list of codes) to ignore; use "all" to ignore all nonzero codes. + :type ignore: int, list of int, "all", or None + :param inputtext: Optional text or bytes to send to the process's stdin. + :type inputtext: str, bytes, or None + :returns: A tuple (output, error) with decoded strings. + :rtype: (str, str) + :raises ValueError: If the ignore parameter is improperly formatted. + :raises TimeoutError: If the process does not complete before the timeout expires. + :raises subprocess.CalledProcessError: If the process exits with a nonzero code not in ignore. + """ + output: Union[str, bytes] = '' + error: Union[str, bytes] = '' -def check_subprocess_output(process, timeout=None, ignore=None, inputtext=None): - output = None - error = None # pylint: disable=too-many-branches if ignore is None: ignore = [] @@ -170,39 +231,81 @@ def check_subprocess_output(process, timeout=None, ignore=None, inputtext=None): raise ValueError(message.format(ignore)) with process: + timeout_expired: Optional[subprocess.TimeoutExpired] = None try: output, error = process.communicate(inputtext, timeout=timeout) except subprocess.TimeoutExpired as e: timeout_expired = e - else: - timeout_expired = None # Currently errors=replace is needed as 0x8c throws an error - output = output.decode(sys.stdout.encoding or 'utf-8', "replace") if output else '' - error = error.decode(sys.stderr.encoding or 'utf-8', "replace") if error else '' + output = cast(str, output.decode(sys.stdout.encoding or 'utf-8', "replace") if isinstance(output, bytes) else output) + error = cast(str, error.decode(sys.stderr.encoding or 'utf-8', "replace") if isinstance(error, bytes) else error) if timeout_expired: raise TimeoutError(process.args, output='\n'.join([output, error])) - retcode = process.returncode + retcode: int = process.returncode if retcode and ignore != 'all' and retcode not in ignore: raise subprocess.CalledProcessError(retcode, process.args, output, error) return output, error -def check_output(command, timeout=None, ignore=None, inputtext=None, **kwargs): - """This is a version of subprocess.check_output that adds a timeout parameter to kill - the subprocess if it does not return within the specified time.""" +def check_output(command: SubprocessCommand, timeout: Optional[int] = None, + ignore: Optional[Union[int, List[int], Literal['all']]] = None, + inputtext: Union[str, bytes, None] = None, **kwargs) -> Tuple[str, str]: + """ + This is a version of subprocess.check_output that adds a timeout parameter to kill + the subprocess if it does not return within the specified time. + + :param command: The command to execute. + :type command: SubprocessCommand + :param timeout: Time in seconds to wait for the command to complete. + :type timeout: int or None + :param ignore: A return code or list of return codes to ignore, or "all" to ignore all. + :type ignore: int, list of int, "all", or None + :param inputtext: Optional text or bytes to send to the command's stdin. + :type inputtext: str, bytes, or None + :param kwargs: Additional keyword arguments for subprocess.Popen. + :returns: A tuple (stdout, stderr) of the command's decoded output. + :rtype: (str, str) + :raises TimeoutError: If the command does not complete in time. + :raises subprocess.CalledProcessError: If the command fails and its return code is not ignored. + """ process = get_subprocess(command, **kwargs) return check_subprocess_output(process, timeout=timeout, ignore=ignore, inputtext=inputtext) -def walk_modules(path): +class ExtendedHostError(HostError): + """ + Exception class that extends HostError with additional attributes. + + :param message: The error message. + :type message: str + :param module: The name of the module where the error originated. + :type module: str or None + :param exc_info: Exception information from sys.exc_info(). + :type exc_info: Any + :param orig_exc: The original exception that was caught. + :type orig_exc: Exception or None + """ + def __init__(self, message: str, module: Optional[str] = None, + exc_info: Any = None, orig_exc: Optional[Exception] = None): + super().__init__(message) + self.module = module + self.exc_info = exc_info + self.orig_exc = orig_exc + + +def walk_modules(path: str) -> List[types.ModuleType]: """ Given package name, return a list of all modules (including submodules, etc) in that package. + :param path: The package name to walk (e.g., 'mypackage'). + :type path: str + :returns: A list of module objects. + :rtype: list of ModuleType :raises HostError: if an exception is raised while trying to import one of the modules under ``path``. The exception will have addtional attributes set: ``module`` will be set to the qualified name @@ -211,39 +314,54 @@ def walk_modules(path): """ - def __try_import(path): + def __try_import(path: str) -> types.ModuleType: try: return __import__(path, {}, {}, ['']) except Exception as e: he = HostError('Could not load {}: {}'.format(path, str(e))) - he.module = path - he.exc_info = sys.exc_info() - he.orig_exc = e + cast(ExtendedHostError, he).module = path + cast(ExtendedHostError, he).exc_info = sys.exc_info() + cast(ExtendedHostError, he).orig_exc = e raise he - root_mod = __try_import(path) - mods = [root_mod] + root_mod: types.ModuleType = __try_import(path) + mods: List[types.ModuleType] = [root_mod] if not hasattr(root_mod, '__path__'): # root is a module not a package -- nothing to walk return mods for _, name, ispkg in pkgutil.iter_modules(root_mod.__path__): - submod_path = '.'.join([path, name]) + submod_path: str = '.'.join([path, name]) if ispkg: mods.extend(walk_modules(submod_path)) else: - submod = __try_import(submod_path) + submod: types.ModuleType = __try_import(submod_path) mods.append(submod) return mods -def redirect_streams(stdout, stderr, command): + +def redirect_streams(stdout: int, stderr: int, + command: SubprocessCommand) -> Tuple[int, int, SubprocessCommand]: """ - Update a command to redirect a given stream to /dev/null if it's - ``subprocess.DEVNULL``. + Adjust a command string to redirect output streams to specific targets. + If a stream is set to subprocess.DEVNULL, it replaces it with a redirect + to /dev/null; for subprocess.STDOUT, it merges stderr into stdout. + + :param stdout: The desired stdout value. + :type stdout: int + :param stderr: The desired stderr value. + :type stderr: int + :param command: The original command to run. + :type command: SubprocessCommand :return: A tuple (stdout, stderr, command) with stream set to ``subprocess.PIPE`` if the `stream` parameter was set to ``subprocess.DEVNULL``. + :rtype: (int, int, SubprocessCommand) """ - def redirect(stream, redirection): + + def redirect(stream: int, redirection: str) -> Tuple[int, str]: + """ + redirect output and error streams + """ if stream == subprocess.DEVNULL: suffix = '{}/dev/null'.format(redirection) elif stream == subprocess.STDOUT: @@ -259,47 +377,89 @@ def redirect(stream, redirection): stdout, suffix1 = redirect(stdout, '>') stderr, suffix2 = redirect(stderr, '2>') - command = 'sh -c {} {} {}'.format(quote(command), suffix1, suffix2) + command = 'sh -c {} {} {}'.format(quote(cast(str, command)), suffix1, suffix2) return (stdout, stderr, command) -def ensure_directory_exists(dirpath): + +def ensure_directory_exists(dirpath: str) -> str: """A filter for directory paths to ensure they exist.""" if not os.path.isdir(dirpath): os.makedirs(dirpath) return dirpath -def ensure_file_directory_exists(filepath): +def ensure_file_directory_exists(filepath: str) -> str: """ A filter for file paths to ensure the directory of the file exists and the file can be created there. The file itself is *not* going to be created if it doesn't already exist. + :param dirpath: The directory path to check. + :type dirpath: str + :returns: The directory path. + :rtype: str + :raises OSError: If the directory cannot be created """ ensure_directory_exists(os.path.dirname(filepath)) return filepath -def merge_dicts(*args, **kwargs): +def merge_dicts(*args, **kwargs) -> Dict: + """ + Merge multiple dictionaries together. + + :param args: Two or more dictionaries to merge. + :type args: dict + :param kwargs: Additional keyword arguments to pass to the merging function. + :type kwargs: dict + :returns: A new dictionary containing the merged keys and values. + :rtype: dict + :raises ValueError: If fewer than two dictionaries are provided. + """ if not len(args) >= 2: raise ValueError('Must specify at least two dicts to merge.') - func = partial(_merge_two_dicts, **kwargs) + func: partial[Dict] = partial(_merge_two_dicts, **kwargs) return reduce(func, args) -def _merge_two_dicts(base, other, list_duplicates='all', match_types=False, # pylint: disable=R0912,R0914 - dict_type=dict, should_normalize=True, should_merge_lists=True): - """Merge dicts normalizing their keys.""" +def _merge_two_dicts(base: Dict, other: Dict, list_duplicates: str = 'all', + match_types: bool = False, # pylint: disable=R0912,R0914 + dict_type: Type[Dict] = dict, should_normalize: bool = True, + should_merge_lists: bool = True) -> Dict: + """ + Merge two dictionaries recursively, normalizing their keys. The merging behavior + for lists and duplicate keys can be controlled via parameters. + + :param base: The first dictionary. + :type base: dict + :param other: The second dictionary to merge into the first. + :type other: dict + :param list_duplicates: Strategy for handling duplicate list entries ("all", "first", or "last"). + :type list_duplicates: str + :param match_types: If True, enforce that overlapping keys have the same type. + :type match_types: bool + :param dict_type: The dictionary type to use for constructing merged dictionaries. + :type dict_type: Type[dict] + :param should_normalize: If True, normalize keys/values during merge. + :type should_normalize: bool + :param should_merge_lists: If True, merge lists; otherwise, override base list. + :type should_merge_lists: bool + :returns: A merged dictionary. + :rtype: dict + :raises ValueError: If there is a type mismatch for a key when match_types is True. + :raises AssertionError: If an unexpected merge key is encountered. + """ merged = dict_type() base_keys = list(base.keys()) other_keys = list(other.keys()) - norm = normalize if should_normalize else lambda x, y: x + # FIXME - annotate the lambda. type checker is not able to deduce its type + norm: Callable = normalize if should_normalize else lambda x, y: x # type:ignore - base_only = [] - other_only = [] - both = [] - union = [] + base_only: List = [] + other_only: List = [] + both: List = [] + union: List = [] for k in base_keys: if k in other_keys: both.append(k) @@ -345,50 +505,78 @@ def _merge_two_dicts(base, other, list_duplicates='all', match_types=False, # p return merged -def merge_lists(*args, **kwargs): +def merge_lists(*args, **kwargs) -> List: + """ + Merge multiple lists together. + + :param args: Two or more lists to merge. + :type args: list + :param kwargs: Additional keyword arguments to pass to the merging function. + :type kwargs: dict + :returns: A merged list containing the combined items. + :rtype: list + :raises ValueError: If fewer than two lists are provided. + """ if not len(args) >= 2: raise ValueError('Must specify at least two lists to merge.') func = partial(_merge_two_lists, **kwargs) return reduce(func, args) -def _merge_two_lists(base, other, duplicates='all', dict_type=dict): # pylint: disable=R0912 +def _merge_two_lists(base: List, other: List, duplicates: str = 'all', + dict_type: Type[Dict] = dict) -> List: # pylint: disable=R0912 """ Merge lists, normalizing their entries. - parameters: - - :base, other: the two lists to be merged. ``other`` will be merged on - top of base. - :duplicates: Indicates the strategy of handling entries that appear - in both lists. ``all`` will keep occurrences from both - lists; ``first`` will only keep occurrences from - ``base``; ``last`` will only keep occurrences from - ``other``; - - .. note:: duplicate entries that appear in the *same* list + :param base: The base list. + :type base: list + :param other: The list to merge into base. + :type other: list + :param duplicates: Indicates the strategy of handling entries that appear + in both lists. ``all`` will keep occurrences from both + lists; ``first`` will only keep occurrences from + ``base``; ``last`` will only keep occurrences from + ``other``; + :type duplicates: str + + .. note:: duplicate entries that appear in the *same* list will never be removed. - + :param dict_type: The dictionary type to use for normalization if needed. + :type dict_type: Type[dict] + :returns: A merged list with duplicate handling applied. + :rtype: list + :raises ValueError: If an unexpected value is provided for duplicates. """ if not isiterable(base): base = [base] if not isiterable(other): other = [other] if duplicates == 'all': - merged_list = [] - for v in normalize(base, dict_type) + normalize(other, dict_type): + merged_list: List = [] + combined: List = [] + normalized_base = normalize(base, dict_type) + normalized_other = normalize(other, dict_type) + if isinstance(normalized_base, (list, tuple)) and isinstance(normalized_other, (list, tuple)): + combined = list(normalized_base) + list(normalized_other) + elif isinstance(normalized_base, dict) and isinstance(normalized_other, dict): + combined = [normalized_base, normalized_other] + elif isinstance(normalized_base, set) and isinstance(normalized_other, set): + combined = list(normalized_base.union(normalized_other)) + else: + combined = list(normalized_base) + list(normalized_other) + for v in combined: if not _check_remove_item(merged_list, v): merged_list.append(v) return merged_list elif duplicates == 'first': base_norm = normalize(base, dict_type) - merged_list = normalize(base, dict_type) + merged_list = cast(List, normalize(base, dict_type)) for v in base_norm: _check_remove_item(merged_list, v) for v in normalize(other, dict_type): if not _check_remove_item(merged_list, v): if v not in base_norm: - merged_list.append(v) # pylint: disable=no-member + cast(List, merged_list).append(v) # pylint: disable=no-member return merged_list elif duplicates == 'last': other_norm = normalize(other, dict_type) @@ -406,11 +594,19 @@ def _merge_two_lists(base, other, duplicates='all', dict_type=dict): # pylint: 'Must be in {"all", "first", "last"}.') -def _check_remove_item(the_list, item): - """Helper function for merge_lists that implements checking wether an items - should be removed from the list and doing so if needed. Returns ``True`` if - the item has been removed and ``False`` otherwise.""" - if not isinstance(item, basestring): +def _check_remove_item(the_list: List, item: Any) -> bool: + """ + Check whether an item should be removed from a list based on certain criteria. + If the item is a string starting with '~', its unprefixed version is removed from the list. + + :param the_list: The list in which to check for the item. + :type the_list: list + :param item: The item to check. + :type item: Any + :returns: True if the item was removed; False otherwise. + :rtype: bool + """ + if not isinstance(item, str): return False if not item.startswith('~'): return False @@ -420,9 +616,19 @@ def _check_remove_item(the_list, item): return True -def normalize(value, dict_type=dict): - """Normalize values. Recursively normalizes dict keys to be lower case, - no surrounding whitespace, underscore-delimited strings.""" +def normalize(value: Union[Dict, List, Tuple, Set], + dict_type: Type[Dict] = dict) -> Union[Dict, List, Tuple, Set]: + """ + Recursively normalize values by converting dictionary keys to lower-case, + stripping whitespace, and replacing spaces with underscores. + + :param value: A dict, list, tuple, or set to normalize. + :type value: dict, list, tuple, or set + :param dict_type: The dictionary type to use for normalized dictionaries. + :type dict_type: Type[dict] + :returns: A normalized version of the input value. + :rtype: dict, list, tuple, or set + """ if isinstance(value, dict): normalized = dict_type() for k, v in value.items(): @@ -437,12 +643,29 @@ def normalize(value, dict_type=dict): return value -def convert_new_lines(text): - """ Convert new lines to a common format. """ +def convert_new_lines(text: str) -> str: + """ + Convert different newline conventions to a single '\n' format. + + :param text: The input text. + :type text: str + :returns: The text with unified newline characters. + :rtype: str + """ return text.replace('\r\n', '\n').replace('\r', '\n') -def sanitize_cmd_template(cmd): - msg = ( + +def sanitize_cmd_template(cmd: str) -> str: + """ + Replace quoted placeholders with unquoted ones in a command template, + warning the user if quoted placeholders are detected. + + :param cmd: The command template string. + :type cmd: str + :returns: The sanitized command template. + :rtype: str + """ + msg: str = ( '''Quoted placeholder should not be used, as it will result in quoting the text twice. {} should be used instead of '{}' or "{}" in the template: ''' ) for unwanted in ('"{}"', "'{}'"): @@ -452,51 +675,79 @@ def sanitize_cmd_template(cmd): return cmd -def escape_quotes(text): + +def escape_quotes(text: str) -> str: """ - Escape quotes, and escaped quotes, in the specified text. + Escape quotes and escaped quotes in the given text. + + .. note:: It is recommended to use shlex.quote when possible. - .. note:: :func:`shlex.quote` should be favored where possible. + :param text: The text to escape. + :type text: str + :returns: The text with quotes escaped. + :rtype: str """ return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\'', '\\\'').replace('\"', '\\\"') -def escape_single_quotes(text): +def escape_single_quotes(text: str) -> str: """ - Escape single quotes, and escaped single quotes, in the specified text. + Escape single quotes in the provided text. - .. note:: :func:`shlex.quote` should be favored where possible. + .. note:: Prefer using shlex.quote when possible. + + :param text: The text to process. + :type text: str + :returns: The text with single quotes escaped. + :rtype: str """ return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\'', '\'\\\'\'') -def escape_double_quotes(text): +def escape_double_quotes(text: str) -> str: """ - Escape double quotes, and escaped double quotes, in the specified text. + Escape double quotes in the given text. + + .. note:: Prefer using shlex.quote when possible. - .. note:: :func:`shlex.quote` should be favored where possible. + :param text: The input text. + :type text: str + :returns: The text with double quotes escaped. + :rtype: str """ return re.sub(r'\\("|\')', r'\\\\\1', text).replace('\"', '\\\"') -def escape_spaces(text): +def escape_spaces(text: str) -> str: """ - Escape spaces in the specified text + Escape spaces in the provided text. + + .. note:: Prefer using shlex.quote when possible. - .. note:: :func:`shlex.quote` should be favored where possible. + :param text: The text to process. + :type text: str + :returns: The text with spaces escaped. + :rtype: str """ return text.replace(' ', '\\ ') -def getch(count=1): - """Read ``count`` characters from standard input.""" +def getch(count: int = 1) -> str: + """ + Read a specified number of characters from standard input. + + :param count: The number of characters to read. + :type count: int + :returns: A string of characters read from stdin. + :rtype: str + """ if os.name == 'nt': import msvcrt # pylint: disable=F0401 - return ''.join([msvcrt.getch() for _ in range(count)]) + return ''.join([msvcrt.getch() for _ in range(count)]) # type:ignore else: # assume Unix import tty # NOQA import termios # NOQA - fd = sys.stdin.fileno() + fd: int = sys.stdin.fileno() old_settings = termios.tcgetattr(fd) try: tty.setraw(sys.stdin.fileno()) @@ -506,45 +757,81 @@ def getch(count=1): return ch -def isiterable(obj): - """Returns ``True`` if the specified object is iterable and - *is not a string type*, ``False`` otherwise.""" - return hasattr(obj, '__iter__') and not isinstance(obj, basestring) +def isiterable(obj: Any) -> bool: + """ + Determine if the provided object is iterable, excluding strings. + + :param obj: The object to test. + :type obj: Any + :returns: True if the object is iterable and is not a string; otherwise, False. + :rtype: bool + """ + return hasattr(obj, '__iter__') and not isinstance(obj, str) -def as_relative(path): - """Convert path to relative by stripping away the leading '/' on UNIX or - the equivant on other platforms.""" +def as_relative(path: str) -> str: + """ + Convert an absolute path to a relative path by removing leading separators. + + :param path: The absolute path. + :type path: str + :returns: A relative path. + :rtype: str + """ path = os.path.splitdrive(path)[1] return path.lstrip(os.sep) -def commonprefix(file_list, sep=os.sep): +def commonprefix(file_list: List[str], sep: str = os.sep) -> str: """ - Find the lowest common base folder of a passed list of files. + Determine the lowest common base folder among a list of file paths. + + :param file_list: A list of file paths. + :type file_list: list of str + :param sep: The path separator to use. + :type sep: str + :returns: The common prefix path. + :rtype: str """ - common_path = os.path.commonprefix(file_list) - cp_split = common_path.split(sep) - other_split = file_list[0].split(sep) - last = len(cp_split) - 1 + common_path: str = os.path.commonprefix(file_list) + cp_split: List[str] = common_path.split(sep) + other_split: List[str] = file_list[0].split(sep) + last: int = len(cp_split) - 1 if cp_split[last] != other_split[last]: cp_split = cp_split[:-1] return sep.join(cp_split) -def get_cpu_mask(cores): - """Return a string with the hex for the cpu mask for the specified core numbers.""" +def get_cpu_mask(cores: List[int]) -> str: + """ + Compute a hexadecimal CPU mask for the specified core indices. + + :param cores: A list of core numbers. + :type cores: list of int + :returns: A hexadecimal string representing the CPU mask. + :rtype: str + """ mask = 0 for i in cores: mask |= 1 << i return '0x{0:x}'.format(mask) -def which(name): - """Platform-independent version of UNIX which utility.""" +def which(name: str) -> Optional[str]: + """ + Find the full path to an executable by searching the system PATH. + Provides a platform-independent implementation of the UNIX 'which' utility. + + :param name: The name of the executable to find. + :type name: str + :returns: The full path to the executable if found, otherwise None. + :rtype: str or None + """ if os.name == 'nt': - paths = os.getenv('PATH').split(os.pathsep) - exts = os.getenv('PATHEXT').split(os.pathsep) + path_env = os.getenv('PATH') + pathext_env = os.getenv('PATHEXT') + paths: List[str] = path_env.split(os.pathsep) if path_env else [] + exts: List[str] = pathext_env.split(os.pathsep) if pathext_env else [] for path in paths: testpath = os.path.join(path, name) if os.path.isfile(testpath): @@ -562,13 +849,22 @@ def which(name): # This matches most ANSI escape sequences, not just colors -_bash_color_regex = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]') +_bash_color_regex: Pattern[str] = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]') + + +def strip_bash_colors(text: str) -> str: + """ + Remove ANSI escape sequences (commonly used for terminal colors) from the given text. -def strip_bash_colors(text): + :param text: The input string potentially containing ANSI escape sequences. + :type text: str + :returns: The input text with all ANSI escape sequences removed. + :rtype: str + """ return _bash_color_regex.sub('', text) -def get_random_string(length): +def get_random_string(length: int) -> str: """Returns a random ASCII string of the specified length).""" return ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(length)) @@ -581,7 +877,7 @@ def message(self): return self.args[0] return str(self) - def __init__(self, message, filepath, lineno): + def __init__(self, message: str, filepath: str, lineno: Optional[int]): super(LoadSyntaxError, self).__init__(message) self.filepath = filepath self.lineno = lineno @@ -591,7 +887,7 @@ def __str__(self): return message.format(self.filepath, self.lineno, self.message) -def load_struct_from_yaml(filepath): +def load_struct_from_yaml(filepath: str) -> Dict: """ Parses a config structure from a YAML file. The structure should be composed of basic Python types. @@ -609,18 +905,18 @@ def load_struct_from_yaml(filepath): yaml = YAML(typ='safe', pure=True) with open(filepath, 'r', encoding='utf-8') as file_handler: return yaml.load(file_handler) - except yaml.YAMLError as ex: - message = ex.message if hasattr(ex, 'message') else '' - lineno = ex.problem_mark.line if hasattr(ex, 'problem_mark') else None + except YAMLError as ex: + message = str(ex) + lineno = cast(MarkedYAMLError, ex).problem_mark.line if hasattr(ex, 'problem_mark') else None raise LoadSyntaxError(message, filepath=filepath, lineno=lineno) from ex -RAND_MOD_NAME_LEN = 30 -BAD_CHARS = string.punctuation + string.whitespace -TRANS_TABLE = str.maketrans(BAD_CHARS, '_' * len(BAD_CHARS)) +RAND_MOD_NAME_LEN: int = 30 +BAD_CHARS: str = string.punctuation + string.whitespace +TRANS_TABLE: Dict[int, int] = str.maketrans(BAD_CHARS, '_' * len(BAD_CHARS)) -def to_identifier(text): +def to_identifier(text: str) -> str: """Converts text to a valid Python identifier by replacing all whitespace and punctuation and adding a prefix if starting with a digit""" if text[:1].isdigit(): @@ -628,7 +924,7 @@ def to_identifier(text): return re.sub('_+', '_', str(text).translate(TRANS_TABLE)) -def unique(alist): +def unique(alist: List) -> List: """ Returns a list containing only unique elements from the input list (but preserves order, unlike sets). @@ -641,9 +937,9 @@ def unique(alist): return result -def ranges_to_list(ranges_string): +def ranges_to_list(ranges_string: str) -> List[int]: """Converts a sysfs-style ranges string, e.g. ``"0,2-4"``, into a list ,e.g ``[0,2,3,4]``""" - values = [] + values: List[int] = [] for rg in ranges_string.split(','): if '-' in rg: first, last = list(map(int, rg.split('-'))) @@ -653,13 +949,13 @@ def ranges_to_list(ranges_string): return values -def list_to_ranges(values): +def list_to_ranges(values: List) -> str: """Converts a list, e.g ``[0,2,3,4]``, into a sysfs-style ranges string, e.g. ``"0,2-4"``""" values = sorted(values) range_groups = [] for _, g in groupby(enumerate(values), lambda i_x: i_x[0] - i_x[1]): range_groups.append(list(map(itemgetter(1), g))) - range_strings = [] + range_strings: List[str] = [] for group in range_groups: if len(group) == 1: range_strings.append(str(group[0])) @@ -668,7 +964,7 @@ def list_to_ranges(values): return ','.join(range_strings) -def list_to_mask(values, base=0x0): +def list_to_mask(values: List[int], base: int = 0x0) -> int: """Converts the specified list of integer values into a bit mask for those values. Optinally, the list can be applied to an existing mask.""" @@ -677,7 +973,7 @@ def list_to_mask(values, base=0x0): return base -def mask_to_list(mask): +def mask_to_list(mask: int) -> List[int]: """Converts the specfied integer bitmask into a list of indexes of bits that are set in the mask.""" size = len(bin(mask)) - 2 # because of "0b" @@ -685,27 +981,33 @@ def mask_to_list(mask): if mask & (1 << size - i - 1)] -__memo_cache = {} +__memo_cache: Dict[str, Any] = {} -def reset_memo_cache(): +def reset_memo_cache() -> None: + """ + Clear the global memoization cache used for caching function results. + + :returns: None + :rtype: None + """ __memo_cache.clear() -def __get_memo_id(obj): +def __get_memo_id(obj: object) -> str: """ An object's id() may be re-used after an object is freed, so it's not sufficiently unique to identify params for the memo cache (two different params may end up with the same id). this attempts to generate a more unique ID string. """ - obj_id = id(obj) + obj_id: int = id(obj) try: return '{}/{}'.format(obj_id, hash(obj)) except TypeError: # obj is not hashable obj_pyobj = ctypes.cast(obj_id, ctypes.py_object) # TODO: Note: there is still a possibility of a clash here. If Two - # different objects get assigned the same ID, an are large and are + # different objects get assigned the same ID, and are large and are # identical in the first thirty two bytes. This shouldn't be much of an # issue in the current application of memoizing Target calls, as it's very # unlikely that a target will get passed large params; but may cause @@ -715,24 +1017,38 @@ def __get_memo_id(obj): # undesirable impact on performance. num_bytes = min(ctypes.sizeof(obj_pyobj), 32) obj_bytes = ctypes.string_at(ctypes.addressof(obj_pyobj), num_bytes) - return '{}/{}'.format(obj_id, obj_bytes) + return '{}/{}'.format(obj_id, cast(str, obj_bytes)) -@wrapt.decorator -def memoized(wrapped, instance, args, kwargs): # pylint: disable=unused-argument +def memoized_decor(wrapped: Callable[..., Any], instance: Optional[Any], + args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: # pylint: disable=unused-argument """ - A decorator for memoizing functions and methods. + Decorator helper function for memoizing the results of a function call. + The result is cached based on a key derived from the function's arguments. + Note that this method does not account for changes to mutable arguments. .. warning:: this may not detect changes to mutable types. As long as the memoized function was used with an object as an argument before, the cached result will be returned, even if the structure of the object (e.g. a list) has changed in the mean time. + :param wrapped: The function to be memoized. + :type wrapped: Callable[..., Any] + :param instance: The instance on which the function is called (if it is a method), or None. + :type instance: Any or None + :param args: Tuple of positional arguments passed to the function. + :type args: Tuple[Any, ...] + :param kwargs: Dictionary of keyword arguments passed to the function. + :type kwargs: Dict[str, Any] + :returns: The cached result if available; otherwise, the result from calling the function. + :rtype: Any + :raises Exception: Any exception raised during the execution of the wrapped function is propagated. + """ - func_id = repr(wrapped) + func_id: str = repr(wrapped) - def memoize_wrapper(*args, **kwargs): - id_string = func_id + ','.join([__get_memo_id(a) for a in args]) + def memoize_wrapper(*args, **kwargs) -> Dict[str, Any]: + id_string: str = func_id + ','.join([__get_memo_id(a) for a in args]) id_string += ','.join('{}={}'.format(k, __get_memo_id(v)) for k, v in kwargs.items()) if id_string not in __memo_cache: @@ -741,13 +1057,19 @@ def memoize_wrapper(*args, **kwargs): return memoize_wrapper(*args, **kwargs) + +# create memoized decorator from memoized_decor function +memoized = wrapt.decorator(memoized_decor) + + @contextmanager -def batch_contextmanager(f, kwargs_list): +def batch_contextmanager(f: Callable, kwargs_list: List[Dict[str, Any]]) -> Generator: """ Return a context manager that will call the ``f`` callable with the keyword arguments dict in the given list, in one go. :param f: Callable expected to return a context manager. + :type f: callable :param kwargs_list: list of kwargs dictionaries to be used to call ``f``. :type kwargs_list: list(dict) @@ -770,7 +1092,8 @@ class nullcontext: statement, or `None` if nothing is specified. :type enter_result: object """ - def __init__(self, enter_result=None): + + def __init__(self, enter_result: Any = None): self.enter_result = enter_result def __enter__(self): @@ -797,21 +1120,51 @@ class tls_property: to that object, like :meth:`_BoundTLSProperty.get_all_values`. Values can be set and deleted as well, which will be a thread-local set. + + :param factory: A callable used to generate the property value. + :type factory: Callable """ @property - def name(self): + def name(self) -> str: + """ + Retrieve the name of the factory function used for this property. + + :returns: The name of the factory function. + :rtype: str + """ return self.factory.__name__ - def __init__(self, factory): + def __init__(self, factory: Callable): self.factory = factory # Lock accesses to shared WeakKeyDictionary and WeakSet self.lock = threading.RLock() - def __get__(self, instance, owner=None): + def __get__(self, instance: 'Target', owner: Optional[Type['Target']] = None) -> '_BoundTLSProperty': + """ + Retrieve the thread-local property proxy for the given instance. + + :param instance: The target instance. + :type instance: Target + :param owner: The class owning the property (optional). + :type owner: Type[Target] or None + :returns: A bound TLS property proxy. + :rtype: _BoundTLSProperty + """ return _BoundTLSProperty(self, instance, owner) - def _get_value(self, instance, owner): + def _get_value(self, instance: 'Target', owner: Optional[Type['Target']]) -> Any: + """ + Retrieve or compute the thread-local value for the given instance. If the value + does not exist, it is created using the factory callable. + + :param instance: The target instance. + :type instance: Target + :param owner: The class owning the property (optional). + :type owner: Type[Target] or None + :returns: The thread-local value. + :rtype: Any + """ tls, values = self._get_tls(instance) try: return tls.value @@ -826,20 +1179,44 @@ def _get_value(self, instance, owner): values.add(obj) return obj - def _get_all_values(self, instance, owner): + def _get_all_values(self, instance: 'Target', owner: Optional[Type['Target']]) -> Set: + """ + Retrieve all thread-local values currently cached for this property in the given instance. + + :param instance: The target instance. + :type instance: Target + :param owner: The class owning the property (optional). + :type owner: Type[Target] or None + :returns: A set containing all cached values. + :rtype: set + """ with self.lock: # Grab a reference to all the objects at the time of the call by # using a regular set tls, values = self._get_tls(instance=instance) return set(values) - def __set__(self, instance, value): + def __set__(self, instance: 'Target', value): + """ + Set the thread-local value for this property on the given instance. + + :param instance: The target instance. + :type instance: Target + :param value: The value to set. + :type value: Any + """ tls, values = self._get_tls(instance) tls.value = value with self.lock: values.add(value) - def __delete__(self, instance): + def __delete__(self, instance: 'Target'): + """ + Delete the thread-local value for this property from the given instance. + + :param instance: The target instance. + :type instance: Target + """ tls, values = self._get_tls(instance) with self.lock: try: @@ -850,7 +1227,16 @@ def __delete__(self, instance): values.discard(value) del tls.value - def _get_tls(self, instance): + def _get_tls(self, instance: 'Target') -> Any: + """ + Retrieve the thread-local storage tuple for this property from the instance. + If not present, a new tuple is created and stored. + + :param instance: The target instance. + :type instance: Target + :returns: A tuple (tls, values) where tls is a thread-local object and values is a WeakSet. + :rtype: tuple + """ dct = instance.__dict__ name = self.name try: @@ -868,40 +1254,62 @@ def _get_tls(self, instance): return tls @property - def basic_property(self): + def basic_property(self) -> property: """ Return a basic property that can be used to access the TLS value without having to call it first. The drawback is that it's not possible to do anything over than getting/setting/deleting. + + :returns: A property object for direct access. + :rtype: property """ + def getter(instance, owner=None): prop = self.__get__(instance, owner) return prop() return property(getter, self.__set__, self.__delete__) + class _BoundTLSProperty: """ Simple proxy object to allow either calling it to get the TLS value, or get some other informations by calling methods. + + :param tls_property: The tls_property descriptor. + :type tls_property: tls_property + :param instance: The target instance to which the property is bound. + :type instance: Target + :param owner: The owning class (optional). + :type owner: Type[Target] or None """ - def __init__(self, tls_property, instance, owner): + + def __init__(self, tls_property: tls_property, instance: 'Target', owner: Optional[Type['Target']]): self.tls_property = tls_property self.instance = instance self.owner = owner def __call__(self): + """ + Retrieve the thread-local value by calling the underlying tls_property. + + :returns: The thread-local value. + :rtype: Any + """ return self.tls_property._get_value( instance=self.instance, owner=self.owner, ) - def get_all_values(self): + def get_all_values(self) -> Set[Any]: """ Returns all the thread-local values currently in use in the process for that property for that instance. + + :returns: A set of all thread-local values. + :rtype: set """ return self.tls_property._get_all_values( instance=self.instance, @@ -920,9 +1328,25 @@ class InitCheckpointMeta(type): ``is_in_use`` is set to ``True`` when an instance method is being called. This allows to detect reentrance. """ - def __new__(metacls, name, bases, dct, **kwargs): + + def __new__(metacls, name: str, bases: Tuple, dct: Dict, **kwargs: Dict) -> Type: + """ + Create a new class with the augmented __init__ and methods for tracking initialization + and usage. + + :param name: The name of the new class. + :type name: str + :param bases: Base classes for the new class. + :type bases: tuple + :param dct: Dictionary of attributes for the new class. + :type dct: dict + :param kwargs: Additional keyword arguments. + :type kwargs: dict + :returns: The newly created class. + :rtype: type + """ cls = super().__new__(metacls, name, bases, dct, **kwargs) - init_f = cls.__init__ + init_f = cls.__init__ # type:ignore @wraps(init_f) def init_wrapper(self, *args, **kwargs): @@ -949,7 +1373,7 @@ def init_wrapper(self, *args, **kwargs): return x - cls.__init__ = init_wrapper + cls.__init__ = init_wrapper # type:ignore # Set the is_in_use attribute to allow external code to detect if the # methods are about to be re-entered. @@ -977,8 +1401,8 @@ def wrapper(self, *args, **kwargs): # Only wrap the methods (exposed as functions), not things like # classmethod or staticmethod if ( - name not in ('__init__', '__new__') and - isinstance(attr, types.FunctionType) + name not in ('__init__', '__new__') and + isinstance(attr, types.FunctionType) ): setattr(cls, name, make_wrapper(attr)) elif isinstance(attr, property): @@ -1000,7 +1424,7 @@ class InitCheckpoint(metaclass=InitCheckpointMeta): pass -def groupby_value(dct): +def groupby_value(dct: Dict[Any, Any]) -> Dict[Tuple[Any, ...], Any]: """ Process the input dict such that all keys sharing the same values are grouped in a tuple, used as key in the returned dict. @@ -1013,7 +1437,8 @@ def groupby_value(dct): } -def safe_extract(tar, path=".", members=None, *, numeric_owner=False): +def safe_extract(tar: 'TarFile', path: str = ".", members: Optional[List['TarInfo']] = None, + *, numeric_owner: bool = False) -> None: """ A wrapper around TarFile.extract all to mitigate CVE-2007-4995 (see https://www.trellix.com/en-us/about/newsroom/stories/research/tarfile-exploiting-the-world.html) @@ -1026,8 +1451,8 @@ def safe_extract(tar, path=".", members=None, *, numeric_owner=False): tar.extractall(path, members, numeric_owner=numeric_owner) -def _is_within_directory(directory, target): +def _is_within_directory(directory: str, target: str) -> bool: abs_directory = os.path.abspath(directory) abs_target = os.path.abspath(target) diff --git a/devlib/utils/rendering.py b/devlib/utils/rendering.py index 52d4f00dc..78961b3bf 100644 --- a/devlib/utils/rendering.py +++ b/devlib/utils/rendering.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,9 +24,11 @@ from shlex import quote # pylint: disable=redefined-builtin -from devlib.exception import WorkerThreadError, TargetNotRespondingError, TimeoutError +from devlib.exception import WorkerThreadError, TargetNotRespondingError, TimeoutError from devlib.utils.csvutil import csvwriter - +from typing import List, Optional, TYPE_CHECKING, cast +if TYPE_CHECKING: + from devlib.target import Target logger = logging.getLogger('rendering') @@ -38,12 +40,12 @@ class FrameCollector(threading.Thread): - def __init__(self, target, period): + def __init__(self, target: 'Target', period: int): super(FrameCollector, self).__init__() self.target = target self.period = period self.stop_signal = threading.Event() - self.frames = [] + self.frames: List = [] self.temp_file = None self.refresh_period = None @@ -51,7 +53,7 @@ def __init__(self, target, period): self.unresponsive_count = 0 self.last_ready_time = 0 self.exc = None - self.header = None + self.header: Optional[List[str]] = None def run(self): logger.debug('Frame data collection started.') @@ -95,17 +97,18 @@ def process_frames(self, outfile=None): os.unlink(self.temp_file) self.temp_file = None - def write_frames(self, outfile, columns=None): + def write_frames(self, outfile, columns: Optional[List[str]] = None): if columns is None: header = self.header frames = self.frames else: - indexes = [] + indexes: List = [] for c in columns: - if c not in self.header: - msg = 'Invalid column "{}"; must be in {}' - raise ValueError(msg.format(c, self.header)) - indexes.append(self.header.index(c)) + if self.header: + if c not in self.header: + msg = 'Invalid column "{}"; must be in {}' + raise ValueError(msg.format(c, self.header)) + indexes.append(self.header.index(c)) frames = [[f[i] for i in indexes] for f in self.frames] header = columns with csvwriter(outfile) as writer: @@ -128,7 +131,7 @@ class SurfaceFlingerFrameCollector(FrameCollector): def __init__(self, target, period, view, header=None): super(SurfaceFlingerFrameCollector, self).__init__(target, period) self.view = view - self.header = header or SurfaceFlingerFrame._fields + self.header = cast(List[str], header or SurfaceFlingerFrame._fields) def collect_frames(self, wfh): activities = [a for a in self.list() if a.startswith(self.view)] @@ -180,7 +183,7 @@ def _process_trace_parts(self, parts): if len(parts) == 3: frame = SurfaceFlingerFrame(*parts) if not frame.frame_ready_time: - return # "null" frame + return # "null" frame if frame.frame_ready_time <= self.last_ready_time: return # duplicate frame if (frame.frame_ready_time - frame.desired_present_time) > self.drop_threshold: @@ -196,8 +199,8 @@ def _process_trace_parts(self, parts): logger.warning(msg) -def read_gfxinfo_columns(target): - output = target.execute('dumpsys gfxinfo --list framestats') +def read_gfxinfo_columns(target: 'Target') -> List[str]: + output: str = target.execute('dumpsys gfxinfo --list framestats') lines = iter(output.split('\n')) for line in lines: if line.startswith('---PROFILEDATA---'): @@ -222,7 +225,7 @@ def collect_frames(self, wfh): def clear(self): pass - def _init_header(self, header): + def _init_header(self, header: Optional[List[str]]): if header is not None: self.header = header else: diff --git a/devlib/utils/serial_port.py b/devlib/utils/serial_port.py index c4915a959..4db00eaad 100644 --- a/devlib/utils/serial_port.py +++ b/devlib/utils/serial_port.py @@ -1,4 +1,4 @@ -# Copyright 2013-2024 ARM Limited +# Copyright 2013-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ from pexpect import fdpexpect # pexpect < 4.0.0 does not have fdpexpect module except ImportError: - import fdpexpect + import fdpexpect # type:ignore # Adding pexpect exceptions into this module's namespace @@ -32,6 +32,8 @@ from devlib.exception import HostError +from typing import Optional, TextIO, Union, Tuple, Generator + class SerialLogger(Logger): @@ -41,17 +43,22 @@ def flush(self): pass -def pulse_dtr(conn, state=True, duration=0.1): +def pulse_dtr(conn: serial.Serial, state: bool = True, duration: float = 0.1) -> None: """Set the DTR line of the specified serial connection to the specified state for the specified duration (note: the initial state of the line is *not* checked.""" - conn.setDTR(state) + conn.dtr = state time.sleep(duration) - conn.setDTR(not state) + conn.dtr = not state # pylint: disable=keyword-arg-before-vararg -def get_connection(timeout, init_dtr=None, logcls=SerialLogger, - logfile=None, *args, **kwargs): +def get_connection(timeout: int, init_dtr: Optional[bool] = None, + logcls=SerialLogger, + logfile: Optional[TextIO] = None, *args, **kwargs) -> Tuple[fdpexpect.fdspawn, + serial.Serial]: + """ + get the serial connection + """ if init_dtr is not None: kwargs['dsrdtr'] = True try: @@ -59,10 +66,10 @@ def get_connection(timeout, init_dtr=None, logcls=SerialLogger, except serial.SerialException as e: raise HostError(str(e)) if init_dtr is not None: - conn.setDTR(init_dtr) + conn.dtr = init_dtr conn.nonblocking() - conn.flushOutput() - target = fdpexpect.fdspawn(conn.fileno(), timeout=timeout, logfile=logfile) + conn.reset_output_buffer() + target: fdpexpect.fdspawn = fdpexpect.fdspawn(conn.fileno(), timeout=timeout, logfile=logfile) target.logfile_read = logcls('read') target.logfile_send = logcls('send') @@ -73,15 +80,16 @@ def get_connection(timeout, init_dtr=None, logcls=SerialLogger, # corruption. The delay prevents that. tsln = target.sendline - def sendline(x): - tsln(x) + def sendline(s: Union[str, bytes]) -> int: + ret: int = tsln(s) time.sleep(0.1) + return ret target.sendline = sendline return target, conn -def write_characters(conn, line, delay=0.05): +def write_characters(conn: fdpexpect.fdspawn, line: str, delay: float = 0.05) -> None: """Write a single line out to serial charcter-by-character. This will ensure that nothing will be dropped for longer lines.""" line = line.rstrip('\r\n') @@ -93,8 +101,10 @@ def write_characters(conn, line, delay=0.05): # pylint: disable=keyword-arg-before-vararg @contextmanager -def open_serial_connection(timeout, get_conn=False, init_dtr=None, - logcls=SerialLogger, *args, **kwargs): +def open_serial_connection(timeout: int, get_conn: bool = False, + init_dtr: Optional[bool] = None, + logcls=SerialLogger, *args, **kwargs) -> Generator[Union[Tuple[fdpexpect.fdspawn, serial.Serial], + fdpexpect.fdspawn], None, None]: """ Opens a serial connection to a device. @@ -112,11 +122,11 @@ def open_serial_connection(timeout, get_conn=False, init_dtr=None, See: http://pexpect.sourceforge.net/pexpect.html """ - target, conn = get_connection(timeout, init_dtr=init_dtr, - logcls=logcls, *args, **kwargs) + target, conn = get_connection(timeout, init_dtr, + logcls, *args, **kwargs) if get_conn: - target_and_conn = (target, conn) + target_and_conn: Union[Tuple[fdpexpect.fdspawn, serial.Serial], fdpexpect.fdspawn] = (target, conn) else: target_and_conn = target diff --git a/devlib/utils/ssh.py b/devlib/utils/ssh.py index 9fe4c613a..709e8b277 100644 --- a/devlib/utils/ssh.py +++ b/devlib/utils/ssh.py @@ -1,4 +1,4 @@ -# Copyright 2014-2024 ARM Limited +# Copyright 2014-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ # limitations under the License. # - import os import stat import logging @@ -31,13 +30,6 @@ import functools import shutil from shlex import quote - -from paramiko.client import SSHClient, AutoAddPolicy, RejectPolicy -import paramiko.ssh_exception -from scp import SCPClient -# By default paramiko is very verbose, including at the INFO level -logging.getLogger("paramiko").setLevel(logging.WARNING) - # pylint: disable=import-error,wrong-import-position,ungrouped-imports,wrong-import-order import pexpect @@ -45,7 +37,11 @@ from pexpect import pxssh # pexpect < 4.0.0 does not have a pxssh module except ImportError: - import pxssh + import pxssh # type: ignore +from paramiko.client import SSHClient, AutoAddPolicy, RejectPolicy, MissingHostKeyPolicy +import paramiko.ssh_exception +from scp import SCPClient + from pexpect import EOF, TIMEOUT, spawn @@ -58,46 +54,106 @@ from devlib.utils.misc import (which, strip_bash_colors, check_output, sanitize_cmd_template, memoized, redirect_streams) from devlib.utils.types import boolean -from devlib.connection import ConnectionBase, ParamikoBackgroundCommand, SSHTransferHandle +from devlib.connection import ConnectionBase, ParamikoBackgroundCommand, SSHTransferHandle, TransferManager +from typing import (Optional, TYPE_CHECKING, Tuple, cast, + Callable, Union, IO, List, Sized, Dict, + Pattern, Generator, Type, Any) +from typing_extensions import Literal +from io import BufferedReader, BufferedWriter +if TYPE_CHECKING: + from devlib.utils.annotation_helpers import SubprocessCommand + from devlib.platform import Platform + from paramiko.transport import Transport + from paramiko.channel import Channel, ChannelStderrFile, ChannelFile, ChannelStdinFile + from paramiko.sftp_client import SFTPClient + from logging import Logger + from subprocess import Popen +# By default paramiko is very verbose, including at the INFO level +logging.getLogger("paramiko").setLevel(logging.WARNING) # Empty prompt with -p '' to avoid adding a leading space to the output. DEFAULT_SSH_SUDO_COMMAND = "sudo -k -p '' -S -- sh -c {}" +""" +Default command template for acquiring sudo privileges over SSH. +""" + +OutStreamType = Tuple[Union[Optional['BufferedReader'], int], Union[Optional['BufferedWriter'], int, bytes]] +""" +Represents a pair of read-end and write-end streams used for background command output. +""" + +ChannelFiles = Tuple['ChannelStdinFile', 'ChannelFile', 'ChannelStderrFile'] +""" +Represents a triple of Paramiko channel file objects for stdin, stdout, stderr. +""" class _SSHEnv: + """ + Provides resolved paths to SSH-related utilities. + + The main usage includes: + - ``ssh`` for connecting to remote hosts, + - ``scp`` for file transfers, + - ``sshpass`` if password authentication is needed. + + The paths are discovered on the host system using :func:`which`. + """ @functools.lru_cache(maxsize=None) - def get_path(self, tool): + def get_path(self, tool: str) -> str: + """ + Return the full path to the specified ``tool`` (one of ``ssh``, ``scp``, or ``sshpass``). + + :param tool: Name of the executable to look for. + :type tool: str + :returns: The full path to the requested tool. + :rtype: str + :raises HostError: If the tool cannot be found in PATH. + """ if tool in {'ssh', 'scp', 'sshpass'}: - path = which(tool) + path: Optional[str] = which(tool) if path: return path else: raise HostError(f'OpenSSH must be installed on the host: could not find {tool} command') else: raise AttributeError(f"Tool '{tool}' is not supported") + + _SSH_ENV = _SSHEnv() -logger = logging.getLogger('ssh') -gem5_logger = logging.getLogger('gem5-connection') +logger: 'Logger' = logging.getLogger('ssh') +gem5_logger: 'Logger' = logging.getLogger('gem5-connection') @contextlib.contextmanager -def _handle_paramiko_exceptions(command=None): +def _handle_paramiko_exceptions(command: Optional['SubprocessCommand'] = None) -> Generator: + """ + A context manager that catches exceptions from Paramiko calls, raising devlib-friendly + exceptions where appropriate. + + :param command: Optional command string for context in exception messages. + :type command: SubprocessCommand or None + :raises TargetNotRespondingError: If connection issues are detected. + :raises TargetStableError: If there is an SSH logic or host key error. + :raises TargetTransientError: If an SSH logic error suggests a transient condition. + :raises TimeoutError: If a socket timeout occurs. + """ try: yield except paramiko.ssh_exception.NoValidConnectionsError as e: raise TargetNotRespondingError('Connection lost: {}'.format(e)) - except paramiko.ssh_exception.AuthenticationException as e: - raise TargetStableError('Could not authenticate: {}'.format(e)) except paramiko.ssh_exception.BadAuthenticationType as e: raise TargetStableError('Bad authentication type: {}'.format(e)) + except paramiko.ssh_exception.PasswordRequiredException as e: + raise TargetStableError('Please unlock the private key file: {}'.format(e)) + except paramiko.ssh_exception.AuthenticationException as e: + raise TargetStableError('Could not authenticate: {}'.format(e)) except paramiko.ssh_exception.BadHostKeyException as e: raise TargetStableError('Bad host key: {}'.format(e)) except paramiko.ssh_exception.ChannelException as e: raise TargetStableError('Could not open an SSH channel: {}'.format(e)) - except paramiko.ssh_exception.PasswordRequiredException as e: - raise TargetStableError('Please unlock the private key file: {}'.format(e)) except paramiko.ssh_exception.ProxyCommandFailure as e: raise TargetStableError('Proxy command failure: {}'.format(e)) except paramiko.ssh_exception.SSHException as e: @@ -106,7 +162,30 @@ def _handle_paramiko_exceptions(command=None): raise TimeoutError(command, output=None) -def _read_paramiko_streams(stdout, stderr, select_timeout, callback, init, chunk_size=int(1e42)): +def _read_paramiko_streams(stdout: 'ChannelFile', stderr: 'ChannelStderrFile', + select_timeout: Optional[float], callback: Callable, + init: List[bytes], chunk_size=int(1e42)) -> Tuple[Optional[List[bytes]], int]: + """ + Read data from Paramiko's stdout/stderr streams until the channel closes. + Applies an optional callback to each chunk read for each stream. + + :param stdout: Paramiko file-like object for stdout. + :type stdout: ChannelFile + :param stderr: Paramiko file-like object for stderr. + :type stderr: ChannelStderrFile + :param select_timeout: Maximum time (seconds) to block when reading from the channel. + :type select_timeout: float or None + :param callback: A function receiving (callback_state, 'stdout' or 'stderr', chunk). + Must return the new callback_state for subsequent calls. + :type callback: Callable + :param init: Initial callback state. + :type init: list + :param chunk_size: Maximum chunk size in bytes for each read. Defaults to a large integer. + :type chunk_size: int + :returns: A tuple of (final_callback_state, exit_code). + :rtype: (list or None, int) + :raises Exception: If the callback itself raises an exception. + """ try: return _read_paramiko_streams_internal(stdout, stderr, select_timeout, callback, init, chunk_size) finally: @@ -118,11 +197,21 @@ def _read_paramiko_streams(stdout, stderr, select_timeout, callback, init, chunk stdout.channel.close() -def _read_paramiko_streams_internal(stdout, stderr, select_timeout, callback, init, chunk_size): +def _read_paramiko_streams_internal(stdout: 'ChannelFile', stderr: 'ChannelStderrFile', + select_timeout: Optional[float], + callback: Callable[[Optional[List[bytes]], str, bytes], List[bytes]], + init: Optional[List[bytes]], chunk_size: int) -> Tuple[Optional[List[bytes]], int]: + """ + Internal helper for :func:`_read_paramiko_streams`. + """ channel = stdout.channel assert stdout.channel is stderr.channel - def read_channel(callback_state): + def read_channel(callback_state: Optional[List[bytes]]) -> Tuple[Optional[Exception], Optional[List[bytes]]]: + """ + read data from the channel, stdout or stderr + """ + read_list: List['Channel'] read_list, _, _ = select.select([channel], [], [], select_timeout) for desc in read_list: for ready, recv, name in ( @@ -130,7 +219,7 @@ def read_channel(callback_state): (desc.recv_stderr_ready(), desc.recv_stderr, 'stderr') ): if ready: - chunk = recv(chunk_size) + chunk: bytes = recv(chunk_size) if chunk: try: callback_state = callback(callback_state, name, chunk) @@ -139,7 +228,11 @@ def read_channel(callback_state): return (None, callback_state) - def read_all_channel(callback=None, callback_state=None): + def read_all_channel(callback: Optional[Callable[[Optional[List[bytes]], str, bytes], List[bytes]]] = None, + callback_state: Optional[List[bytes]] = None) -> Optional[List[bytes]]: + """ + read data from both stdout and stderr + """ for stream, name in ((stdout, 'stdout'), (stderr, 'stderr')): try: chunk = stream.read() @@ -151,7 +244,7 @@ def read_all_channel(callback=None, callback_state=None): return callback_state - callback_excep = None + callback_excep: Optional[Exception] = None try: callback_state = init while not channel.exit_status_ready(): @@ -176,25 +269,52 @@ def read_all_channel(callback=None, callback_state=None): return (callback_state, exit_code) -def _resolve_known_hosts(strict_host_check): +def _resolve_known_hosts(strict_host_check: Optional[Union[bool, str, os.PathLike]]) -> str: + """ + Compute a path to the known_hosts file based on ``strict_host_check``. + + :param strict_host_check: If True, uses ~/.ssh/known_hosts; if a path is given, uses that path; if False, returns /dev/null. + :type strict_host_check: bool or str or os.PathLike + :returns: Absolute path to the known_hosts file (or '/dev/null'). + :rtype: str + """ if strict_host_check: if isinstance(strict_host_check, (str, os.PathLike)): path = Path(strict_host_check) else: - path = Path('~/.ssh/known_hosts').expandvars() + path = Path(os.path.expandvars('~/.ssh/known_hosts')) else: path = Path('/dev/null') return str(path.resolve()) -def telnet_get_shell(host, - username, - password=None, - port=None, - timeout=10, - original_prompt=None): - start_time = time.time() +def telnet_get_shell(host: str, + username: str, + password: Optional[str] = None, + port: Optional[int] = None, + timeout: float = 10, + original_prompt: Optional[str] = None) -> 'TelnetPxssh': + """ + Obtain a Telnet shell by calling :class:`TelnetPxssh`. + + :param host: The host name or IP address for the Telnet connection. + :type host: str + :param username: The username for Telnet login. + :type username: str + :param password: Password for Telnet login, or None if no password is needed. + :type password: str, optional + :param port: TCP port for Telnet. Defaults to 23 if unspecified. + :type port: int or None + :param timeout: Time in seconds to wait for the initial connection. + :type timeout: float + :param original_prompt: Regex for matching the shell prompt if it differs from default. + :type original_prompt: str or None + :returns: A TelnetPxssh object for interacting with the shell. + :rtype: TelnetPxssh + :raises TargetTransientError: If connection fails repeatedly within the timeout period. + """ + start_time: float = time.time() while True: conn = TelnetPxssh(original_prompt=original_prompt) @@ -217,25 +337,56 @@ def telnet_get_shell(host, class TelnetPxssh(pxssh.pxssh): # pylint: disable=arguments-differ + """ + A specialized Telnet-based shell session class, derived from :class:`pxssh.pxssh`. - def __init__(self, original_prompt): + :param original_prompt: A regex pattern for the shell's default prompt. + :type original_prompt: str or None + """ + def __init__(self, original_prompt: Optional[str]): super(TelnetPxssh, self).__init__() self.original_prompt = original_prompt or r'[#$]' - def login(self, server, username, password='', login_timeout=10, - auto_prompt_reset=True, sync_multiplier=1, port=23): - args = ['telnet'] + def login(self, server: str, username: str, password: Optional[str] = '', login_timeout: float = 10, + auto_prompt_reset: bool = True, sync_multiplier: int = 1, port: Optional[int] = 23) -> bool: + """ + Attempt Telnet login, specifying a host, username, and optional password. + + :param server: Host name or IP address. + :type server: str + :param username: Username to log in with. + :type username: str + :param password: Password, if any, or empty string. + :type password: str + :param login_timeout: Time in seconds to wait for login prompts before failing. + :type login_timeout: float + :param auto_prompt_reset: If True, attempt to detect and set a unique prompt. + :type auto_prompt_reset: bool + :param sync_multiplier: Adjust how aggressively pxssh synchronizes prompt detection. + :type sync_multiplier: int + :param port: Telnet port, default 23. + :type port: int + :returns: True if login was successful. + :rtype: bool + :raises pxssh.ExceptionPxssh: If login fails or the password is incorrect. + :raises TIMEOUT: If no password prompt is shown within the timeout. + """ + args: List[str] = ['telnet'] if username is not None: args += ['-l', username] args += [server, str(port)] - cmd = ' '.join(args) - - spawn._spawn(self, cmd) # pylint: disable=protected-access + cmd: str = ' '.join(args) + # FIXME - Modified the access to _spawn protected method and instead use public method of pexpect. + # need to see if there is any issue with the replacement + # Spawn the command + child = pexpect.spawn(cmd) + # Wait for the command to complete + child.expect(pexpect.EOF) try: - i = self.expect('(?i)(?:password)', timeout=login_timeout) + i: int = self.expect('(?i)(?:password)', timeout=login_timeout) if i == 0: - self.sendline(password) + self.sendline(password or '') i = self.expect([self.original_prompt, 'Login incorrect'], timeout=login_timeout) if i: raise pxssh.ExceptionPxssh('could not log in: password was incorrect') @@ -259,18 +410,24 @@ def login(self, server, username, password='', login_timeout=10, return True -def check_keyfile(keyfile): +def check_keyfile(keyfile: str) -> str: """ keyfile must have the right access premissions in order to be useable. If the specified file doesn't, create a temporary copy and set the right permissions for that. Returns either the ``keyfile`` (if the permissions on it are correct) or the path to a temporary copy with the right permissions. + + :param keyfile: The path to the SSH private key file. + :type keyfile: str + :returns: Either the original ``keyfile`` (if it already has 0600 perms) + or a temporary copy path with corrected permissions. + :rtype: str """ - desired_mask = stat.S_IWUSR | stat.S_IRUSR - actual_mask = os.stat(keyfile).st_mode & 0xFF + desired_mask: int = stat.S_IWUSR | stat.S_IRUSR + actual_mask: int = os.stat(keyfile).st_mode & 0xFF if actual_mask != desired_mask: - tmp_file = os.path.join(tempfile.gettempdir(), os.path.basename(keyfile)) + tmp_file: str = os.path.join(tempfile.gettempdir(), os.path.basename(keyfile)) shutil.copy(keyfile, tmp_file) os.chmod(tmp_file, desired_mask) return tmp_file @@ -280,17 +437,87 @@ def check_keyfile(keyfile): class SshConnectionBase(ConnectionBase): """ - Base class for SSH connections. + Base class for SSH-derived connections, providing shared functionality + like verifying keyfile permissions, tracking host info, and more. + + :param host: The SSH target hostname or IP address. + :type host: str + :param username: Username to log in as. + :type username: str + :param password: Password for the SSH connection, or None if key-based auth is used. + :type password: str, optional + :param keyfile: Path to an SSH private key if using key-based auth. + :type keyfile: str, optional + :param port: TCP port for the SSH server. Defaults to 22 if unspecified. + :type port: int, optional + :param platform: A devlib.platform.Platform instance describing the device. + :type platform: Platform, optional + :param sudo_cmd: A template string for granting sudo privileges (e.g. "sudo -S sh -c {}"). + :type sudo_cmd: str + :param strict_host_check: If True, host key checking is enforced using a known_hosts file. + If a string/path is supplied, that path is used as known_hosts. If False, host keys are not checked. + :type strict_host_check: bool or str or os.PathLike + :param poll_transfers: If True, uses :class:`TransferManager` to poll file transfers. + :type poll_transfers: bool + :param start_transfer_poll_delay: Delay in seconds before the first poll of a new file transfer. + :type start_transfer_poll_delay: int + :param total_transfer_timeout: If a file transfer exceeds this many seconds, it is canceled. + :type total_transfer_timeout: int + :param transfer_poll_period: Interval (seconds) between file transfer progress checks. + :type transfer_poll_period: int """ + def __init__(self, + host: str, + username: str, + password: Optional[str] = None, + keyfile: Optional[str] = None, + port: Optional[int] = None, + platform: Optional['Platform'] = None, + sudo_cmd: str = DEFAULT_SSH_SUDO_COMMAND, + strict_host_check: Union[bool, str, os.PathLike] = True, + poll_transfers: bool = False, + start_transfer_poll_delay: int = 30, + total_transfer_timeout: int = 3600, + transfer_poll_period: int = 30, + ): + super().__init__( + poll_transfers=poll_transfers, + start_transfer_poll_delay=start_transfer_poll_delay, + total_transfer_timeout=total_transfer_timeout, + transfer_poll_period=transfer_poll_period, + ) + self._connected_as_root: Optional[bool] = None + self.host = host + self.username = username + self.password = password + self.keyfile = check_keyfile(keyfile) if keyfile else keyfile + self.port = port + self.sudo_cmd = sanitize_cmd_template(sudo_cmd) + self.platform = platform + self.strict_host_check = strict_host_check + logger.debug('Logging in {}@{}'.format(username, host)) - default_timeout = 10 + default_timeout: int = 10 + """ + Default timeout in seconds for SSH operations if not otherwise specified. + """ @property - def name(self): + def name(self) -> str: + """ + :returns: A string identifying the host (e.g. the IP or hostname). + :rtype: str + """ return self.host @property - def connected_as_root(self): + def connected_as_root(self) -> bool: + """ + Indicates if the current user on the remote SSH session is root (uid=0). + + :returns: True if root, else False. + :rtype: bool + """ if self._connected_as_root is None: try: result = self.execute('id', as_root=False) @@ -303,58 +530,99 @@ def connected_as_root(self): @connected_as_root.setter def connected_as_root(self, state): - self._connected_as_root = state + """ + Explicitly set the known state of root usage on this connection. - def __init__(self, - host, - username, - password=None, - keyfile=None, - port=None, - platform=None, - sudo_cmd=DEFAULT_SSH_SUDO_COMMAND, - strict_host_check=True, - - poll_transfers=False, - start_transfer_poll_delay=30, - total_transfer_timeout=3600, - transfer_poll_period=30, - ): - super().__init__( - poll_transfers=poll_transfers, - start_transfer_poll_delay=start_transfer_poll_delay, - total_transfer_timeout=total_transfer_timeout, - transfer_poll_period=transfer_poll_period, - ) - self._connected_as_root = None - self.host = host - self.username = username - self.password = password - self.keyfile = check_keyfile(keyfile) if keyfile else keyfile - self.port = port - self.sudo_cmd = sanitize_cmd_template(sudo_cmd) - self.platform = platform - self.strict_host_check = strict_host_check - logger.debug('Logging in {}@{}'.format(username, host)) + :param state: True if effectively root, False otherwise. + :type state: bool + """ + self._connected_as_root = state class SshConnection(SshConnectionBase): + """ + A connection to a device on the network over SSH. + + :param host: SSH host to which to connect + :type host: str + :param username: username for SSH login + :type username: str + :param password: password for the SSH connection + :type password: str, optional + + .. note:: To connect to a system without a password this + parameter should be set to an empty string otherwise + ssh key authentication will be attempted. + .. note:: In order to user password-based authentication, + ``sshpass`` utility must be installed on the + system. + + :param keyfile: Path to the SSH private key to be used for the connection. + :type keyfile: str, optional + + .. note:: ``keyfile`` and ``password`` can't be specified + at the same time. + + :param port: TCP port on which SSH server is listening on the remote device. + Omit to use the default port. + :type port: int, optional + :param timeout: Timeout for the connection in seconds. If a connection + cannot be established within this time, an error will be + raised. + :type timeout: int, optional + :param platform: Specify the platform to be used. The generic :class:`~devlib.platform.Platform` + class is used by default. + :type platform: Platform, optional + :param sudo_cmd: Specify the format of the command used to grant sudo access. + :type sudo_cmd: str + :param strict_host_check: Specify the ssh connection parameter + ``StrictHostKeyChecking``. If a path is passed + rather than a boolean, it will be taken for a + ``known_hosts`` file. Otherwise, the default + ``$HOME/.ssh/known_hosts`` will be used. + :type strict_host_check: bool or str or os.PathLike + :param use_scp: If True, prefer using the scp binary for file transfers instead of SFTP. + :type use_scp: bool + :param poll_transfers: Specify whether file transfers should be polled. Polling + monitors the progress of file transfers and periodically + checks whether they have stalled, attempting to cancel + the transfers prematurely if so. + :type poll_transfers: bool + :param start_transfer_poll_delay: If transfers are polled, specify the length of + time after a transfer has started before polling + should start. + :type start_transfer_poll_delay: int + :param total_transfer_timeout: If transfers are polled, specify the total amount of time + to elapse before the transfer is cancelled, regardless + of its activity. + :type total_transfer_timeout: int + :param transfer_poll_period: If transfers are polled, specify the period at which + the transfers are sampled for activity. Too small values + may cause the destination size to appear the same over + one or more sample periods, causing improper transfer + cancellation. + :type transfer_poll_period: int + + :raises TargetNotRespondingError: If the SSH server cannot be reached. + :raises HostError: If the password or keyfile are invalid, or scp/sftp cannot be opened. + :raises TargetStableError: If authentication fails or paramiko encounters an unrecoverable error. + """ # pylint: disable=unused-argument,super-init-not-called def __init__(self, - host, - username, - password=None, - keyfile=None, - port=22, - timeout=None, - platform=None, - sudo_cmd=DEFAULT_SSH_SUDO_COMMAND, - strict_host_check=True, - use_scp=False, - poll_transfers=False, - start_transfer_poll_delay=30, - total_transfer_timeout=3600, - transfer_poll_period=30, + host: str, + username: str, + password: Optional[str] = None, + keyfile: Optional[str] = None, + port: Optional[int] = 22, + timeout: Optional[int] = None, + platform: Optional['Platform'] = None, + sudo_cmd: str = DEFAULT_SSH_SUDO_COMMAND, + strict_host_check: Union[bool, str, os.PathLike] = True, + use_scp: bool = False, + poll_transfers: bool = False, + start_transfer_poll_delay: int = 30, + total_transfer_timeout: int = 3600, + transfer_poll_period: int = 30, ): super().__init__( @@ -381,7 +649,7 @@ def __init__(self, else: logger.debug('Using SFTP for file transfer') - self.client = None + self.client: Optional[SSHClient] = None try: self.client = self._make_client() @@ -391,7 +659,7 @@ def __init__(self, # everything will work as long as we login as root). If sudo is still # needed, it will explode when someone tries to use it. After all, the # user might not be interested in being root at all. - self._sudo_needs_password = ( + self._sudo_needs_password: bool = ( 'NEED_PASSWORD' in self.execute( # sudo -n is broken on some versions on MacOSX, revisit that if @@ -410,13 +678,16 @@ def __init__(self, finally: raise e - def _make_client(self): + def _make_client(self) -> SSHClient: + """ + Create, connect and return a class:SSHClient object + """ if self.strict_host_check: - policy = RejectPolicy + policy: Type[MissingHostKeyPolicy] = RejectPolicy else: policy = AutoAddPolicy # Only try using SSH keys if we're not using a password - check_ssh_keys = self.password is None + check_ssh_keys: bool = self.password is None with _handle_paramiko_exceptions(): client = SSHClient() @@ -427,7 +698,7 @@ def _make_client(self): client.set_missing_host_key_policy(policy) client.connect( hostname=self.host, - port=self.port, + port=self.port or 0, username=self.username, password=self.password, key_filename=self.keyfile, @@ -438,19 +709,28 @@ def _make_client(self): return client - def _make_channel(self): + def _make_channel(self) -> Optional['Channel']: + """ + The Transport class in the Paramiko library is a core component for handling SSH connections. + It attaches to a stream (usually a socket), negotiates an encrypted session, authenticates, + and then creates stream tunnels, called channels, across the session. + Multiple channels can be multiplexed across a single session + """ with _handle_paramiko_exceptions(): - transport = self.client.get_transport() - channel = transport.open_session() + transport: Optional['Transport'] = self.client.get_transport() if self.client else None + channel = transport.open_session() if transport else None return channel # Limit the number of opened channels to a low number, since some servers # will reject more connections request. For OpenSSH, this is controlled by # the MaxSessions config. @functools.lru_cache(maxsize=1) - def _cached_get_sftp(self): + def _cached_get_sftp(self) -> Optional['SFTPClient']: + """ + get the cached sftp channel to avoid opening too many channels to server + """ try: - sftp = self.client.open_sftp() + sftp: Optional['SFTPClient'] = self.client.open_sftp() if self.client else None except paramiko.ssh_exception.SSHException as e: if 'EOF during negotiation' in str(e): raise TargetStableError('The SSH server does not support SFTP. Please install and enable appropriate module.') from e @@ -458,21 +738,41 @@ def _cached_get_sftp(self): raise return sftp - def _get_sftp(self, timeout): - sftp = self._cached_get_sftp() - sftp.get_channel().settimeout(timeout) + def _get_sftp(self, timeout: Optional[float]) -> Optional['SFTPClient']: + """ + get the cached sftp channel and set a channel timeout for read write operations. + returns the channel with the timeout set + """ + sftp: Optional['SFTPClient'] = self._cached_get_sftp() + if sftp: + channel = sftp.get_channel() + if channel: + channel.settimeout(timeout) return sftp @functools.lru_cache() - def _get_scp(self, timeout, callback=lambda *_: None): - cb = lambda _, to_transfer, transferred: callback(to_transfer, transferred) - return SCPClient(self.client.get_transport(), socket_timeout=timeout, progress=cb) - - def _push_file(self, sftp, src, dst, callback): + def _get_scp(self, timeout: float, callback: Callable[..., None] = lambda *_: None) -> Optional[SCPClient]: + """ + get scp client as a class:SCPClient object + """ + cb: Callable[[bytes, int, int], None] = lambda _, to_transfer, transferred: callback(to_transfer, transferred) + if self.client: + transport: Optional['Transport'] = self.client.get_transport() + if transport: + return SCPClient(transport, socket_timeout=timeout, progress=cb) + return None + + def _push_file(self, sftp: 'SFTPClient', src: str, dst: str, callback: Optional[Callable]) -> None: + """ + push file to device via SFTP client + """ sftp.put(src, dst, callback=callback) @classmethod - def _path_exists(cls, sftp, path): + def _path_exists(cls, sftp: 'SFTPClient', path: str) -> bool: + """ + check whether the path exists on the device + """ try: sftp.lstat(path) except FileNotFoundError: @@ -480,12 +780,16 @@ def _path_exists(cls, sftp, path): else: return True - def _push_folder(self, sftp, src, dst, callback): + def _push_folder(self, sftp: 'SFTPClient', src: str, dst: str, + callback: Optional[Callable]) -> None: + """ + push a folder into device via SFTP client + """ sftp.mkdir(dst) for entry in os.scandir(src): - name = entry.name - src_path = os.path.join(src, name) - dst_path = os.path.join(dst, name) + name: str = entry.name + src_path: str = os.path.join(src, name) + dst_path: str = os.path.join(dst, name) if entry.is_dir(): push = self._push_folder else: @@ -493,28 +797,44 @@ def _push_folder(self, sftp, src, dst, callback): push(sftp, src_path, dst_path, callback) - def _push_path(self, sftp, src, dst, callback=None): + def _push_path(self, sftp: 'SFTPClient', src: str, dst: str, + callback: Optional[Callable] = None) -> None: + """ + push a path via sftp client + """ logger.debug('Pushing via sftp: {} -> {}'.format(src, dst)) push = self._push_folder if os.path.isdir(src) else self._push_file push(sftp, src, dst, callback) - def _pull_file(self, sftp, src, dst, callback): + def _pull_file(self, sftp: 'SFTPClient', src: str, dst: str, + callback: Optional[Callable]) -> None: + """ + pull a file via sftp client + """ sftp.get(src, dst, callback=callback) - def _pull_folder(self, sftp, src, dst, callback): + def _pull_folder(self, sftp: 'SFTPClient', src: str, dst: str, + callback: Optional[Callable]) -> None: + """ + pull a folder via sftp client + """ os.makedirs(dst) for fileattr in sftp.listdir_attr(src): filename = fileattr.filename src_path = os.path.join(src, filename) dst_path = os.path.join(dst, filename) - if stat.S_ISDIR(fileattr.st_mode): + if stat.S_ISDIR(fileattr.st_mode or 0): pull = self._pull_folder else: pull = self._pull_file pull(sftp, src_path, dst_path, callback) - def _pull_path(self, sftp, src, dst, callback=None): + def _pull_path(self, sftp: 'SFTPClient', src: str, dst: str, + callback: Optional[Callable] = None) -> None: + """ + pull a path from the device via sftp client + """ logger.debug('Pulling via sftp: {} -> {}'.format(src, dst)) try: self._pull_file(sftp, src, dst, callback) @@ -522,58 +842,122 @@ def _pull_path(self, sftp, src, dst, callback=None): # Maybe that was a directory, so retry as such self._pull_folder(sftp, src, dst, callback) - def push(self, sources, dest, timeout=None): + def push(self, sources: Tuple[str, ...], dest: str, timeout: Optional[int] = None) -> None: + """ + Transfer (push) one or more files from the host to the remote target. + + :param sources: A tuple of paths on the host system to be pushed. + :type sources: tuple(str, ...) + :param dest: Destination path on the remote device. If multiple sources, it should be a directory. + :type dest: str + :param timeout: Optional time limit in seconds for each file transfer. If exceeded, raises an error. + :type timeout: int or None + :raises TargetStableError: If uploading fails or the remote host is not ready. + :raises HostError: If local scp or sftp usage fails. + """ self._push_pull('push', sources, dest, timeout) - def pull(self, sources, dest, timeout=None): + def pull(self, sources: Tuple[str, ...], dest: str, timeout: Optional[int] = None) -> None: + """ + Transfer (pull) one or more files from the remote target to the host. + + :param sources: A tuple of paths on the remote device to be pulled. + :type sources: tuple(str, ...) + :param dest: Destination path on the host. If multiple sources, it should be a directory. + :type dest: str + :param timeout: Optional time limit in seconds for each file transfer. + :type timeout: int or None + :raises TargetStableError: If downloading fails on the remote side. + :raises HostError: If local scp or sftp usage fails. + """ self._push_pull('pull', sources, dest, timeout) - def _push_pull(self, action, sources, dest, timeout): + def _push_pull(self, action: Union[Literal['push'], Literal['pull']], + sources: Tuple[str, ...], dest: str, timeout: Optional[int]) -> None: + """ + Internal helper to handle both push and pull operations, optionally + using SCP or SFTP, with optional timeouts or polling. + + :param action: Either 'push' or 'pull', indicating the transfer direction. + :type action: str + :param sources: Paths to upload/download. + :type sources: tuple(str, ...) + :param dest: The destination path, on host (for pull) or remote (for push). + :type dest: str + :param timeout: If set, a per-file time limit (seconds) for the operation. + :type timeout: int or None + :raises TargetStableError: If the remote side fails or scp/sftp commands fail. + :raises HostError: If local environment or tools are unavailable. + """ if action not in ['push', 'pull']: raise ValueError("Action must be either `push` or `pull`") - def make_handle(obj): + def make_handle(obj: Union[SCPClient, 'SFTPClient']): handle = SSHTransferHandle(obj, manager=self.transfer_manager) - cm = self.transfer_manager.manage(sources, dest, action, handle) + cm = cast(TransferManager, self.transfer_manager).manage(sources, dest, action, handle) return (handle, cm) # If timeout is set if timeout is not None: if self.use_scp: - scp = self._get_scp(timeout) - scp_cmd = getattr(scp, 'put' if action == 'push' else 'get') - scp_msg = '{}ing via scp: {} -> {}'.format(action, sources, dest) + scp: Optional[SCPClient] = self._get_scp(timeout) + scp_cmd: Callable = getattr(scp, 'put' if action == 'push' else 'get') + scp_msg: str = '{}ing via scp: {} -> {}'.format(action, sources, dest) logger.debug(scp_msg.capitalize()) scp_cmd(sources, dest, recursive=True) else: - sftp = self._get_sftp(timeout) - sftp_cmd = getattr(self, '_' + action + '_path') + sftp: Optional['SFTPClient'] = self._get_sftp(timeout) + sftp_cmd: Callable = getattr(self, '_' + action + '_path') with _handle_paramiko_exceptions(): for source in sources: sftp_cmd(sftp, source, dest) # No timeout elif self.use_scp: - def progress_cb(*args, **kwargs): + def progress_cb(*args, **kwargs) -> None: return handle.progress_cb(*args, **kwargs) scp = self._get_scp(timeout, callback=progress_cb) - handle, cm = make_handle(scp) + if scp: + handle, cm = make_handle(scp) scp_cmd = getattr(scp, 'put' if action == 'push' else 'get') - with _handle_paramiko_exceptions(), cm: + with _handle_paramiko_exceptions(), cast(contextlib._GeneratorContextManager, cm): scp_msg = '{}ing via scp: {} -> {}'.format(action, sources, dest) logger.debug(scp_msg.capitalize()) scp_cmd(sources, dest, recursive=True) else: sftp = self._get_sftp(timeout) - handle, cm = make_handle(sftp) - sftp_cmd = getattr(self, '_' + action + '_path') - with _handle_paramiko_exceptions(), cm: - for source in sources: - sftp_cmd(sftp, source, dest, callback=handle.progress_cb) - - def execute(self, command, timeout=None, check_exit_code=True, - as_root=False, strip_colors=True, will_succeed=False): #pylint: disable=unused-argument + if sftp: + handle, cm = make_handle(sftp) + sftp_cmd = getattr(self, '_' + action + '_path') + with _handle_paramiko_exceptions(), cm: + for source in sources: + sftp_cmd(sftp, source, dest, callback=handle.progress_cb) + + def execute(self, command: 'SubprocessCommand', timeout: Optional[int] = None, check_exit_code: bool = True, + as_root: Optional[bool] = False, strip_colors: bool = True, will_succeed: bool = False) -> str: # pylint: disable=unused-argument + """ + Run a command synchronously on the remote machine, capturing its output. + By default, raises an exception if the command returns a non-zero exit code. + + :param command: The shell command to run, as a string or SubprocessCommand object. + :type command: SubprocessCommand + :param timeout: Maximum time in seconds to wait for completion. If None, uses a default or indefinite wait. + :type timeout: int or None + :param check_exit_code: If True, raise an error if the command's exit code is non-zero. + :type check_exit_code: bool + :param as_root: If True, attempt to run the command via sudo unless already connected as root. + :type as_root: bool or None + :param strip_colors: If True, remove ANSI color codes from the captured output. + :type strip_colors: bool + :param will_succeed: If True, treat a non-zero exit code as transient instead of stable. + :type will_succeed: bool + :returns: The combined stdout/stderr of the command. + :rtype: str + :raises TargetTransientCalledProcessError: If `check_exit_code=True` and the command fails while `will_succeed=True`. + :raises TargetStableCalledProcessError: If `check_exit_code=True` and the command fails while `will_succeed=False`. + :raises TargetStableError: If a stable SSH or environment error occurs. + """ if command == '': return '' try: @@ -597,18 +981,63 @@ def execute(self, command, timeout=None, check_exit_code=True, ) return output - def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False): + def background(self, command: 'SubprocessCommand', stdout: int = subprocess.PIPE, + stderr: int = subprocess.PIPE, as_root: Optional[bool] = False) -> ParamikoBackgroundCommand: + """ + Execute a command in the background on the remote host, returning a handle + to manage it. The command runs until completion or cancellation. + + :param command: The command to run. + :type command: SubprocessCommand + :param stdout: Where to direct the command's stdout (default: subprocess.PIPE). + :type stdout: int + :param stderr: Where to direct the command's stderr (default: subprocess.PIPE). + :type stderr: int + :param as_root: If True, attempt to run under sudo unless already root. + :type as_root: bool or None + :returns: A :class:`ParamikoBackgroundCommand` instance to manage or query the process. + :rtype: ParamikoBackgroundCommand + :raises TargetStableError: If channel creation fails or paramiko indicates a stable error. + :raises TargetNotRespondingError: If the SSH session is lost unexpectedly. + + .. note:: This **will block the connection** until the command completes. + """ with _handle_paramiko_exceptions(command): return self._background(command, stdout, stderr, as_root) - def _background(self, command, stdout, stderr, as_root): + def _background(self, command: 'SubprocessCommand', stdout: int, + stderr: int, as_root: Optional[bool]) -> ParamikoBackgroundCommand: + """ + Internal helper for :meth:`background` that sets up the paramiko channel, + spawns the command, and wires up redirection threads. + + :param command: The shell command to execute in the background. + :type command: SubprocessCommand + :param stdout: Destination for stdout (int file descriptor or special constant). + :type stdout: int + :param stderr: Destination for stderr (int file descriptor or special constant). + :type stderr: int + :param as_root: If True, run under sudo (if not already root). + :type as_root: bool or None + :returns: The background command object. + :rtype: ParamikoBackgroundCommand + :raises subprocess.CalledProcessError: If we cannot detect a valid PID or if the remote fails immediately. + :raises TargetStableError: If paramiko cannot open a session or other stable error occurs. + """ orig_command = command stdout, stderr, command = redirect_streams(stdout, stderr, command) - command = "printf '%s\n' $$; exec sh -c {}".format(quote(command)) - channel = self._make_channel() + command = "printf '%s\n' $$; exec sh -c {}".format(quote(cast(str, command))) + channel: Optional['Channel'] = self._make_channel() + if channel is None: + raise TargetStableError("channel is None") - def executor(cmd, timeout): + def executor(cmd: str, timeout: Optional[float]) -> ChannelFiles: + """ + executor to run the command via the paramiko channel + """ + if channel is None: + raise TargetStableError("channel is None") channel.exec_command(cmd) # Read are not buffered so we will always get the data as soon as # they arrive @@ -625,36 +1054,38 @@ def executor(cmd, timeout): timeout=None, executor=executor, ) - pid = stdout_in.readline() - if not pid: - stderr = stderr_in.read() - if channel.exit_status_ready(): - ret = channel.recv_exit_status() + pidstr: Optional[str] = stdout_in.readline() + if not pidstr: + stderr_t = stderr_in.read() + if channel and channel.exit_status_ready(): + ret: int = channel.recv_exit_status() else: ret = 126 raise subprocess.CalledProcessError( ret, command, b'', - stderr, + stderr_t, ) - pid = int(pid) + pid = int(pidstr) - def create_out_stream(stream_in, stream_out): + def create_out_stream(stream_in: int, stream_out: Union[int, bytes] + ) -> OutStreamType: """ Create a pair of file-like objects. The first one is used to read data and the second one to write. """ if stream_out == subprocess.DEVNULL: - r, w = None, None + r: Union[Optional['BufferedReader'], int] = None + w: Union[Optional['BufferedWriter'], int, bytes] = None # When asked for a pipe, we just give the file-like object as the # reading end and no writing end, since paramiko already writes to # it elif stream_out == subprocess.PIPE: - r, w = os.pipe() - r = os.fdopen(r, 'rb') - w = os.fdopen(w, 'wb') + r_tmp, w_tmp = os.pipe() + r = os.fdopen(r_tmp, 'rb') + w = os.fdopen(w_tmp, 'wb') # Turn a file descriptor into a file-like object elif isinstance(stream_out, int) and stream_out >= 0: r = os.fdopen(stream_in, 'rb') @@ -666,23 +1097,30 @@ def create_out_stream(stream_in, stream_out): return (r, w) - out_streams = { - name: create_out_stream(stream_in, stream_out) + out_streams: Dict[str, OutStreamType] = { + name: create_out_stream(cast(int, stream_in), stream_out) for stream_in, stream_out, name in ( (stdout_in, stdout, 'stdout'), (stderr_in, stderr, 'stderr'), ) } - def redirect_thread_f(stdout_in, stderr_in, out_streams, select_timeout): - def callback(out_streams, name, chunk): + def redirect_thread_f(stdout_in: 'ChannelFile', stderr_in: 'ChannelStderrFile', + out_streams: Dict[str, OutStreamType], + select_timeout: int) -> None: + """ + the thread that does the background read/write operation + """ + def callback(out_streams: Dict[str, OutStreamType], name: str, + chunk: bytes) -> Dict[str, OutStreamType]: try: r, w = out_streams[name] except KeyError: return out_streams try: - w.write(chunk) + cast(BufferedWriter, w).write(chunk) + # Write failed except ValueError: # Since that stream is now closed, stop trying to write to it @@ -695,7 +1133,7 @@ def callback(out_streams, name, chunk): return out_streams try: - _read_paramiko_streams(stdout_in, stderr_in, select_timeout, callback, copy.copy(out_streams)) + _read_paramiko_streams(stdout_in, stderr_in, select_timeout, callback, copy.copy(cast(List[bytes], out_streams))) # The streams closed while we were writing to it, the job is done here except ValueError: pass @@ -703,14 +1141,14 @@ def callback(out_streams, name, chunk): # Make sure the writing end are closed proper since we are not # going to write anything anymore for r, w in out_streams.values(): - w.flush() + cast(BufferedWriter, w).flush() if r is not w and w is not None: - w.close() + cast(BufferedWriter, w).close() # If there is anything we need to redirect to, spawn a thread taking # care of that - select_timeout = 1 - thread_out_streams = { + select_timeout: int = 1 + thread_out_streams: Dict[str, OutStreamType] = { name: (r, w) for name, (r, w) in out_streams.items() if w is not None @@ -728,23 +1166,43 @@ def callback(out_streams, name, chunk): as_root=as_root, chan=channel, pid=pid, - stdin=stdin, + stdin=cast(IO, stdin), # We give the reading end to the consumer of the data - stdout=out_streams['stdout'][0], - stderr=out_streams['stderr'][0], + stdout=cast(IO, out_streams['stdout'][0]), + stderr=cast(IO, out_streams['stderr'][0]), redirect_thread=redirect_thread, cmd=orig_command, ) - def _close(self): + def _close(self) -> None: + """ + Close the SSH connection, releasing any underlying resources such as paramiko + sessions or sockets. After this call, the SshConnection is no longer usable. + + :raises TargetStableError: If a stable error occurs during disconnection. + """ logger.debug('Logging out {}@{}'.format(self.username, self.host)) with _handle_paramiko_exceptions(): - self.client.close() + if self.client: + self.client.close() - def _execute_command(self, command, as_root, log, timeout, executor): + def _execute_command(self, command: str, as_root: Optional[bool], + log: bool, timeout: Optional[int], + executor: Callable[..., ChannelFiles]) -> ChannelFiles: + """ + execute the command over the channel using the executor and return the channel in, out and err files + """ + def get_logger(log: bool) -> Callable[..., None]: + """ + get the logger + """ + if log: + return logger.debug + else: + return lambda msg: None # As we're already root, there is no need to use sudo. - log_debug = logger.debug if log else lambda msg: None - use_sudo = as_root and not self.connected_as_root + log_debug = get_logger(log) + use_sudo: Optional[bool] = as_root and not self.connected_as_root if use_sudo: if self._sudo_needs_password and not self.password: @@ -753,8 +1211,8 @@ def _execute_command(self, command, as_root, log, timeout, executor): command = self.sudo_cmd.format(quote(command)) log_debug(command) - streams = executor(command, timeout=timeout) - if self._sudo_needs_password: + streams: ChannelFiles = executor(command, timeout=timeout) + if self._sudo_needs_password and streams and self.password: stdin = streams[0] stdin.write(self.password + '\n') stdin.flush() @@ -764,10 +1222,16 @@ def _execute_command(self, command, as_root, log, timeout, executor): return streams - def _execute(self, command, timeout=None, as_root=False, strip_colors=True, log=True): + def _execute(self, command: 'SubprocessCommand', timeout: Optional[int] = None, + as_root: Optional[bool] = False, strip_colors: bool = True, + log: bool = True) -> Tuple[int, str]: + """ + execute the command and return the exit code and output + """ # Merge stderr into stdout since we are going without a TTY - command = '({}) 2>&1'.format(command) - + command = '({}) 2>&1'.format(cast(str, command)) + if self.client is None: + raise TargetStableError("client is None") stdin, stdout, stderr = self._execute_command( command, as_root=as_root, @@ -779,36 +1243,87 @@ def _execute(self, command, timeout=None, as_root=False, strip_colors=True, log= # Empty the stdout buffer of the command, allowing it to carry on to # completion - def callback(output_chunks, name, chunk): + def callback(output_chunks: List[bytes], name: str, chunk: bytes) -> List[bytes]: + """ + callback for _read_paramiko_streams + """ output_chunks.append(chunk) return output_chunks - select_timeout = 1 + select_timeout: float = 1 output_chunks, exit_code = _read_paramiko_streams(stdout, stderr, select_timeout, callback, []) + if output_chunks is None: + raise TargetStableError("output_chunks is None") # Join in one go to avoid O(N^2) concatenation - output = b''.join(output_chunks) - output = output.decode(sys.stdout.encoding or 'utf-8', 'replace') + output_b = b''.join(output_chunks) + output = output_b.decode(sys.stdout.encoding or 'utf-8', 'replace') return (exit_code, output) class TelnetConnection(SshConnectionBase): + """ + A connection using the Telnet protocol. In practice, this implements minimal + features such as command execution, but leverages local scp if needed for file + transfers (since Telnet does not provide a built-in file transfer mechanism). + + .. note:: Since Telnet protocol is does not support file transfer, scp is + used for that purpose. + + :param host: SSH host to which to connect + :type host: str + :param username: username for SSH login + :type username: str + :param password: password for the SSH connection + :type password: str, optional + + .. note:: In order to user password-based authentication, + ``sshpass`` utility must be installed on the system. + + :param port: TCP port on which SSH server is listening on the remote device. + Omit to use the default port. + :type port: int, optional + :param timeout: Timeout for the connection in seconds. If a connection + cannot be established within this time, an error will be + raised. + :type timeout: int, optional + :param password_prompt: A string with the password prompt used by + ``sshpass``. Set this if your version of ``sshpass`` + uses something other than ``"[sudo] password"``. + :type password_prompt: str, optional + :param original_prompt: A regex for the shell prompted presented in the Telnet + connection (the prompt will be reset to a + randomly-generated pattern for the duration of the + connection to reduce the possibility of clashes). + This parameter is ignored for SSH connections. + :type original_prompt: str, optional + :param sudo_cmd: Template string for running commands with sudo privileges. + :type sudo_cmd: str + :param strict_host_check: Ignored for Telnet connections, included for interface consistency. + :type strict_host_check: bool or str or os.PathLike + :param platform: A devlib Platform describing hardware or OS features. + :type platform: Platform, optional + + :raises TargetNotRespondingError: If the Telnet server is not reachable. + :raises HostError: If local scp usage fails for file transfers. + :raises TargetStableError: If login fails or commands cannot be executed. + """ - default_password_prompt = '[sudo] password' - max_cancel_attempts = 5 + default_password_prompt: str = '[sudo] password' + max_cancel_attempts: int = 5 # pylint: disable=unused-argument,super-init-not-called def __init__(self, - host, - username, - password=None, - port=None, - timeout=None, - password_prompt=None, - original_prompt=None, - sudo_cmd="sudo -- sh -c {}", - strict_host_check=True, - platform=None): + host: str, + username: str, + password: Optional[str] = None, + port: Optional[int] = None, + timeout: Optional[int] = None, + password_prompt: Optional[str] = None, + original_prompt: Optional[str] = None, + sudo_cmd: str = "sudo -- sh -c {}", + strict_host_check: Union[bool, str, os.PathLike] = True, + platform: Optional['Platform'] = None): super().__init__( host=host, @@ -828,12 +1343,18 @@ def __init__(self, logger.debug('Logging in {}@{}'.format(username, host)) timeout = timeout if timeout is not None else self.default_timeout - self.conn = telnet_get_shell(host, username, password, port, timeout, original_prompt) + self.conn: Optional['TelnetPxssh'] = telnet_get_shell(host, username, password, port, timeout, original_prompt) - def fmt_remote_path(self, path): + def fmt_remote_path(self, path: str) -> str: + """ + format remote path + """ return '{}@{}:{}'.format(self.username, self.host, path) - def _get_default_options(self): + def _get_default_options(self) -> Dict[str, str]: + """ + get defaults for stricthostcheck and known hosts + """ check = self.strict_host_check known_hosts = _resolve_known_hosts(check) return { @@ -841,13 +1362,19 @@ def _get_default_options(self): 'UserKnownHostsFile': str(known_hosts), } - def push(self, sources, dest, timeout=30): + def push(self, sources: List[str], dest: str, timeout: int = 30) -> None: + """ + push files to device through the connection + """ # Quote the destination as SCP would apply globbing too dest = self.fmt_remote_path(quote(dest)) paths = list(sources) + [dest] return self._scp(paths, timeout) - def pull(self, sources, dest, timeout=30): + def pull(self, sources: str, dest: str, timeout=30): + """ + pull files from device + """ # First level of escaping for the remote shell sources = ' '.join(map(quote, sources)) # All the sources are merged into one scp parameter @@ -855,22 +1382,22 @@ def pull(self, sources, dest, timeout=30): paths = [sources, dest] self._scp(paths, timeout) - def _scp(self, paths, timeout=30): + def _scp(self, paths: List[str], timeout=30): # NOTE: the version of scp in Ubuntu 12.04 occasionally (and bizarrely) # fails to connect to a device if port is explicitly specified using -P # option, even if it is the default port, 22. To minimize this problem, # only specify -P for scp if the port is *not* the default. - port_string = '-P {}'.format(quote(str(self.port))) if (self.port and self.port != 22) else '' - keyfile_string = '-i {}'.format(quote(self.keyfile)) if self.keyfile else '' - options = " ".join(["-o {}={}".format(key, val) - for key, val in self.options.items()]) - paths = ' '.join(map(quote, paths)) - command = '{} {} -r {} {} {}'.format(_SSH_ENV.get_path('scp'), - options, - keyfile_string, - port_string, - paths) - command_redacted = command + port_string: str = '-P {}'.format(quote(str(self.port))) if (self.port and self.port != 22) else '' + keyfile_string: str = '-i {}'.format(quote(self.keyfile)) if self.keyfile else '' + options: str = " ".join(["-o {}={}".format(key, val) + for key, val in self.options.items()]) + paths_s: str = ' '.join(map(quote, paths)) + command: str = '{} {} -r {} {} {}'.format(_SSH_ENV.get_path('scp'), + options, + keyfile_string, + port_string, + paths_s) + command_redacted: str = command logger.debug(command) if self.password: command, command_redacted = _give_password(self.password, command) @@ -882,21 +1409,20 @@ def _scp(self, paths, timeout=30): except TimeoutError as e: raise TimeoutError(command_redacted, e.output) - - def execute(self, command, timeout=None, check_exit_code=True, - as_root=False, strip_colors=True, will_succeed=False): #pylint: disable=unused-argument + def execute(self, command: 'SubprocessCommand', timeout: Optional[int] = None, check_exit_code: bool = True, + as_root: Optional[bool] = False, strip_colors: bool = True, will_succeed: bool = False) -> str: # pylint: disable=unused-argument if command == '': # Empty command is valid but the __devlib_ec stuff below will # produce a syntax error with bash. Treat as a special case. return '' try: with self.lock: - _command = '({}); __devlib_ec=$?; echo; echo $__devlib_ec'.format(command) + _command = '({}); __devlib_ec=$?; echo; echo $__devlib_ec'.format(cast(str, command)) full_output = self._execute_and_wait_for_prompt(_command, timeout, as_root, strip_colors) split_output = full_output.rsplit('\r\n', 2) try: output, exit_code_text, _ = split_output - except ValueError as e: + except ValueError: raise TargetStableError( "cannot split reply (target misconfiguration?):\n'{}'".format(full_output)) if check_exit_code: @@ -904,8 +1430,8 @@ def execute(self, command, timeout=None, check_exit_code=True, exit_code = int(exit_code_text) except (ValueError, IndexError): raise ValueError( - 'Could not get exit code for "{}",\ngot: "{}"'\ - .format(command, exit_code_text)) + 'Could not get exit code for "{}",\ngot: "{}"' + .format(cast(str, command), exit_code_text)) if exit_code: cls = TargetTransientCalledProcessError if will_succeed else TargetStableCalledProcessError raise cls( @@ -925,42 +1451,56 @@ def execute(self, command, timeout=None, check_exit_code=True, else: raise - def background(self, command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, as_root=False): + def background(self, command: 'SubprocessCommand', stdout: int = subprocess.PIPE, + stderr: int = subprocess.PIPE, as_root: Optional[bool] = False) -> 'Popen': try: port_string = '-p {}'.format(self.port) if self.port else '' keyfile_string = '-i {}'.format(self.keyfile) if self.keyfile else '' if as_root and not self.connected_as_root: - command = self.sudo_cmd.format(command) - options = " ".join([ "-o {}={}".format(key,val) - for key,val in self.options.items()]) - command = '{} {} {} {} {}@{} {}'.format(_SSH_ENV.get_path('ssh'), - options, - keyfile_string, - port_string, - self.username, - self.host, - command) + commandstr = self.sudo_cmd.format(command) + options = " ".join(["-o {}={}".format(key, val) + for key, val in self.options.items()]) + commandstr = '{} {} {} {} {}@{} {}'.format(_SSH_ENV.get_path('ssh'), + options, + keyfile_string, + port_string, + self.username, + self.host, + commandstr) logger.debug(command) if self.password: - command, _ = _give_password(self.password, command) + command, _ = _give_password(self.password, cast(str, command)) return subprocess.Popen(command, stdout=stdout, stderr=stderr, shell=True) except EOF: raise TargetNotRespondingError('Connection lost.') - def _close(self): + def _close(self) -> None: + """ + Close the connection to the device. The :class:`Connection` object should not + be used after this method is called. There is no way to reopen a previously + closed connection, a new connection object should be created instead. + """ logger.debug('Logging out {}@{}'.format(self.username, self.host)) try: - self.conn.logout() + if self.conn: + self.conn.logout() except: logger.debug('Connection lost.') - self.conn.close(force=True) + if self.conn: + self.conn.close(force=True) - def cancel_running_command(self): + def cancel_running_command(self) -> bool: + """ + Cancel a running command (previously started with :func:`background`) and free up the connection. + It is valid to call this if the command has already terminated (or if no + command was issued), in which case this is a no-op. + """ + # FIXME - other instances of cancel_running_command is just returning None. should this also be changed to do the same? # simulate impatiently hitting ^C until command prompt appears logger.debug('Sending ^C') for _ in range(self.max_cancel_attempts): self._sendline(chr(3)) - if self.conn.prompt(0.1): + if self.conn and self.conn.prompt(0.1): return True return False @@ -970,16 +1510,23 @@ def wait_for_device(self, timeout=30): def reboot_bootloader(self, timeout=30): raise NotImplementedError() - def _execute_and_wait_for_prompt(self, command, timeout=None, as_root=False, strip_colors=True, log=True): + def _execute_and_wait_for_prompt(self, command: 'SubprocessCommand', timeout: Optional[int] = None, + as_root: Optional[bool] = False, strip_colors: bool = True, + log: bool = True) -> str: + """ + execute command and wait for prompt + """ + if not self.conn: + raise TargetStableError("conn is None") self.conn.prompt(0.1) # clear an existing prompt if there is one. if as_root and self.connected_as_root: # As we're already root, there is no need to use sudo. as_root = False if as_root: - command = self.sudo_cmd.format(quote(command)) + command = self.sudo_cmd.format(quote(cast(str, command))) if log: logger.debug(command) - self._sendline(command) + self._sendline(cast(str, command)) if self.password: index = self.conn.expect_exact([self.password_prompt, TIMEOUT], timeout=0.5) if index == 0: @@ -987,8 +1534,10 @@ def _execute_and_wait_for_prompt(self, command, timeout=None, as_root=False, str else: # not as_root if log: logger.debug(command) - self._sendline(command) + self._sendline(cast(str, command)) timed_out = self._wait_for_prompt(timeout) + if self.conn.before is None: + raise TargetStableError("conn.before is None") output = process_backspaces(self.conn.before.decode(sys.stdout.encoding or 'utf-8', 'replace')) if timed_out: @@ -998,28 +1547,45 @@ def _execute_and_wait_for_prompt(self, command, timeout=None, as_root=False, str output = strip_bash_colors(output) return output - def _wait_for_prompt(self, timeout=None): - if timeout: - return not self.conn.prompt(timeout) - else: # cannot timeout; wait forever - while not self.conn.prompt(1): - pass + def _wait_for_prompt(self, timeout: Optional[int] = None) -> bool: + """ + wait for prompt + """ + if self.conn: + if timeout: + return not self.conn.prompt(timeout) + else: # cannot timeout; wait forever + while not self.conn.prompt(1): + pass + return False + else: return False - def _sendline(self, command): + def _sendline(self, command: str) -> None: + """ + send a line of string + """ # Workaround for https://github.com/pexpect/pexpect/issues/552 if len(command) == self._get_window_size()[1] - self._get_prompt_length(): command += ' ' - self.conn.sendline(command) + if self.conn: + self.conn.sendline(command) @memoized - def _get_prompt_length(self): + def _get_prompt_length(self) -> int: + """ + get the length of the prompt + """ + if not self.conn: + raise TargetStableError("conn is none") self.conn.sendline() self.conn.prompt() - return len(self.conn.after) + return len(cast(Sized, self.conn.after)) @memoized - def _get_window_size(self): + def _get_window_size(self) -> Tuple[int, int]: + if not self.conn: + raise TargetStableError("conn is none") return self.conn.getwinsize() @@ -1041,9 +1607,9 @@ def __init__(self, host_system = socket.gethostname() if host_system != host: raise TargetStableError("Gem5Connection can only connect to gem5 " - "simulations on your current host {}, which " - "differs from the one given {}!" - .format(host_system, host)) + "simulations on your current host {}, which " + "differs from the one given {}!" + .format(host_system, host)) if username is not None and username != 'root': raise ValueError('User should be root in gem5!') if password is not None and password != '': @@ -1060,7 +1626,7 @@ def __init__(self, if timeout is not None: if timeout > self.default_timeout: logger.info('Overwriting the default timeout of gem5 ({})' - ' to {}'.format(self.default_timeout, timeout)) + ' to {}'.format(self.default_timeout, timeout)) self.default_timeout = timeout else: logger.info('Ignoring the given timeout --> gem5 needs longer timeouts') @@ -1077,7 +1643,7 @@ def __init__(self, # Lock file to prevent multiple connections to same gem5 simulation # (gem5 does not allow this) self.lock_directory = '/tmp/' - self.lock_file_name = None # Will be set once connected to gem5 + self.lock_file_name = None # Will be set once connected to gem5 # These parameters will be set by either the method to connect to the # gem5 platform or directly to the gem5 simulation @@ -1134,7 +1700,7 @@ def push(self, sources, dest, timeout=None): self._gem5_shell("ls -al {}".format(quote(self.gem5_input_dir))) logger.debug("Push complete.") - def pull(self, sources, dest, timeout=0): #pylint: disable=unused-argument + def pull(self, sources, dest, timeout=0): # pylint: disable=unused-argument """ Pull a file from the gem5 device using m5 writefile @@ -1160,30 +1726,32 @@ def pull(self, sources, dest, timeout=0): #pylint: disable=unused-argument # error if the file was not where we expected it to be. if os.path.isabs(source): if os.path.dirname(source) != self.execute('pwd', - check_exit_code=False): + check_exit_code=False): self._gem5_shell("cat {} > {}".format(quote(filename), - quote(dest_file))) + quote(dest_file))) self._gem5_shell("sync") self._gem5_shell("ls -la {}".format(dest_file)) logger.debug('Finished the copy in the simulator') self._gem5_util("writefile {}".format(dest_file)) if 'cpu' not in filename: - while not os.path.exists(os.path.join(self.gem5_out_dir, - dest_file)): - time.sleep(1) + if self.gem5_out_dir: + while not os.path.exists(os.path.join(self.gem5_out_dir, + dest_file)): + time.sleep(1) # Perform the local move if os.path.exists(os.path.join(dest, dest_file)): logger.warning( - 'Destination file {} already exists!'\ - .format(dest_file)) + 'Destination file {} already exists!' + .format(dest_file)) else: - shutil.move(os.path.join(self.gem5_out_dir, dest_file), dest) + if self.gem5_out_dir: + shutil.move(os.path.join(self.gem5_out_dir, dest_file), dest) logger.debug("Pull complete.") def execute(self, command, timeout=1000, check_exit_code=True, - as_root=False, strip_colors=True, will_succeed=False): + as_root: Optional[bool] = False, strip_colors=True, will_succeed=False): """ Execute a command on the gem5 platform """ @@ -1213,7 +1781,7 @@ def background(self, command, stdout=subprocess.PIPE, self._check_ready() # Create the logfile for stderr/stdout redirection - command_name = command.split(' ')[0].split('/')[-1] + command_name = cast(str, command).split(' ')[0].split('/')[-1] redirection_file = 'BACKGROUND_{}.log'.format(command_name) trial = 0 while os.path.isfile(redirection_file): @@ -1242,27 +1810,31 @@ def _close(self): # the end of a simulation! self._unmount_virtio() self._gem5_util("exit") - self.gem5simulation.wait() + if self.gem5simulation: + self.gem5simulation.wait() except EOF: pass gem5_logger.info("Removing the temporary directory") try: - shutil.rmtree(self.gem5_interact_dir) + if self.gem5_interact_dir: + shutil.rmtree(self.gem5_interact_dir) except OSError: gem5_logger.warning("Failed to remove the temporary directory!") # Delete the lock file - os.remove(self.lock_file_name) + if self.lock_file_name: + os.remove(self.lock_file_name) def wait_for_device(self, timeout=30): """ Wait for Gem5 to be ready for interation with a timeout. """ - for _ in attempts(timeout): + # FIXME - attempts function not defined. not sure if it is a library function or this is the right intention + for _ in attempts(timeout): # type:ignore if self.ready: return time.sleep(1) - raise TimeoutError('Gem5 is not ready for interaction') + raise TimeoutError('Gem5 is not ready for interaction', '') def reboot_bootloader(self, timeout=30): raise NotImplementedError() @@ -1293,7 +1865,7 @@ def _gem5_EOF_handler(self, gem5_simulation, gem5_out_dir, err): # This function connects to the gem5 simulation # pylint: disable=too-many-statements def connect_gem5(self, port, gem5_simulation, gem5_interact_dir, - gem5_out_dir): + gem5_out_dir): """ Connect to the telnet port of the gem5 simulation. @@ -1314,8 +1886,8 @@ def connect_gem5(self, port, gem5_simulation, gem5_interact_dir, if os.path.isfile(lock_file_name): # There is already a connection to this gem5 simulation raise TargetStableError('There is already a connection to the gem5 ' - 'simulation using port {} on {}!' - .format(port, host)) + 'simulation using port {} on {}!' + .format(port, host)) # Connect to the gem5 telnet port. Use a short timeout here. attempts = 0 @@ -1338,7 +1910,7 @@ def connect_gem5(self, port, gem5_simulation, gem5_interact_dir, # Create the lock file self.lock_file_name = lock_file_name - open(self.lock_file_name, 'w').close() # Similar to touch + open(self.lock_file_name, 'w').close() # Similar to touch gem5_logger.info("Created lock file {} to prevent reconnecting to " "same simulation".format(self.lock_file_name)) @@ -1394,6 +1966,8 @@ def _login_to_device(self): def _find_prompt(self): prompt = r'\[PEXPECT\][\\\$\#]+ ' synced = False + if self.conn is None: + raise TargetStableError("Conn is None") while not synced: self.conn.send('\n') i = self.conn.expect([prompt, self.conn.UNIQUE_PROMPT, r'[\$\#] '], timeout=self.default_timeout) @@ -1403,10 +1977,11 @@ def _find_prompt(self): prompt = self.conn.UNIQUE_PROMPT synced = True else: - prompt = re.sub(r'\$', r'\\\$', self.conn.before.strip() + self.conn.after.strip()) - prompt = re.sub(r'\#', r'\\\#', prompt) - prompt = re.sub(r'\[', r'\[', prompt) - prompt = re.sub(r'\]', r'\]', prompt) + if self.conn.before and self.conn.after: + prompt = re.sub(r'\$', r'\\\$', self.conn.before.strip() + cast(bytes, self.conn.after).strip()) + prompt = re.sub(r'\#', r'\\\#', prompt) + prompt = re.sub(r'\[', r'\[', prompt) + prompt = re.sub(r'\]', r'\]', prompt) self.conn.PROMPT = prompt @@ -1419,10 +1994,11 @@ def _sync_gem5_shell(self): both of these. """ gem5_logger.debug("Sending Sync") - self.conn.send("echo \\*\\*sync\\*\\*\n") - self.conn.expect(r"\*\*sync\*\*", timeout=self.default_timeout) - self.conn.expect([self.conn.UNIQUE_PROMPT, self.conn.PROMPT], timeout=self.default_timeout) - self.conn.expect([self.conn.UNIQUE_PROMPT, self.conn.PROMPT], timeout=self.default_timeout) + if self.conn: + self.conn.send("echo \\*\\*sync\\*\\*\n") + self.conn.expect(r"\*\*sync\*\*", timeout=self.default_timeout) + self.conn.expect([self.conn.UNIQUE_PROMPT, self.conn.PROMPT], timeout=self.default_timeout) + self.conn.expect([self.conn.UNIQUE_PROMPT, self.conn.PROMPT], timeout=self.default_timeout) def _gem5_util(self, command): """ Execute a gem5 utility command using the m5 binary on the device """ @@ -1430,7 +2006,8 @@ def _gem5_util(self, command): raise TargetStableError('Path to m5 binary on simulated system is not set!') self._gem5_shell('{} {}'.format(self.m5_path, command)) - def _gem5_shell(self, command, as_root=False, timeout=None, check_exit_code=True, sync=True, will_succeed=False): # pylint: disable=R0912 + def _gem5_shell(self, command, as_root: Optional[bool] = False, + timeout=None, check_exit_code=True, sync=True, will_succeed=False): # pylint: disable=R0912 """ Execute a command in the gem5 shell @@ -1450,7 +2027,8 @@ def _gem5_shell(self, command, as_root=False, timeout=None, check_exit_code=True if as_root: command = 'echo {} | su'.format(quote(command)) - + if self.conn is None: + raise TargetStableError("Conn is None") # Send the actual command self.conn.send("{}\n".format(command)) @@ -1460,7 +2038,7 @@ def _gem5_shell(self, command, as_root=False, timeout=None, check_exit_code=True command_index = -1 while command_index == -1: if self.conn.prompt(): - output = re.sub(r' \r([^\n])', r'\1', self.conn.before) + output = re.sub(r' \r([^\n])', r'\1', self.conn.before or '') output = re.sub(r'[\b]', r'', output) # Deal with line wrapping output = re.sub(r'[\r].+?<', r'', output) @@ -1471,7 +2049,7 @@ def _gem5_shell(self, command, as_root=False, timeout=None, check_exit_code=True # warn, and return the whole output. if command_index == -1: gem5_logger.warning("gem5_shell: Unable to match command in " - "command output. Expect parsing errors!") + "command output. Expect parsing errors!") command_index = 0 output = output[command_index + len(command):].strip() @@ -1491,15 +2069,15 @@ def _gem5_shell(self, command, as_root=False, timeout=None, check_exit_code=True if check_exit_code: exit_code_text = self._gem5_shell('echo $?', as_root=as_root, - timeout=timeout, check_exit_code=False, - sync=False) + timeout=timeout, check_exit_code=False, + sync=False) try: exit_code = int(exit_code_text.split()[0]) except (ValueError, IndexError): raise ValueError('Could not get exit code for "{}",\ngot: "{}"'.format(command, exit_code_text)) else: if exit_code: - cls = TragetTransientCalledProcessError if will_succeed else TargetStableCalledProcessError + cls = TargetTransientCalledProcessError if will_succeed else TargetStableCalledProcessError raise cls( exit_code, command, @@ -1579,20 +2157,21 @@ def _login_to_device(self): gem5_logger.info("Trying to log in to gem5 device") login_prompt = ['login:', 'AEL login:', 'username:', 'aarch64-gem5 login:'] login_password_prompt = ['password:'] + if self.conn is None: + raise TargetStableError("Conn is None") # Wait for the login prompt prompt = login_prompt + [self.conn.UNIQUE_PROMPT] - i = self.conn.expect(prompt, timeout=10) + i = self.conn.expect(cast(Pattern[str], prompt), timeout=10) # Check if we are already at a prompt, or if we need to log in. if i < len(prompt) - 1: self.conn.sendline("{}".format(self.username)) password_prompt = login_password_prompt + [r'# ', self.conn.UNIQUE_PROMPT] - j = self.conn.expect(password_prompt, timeout=self.default_timeout) + j = self.conn.expect(cast(Pattern[str], password_prompt), timeout=self.default_timeout) if j < len(password_prompt) - 2: self.conn.sendline("{}".format(self.password)) self.conn.expect([r'# ', self.conn.UNIQUE_PROMPT], timeout=self.default_timeout) - class AndroidGem5Connection(Gem5Connection): def _wait_for_boot(self): @@ -1622,7 +2201,20 @@ def _wait_for_boot(self): gem5_logger.info("Android booted") -def _give_password(password, command): +def _give_password(password: str, command: str) -> Tuple[str, str]: + """ + Insert a password into an ``sshpass``-based command to allow non-interactive + authentication. + + :param password: The password to embed in the command. + :type password: str + :param command: The original shell command that invokes ``sshpass``. + :type command: str + :returns: A tuple of (modified_command, redacted_command). The first string is + safe to execute, while the second omits the password for logging. + :rtype: (str, str) + :raises ValueError: If the command cannot be adjusted or if ``sshpass`` is unavailable. + """ sshpass = _SSH_ENV.get_path('sshpass') if sshpass: pass_template = "{} -p {} " @@ -1633,8 +2225,11 @@ def _give_password(password, command): raise HostError('Must have sshpass installed on the host in order to use password-based auth.') -def process_backspaces(text): - chars = [] +def process_backspaces(text: str) -> str: + """ + process backspace in the command + """ + chars: List[str] = [] for c in text: if c == chr(8) and chars: # backspace chars.pop() diff --git a/devlib/utils/types.py b/devlib/utils/types.py index d7c8864b0..55f0fe5a0 100644 --- a/devlib/utils/types.py +++ b/devlib/utils/types.py @@ -1,4 +1,4 @@ -# Copyright 2014-2018 ARM Limited +# Copyright 2014-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,9 +30,9 @@ import sys from functools import total_ordering -from past.builtins import basestring from devlib.utils.misc import isiterable, to_identifier, ranges_to_list, list_to_mask +from typing import List, Union def identifier(text): @@ -49,7 +49,7 @@ def boolean(value): """ false_strings = ['', '0', 'n', 'no', 'off'] - if isinstance(value, basestring): + if isinstance(value, str): value = value.lower() if value in false_strings or 'false'.startswith(value): return False @@ -58,7 +58,7 @@ def boolean(value): def integer(value): """Handles conversions for string respresentations of binary, octal and hex.""" - if isinstance(value, basestring): + if isinstance(value, str): return int(value, 0) else: return int(value) @@ -74,7 +74,7 @@ def numeric(value): if isinstance(value, int): return value - if isinstance(value, basestring): + if isinstance(value, str): value = value.strip() if value.endswith('%'): try: @@ -102,17 +102,17 @@ class caseless_string(str): """ def __eq__(self, other): - if isinstance(other, basestring): + if isinstance(other, str): other = other.lower() return self.lower() == other def __ne__(self, other): - if isinstance(other, basestring): + if isinstance(other, str): other = other.lower() return self.lower() != other def __lt__(self, other): - if isinstance(other, basestring): + if isinstance(other, str): other = other.lower() return self.lower() < other @@ -123,11 +123,14 @@ def format(self, *args, **kwargs): return caseless_string(super(caseless_string, self).format(*args, **kwargs)) -def bitmask(value): - if isinstance(value, basestring): +def bitmask(value: Union[int, List[int], str]) -> int: + if isinstance(value, str): value = ranges_to_list(value) if isiterable(value): - value = list_to_mask(value) + if isinstance(value, list): + value = list_to_mask(value) + else: + raise TypeError("Expected a list of integers") if not isinstance(value, int): raise ValueError(value) return value diff --git a/devlib/utils/version.py b/devlib/utils/version.py index ec6a3f1c5..cb07c267e 100644 --- a/devlib/utils/version.py +++ b/devlib/utils/version.py @@ -1,4 +1,4 @@ -# Copyright 2018 ARM Limited +# Copyright 2018-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,16 +15,21 @@ import os import sys -from collections import namedtuple from subprocess import Popen, PIPE +from typing import NamedTuple, Optional -VersionTuple = namedtuple('Version', ['major', 'minor', 'revision', 'dev']) +class Version(NamedTuple): + major: int + minor: int + revision: int + dev: str -version = VersionTuple(1, 4, 0, 'dev2') +version = Version(1, 4, 0, 'dev2') -def get_devlib_version(): + +def get_devlib_version() -> str: version_string = '{}.{}.{}'.format( version.major, version.minor, version.revision) if version.dev: @@ -32,7 +37,7 @@ def get_devlib_version(): return version_string -def get_commit(): +def get_commit() -> Optional[str]: try: p = Popen(['git', 'rev-parse', 'HEAD'], cwd=os.path.dirname(__file__), stdout=PIPE, stderr=PIPE) diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 000000000..4077efd50 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,6 @@ +[mypy] +ignore_missing_imports = True +python_version = 3.10 + +[mypy-numpy.*] +ignore_errors = True \ No newline at end of file diff --git a/py.typed b/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/setup.py b/setup.py index cba25a26b..9686c6233 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright 2013-2015 ARM Limited +# Copyright 2013-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -97,17 +97,18 @@ def _load_path(filepath): 'python-dateutil', # converting between UTC and local time. 'pexpect>=3.3', # Send/recieve to/from device 'pyserial', # Serial port interface - 'paramiko', # SSH connection - 'scp', # SSH connection file transfers + 'paramiko', # SSH connection + 'scp', # SSH connection file transfers 'wrapt', # Basic for construction of decorator functions 'numpy', 'pandas', 'pytest', - 'lxml', # More robust xml parsing - 'nest_asyncio', # Allows running nested asyncio loops - 'greenlet', # Allows running nested asyncio loops - 'future', # for the "past" Python package - 'ruamel.yaml >= 0.15.72', # YAML formatted config parsing + 'lxml', # More robust xml parsing + 'nest_asyncio', # Allows running nested asyncio loops + 'greenlet', # Allows running nested asyncio loops + 'future', # for the "past" Python package + 'ruamel.yaml >= 0.15.72', # YAML formatted config parsing + 'typing_extensions' ], extras_require={ 'daq': ['daqpower>=2'], @@ -115,7 +116,7 @@ def _load_path(filepath): 'monsoon': ['python-gflags'], 'acme': ['pandas', 'numpy'], 'dev': [ - 'uvloop', # Test async features under uvloop + 'uvloop', # Test async features under uvloop ] }, # https://pypi.python.org/pypi?%3Aaction=list_classifiers @@ -142,7 +143,6 @@ def initialize_options(self): orig_sdist.initialize_options(self) self.strip_commit = False - def run(self): if self.strip_commit: self.distribution.get_version = lambda : __version__.split('+')[0] diff --git a/tests/test_target.py b/tests/test_target.py index 2d811321f..c576a83e5 100644 --- a/tests/test_target.py +++ b/tests/test_target.py @@ -1,5 +1,5 @@ # -# Copyright 2024 ARM Limited +# Copyright 2024-2025 ARM Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,6 +25,8 @@ import logging import os +from typing import Optional + import pytest from devlib import AndroidTarget, ChromeOsTarget, LinuxTarget, LocalLinuxTarget @@ -36,7 +38,7 @@ logger = logging.getLogger('test_target') -def get_class_object(name): +def get_class_object(name: str) -> Optional[object]: """ Get associated class object from string formatted class name