From 243693bae57b74f1517aaf4cdf792dd57a4cbccc Mon Sep 17 00:00:00 2001 From: Alex Wijnholds Date: Sun, 26 Jan 2025 15:51:46 +0100 Subject: [PATCH] No need to pass the member id when we can retrieve it ourselves --- custom_components/sat/manufacturer.py | 12 ++++++------ tests/test_manufacturer.py | 3 ++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/custom_components/sat/manufacturer.py b/custom_components/sat/manufacturer.py index 904116be..7bdac462 100644 --- a/custom_components/sat/manufacturer.py +++ b/custom_components/sat/manufacturer.py @@ -26,8 +26,8 @@ class Manufacturer(ABC): - def __init__(self, member_id: int): - self._member_id = member_id + def __init__(self): + self._member_id = MANUFACTURERS.get(type(self).__name__) @property def member_id(self) -> int: @@ -46,15 +46,15 @@ def resolve_by_name(name: str) -> Optional[Manufacturer]: if not (member_id := MANUFACTURERS.get(name)): return None - return ManufacturerFactory._import_class(snake_case(name), name)(member_id) + return ManufacturerFactory._import_class(snake_case(name), name)() @staticmethod def resolve_by_member_id(member_id: int) -> List[Manufacturer]: """Resolve a list of Manufacturer instances by member ID.""" return [ - ManufacturerFactory._import_class(snake_case(name), name)(identifier) - for name, identifier in MANUFACTURERS.items() - if member_id == identifier + ManufacturerFactory._import_class(snake_case(name), name)() + for name, value in MANUFACTURERS.items() + if member_id == value ] @staticmethod diff --git a/tests/test_manufacturer.py b/tests/test_manufacturer.py index a295aeac..f1689d0a 100644 --- a/tests/test_manufacturer.py +++ b/tests/test_manufacturer.py @@ -25,7 +25,8 @@ def test_resolve_by_member_id(): assert len(manufacturers) == len(names), f"Expected {len(names)} manufacturers for member ID {member_id}" for manufacturer in manufacturers: - assert manufacturer.__class__.__name__ in names, f"Manufacturer name '{manufacturer.name}' not expected for member ID {member_id}" + assert manufacturer.member_id == member_id, f"Expected {manufacturer.member_id} for member ID {member_id}" + assert manufacturer.__class__.__name__ in names, f"Manufacturer name '{manufacturer.friendly_name}' not expected for member ID {member_id}" # Test invalid member ID manufacturers = ManufacturerFactory.resolve_by_member_id(999)