Skip to content

Commit

Permalink
chore: lint codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
Brooke-white committed Nov 20, 2023
1 parent 6fbf536 commit a19417e
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 47 deletions.
37 changes: 17 additions & 20 deletions redshift_connector/interval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import typing
from redshift_connector.config import max_int4, max_int8, min_int4, min_int8
from datetime import timedelta as Timedelta

from redshift_connector.config import max_int4, max_int8, min_int4, min_int8


class Interval:
"""An Interval represents a measurement of time. In Amazon Redshift, an
Expand Down Expand Up @@ -90,6 +91,7 @@ def total_seconds(self: "Interval") -> float:
"""Total seconds in the Interval, excluding month field."""
return ((self.days * 86400) * 10**6 + self.microseconds) / 10**6


class IntervalYearToMonth(Interval):
"""An Interval Year To Month represents a measurement of time of the order
of a few months and years. Note the difference with Interval which can
Expand All @@ -99,9 +101,10 @@ class IntervalYearToMonth(Interval):
Note that 1year = 12months.
"""
def __init__(self: "IntervalYearToMonth",
months: int = 0,
year_month: typing.Tuple[int, int] = None) -> None:

def __init__(
self: "IntervalYearToMonth", months: int = 0, year_month: typing.Optional[typing.Tuple[int, int]] = None
) -> None:
if year_month is not None:
year, month = year_month
self.months = year * 12 + month
Expand All @@ -121,7 +124,7 @@ def _setMonths(self: "IntervalYearToMonth", value: int) -> None:
# days = property(lambda self: self._days, _setDays)
months = property(lambda self: self._months, _setMonths)

def getYearMonth(self: "IntervalDayToSecond") -> typing.Tuple[int, int]:
def getYearMonth(self: "IntervalYearToMonth") -> typing.Tuple[int, int]:
years = int(self.months / 12)
months = self.months - 12 * years
return (years, months)
Expand All @@ -130,15 +133,12 @@ def __repr__(self: "IntervalYearToMonth") -> str:
return "<IntervalYearToMonth %s months>" % (self.months)

def __eq__(self: "IntervalYearToMonth", other: object) -> bool:
return (
other is not None
and isinstance(other, IntervalYearToMonth)
and self.months == other.months
)
return other is not None and isinstance(other, IntervalYearToMonth) and self.months == other.months

def __neq__(self: "IntervalYearToMonth", other: "IntervalYearToMonth") -> bool:
def __neq__(self: "IntervalYearToMonth", other: "Interval") -> bool:
return not self.__eq__(other)


class IntervalDayToSecond(Interval):
"""An Interval Day To Second represents a measurement of time of the order
of a few microseconds. Note the difference with Interval which can
Expand All @@ -148,9 +148,10 @@ class IntervalDayToSecond(Interval):
Note that 1day = 24 * 3600 * 1000000 microseconds.
"""
def __init__(self: "IntervalDayToSecond",
microseconds: int = 0,
timedelta: Timedelta = None) -> None:

def __init__(
self: "IntervalDayToSecond", microseconds: int = 0, timedelta: typing.Optional[Timedelta] = None
) -> None:
if timedelta is not None:
self.microseconds = int(timedelta.total_seconds() * (10**6))
else:
Expand All @@ -173,13 +174,9 @@ def __repr__(self: "IntervalDayToSecond") -> str:
return "<IntervalDayToSecond %s microseconds>" % (self.microseconds)

def __eq__(self: "IntervalDayToSecond", other: object) -> bool:
return (
other is not None
and isinstance(other, IntervalDayToSecond)
and self.microseconds == other.microseconds
)
return other is not None and isinstance(other, IntervalDayToSecond) and self.microseconds == other.microseconds

def __neq__(self: "IntervalDayToSecond", other: "IntervalDayToSecond") -> bool:
def __neq__(self: "IntervalDayToSecond", other: "Interval") -> bool:
return not self.__eq__(other)

def getTimedelta(self: "IntervalDayToSecond") -> Timedelta:
Expand Down
15 changes: 12 additions & 3 deletions redshift_connector/utils/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
_client_encoding,
timegm,
)
from redshift_connector.interval import Interval, IntervalYearToMonth, IntervalDayToSecond
from redshift_connector.interval import (
Interval,
IntervalDayToSecond,
IntervalYearToMonth,
)
from redshift_connector.pg_types import (
PGEnum,
PGJson,
Expand Down Expand Up @@ -205,11 +209,13 @@ def interval_send_integer(v: typing.Union[Timedelta, Interval]) -> bytes:

return typing.cast(bytes, qhh_pack(microseconds, 0, months))


def intervaly2m_send_integer(v: IntervalYearToMonth) -> bytes:
months = v.months # type: ignore

return typing.cast(bytes, i_pack(months))


def intervald2s_send_integer(v: IntervalDayToSecond) -> bytes:
microseconds = v.microseconds # type: ignore

Expand Down Expand Up @@ -284,14 +290,17 @@ def interval_recv_integer(data: bytes, offset: int, length: int) -> typing.Union
else:
return Timedelta(days, seconds, micros)


def intervaly2m_recv_integer(data: bytes, offset: int, length: int) -> IntervalYearToMonth:
months, = typing.cast(typing.Tuple[int], i_unpack(data, offset))
(months,) = typing.cast(typing.Tuple[int], i_unpack(data, offset))
return IntervalYearToMonth(months)


def intervald2s_recv_integer(data: bytes, offset: int, length: int) -> IntervalDayToSecond:
microseconds, = typing.cast(typing.Tuple[int], q_unpack(data, offset))
(microseconds,) = typing.cast(typing.Tuple[int], q_unpack(data, offset))
return IntervalDayToSecond(microseconds)


def timetz_recv_binary(data: bytes, offset: int, length: int) -> time:
return time_recv_binary(data, offset, length).replace(tzinfo=Timezone.utc)

Expand Down
1 change: 1 addition & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def serverless_cname_db_kwargs() -> typing.Dict[str, typing.Union[str, bool]]:

return db_connect


@pytest.fixture(scope="class")
def ds_consumer_db_kwargs() -> typing.Dict[str, str]:
db_connect = {
Expand Down
23 changes: 17 additions & 6 deletions test/integration/datatype/_generate_test_datatype_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import typing
from datetime import date, datetime, time, timezone
from decimal import Decimal
from redshift_connector.interval import IntervalYearToMonth, IntervalDayToSecond
from enum import Enum, auto

from redshift_connector.interval import IntervalDayToSecond, IntervalYearToMonth

if typing.TYPE_CHECKING:
from redshift_connector import Connection

Expand Down Expand Up @@ -36,7 +37,17 @@ def list(cls) -> typing.List["RedshiftDatatypes"]:


redshift_test_data: typing.Dict[
str, typing.Union[typing.Tuple[typing.Tuple[str, str], ...], typing.List[typing.Tuple[str, ...]]]
str,
typing.Union[
typing.Tuple[typing.Tuple[str, str], ...],
typing.List[
typing.Union[
typing.Tuple[str, ...],
typing.Tuple[str, IntervalYearToMonth, str],
typing.Tuple[str, IntervalDayToSecond, str],
]
],
],
] = {
RedshiftDatatypes.geometry.name: (
(
Expand Down Expand Up @@ -161,14 +172,14 @@ def list(cls) -> typing.List["RedshiftDatatypes"]:
RedshiftDatatypes.intervaly2m.name: [
("37 months", IntervalYearToMonth(37), "y2m_postgres_format"),
("1-1", IntervalYearToMonth(13), "y2m_sql_standard_format"),
("-178956970-8", IntervalYearToMonth(-2**31), "y2m_min_value"),
("178956970-7", IntervalYearToMonth(2**31 - 1), "y2m_max_value")
("-178956970-8", IntervalYearToMonth(-(2**31)), "y2m_min_value"),
("178956970-7", IntervalYearToMonth(2**31 - 1), "y2m_max_value"),
],
RedshiftDatatypes.intervald2s.name: [
("10 days 48 hours", IntervalDayToSecond(12 * 86400 * 1000000), "d2s_postgres_format"),
("10 23:59:59.999999", IntervalDayToSecond(11 * 86400 * 1000000 - 1), "d2s_sql_standard_format"),
("-106751991 -04:00:54.775808", IntervalDayToSecond(-2**63), "d2s_min_value"),
("106751991 04:00:54.775807", IntervalDayToSecond(2**63 - 1), "d2s_max_value")
("-106751991 -04:00:54.775808", IntervalDayToSecond(-(2**63)), "d2s_min_value"),
("106751991 04:00:54.775807", IntervalDayToSecond(2**63 - 1), "d2s_max_value"),
]
# TODO: re-enable
# RedshiftDatatypes.geography.name: (
Expand Down
27 changes: 14 additions & 13 deletions test/integration/datatype/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import typing
from datetime import datetime as Datetime
from datetime import timezone
from redshift_connector.interval import IntervalYearToMonth, IntervalDayToSecond
from math import isclose
from test.integration.datatype._generate_test_datatype_tables import ( # type: ignore
DATATYPES_WITH_MS,
Expand All @@ -22,6 +21,7 @@

import redshift_connector
from redshift_connector.config import ClientProtocolVersion
from redshift_connector.interval import IntervalDayToSecond, IntervalYearToMonth

conf = configparser.ConfigParser()
root_path = os.path.dirname(os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))))
Expand Down Expand Up @@ -120,9 +120,9 @@ def test_redshift_varbyte_insert(db_kwargs, _input, client_protocol) -> None:
assert len(results[0]) == 2
assert results[0][1] == bytes(data, encoding="utf-8").hex()


@pytest.mark.parametrize("client_protocol", ClientProtocolVersion.list())
@pytest.mark.parametrize("datatype", [RedshiftDatatypes.intervaly2m.name,
RedshiftDatatypes.intervald2s.name])
@pytest.mark.parametrize("datatype", [RedshiftDatatypes.intervaly2m.name, RedshiftDatatypes.intervald2s.name])
def test_redshift_interval_insert(db_kwargs, datatype, client_protocol) -> None:
db_kwargs["client_protocol_version"] = client_protocol
data = redshift_test_data[datatype]
Expand All @@ -139,13 +139,13 @@ def test_redshift_interval_insert(db_kwargs, datatype, client_protocol) -> None:
print(results)
for idx, result in enumerate(results):
print(result[1], data[idx][1])
assert(isinstance(result[1], redshift_type))
assert(result[1] == data[idx][1])
assert isinstance(result[1], redshift_type)
assert result[1] == data[idx][1]
cursor.execute("drop table t_interval")


@pytest.mark.parametrize("client_protocol", ClientProtocolVersion.list())
@pytest.mark.parametrize("datatype", [RedshiftDatatypes.intervaly2m.name,
RedshiftDatatypes.intervald2s.name])
@pytest.mark.parametrize("datatype", [RedshiftDatatypes.intervaly2m.name, RedshiftDatatypes.intervald2s.name])
def test_redshift_interval_prep_stmt(db_kwargs, datatype, client_protocol) -> None:
db_kwargs["client_protocol_version"] = client_protocol
data = redshift_test_data[datatype]
Expand All @@ -155,19 +155,20 @@ def test_redshift_interval_prep_stmt(db_kwargs, datatype, client_protocol) -> No
with con.cursor() as cursor:
cursor.execute("create table t_interval_ps(id text, v1 {})".format(datatype))
cursor.paramstyle = "pyformat"
cursor.executemany("insert into t_interval_ps(id, v1) values (%(id_val)s, %(v1_val)s)",
({"id_val": row[-1], "v1_val": row[1]} for row in data[:2]))
cursor.executemany(
"insert into t_interval_ps(id, v1) values (%(id_val)s, %(v1_val)s)",
({"id_val": row[-1], "v1_val": row[1]} for row in data[:2]),
)
cursor.paramstyle = "qmark"
cursor.executemany("insert into t_interval_ps values (?, ?)",
([row[-1], row[1]] for row in data[2:]))
cursor.executemany("insert into t_interval_ps values (?, ?)", ([row[-1], row[1]] for row in data[2:]))
cursor.execute("select id, v1 from t_interval_ps")
results: typing.Tuple = cursor.fetchall()
assert len(results) == len(data)
print(results)
for idx, result in enumerate(results):
print(result[1], data[idx][1])
assert(isinstance(result[1], redshift_type))
assert(result[1] == data[idx][1])
assert isinstance(result[1], redshift_type)
assert result[1] == data[idx][1]
cursor.execute("drop table t_interval_ps")


Expand Down
19 changes: 14 additions & 5 deletions test/unit/datatype/test_data_in.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
import pytest # type: ignore

import redshift_connector
from redshift_connector.interval import Interval, IntervalYearToMonth, IntervalDayToSecond
from redshift_connector.interval import (
Interval,
IntervalDayToSecond,
IntervalYearToMonth,
)
from redshift_connector.utils import type_utils


Expand Down Expand Up @@ -347,14 +351,19 @@ class Datatypes(Enum):
(b"\x01\x0e\x00\x00\x00*", 2, 4, IntervalYearToMonth(months=42)),
(b"\x01\x0e\x00\x00\x00*", 2, 4, IntervalYearToMonth(year_month=(3, 6))),
(b"\x00\x00\x00\x02", 0, 4, IntervalYearToMonth(months=2)),
(b"\x00\x00\x00\x02", 0, 4, IntervalYearToMonth(year_month=(1, -10)))
(b"\x00\x00\x00\x02", 0, 4, IntervalYearToMonth(year_month=(1, -10))),
],
Datatypes.intervald2s: [
(b"\x00\x00\x0c\x00\x00\x00\x00\x02\xdf\xda\xe8\x00", 4, 8, IntervalDayToSecond(microseconds=12345600000)),
(b"\x00\x00\x0c\x00\x00\x00\x00\x02\xdf\xda\xe8\x00", 4, 8, IntervalDayToSecond(timedelta=timedelta(hours=3, minutes=25, seconds=45.6))),
(
b"\x00\x00\x0c\x00\x00\x00\x00\x02\xdf\xda\xe8\x00",
4,
8,
IntervalDayToSecond(timedelta=timedelta(hours=3, minutes=25, seconds=45.6)),
),
(b"\x00\x00\x00\x00\x00\x00\x00\x02", 0, 8, IntervalDayToSecond(microseconds=2)),
(b"\x00\x00\x00\x00\x00\x00\x00\x02", 0, 8, IntervalDayToSecond(timedelta=timedelta(microseconds=2)))
]
(b"\x00\x00\x00\x00\x00\x00\x00\x02", 0, 8, IntervalDayToSecond(timedelta=timedelta(microseconds=2))),
],
}


Expand Down

0 comments on commit a19417e

Please sign in to comment.