Skip to content

Commit

Permalink
feat(core): enforce return type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
matejcik committed Nov 12, 2024
1 parent 34d97ee commit 8fb41ee
Show file tree
Hide file tree
Showing 34 changed files with 110 additions and 71 deletions.
2 changes: 1 addition & 1 deletion core/src/apps/benchmark/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

# This is a wrapper above the trezor.crypto.curve.ed25519 module that satisfies SignCurve protocol, the modules uses `message` instead of `digest` in `sign()` and `verify()`
class Ed25519:
def __init__(self):
def __init__(self) -> None:
pass

def generate_secret(self) -> bytes:
Expand Down
12 changes: 6 additions & 6 deletions core/src/apps/benchmark/cipher_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ def decrypt(self, data: bytes) -> bytes: ...
class EncryptBenchmark:
def __init__(
self, cipher_ctx_constructor: Callable[[], CipherCtx], block_size: int
):
) -> None:
self.cipher_ctx_constructor = cipher_ctx_constructor
self.block_size = block_size

def prepare(self):
def prepare(self) -> None:
self.cipher_ctx = self.cipher_ctx_constructor()
self.blocks_count = maximum_used_memory_in_bytes // self.block_size
self.iterations_count = 100
self.data = random_bytes(self.blocks_count * self.block_size)

def run(self):
def run(self) -> None:
for _ in range(self.iterations_count):
self.cipher_ctx.encrypt(self.data)

Expand All @@ -44,17 +44,17 @@ def get_result(self, duration_us: int, repetitions: int) -> BenchmarkResult:
class DecryptBenchmark:
def __init__(
self, cipher_ctx_constructor: Callable[[], CipherCtx], block_size: int
):
) -> None:
self.cipher_ctx_constructor = cipher_ctx_constructor
self.block_size = block_size

def prepare(self):
def prepare(self) -> None:
self.cipher_ctx = self.cipher_ctx_constructor()
self.blocks_count = maximum_used_memory_in_bytes // self.block_size
self.iterations_count = 100
self.data = random_bytes(self.blocks_count * self.block_size)

def run(self):
def run(self) -> None:
for _ in range(self.iterations_count):
self.cipher_ctx.decrypt(self.data)

Expand Down
24 changes: 12 additions & 12 deletions core/src/apps/benchmark/curve_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ def multiply(self, secret_key: bytes, public_key: bytes) -> bytes: ...


class SignBenchmark:
def __init__(self, curve: SignCurve):
def __init__(self, curve: SignCurve) -> None:
self.curve = curve

def prepare(self):
def prepare(self) -> None:
self.iterations_count = 10
self.secret_key = self.curve.generate_secret()
self.digest = random_bytes(32)

def run(self):
def run(self) -> None:
for _ in range(self.iterations_count):
self.curve.sign(self.secret_key, self.digest)

Expand All @@ -51,17 +51,17 @@ def get_result(self, duration_us: int, repetitions: int) -> BenchmarkResult:


class VerifyBenchmark:
def __init__(self, curve: SignCurve):
def __init__(self, curve: SignCurve) -> None:
self.curve = curve

def prepare(self):
def prepare(self) -> None:
self.iterations_count = 10
self.secret_key = self.curve.generate_secret()
self.public_key = self.curve.publickey(self.secret_key)
self.digest = random_bytes(32)
self.signature = self.curve.sign(self.secret_key, self.digest)

def run(self):
def run(self) -> None:
for _ in range(self.iterations_count):
self.curve.verify(self.public_key, self.signature, self.digest)

Expand All @@ -72,15 +72,15 @@ def get_result(self, duration_us: int, repetitions: int) -> BenchmarkResult:


class MultiplyBenchmark:
def __init__(self, curve: MultiplyCurve):
def __init__(self, curve: MultiplyCurve) -> None:
self.curve = curve

def prepare(self):
def prepare(self) -> None:
self.secret_key = self.curve.generate_secret()
self.public_key = self.curve.publickey(self.curve.generate_secret())
self.iterations_count = 10

def run(self):
def run(self) -> None:
for _ in range(self.iterations_count):
self.curve.multiply(self.secret_key, self.public_key)

Expand All @@ -91,14 +91,14 @@ def get_result(self, duration_us: int, repetitions: int) -> BenchmarkResult:


class PublickeyBenchmark:
def __init__(self, curve: Curve):
def __init__(self, curve: Curve) -> None:
self.curve = curve

def prepare(self):
def prepare(self) -> None:
self.iterations_count = 10
self.secret_key = self.curve.generate_secret()

def run(self):
def run(self) -> None:
for _ in range(self.iterations_count):
self.curve.publickey(self.secret_key)

Expand Down
6 changes: 3 additions & 3 deletions core/src/apps/benchmark/hash_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ def update(self, __buf: bytes) -> None: ...


