diff --git a/README.rst b/README.rst index d6215e7..f3f50e2 100644 --- a/README.rst +++ b/README.rst @@ -306,6 +306,14 @@ Connection Parameters +-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+ | iam_disable_cache | bool | This option specifies whether the IAM credentials are cached. By default the IAM credentials are cached. This improves performance when requests to the API gateway are throttled. | FALSE | No | +-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+ +| idc_client_display_name | str | The client display name to be used in user consent in IdC browser auth. This is an optional value. The default value is "Amazon Redshift Python connector". | None | No | ++-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+ +| idc_region | str | The AWS region where IdC instance is located. It is required for the IdC browser auth plugin. | None | No | ++-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+ +| idc_response_timeout | int | The timeout value in seconds for the IdC browser auth plugin. This is an optional value. | 120 | No | ++-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+ +| identity_namespace | str | The identity namespace to be used for the IdC browser auth plugin and IdP token auth plugin. It is an optional value if there is only one IdC instance existing or if default identity namespace is set on the cluster - else it is required. | None | No | ++-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+ | idp_response_timeout | int | The timeout for retrieving SAML assertion from IdP | 120 | No | +-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+ | idp_tenant | str | The IdP tenant | None | No | @@ -354,8 +362,14 @@ Connection Parameters +-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+ | sslmode | str | The security of the connection to Amazon Redshift. verify-ca and verify-full are supported. | verify_ca | No | +-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+ +| start_url | str | The directory or start url for the AWS IdC access portal. It is required for the IdC browser auth plugin. | None | No | ++-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+ | timeout | int | The number of seconds before the connection to the server will timeout. | None | No | +-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+ +| token | str | The access token required for the IdP token auth plugin. | None | No | ++-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+ +| token_type | str | The token type required for the IdP token auth plugin. | ACCESS_TOKEN | No | ++-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+ | user | str | The username to use for authentication | None | No | +-----------------------------------+------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------+ | web_identity_token | str | The OAuth 2.0 access token or OpenID Connect ID token that is provided by the identity provider. Your application must get this token by authenticating the user who is using your application with a web identity provider. This parameter is used by JwtCredentialsProvider. For this provider, this is a mandatory parameter. | None | No | diff --git a/redshift_connector/__init__.py b/redshift_connector/__init__.py index 94264f2..f437488 100644 --- a/redshift_connector/__init__.py +++ b/redshift_connector/__init__.py @@ -56,6 +56,14 @@ logging.getLogger(__name__).addHandler(logging.NullHandler()) _logger: logging.Logger = logging.getLogger(__name__) +IDC_PLUGINS_LIST = ("redshift_connector.plugin.BrowserIdcAuthPlugin", "BrowserIdcAuthPlugin", + "redshift_connector.plugin.IdpTokenAuthPlugin", "IdpTokenAuthPlugin") +IDC_OR_NATIVE_IDP_PLUGINS_LIST = ( + "redshift_connector.plugin.BrowserAzureOAuth2CredentialsProvider", "BrowserAzureOAuth2CredentialsProvider", + "redshift_connector.plugin.BasicJwtCredentialsProvider", "BasicJwtCredentialsProvider", + "redshift_connector.plugin.BrowserIdcAuthPlugin", "BrowserIdcAuthPlugin", + "redshift_connector.plugin.IdpTokenAuthPlugin", "IdpTokenAuthPlugin") + # Copyright (c) 2007-2009, Mathieu Fenniak # Copyright (c) The Contributors # All rights reserved. @@ -143,6 +151,13 @@ def connect( serverless_acct_id: typing.Optional[str] = None, serverless_work_group: typing.Optional[str] = None, group_federation: typing.Optional[bool] = None, + start_url: typing.Optional[str] = None, + idc_region: typing.Optional[str] = None, + idc_response_timeout: typing.Optional[int] = None, + identity_namespace: typing.Optional[str] = None, + idc_client_display_name: typing.Optional[str] = None, + token: typing.Optional[str] = None, + token_type: typing.Optional[str] = None, ) -> Connection: """ Establishes a :class:`Connection` to an Amazon Redshift cluster. This function validates user input, optionally authenticates using an identity provider plugin, then constructs a :class:`Connection` object. @@ -246,6 +261,20 @@ def connect( The name of work group for serverless end point. Default value None. group_federation: Optional[bool] Use the IDP Groups in the Redshift. Default value False. + start_url: Optional[str] + The directory or start url for the AWS IdC access portal. Default value is None. + idc_region: Optional[str] + The AWS region where IdC instance is located. Default value is None. + idc_response_timeout: Optional[int] + The timeout value in seconds for the IdC browser auth. Default value is `120`. + identity_namespace: Optional[str] + The identity namespace to be used with IdC auth plugin. Default value is None. + idc_client_display_name: Optional[str] + The client display name to be used in user consent in IdC browser auth. Default value is `Amazon Redshift Python connector`. + token: Optional[str] + The access token to be used with IdC basic credentials provider plugin. Default value is None. + token_type: Optional[str] + The token type to be used for authentication using IdP token auth plugin. Default value is None. Returns ------- A Connection object associated with the specified Amazon Redshift cluster: :class:`Connection` @@ -273,6 +302,10 @@ def connect( info.put("host", host) info.put("iam", iam) info.put("iam_disable_cache", iam_disable_cache) + info.put("idc_client_display_name", idc_client_display_name) + info.put("idc_region", idc_region) + info.put("idc_response_timeout", idc_response_timeout) + info.put("identity_namespace", identity_namespace) info.put("idp_host", idp_host) info.put("idp_response_timeout", idp_response_timeout) info.put("idp_tenant", idp_tenant) @@ -298,11 +331,14 @@ def connect( info.put("serverless_work_group", serverless_work_group) info.put("session_token", session_token) info.put("source_address", source_address) + info.put("start_url", start_url) info.put("ssl", ssl) info.put("ssl_insecure", ssl_insecure) info.put("sslmode", sslmode) info.put("tcp_keepalive", tcp_keepalive) info.put("timeout", timeout) + info.put("token", token) + info.put("token_type", token_type) info.put("unix_sock", unix_sock) info.put("user_name", user) info.put("web_identity_token", web_identity_token) @@ -313,8 +349,15 @@ def connect( _logger.debug(mask_secure_info_in_props(info).__str__()) _logger.debug(make_divider_block()) - if (info.ssl is False) and (info.iam is True): - raise InterfaceError("Invalid connection property setting. SSL must be enabled when using IAM") + _logger.debug("plugin = {} and iam={}".format(info.credentials_provider, info.iam)) + if (info.credentials_provider in IDC_PLUGINS_LIST) and (info.iam is True): + raise InterfaceError("You can not use this authentication plugin with IAM enabled.") + + if info.ssl is False: + if info.iam is True: + raise InterfaceError("Invalid connection property setting. SSL must be enabled when using IAM") + if info.credentials_provider in IDC_OR_NATIVE_IDP_PLUGINS_LIST: + raise InterfaceError("Authentication must use an SSL connection.") if (info.iam is False) and (info.ssl_insecure is False): raise InterfaceError("Invalid connection property setting. IAM must be enabled when using ssl_insecure") @@ -362,6 +405,9 @@ def connect( provider_name=info.provider_name, web_identity_token=info.web_identity_token, numeric_to_float=info.numeric_to_float, + identity_namespace=info.identity_namespace, + token_type=info.token_type, + idc_client_display_name=info.idc_client_display_name, ) diff --git a/redshift_connector/core.py b/redshift_connector/core.py index 56389db..6c93f4b 100644 --- a/redshift_connector/core.py +++ b/redshift_connector/core.py @@ -431,6 +431,9 @@ def __init__( provider_name: typing.Optional[str] = None, web_identity_token: typing.Optional[str] = None, numeric_to_float: bool = False, + identity_namespace: typing.Optional[str] = None, + token_type: typing.Optional[str] = None, + idc_client_display_name: typing.Optional[str] = None, ): """ Creates a :class:`Connection` to an Amazon Redshift cluster. For more information on establishing a connection to an Amazon Redshift cluster using `federated API access `_ see our examples page. @@ -475,6 +478,12 @@ def __init__( A web identity token used for authentication via Redshift Native IDP Integration numeric_to_float: bool Specifies if NUMERIC datatype values will be converted from ``decimal.Decimal`` to ``float``. By default NUMERIC values are received as ``decimal.Decimal``. + identity_namespace: Optional[str] + The identity namespace to be used with IdC auth plugin. Default value is None. + token_type: Optional[str] + The token type to be used for authentication using IdP Token auth plugin + idc_client_display_name: Optional[str] + The client display name to be used for user consent in IdC browser auth plugin. """ self.merge_socket_read = True @@ -555,8 +564,16 @@ def get_calling_module() -> str: redshift_native_auth = True init_params["idp_type"] = "AzureAD" - if provider_name: - init_params["provider_name"] = provider_name + if credentials_provider.split(".")[-1] in ( + "IdpTokenAuthPlugin", + "BrowserIdcAuthPlugin", + ): + redshift_native_auth = True + self.set_idc_plugins_params(init_params, credentials_provider, identity_namespace, token_type, + idc_client_display_name) + + if redshift_native_auth and provider_name: + init_params["provider_name"] = provider_name if not redshift_native_auth or user: init_params["user"] = user @@ -1434,17 +1451,17 @@ def handle_AUTHENTICATION_REQUEST(self: "Connection", data: bytes, cursor: Curso self.auth.set_server_final(data[4:].decode("utf8")) elif auth_code == 14: # Redshift Native IDP Integration - _logger.debug("BE requested Redshift Native IDP authentication") - aad_token: str = typing.cast(str, self.web_identity_token) + _logger.debug("BE requested Redshift Native IDP or IdC authentication") + nativeidp_or_idcpez_token: str = typing.cast(str, self.web_identity_token) - if not aad_token: + if not nativeidp_or_idcpez_token: raise ConnectionAbortedError( - "The server requested AAD token-based authentication, but no token was provided." + "The server requested Native IdP or IdC token-based authentication, but no token was provided." ) - _logger.debug("Sending IdP token to BE") + _logger.debug("Sending IdP or IdC/PEZ token to BE") - token: bytes = aad_token.encode(encoding="utf-8") + token: bytes = nativeidp_or_idcpez_token.encode(encoding="utf-8") self._write(create_message(b"i", token)) # self._write(NULL_BYTE) self._flush() @@ -2533,3 +2550,24 @@ def tpc_recover(self: "Connection") -> typing.List[typing.Tuple[typing.Any, ...] return [self.xid(0, row[0], "") for row in curs] finally: self.autocommit = previous_autocommit_mode + + def set_idc_plugins_params(self: "Connection", init_params: typing.Dict[str, typing.Optional[typing.Union[str, bytes]]], + credentials_provider: typing.Optional[str] = None, + identity_namespace: typing.Optional[str] = None, + token_type: typing.Optional[str] = None, + idc_client_display_name: typing.Optional[str] = None) -> None: + plugin_name = credentials_provider.split(".")[-1] + init_params["idp_type"] = "AwsIdc" + + if identity_namespace: + init_params["identity_namespace"] = identity_namespace + + if plugin_name == "BrowserIdcAuthPlugin": + init_params["token_type"] = "ACCESS_TOKEN" + elif token_type: + init_params["token_type"] = token_type + + if idc_client_display_name: + init_params["idc_client_display_name"] = idc_client_display_name + + diff --git a/redshift_connector/idp_auth_helper.py b/redshift_connector/idp_auth_helper.py index fd1e116..6abb39f 100644 --- a/redshift_connector/idp_auth_helper.py +++ b/redshift_connector/idp_auth_helper.py @@ -33,6 +33,7 @@ class IdpAuthHelper: # Subtype of plugin SAML_PLUGIN: int = 1 JWT_PLUGIN: int = 2 + IDC_PLUGIN: int = 3 @staticmethod def get_pkg_version(module_name: str) -> Version: diff --git a/redshift_connector/plugin/__init__.py b/redshift_connector/plugin/__init__.py index 35140d5..2c6c63f 100644 --- a/redshift_connector/plugin/__init__.py +++ b/redshift_connector/plugin/__init__.py @@ -1,10 +1,13 @@ from .adfs_credentials_provider import AdfsCredentialsProvider from .azure_credentials_provider import AzureCredentialsProvider +from .idp_token_auth_plugin import IdpTokenAuthPlugin from .browser_azure_credentials_provider import BrowserAzureCredentialsProvider from .browser_azure_oauth2_credentials_provider import ( BrowserAzureOAuth2CredentialsProvider, ) +from .browser_idc_auth_plugin import BrowserIdcAuthPlugin from .browser_saml_credentials_provider import BrowserSamlCredentialsProvider +from .common_credentials_provider import CommonCredentialsProvider from .idp_credentials_provider import IdpCredentialsProvider from .jwt_credentials_provider import ( BasicJwtCredentialsProvider, diff --git a/redshift_connector/plugin/browser_idc_auth_plugin.py b/redshift_connector/plugin/browser_idc_auth_plugin.py new file mode 100644 index 0000000..c2dec3d --- /dev/null +++ b/redshift_connector/plugin/browser_idc_auth_plugin.py @@ -0,0 +1,235 @@ +import boto3 +import logging +import typing +import time +import webbrowser + +from botocore.exceptions import ClientError +from redshift_connector.error import InterfaceError +from redshift_connector.plugin.common_credentials_provider import CommonCredentialsProvider +from redshift_connector.redshift_property import RedshiftProperty + +logging.getLogger(__name__).addHandler(logging.NullHandler()) +_logger: logging.Logger = logging.getLogger(__name__) + + +class BrowserIdcAuthPlugin(CommonCredentialsProvider): + """ + Class to get IdC Token using SSO OIDC APIs + """ + + DEFAULT_IDC_CLIENT_DISPLAY_NAME = 'Amazon Redshift Python connector' + CLIENT_TYPE = 'public' + GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:device_code' + IDC_SCOPE = 'redshift:connect' + DEFAULT_BROWSER_AUTH_VERIFY_TIMEOUT_IN_SEC = 120 + DEFAULT_CREATE_TOKEN_INTERVAL_IN_SEC = 1 + + def __init__(self: "BrowserIdcAuthPlugin") -> None: + super().__init__() + self.idc_response_timeout: int = self.DEFAULT_BROWSER_AUTH_VERIFY_TIMEOUT_IN_SEC + self.idc_client_display_name: str = self.DEFAULT_IDC_CLIENT_DISPLAY_NAME + self.register_client_cache: typing.Dict[str, dict] = {} + self.start_url: typing.Optional[str] = None + self.idc_region: typing.Optional[str] = None + self.sso_oidc_client: "SSOOIDC.Client" = None + + def add_parameter( + self: "BrowserIdcAuthPlugin", + info: RedshiftProperty, + ) -> None: + """ + Adds parameters to the BrowserIdcAuthPlugin + :param info: RedshiftProperty object containing the parameters to be added to the BrowserIdcAuthPlugin. + :return: None. + """ + super().add_parameter(info) + self.start_url = info.start_url + _logger.debug("Setting start_url = {}".format(self.start_url)) + self.idc_region = info.idc_region + _logger.debug("Setting idc_region = {}".format(self.idc_region)) + if info.idc_response_timeout and info.idc_response_timeout > 10: + self.idc_response_timeout = info.idc_response_timeout + _logger.debug("Setting idc_response_timeout = {}".format(self.idc_response_timeout)) + if info.idc_client_display_name: + self.idc_client_display_name = info.idc_client_display_name + _logger.debug("Setting idc_client_display_name = {}".format(self.idc_client_display_name)) + + def check_required_parameters(self: "BrowserIdcAuthPlugin") -> None: + """ + Checks if the required parameters are set. + :return: None. + :raises InterfaceError: Raised when the parameters are not valid. + """ + super().check_required_parameters() + if not self.start_url: + _logger.error("IdC authentication failed: start_url needs to be provided in connection params") + raise InterfaceError( + "IdC authentication failed: The start URL must be included in the connection parameters.") + if not self.idc_region: + _logger.error("IdC authentication failed: idc_region needs to be provided in connection params") + raise InterfaceError( + "IdC authentication failed: The IdC region must be included in the connection parameters.") + + def get_cache_key(self: "BrowserIdcAuthPlugin") -> str: + """ + Returns the cache key for the BrowserIdcAuthPlugin. + :return: str. + """ + return "{}".format(self.start_url if self.start_url else "") + + def get_auth_token(self: "BrowserIdcAuthPlugin") -> str: + """ + Returns the auth token as per plugin specific implementation. + :return: str. + """ + return self.get_idc_token() + + def get_idc_token(self: "BrowserIdcAuthPlugin") -> str: + """ + Returns the IdC token using SSO OIDC APIs. + :return: str. + """ + try: + self.check_required_parameters() + + self.sso_oidc_client = boto3.client('sso-oidc', region_name=self.idc_region) + register_client_cache_key: str = f"{self.idc_client_display_name}:{self.idc_region}" + + register_client_result: typing.Dict[str, typing.Any] = self.register_client(register_client_cache_key, + self.idc_client_display_name, + self.CLIENT_TYPE, + self.IDC_SCOPE) + start_device_auth_result: typing.Dict[str, typing.Any] = self.start_device_authorization( + register_client_result['clientId'], register_client_result['clientSecret'], self.start_url) + self.open_browser(start_device_auth_result['verificationUriComplete']) + + return self.poll_for_create_token(register_client_result, start_device_auth_result, self.GRANT_TYPE) + except InterfaceError as e: + raise + except Exception as e: + _logger.debug("An error occurred while trying to obtain an IdC token : {}".format(str(e))) + raise InterfaceError("There was an error during authentication.") + + def register_client(self: "BrowserIdcAuthPlugin", register_client_cache_key: str, client_name: str, + client_type: str, scope: str) -> typing.Dict[str, typing.Any]: + """ + Registers the client with IdC. + :param register_client_cache_key: str + The cache key used for storing register client result. + :param client_name: str + The client name to be used for registering the client. + :param client_type: str + The client type to be used for registering the client. + :param scope: str + The scope to be used for registering the client. + :return: dict + The register client result from IdC + """ + if register_client_cache_key in self.register_client_cache and \ + self.register_client_cache[register_client_cache_key]['clientSecretExpiresAt'] > time.time(): + _logger.debug("Valid registerClient result found from cache") + return self.register_client_cache[register_client_cache_key] + + try: + register_client_result: typing.Dict[str, typing.Any] = self.sso_oidc_client.register_client( + clientName=client_name, clientType=client_type, scopes=[scope]) + self.register_client_cache[register_client_cache_key] = register_client_result + return register_client_result + except ClientError as e: + self.handle_error(e, "registering client with IdC") + + def start_device_authorization(self: "BrowserIdcAuthPlugin", client_id: str, client_secret: str, + start_url: str) -> typing.Dict[str, typing.Any]: + """ + Starts device authorization flow with IdC. + :param client_id: str + The client id to be used for starting device authorization. + :param client_secret: str + The client secret to be used for starting device authorization. + :param start_url: str + The portal start url to be used for starting device authorization. + :return: dict + The start device authorization result from IdC. + """ + try: + response: typing.Dict[str, typing.Any] = self.sso_oidc_client.start_device_authorization( + clientId=client_id, + clientSecret=client_secret, + startUrl=start_url + ) + return response + except ClientError as e: + self.handle_error(e, "starting device authorization with IdC") + + def open_browser(self: "BrowserIdcAuthPlugin", url: str) -> None: + """ + Opens the default browser with this url to allow user authentication with the IdC + :param url: str + The verification uri obtained from start device auth response + :return: None. + """ + _logger.debug("Opening browser with url: {}".format(url)) + self.validate_url(url) + webbrowser.open(url) + + def poll_for_create_token(self: "BrowserIdcAuthPlugin", + register_client_result: typing.Dict[str, typing.Any], + start_device_auth_result: typing.Dict[str, typing.Any], grant_type: str) -> str: + """ + Polls for IdC access token using SSO OIDC APIs. + :param register_client_result: dict + The register client result from IdC. + :param start_device_auth_result: dict + The start device auth result from IdC. + :param grant_type: str + The grant type to be used for polling for IdC access token. + :return: str + The IdC access token obtained from polling for IdC access token. + :raises InterfaceError: Raised when the IdC access token is not fetched successfully. + """ + polling_end_time: float = time.time() + self.idc_response_timeout + + polling_interval_in_sec: int = self.DEFAULT_CREATE_TOKEN_INTERVAL_IN_SEC + if start_device_auth_result['interval']: + polling_interval_in_sec = start_device_auth_result['interval'] + + while time.time() < polling_end_time: + try: + response: typing.Dict[str, typing.Any] = self.sso_oidc_client.create_token( + clientId=register_client_result['clientId'], + clientSecret=register_client_result['clientSecret'], + grantType=grant_type, + deviceCode=start_device_auth_result['deviceCode'] + ) + if not response['accessToken']: + raise InterfaceError("IdC authentication failed : The credential token couldn't be created.") + return response['accessToken'] + except ClientError as e: + if e.response['Error']['Code'] == 'AuthorizationPendingException': + _logger.debug("Browser authorization pending from user") + time.sleep(polling_interval_in_sec) + else: + self.handle_error(e, "polling for an IdC access token") + + raise InterfaceError("IdC authentication failed : The request timed out. Authentication wasn't completed.") + + def handle_error(self: "BrowserIdcAuthPlugin", e: ClientError, operation: str) -> None: + """ + Handles the client error from SSO OIDC APIs. + :param e: ClientError + The error from SSO OIDC API. + :param operation: str + The operation for which error was encountered. + :return: None. + :raises InterfaceError: A client error to be returned to the user with appropriate error message + """ + _logger.debug("Error response = {} ".format(e.response)) + error_message = e.response['Error']['Message'] + if not error_message: + error_message = e.response['error_description'] if e.response[ + 'error_description'] else "Something unexpected happened" + error_code = e.response['Error']['Code'] + _logger.debug( + "An error occurred while {}: ClientError = {} - {}".format(operation, error_code, error_message)) + raise InterfaceError("IdC authentication failed : An error occurred during the request.") diff --git a/redshift_connector/plugin/common_credentials_provider.py b/redshift_connector/plugin/common_credentials_provider.py new file mode 100644 index 0000000..b3b1827 --- /dev/null +++ b/redshift_connector/plugin/common_credentials_provider.py @@ -0,0 +1,87 @@ +import logging +import typing +from abc import abstractmethod + +from redshift_connector.error import InterfaceError +from redshift_connector.iam_helper import IamHelper +from redshift_connector.plugin.i_native_plugin import INativePlugin +from redshift_connector.plugin.idp_credentials_provider import IdpCredentialsProvider +from redshift_connector.plugin.native_token_holder import NativeTokenHolder +from redshift_connector.redshift_property import RedshiftProperty + +_logger: logging.Logger = logging.getLogger(__name__) + + +class CommonCredentialsProvider(INativePlugin, IdpCredentialsProvider): + """ + Abstract base class for authentication plugins using IdC authentication. + """ + + def __init__(self: "CommonCredentialsProvider") -> None: + super().__init__() + self.last_refreshed_credentials: typing.Optional[NativeTokenHolder] = None + + @abstractmethod + def get_auth_token(self: "CommonCredentialsProvider") -> str: + """ + Returns the auth token retrieved from corresponding plugin + """ + pass # pragma: no cover + + def add_parameter( + self: "CommonCredentialsProvider", + info: RedshiftProperty, + ) -> None: + self.disable_cache = True + + def get_credentials(self: "CommonCredentialsProvider") -> NativeTokenHolder: + credentials: typing.Optional[NativeTokenHolder] = None + + if not self.disable_cache: + key = self.get_cache_key() + credentials = typing.cast(NativeTokenHolder, self.cache.get(key)) + + if not credentials or credentials.is_expired(): + if self.disable_cache: + _logger.debug("Auth token Cache disabled : fetching new token") + else: + _logger.debug("Auth token Cache enabled - No auth token found from cache : fetching new token") + + self.refresh() + + if self.disable_cache: + credentials = self.last_refreshed_credentials + self.last_refreshed_credentials = None + else: + credentials.refresh = False + _logger.debug("Auth token found from cache") + + if not self.disable_cache: + credentials = typing.cast(NativeTokenHolder, self.cache[key]) + return typing.cast(NativeTokenHolder, credentials) + + def refresh(self: "CommonCredentialsProvider") -> None: + auth_token: str = self.get_auth_token() + _logger.debug("auth token: {}".format(auth_token)) + + if auth_token is None: + raise InterfaceError("IdC authentication failed : An error occurred during the request.") + + credentials: NativeTokenHolder = NativeTokenHolder(access_token=auth_token, expiration=None) + credentials.refresh = True + + _logger.debug("disable_cache={}".format(str(self.disable_cache))) + if not self.disable_cache: + self.cache[self.get_cache_key()] = credentials + else: + self.last_refreshed_credentials = credentials + + def get_idp_token(self: "CommonCredentialsProvider") -> str: + auth_token: str = self.get_auth_token() + return auth_token + + def set_group_federation(self: "CommonCredentialsProvider", group_federation: bool): + pass + + def get_sub_type(self: "CommonCredentialsProvider") -> int: + return IamHelper.IDC_PLUGIN \ No newline at end of file diff --git a/redshift_connector/plugin/idp_token_auth_plugin.py b/redshift_connector/plugin/idp_token_auth_plugin.py new file mode 100644 index 0000000..dd0e78e --- /dev/null +++ b/redshift_connector/plugin/idp_token_auth_plugin.py @@ -0,0 +1,46 @@ +import logging +import typing + +from redshift_connector.error import InterfaceError +from redshift_connector.plugin.common_credentials_provider import CommonCredentialsProvider +from redshift_connector.redshift_property import RedshiftProperty + +logging.getLogger(__name__).addHandler(logging.NullHandler()) +_logger: logging.Logger = logging.getLogger(__name__) + + +class IdpTokenAuthPlugin(CommonCredentialsProvider): + """ + A basic IdP Token auth plugin class. This plugin class allows clients to directly provide any auth token that is handled by Redshift. + """ + + def __init__(self: "IdpTokenAuthPlugin") -> None: + super().__init__() + self.token: typing.Optional[str] = None + self.token_type: typing.Optional[str] = None + + def add_parameter( + self: "IdpTokenAuthPlugin", + info: RedshiftProperty, + ) -> None: + super().add_parameter(info) + self.token = info.token + self.token_type = info.token_type + _logger.debug("Setting token_type = {}".format(self.token_type)) + + def check_required_parameters(self: "IdpTokenAuthPlugin") -> None: + super().check_required_parameters() + if not self.token: + _logger.error("IdC authentication failed: token needs to be provided in connection params") + raise InterfaceError("IdC authentication failed: The token must be included in the connection parameters.") + if not self.token_type: + _logger.error("IdC authentication failed: token_type needs to be provided in connection params") + raise InterfaceError( + "IdC authentication failed: The token type must be included in the connection parameters.") + + def get_cache_key(self: "IdpTokenAuthPlugin") -> str: + pass + + def get_auth_token(self: "IdpTokenAuthPlugin") -> str: + self.check_required_parameters() + return typing.cast(str, self.token) diff --git a/redshift_connector/redshift_property.py b/redshift_connector/redshift_property.py index d488911..d067116 100644 --- a/redshift_connector/redshift_property.py +++ b/redshift_connector/redshift_property.py @@ -60,6 +60,10 @@ def __init__(self: "RedshiftProperty", **kwargs): self.host: str = "" self.iam: bool = False self.iam_disable_cache: bool = False + self.idc_client_display_name: typing.Optional[str] = None + self.idc_region: typing.Optional[str] = None + self.idc_response_timeout: int = 120 + self.identity_namespace: typing.Optional[str] = None # The IdP (identity provider) host you are using to authenticate into Redshift. self.idp_host: typing.Optional[str] = None # timeout for authentication via Browser IDP @@ -98,6 +102,7 @@ def __init__(self: "RedshiftProperty", **kwargs): self.session_token: typing.Optional[str] = None # The source IP address which initiates the connection to the Amazon Redshift server. self.source_address: typing.Optional[str] = None + self.start_url: typing.Optional[str] = None # if SSL authentication will be used self.ssl: bool = True # This property indicates whether the IDP hosts server certificate should be verified. @@ -108,6 +113,8 @@ def __init__(self: "RedshiftProperty", **kwargs): self.tcp_keepalive: bool = True # This is the time in seconds before the connection to the server will time out. self.timeout: typing.Optional[int] = None + self.token: typing.Optional[str] = None + self.token_type: typing.Optional[str] = None # The path to the UNIX socket to access the database through self.unix_sock: typing.Optional[str] = None # The user name. diff --git a/redshift_connector/utils/logging_utils.py b/redshift_connector/utils/logging_utils.py index 2fc4166..7a65eba 100644 --- a/redshift_connector/utils/logging_utils.py +++ b/redshift_connector/utils/logging_utils.py @@ -37,6 +37,10 @@ def mask_secure_info_in_props(info: "RedshiftProperty") -> "RedshiftProperty": "host", "iam", "iam_disable_cache", + "idc_client_display_name", + "idc_region", + "idc_response_timeout", + "identity_namespace", "idp_host", "idpPort", "idp_response_timeout", @@ -66,7 +70,9 @@ def mask_secure_info_in_props(info: "RedshiftProperty") -> "RedshiftProperty": "ssl", "ssl_insecure", "sslmode", + "start_url", "tcp_keepalive", + "token_type", "timeout", "unix_sock", "user_name", diff --git a/test/conftest.py b/test/conftest.py index 4799b76..86e2f1b 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -273,6 +273,38 @@ def redshift_native_browser_azure_oauth2_idp() -> typing.Dict[str, typing.Union[ } return db_connect +@pytest.fixture(scope="class") +def redshift_browser_idc() -> typing.Dict[str, typing.Union[str, bool, int]]: + db_connect = { + "host": conf.get("redshift-browser-idc", "host", fallback=None), + "region": conf.get("redshift-browser-idc", "region", fallback=None), + "database": conf.get("redshift-browser-idc", "database", fallback="dev"), + "credentials_provider": conf.get( + "redshift-browser-idc", "credentials_provider", fallback="BrowserIdcAuthPlugin" + ), + "start_url": conf.get("redshift-browser-idc", "start_url", fallback=None), + "idc_region": conf.get("redshift-browser-idc", "idc_region", fallback=None), + "idc_response_timeout": conf.getint("redshift-browser-idc", "idc_response_timeout", fallback=120), + "idc_client_display_name": conf.get("redshift-browser-idc", "idc_client_display_name", + fallback="Amazon Redshift Python connector"), + "identity_namespace": conf.get("redshift-browser-idc", "identity_namespace", fallback=None), + } + return db_connect + +@pytest.fixture(scope="class") +def redshift_idp_token_auth_plugin() -> typing.Dict[str, typing.Union[str, bool, int]]: + db_connect = { + "host": conf.get("redshift-idp-token-auth-plugin", "host", fallback=None), + "region": conf.get("redshift-idp-token-auth-plugin", "region", fallback=None), + "database": conf.get("redshift-idp-token-auth-plugin", "database", fallback="dev"), + "credentials_provider": conf.get( + "redshift-idp-token-auth-plugin", "credentials_provider", fallback="IdpTokenAuthPlugin" + ), + "token": conf.get("redshift-idp-token-auth-plugin", "token", fallback=None), + "token_type": conf.get("redshift-idp-token-auth-plugin", "token_type", fallback=None), + "identity_namespace": conf.get("redshift-idp-token-auth-plugin", "identity_namespace", fallback=None), + } + return db_connect @pytest.fixture def con(request, db_kwargs) -> redshift_connector.Connection: diff --git a/test/manual/plugin/test_browser_credentials_provider.py b/test/manual/plugin/test_browser_credentials_provider.py index b5aa5f2..5fc192a 100644 --- a/test/manual/plugin/test_browser_credentials_provider.py +++ b/test/manual/plugin/test_browser_credentials_provider.py @@ -1,15 +1,6 @@ import configparser import os import typing -from test import ( - azure_browser_idp, - idp_arg, - jumpcloud_browser_idp, - jwt_azure_v2_idp, - jwt_google_idp, - okta_browser_idp, - ping_browser_idp, -) import pytest # type: ignore @@ -27,6 +18,8 @@ # "jwt_google_idp", "ping_browser_idp", "redshift_native_browser_azure_oauth2_idp", + "redshift_browser_idc", + "redshift_idp_token_auth_plugin", ] """ diff --git a/test/unit/plugin/test_browser_idc_auth_plugin.py b/test/unit/plugin/test_browser_idc_auth_plugin.py new file mode 100644 index 0000000..8a8252c --- /dev/null +++ b/test/unit/plugin/test_browser_idc_auth_plugin.py @@ -0,0 +1,134 @@ +import typing + +import pytest +from pytest_mock import mocker # type: ignore + +from redshift_connector.error import InterfaceError +from redshift_connector.plugin.browser_idc_auth_plugin import ( + BrowserIdcAuthPlugin, +) +from redshift_connector.redshift_property import RedshiftProperty + + +def make_valid_browser_idc_provider() -> typing.Tuple[BrowserIdcAuthPlugin, RedshiftProperty]: + rp: RedshiftProperty = RedshiftProperty() + rp.idc_region = "some_region" + rp.start_url = "some_url" + rp.idc_response_timeout = 120 + cp: BrowserIdcAuthPlugin = BrowserIdcAuthPlugin() + cp.add_parameter(rp) + return cp, rp + + +def test_add_parameter_sets_browser_idc_specific(): + idc_credentials_provider, rp = make_valid_browser_idc_provider() + assert idc_credentials_provider.idc_region == rp.idc_region + assert idc_credentials_provider.start_url == rp.start_url + assert idc_credentials_provider.idc_response_timeout == rp.idc_response_timeout + + +@pytest.mark.parametrize("value", [None, ""]) +def test_check_required_parameters_raises_if_start_url_missing(value): + idc_credentials_provider, _ = make_valid_browser_idc_provider() + idc_credentials_provider.start_url = value + + with pytest.raises(InterfaceError, + match="IdC authentication failed: The start URL must be included in the connection parameters."): + idc_credentials_provider.get_auth_token() + + +@pytest.mark.parametrize("value", [None, ""]) +def test_check_required_parameters_raises_if_idc_region_missing(value): + idc_credentials_provider, _ = make_valid_browser_idc_provider() + idc_credentials_provider.idc_region = value + + with pytest.raises(InterfaceError, + match="IdC authentication failed: The IdC region must be included in the connection parameters."): + idc_credentials_provider.get_auth_token() + + +def test_get_auth_token_fetches_idc_token(mocker): + # Mock the dependencies and their return values + idc_credentials_provider, rp = make_valid_browser_idc_provider() + + test_register_client_cache_key: str = f"{idc_credentials_provider.idc_client_display_name}:{idc_credentials_provider.idc_region}" + mocked_register_client_result: typing.Dict[str, typing.Any] = { + "clientId": "mockedClientId", + "clientSecret": "mockedClientSecret" + } + mocked_start_device_auth_result: typing.Dict[str, typing.Any] = { + "verificationUriComplete": "http://mockedVerificationUriComplete" + } + expected_idc_token: str = "mockedAccessToken" + + mocker.patch("boto3.client") # Mocking boto3.client + + # Mocking the response of internal methods + mocker.patch.object(idc_credentials_provider, "register_client", return_value=mocked_register_client_result) + mocker.patch.object(idc_credentials_provider, "start_device_authorization", + return_value=mocked_start_device_auth_result) + mocker.patch( + "redshift_connector.plugin.browser_idc_auth_plugin.BrowserIdcAuthPlugin.open_browser") + mocker.patch.object(idc_credentials_provider, "poll_for_create_token", return_value=expected_idc_token) + + # Call the method under test + test_result_idc_token: str = idc_credentials_provider.get_auth_token() + + # Assertions + idc_credentials_provider.register_client.assert_called_once_with( + test_register_client_cache_key, idc_credentials_provider.idc_client_display_name, idc_credentials_provider.CLIENT_TYPE, + idc_credentials_provider.IDC_SCOPE + ) + idc_credentials_provider.start_device_authorization.assert_called_once_with( + mocked_register_client_result['clientId'], mocked_register_client_result['clientSecret'], + idc_credentials_provider.start_url + ) + idc_credentials_provider.open_browser.assert_called_once_with( + mocked_start_device_auth_result['verificationUriComplete']) + idc_credentials_provider.poll_for_create_token.assert_called_once_with( + mocked_register_client_result, mocked_start_device_auth_result, idc_credentials_provider.GRANT_TYPE + ) + + assert test_result_idc_token == expected_idc_token + + +def test_register_client_exception_handling(mocker): + idc_credentials_provider, rp = make_valid_browser_idc_provider() + + mocker.patch.object(idc_credentials_provider, "register_client", side_effect=Exception("Some error")) + + with pytest.raises(InterfaceError): + idc_credentials_provider.get_auth_token() + + +def test_start_device_authorization_exception_handling(mocker): + idc_credentials_provider, rp = make_valid_browser_idc_provider() + mocked_register_client_result: typing.Dict[str, typing.Any] = { + "clientId": "mockedClientId", + "clientSecret": "mockedClientSecret" + } + + mocker.patch.object(idc_credentials_provider, "register_client", return_value=mocked_register_client_result) + mocker.patch.object(idc_credentials_provider, "start_device_authorization", side_effect=Exception("Some error")) + + with pytest.raises(InterfaceError): + idc_credentials_provider.get_auth_token() + + +def test_poll_for_create_token_exception_handling(mocker): + idc_credentials_provider, rp = make_valid_browser_idc_provider() + mocked_register_client_result: typing.Dict[str, typing.Any] = { + "clientId": "mockedClientId", + "clientSecret": "mockedClientSecret" + } + mocked_start_device_auth_result: typing.Dict[str, typing.Any] = { + "verificationUriComplete": "http://mockedVerificationUriComplete" + } + + mocker.patch.object(idc_credentials_provider, "register_client", return_value=mocked_register_client_result) + mocker.patch.object(idc_credentials_provider, "start_device_authorization", + return_value=mocked_start_device_auth_result) + mocker.patch.object(idc_credentials_provider, "poll_for_create_token", side_effect=Exception("Unexpected error")) + + with pytest.raises(InterfaceError): + idc_credentials_provider.get_auth_token()