Skip to content

Commit

Permalink
feat: add support for lazy refresh strategy (#1093)
Browse files Browse the repository at this point in the history
Add refresh_strategy argument to Connector() that allows setting
the strategy to "lazy" to use a lazy refresh strategy.

When creating a Connector via Connector(refresh_strategy="lazy"),
the connection info and ephemeral certificate will be refreshed only
when the cache certificate has expired. No background tasks run
periodically with this option, making it ideal for use in serverless
environments such as Cloud Run, Cloud Functions, etc, where the
CPU may be throttled.
  • Loading branch information
jackwotherspoon authored May 30, 2024
1 parent b0c699e commit b9526bb
Show file tree
Hide file tree
Showing 7 changed files with 313 additions and 55 deletions.
30 changes: 27 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,15 @@ defaults for each connection to make, you can initialize a
`Connector` object as follows:

```python
from google.cloud.sql.connector import Connector, IPTypes
from google.cloud.sql.connector import Connector

# Note: all parameters below are optional
connector = Connector(
ip_type="public", # can also be "private" or "psc"
enable_iam_auth=False,
timeout=30,
credentials=custom_creds # google.auth.credentials.Credentials
credentials=custom_creds, # google.auth.credentials.Credentials
refresh_strategy="lazy", # can be "lazy" or "background"
)
```

Expand Down Expand Up @@ -254,6 +255,21 @@ with Connector() as connector:
print(row)
```

### Configuring a Lazy Refresh (Cloud Run, Cloud Functions etc.)

The Connector's `refresh_strategy` argument can be set to `"lazy"` to configure
the Python Connector to retrieve connection info lazily and as-needed.
Otherwise, a background refresh cycle runs to retrive the connection info
periodically. This setting is useful in environments where the CPU may be
throttled outside of a request context, e.g., Cloud Run, Cloud Functions, etc.

To set the refresh strategy, set the `refresh_strategy` keyword argument when
initializing a `Connector`:

```python
connector = Connector(refresh_strategy="lazy")
```

### Specifying IP Address Type

The Cloud SQL Python Connector can be used to connect to Cloud SQL instances
Expand All @@ -277,7 +293,7 @@ conn = connector.connect(
```

> [!IMPORTANT]
>
>
> If specifying Private IP or Private Service Connect (PSC), your application must be
> attached to the proper VPC network to connect to your Cloud SQL instance. For most
> applications this will require the use of a [VPC Connector][vpc-connector].
Expand Down Expand Up @@ -355,6 +371,14 @@ The Python Connector can be used alongside popular Python web frameworks such
as Flask, FastAPI, etc, to integrate Cloud SQL databases within your
web applications.

> [!NOTE]
>
> For serverless environments such as Cloud Functions, Cloud Run, etc, it may be
> beneficial to initialize the `Connector` with the lazy refresh strategy.
> i.e. `Connector(refresh_strategy="lazy")`
>
> See [Configuring a Lazy Refresh](#configuring-a-lazy-refresh-cloud-run-cloud-functions-etc)
#### Flask-SQLAlchemy

[Flask-SQLAlchemy](https://flask-sqlalchemy.palletsprojects.com/en/2.x/)
Expand Down
9 changes: 8 additions & 1 deletion google/cloud/sql/connector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
from google.cloud.sql.connector.connector import Connector
from google.cloud.sql.connector.connector import create_async_connector
from google.cloud.sql.connector.instance import IPTypes
from google.cloud.sql.connector.instance import RefreshStrategy
from google.cloud.sql.connector.version import __version__

__all__ = ["__version__", "create_async_connector", "Connector", "IPTypes"]
__all__ = [
"__version__",
"create_async_connector",
"Connector",
"IPTypes",
"RefreshStrategy",
]
36 changes: 28 additions & 8 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import socket
from threading import Thread
from types import TracebackType
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, Optional, Type, Union

import google.auth
from google.auth.credentials import Credentials
Expand All @@ -34,6 +34,8 @@
from google.cloud.sql.connector.exceptions import DnsNameResolutionError
from google.cloud.sql.connector.instance import IPTypes
from google.cloud.sql.connector.instance import RefreshAheadCache
from google.cloud.sql.connector.instance import RefreshStrategy
from google.cloud.sql.connector.lazy import LazyRefreshCache
import google.cloud.sql.connector.pg8000 as pg8000
import google.cloud.sql.connector.pymysql as pymysql
import google.cloud.sql.connector.pytds as pytds
Expand Down Expand Up @@ -62,6 +64,7 @@ def __init__(
sqladmin_api_endpoint: Optional[str] = None,
user_agent: Optional[str] = None,
universe_domain: Optional[str] = None,
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
) -> None:
"""Initializes a Connector instance.
Expand Down Expand Up @@ -98,6 +101,11 @@ def __init__(
universe_domain (str): The universe domain for Cloud SQL API calls.
Default: "googleapis.com".
refresh_strategy (str | RefreshStrategy): The default refresh strategy
used to refresh SSL/TLS cert and instance metadata. Can be one
of the following: RefreshStrategy.LAZY ("LAZY") or
RefreshStrategy.BACKGROUND ("BACKGROUND").
Default: RefreshStrategy.BACKGROUND
"""
# if event loop is given, use for background tasks
if loop:
Expand All @@ -113,7 +121,7 @@ def __init__(
asyncio.run_coroutine_threadsafe(generate_keys(), self._loop),
loop=self._loop,
)
self._cache: Dict[str, RefreshAheadCache] = {}
self._cache: Dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {}
self._client: Optional[CloudSQLClient] = None

# initialize credentials
Expand All @@ -139,6 +147,10 @@ def __init__(
if isinstance(ip_type, str):
ip_type = IPTypes._from_str(ip_type)
self._ip_type = ip_type
# if refresh_strategy is str, convert to RefreshStrategy enum
if isinstance(refresh_strategy, str):
refresh_strategy = RefreshStrategy._from_str(refresh_strategy)
self._refresh_strategy = refresh_strategy
self._universe_domain = universe_domain
# construct service endpoint for Cloud SQL Admin API calls
if not sqladmin_api_endpoint:
Expand Down Expand Up @@ -265,12 +277,20 @@ async def connect_async(
"connector.Connector object."
)
else:
cache = RefreshAheadCache(
instance_connection_string,
self._client,
self._keys,
enable_iam_auth,
)
if self._refresh_strategy == RefreshStrategy.LAZY:
cache = LazyRefreshCache(
instance_connection_string,
self._client,
self._keys,
enable_iam_auth,
)
else:
cache = RefreshAheadCache(
instance_connection_string,
self._client,
self._keys,
enable_iam_auth,
)
self._cache[instance_connection_string] = cache

connect_func = {
Expand Down
17 changes: 17 additions & 0 deletions google/cloud/sql/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,23 @@ def _parse_instance_connection_name(connection_name: str) -> Tuple[str, str, str
return connection_name_split[1], connection_name_split[3], connection_name_split[4]


class RefreshStrategy(Enum):
LAZY: str = "LAZY"
BACKGROUND: str = "BACKGROUND"

@classmethod
def _missing_(cls, value: object) -> None:
raise ValueError(
f"Incorrect value for refresh_strategy, got '{value}'. Want one of: "
f"{', '.join([repr(m.value) for m in cls])}."
)

@classmethod
def _from_str(cls, refresh_strategy: str) -> RefreshStrategy:
"""Convert refresh strategy from a str into RefreshStrategy."""
return cls(refresh_strategy.upper())


class IPTypes(Enum):
PUBLIC: str = "PRIMARY"
PRIVATE: str = "PRIVATE"
Expand Down
132 changes: 132 additions & 0 deletions google/cloud/sql/connector/lazy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from datetime import datetime
from datetime import timedelta
from datetime import timezone
import logging
from typing import Optional

from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.instance import _parse_instance_connection_name
from google.cloud.sql.connector.refresh_utils import _refresh_buffer

logger = logging.getLogger(name=__name__)


class LazyRefreshCache:
"""Cache that refreshes connection info when a caller requests a connection.
Only refreshes the cache when a new connection is requested and the current
certificate is close to or already expired.
This is the recommended option for serverless environments.
"""

def __init__(
self,
instance_connection_string: str,
client: CloudSQLClient,
keys: asyncio.Future,
enable_iam_auth: bool = False,
) -> None:
"""Initializes a LazyRefreshCache instance.
Args:
instance_connection_string (str): The Cloud SQL Instance's
connection string (also known as an instance connection name).
client (CloudSQLClient): The Cloud SQL Client instance.
keys (asyncio.Future): A future to the client's public-private key
pair.
enable_iam_auth (bool): Enables automatic IAM database authentication
(Postgres and MySQL) as the default authentication method for all
connections.
"""
# validate and parse instance connection name
self._project, self._region, self._instance = _parse_instance_connection_name(
instance_connection_string
)
self._instance_connection_string = instance_connection_string

self._enable_iam_auth = enable_iam_auth
self._keys = keys
self._client = client
self._lock = asyncio.Lock()
self._cached: Optional[ConnectionInfo] = None
self._needs_refresh = False

async def force_refresh(self) -> None:
"""
Invalidates the cache and configures the next call to
connect_info() to retrieve a fresh ConnectionInfo instance.
"""
async with self._lock:
self._needs_refresh = True

async def connect_info(self) -> ConnectionInfo:
"""Retrieves ConnectionInfo instance for establishing a secure
connection to the Cloud SQL instance.
"""
async with self._lock:
# If connection info is cached, check expiration.
# Pad expiration with a buffer to give the client plenty of time to
# establish a connection to the server with the certificate.
if (
self._cached
and not self._needs_refresh
and datetime.now(timezone.utc)
< (self._cached.expiration - timedelta(seconds=_refresh_buffer))
):
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
"is still valid, using cached info"
)
return self._cached
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
"refresh operation started"
)
try:
conn_info = await self._client.get_connection_info(
self._project,
self._region,
self._instance,
self._keys,
self._enable_iam_auth,
)
except Exception as e:
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
f"refresh operation failed: {str(e)}"
)
raise
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
"refresh operation completed successfully"
)
logger.debug(
f"['{self._instance_connection_string}']: Current certificate "
f"expiration = {str(conn_info.expiration)}"
)
self._cached = conn_info
self._needs_refresh = False
return conn_info

async def close(self) -> None:
"""Close is a no-op and provided purely for a consistent interface with
other cache types.
"""
pass
Loading

0 comments on commit b9526bb

Please sign in to comment.