class HashBenchmark:
def __init__(self, hash_ctx_constructor: Callable[[], HashCtx]):
def __init__(self, hash_ctx_constructor: Callable[[], HashCtx]) -> None:
self.hash_ctx_constructor = hash_ctx_constructor

def prepare(self):
def prepare(self) -> None:
self.hash_ctx = self.hash_ctx_constructor()
self.blocks_count = maximum_used_memory_in_bytes // self.hash_ctx.block_size
self.iterations_count = 100
self.data = random_bytes(self.blocks_count * self.hash_ctx.block_size)

def run(self):
def run(self) -> None:
for _ in range(self.iterations_count):
self.hash_ctx.update(self.data)

Expand Down
2 changes: 1 addition & 1 deletion core/src/apps/bitcoin/keychain.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def __init__(
require_bech32: bool,
require_taproot: bool,
account_level: bool = False,
):
) -> None:
self.account_name = account_name
self.pattern = pattern
self.script_type = script_type
Expand Down
26 changes: 14 additions & 12 deletions core/src/apps/bitcoin/sign_tx/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
output_index: int,
chunkify: bool,
address_n: Bip32Path | None,
):
) -> None:
self.output = output
self.coin = coin
self.amount_unit = amount_unit
Expand All @@ -66,7 +66,9 @@ def confirm_dialog(self) -> Awaitable[Any]:


class UiConfirmDecredSSTXSubmission(UiConfirm):
def __init__(self, output: TxOutput, coin: CoinInfo, amount_unit: AmountUnit):
def __init__(
self, output: TxOutput, coin: CoinInfo, amount_unit: AmountUnit
) -> None:
self.output = output
self.coin = coin
self.amount_unit = amount_unit
Expand All @@ -83,7 +85,7 @@ def __init__(
payment_req: TxAckPaymentRequest,
coin: CoinInfo,
amount_unit: AmountUnit,
):
) -> None:
self.payment_req = payment_req
self.amount_unit = amount_unit
self.coin = coin
Expand All @@ -97,7 +99,7 @@ def confirm_dialog(self) -> Awaitable[bool]:


class UiConfirmReplacement(UiConfirm):
def __init__(self, title: str, txid: bytes):
def __init__(self, title: str, txid: bytes) -> None:
self.title = title
self.txid = txid

Expand All @@ -112,7 +114,7 @@ def __init__(
orig_txo: TxOutput,
coin: CoinInfo,
amount_unit: AmountUnit,
):
) -> None:
self.txo = txo
self.orig_txo = orig_txo
self.coin = coin
Expand All @@ -133,7 +135,7 @@ def __init__(
fee_rate: float,
coin: CoinInfo,
amount_unit: AmountUnit,
):
) -> None:
self.title = title
self.user_fee_change = user_fee_change
self.total_fee_new = total_fee_new
Expand Down Expand Up @@ -161,7 +163,7 @@ def __init__(
coin: CoinInfo,
amount_unit: AmountUnit,
address_n: Bip32Path | None,
):
) -> None:
self.spending = spending
self.fee = fee
self.fee_rate = fee_rate
Expand All @@ -183,7 +185,7 @@ def confirm_dialog(self) -> Awaitable[Any]:
class UiConfirmJointTotal(UiConfirm):
def __init__(
self, spending: int, total: int, coin: CoinInfo, amount_unit: AmountUnit
):
) -> None:
self.spending = spending
self.total = total
self.coin = coin
Expand All @@ -196,7 +198,7 @@ def confirm_dialog(self) -> Awaitable[Any]:


class UiConfirmFeeOverThreshold(UiConfirm):
def __init__(self, fee: int, coin: CoinInfo, amount_unit: AmountUnit):
def __init__(self, fee: int, coin: CoinInfo, amount_unit: AmountUnit) -> None:
self.fee = fee
self.coin = coin
self.amount_unit = amount_unit
Expand All @@ -206,7 +208,7 @@ def confirm_dialog(self) -> Awaitable[Any]:


class UiConfirmChangeCountOverThreshold(UiConfirm):
def __init__(self, change_count: int):
def __init__(self, change_count: int) -> None:
self.change_count = change_count

def confirm_dialog(self) -> Awaitable[Any]:
Expand All @@ -219,7 +221,7 @@ def confirm_dialog(self) -> Awaitable[Any]:


class UiConfirmForeignAddress(UiConfirm):
def __init__(self, address_n: list):
def __init__(self, address_n: list) -> None:
self.address_n = address_n

def confirm_dialog(self) -> Awaitable[Any]:
Expand All @@ -229,7 +231,7 @@ def confirm_dialog(self) -> Awaitable[Any]:


class UiConfirmNonDefaultLocktime(UiConfirm):
def __init__(self, lock_time: int, lock_time_disabled: bool):
def __init__(self, lock_time: int, lock_time_disabled: bool) -> None:
self.lock_time = lock_time
self.lock_time_disabled = lock_time_disabled

Expand Down
2 changes: 1 addition & 1 deletion core/src/apps/bitcoin/sign_tx/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class Progress:
def __init__(self):
def __init__(self) -> None:
self.progress = 0
self.steps = 0
self.signing = False
Expand Down
2 changes: 1 addition & 1 deletion core/src/apps/bitcoin/verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
script_sig: bytes | None,
witness: bytes | None,
coin: CoinInfo,
):
) -> None:
from trezor import utils
from trezor.crypto.hashlib import sha256
from trezor.wire import DataError # local_cache_global
Expand Down
2 changes: 1 addition & 1 deletion core/src/apps/cardano/helpers/credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
key_hash: bytes | None,
script_hash: bytes | None,
pointer: messages.CardanoBlockchainPointerType | None,
):
) -> None:
self.type_name = type_name
self.address_type = address_type
self.path = path
Expand Down
2 changes: 1 addition & 1 deletion core/src/apps/cardano/helpers/hash_builder_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class HashBuilderDict(HashBuilderCollection, Generic[K, V]):
key_order_error: wire.ProcessError
previous_encoded_key: bytes

def __init__(self, size: int, key_order_error: wire.ProcessError):
def __init__(self, size: int, key_order_error: wire.ProcessError) -> None:
super().__init__(size)
self.key_order_error = key_order_error
self.previous_encoded_key = b""
Expand Down
2 changes: 1 addition & 1 deletion core/src/apps/cardano/sign_tx/ordinary_signer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
self,
msg: messages.CardanoSignTxInit,
keychain: seed.Keychain,
):
) -> None:
super().__init__(msg, keychain)
self.suite_tx_type: SuiteTxType = self._suite_tx_type()

Expand Down
4 changes: 2 additions & 2 deletions core/src/apps/common/backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ def repeated_backup_enabled() -> bool:
return storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)


def activate_repeated_backup():
def activate_repeated_backup() -> None:
storage_cache.set_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True)
wire.filters.append(_repeated_backup_filter)


def deactivate_repeated_backup():
def deactivate_repeated_backup() -> None:
storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
wire.remove_filter(_repeated_backup_filter)

Expand Down
2 changes: 1 addition & 1 deletion core/src/apps/common/cbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def __eq__(self, other: object) -> bool:

# TODO: this seems to be unused - is checked against, but is never created???
class Raw:
def __init__(self, value: Value):
def __init__(self, value: Value) -> None:
self.value = value


Expand Down
2 changes: 1 addition & 1 deletion core/src/apps/ethereum/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ async def require_confirm_claim(
)


async def require_confirm_unknown_token(address_bytes: bytes):
async def require_confirm_unknown_token(address_bytes: bytes) -> None:
from ubinascii import hexlify

from trezor.ui.layouts import confirm_address, show_warning
Expand Down
2 changes: 1 addition & 1 deletion core/src/apps/management/reset_device/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _get_slip39_mnemonics(
group_threshold: int,
groups: Sequence[tuple[int, int]],
extendable: bool,
):
) -> list[list[str]]:
if extendable:
identifier = slip39.generate_random_identifier()
else:
Expand Down
2 changes: 1 addition & 1 deletion core/src/apps/monero/signing/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self) -> None:
from apps.monero.xmr.mlsag_hasher import PreMlsagHasher

# Account credentials
# type: AccountCreds
# - type: AccountCreds
# - view private/public key
# - spend private/public key
# - and its corresponding address
Expand Down
8 changes: 6 additions & 2 deletions core/src/apps/solana/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from .types import AddressType

if TYPE_CHECKING:
from typing import Sequence

from .transaction.instructions import Instruction, SystemProgramTransferInstruction
from .types import AddressReference

Expand All @@ -33,7 +35,9 @@ def _format_path(path: list[int]) -> str:
return f"Solana #{unharden(account_index) + 1}"


def _get_address_reference_props(address: AddressReference, display_name: str):
def _get_address_reference_props(
address: AddressReference, display_name: str
) -> Sequence[tuple[str, str]]:
return (
(TR.solana__is_provided_via_lookup_table_template.format(display_name), ""),
(f"{TR.solana__lookup_table_address}:", base58.encode(address[0])),
Expand Down Expand Up @@ -293,7 +297,7 @@ async def confirm_token_transfer(
fee: int,
signer_path: list[int],
blockhash: bytes,
):
) -> None:
await confirm_value(
title=TR.words__recipient,
value=base58.encode(destination_account),
Expand Down
Loading

0 comments on commit 8fb41ee

Please sign in to comment.