From 861657249612e23fd2c229418298ecf6717c65f9 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Mon, 16 Dec 2024 04:14:52 -0800 Subject: [PATCH] Inject header in more Session using spots plus more tests --- google/cloud/spanner_v1/session.py | 21 ++- .../cloud/spanner_v1/testing/interceptors.py | 9 +- tests/unit/test_request_id_header.py | 130 ++++++++++++++++++ 3 files changed, 153 insertions(+), 7 deletions(-) diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 166d5488c6..e5eff2dce7 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -193,7 +193,8 @@ def exists(self): current_span, "Checking if Session exists", {"session.id": self._session_id} ) - api = self._database.spanner_api + database = self._database + api = database.spanner_api metadata = _metadata_with_prefix(self._database.name) if self._database._route_to_leader_enabled: metadata.append( @@ -202,12 +203,16 @@ def exists(self): ) ) + all_metadata = database.metadata_with_request_id( + database._next_nth_request, 1, metadata + ) + observability_options = getattr(self._database, "observability_options", None) with trace_call( "CloudSpanner.GetSession", self, observability_options=observability_options ) as span: try: - api.get_session(name=self.name, metadata=metadata) + api.get_session(name=self.name, metadata=all_metadata) if span: span.set_attribute("session_found", True) except NotFound: @@ -237,8 +242,11 @@ def delete(self): current_span, "Deleting Session", {"session.id": self._session_id} ) - api = self._database.spanner_api - metadata = _metadata_with_prefix(self._database.name) + database = self._database + api = database.spanner_api + metadata = database.metadata_with_request_id( + database._next_nth_request, 1, _metadata_with_prefix(database.name) + ) observability_options = getattr(self._database, "observability_options", None) with trace_call( "CloudSpanner.DeleteSession", @@ -255,7 +263,10 @@ def ping(self): if self._session_id is None: raise ValueError("Session ID not set by back-end") api = self._database.spanner_api - metadata = _metadata_with_prefix(self._database.name) + database = self._database + metadata = database.metadata_with_request_id( + database._next_nth_request, 1, _metadata_with_prefix(database.name) + ) request = ExecuteSqlRequest(session=self.name, sql="SELECT 1") api.execute_sql(request=request, metadata=metadata) self._last_use_time = datetime.now() diff --git a/google/cloud/spanner_v1/testing/interceptors.py b/google/cloud/spanner_v1/testing/interceptors.py index 622f838339..918c7d76b9 100644 --- a/google/cloud/spanner_v1/testing/interceptors.py +++ b/google/cloud/spanner_v1/testing/interceptors.py @@ -67,6 +67,9 @@ def reset(self): self._connection = None +X_GOOG_REQUEST_ID = "x-goog-spanner-request-id" + + class XGoogRequestIDHeaderInterceptor(ClientInterceptor): def __init__(self): self._unary_req_segments = [] @@ -77,12 +80,14 @@ def intercept(self, method, request_or_iterator, call_details): metadata = call_details.metadata x_goog_request_id = None for key, value in metadata: - if key == "x-goog-spanner-request-id": + if key == X_GOOG_REQUEST_ID: x_goog_request_id = value break if not x_goog_request_id: - raise Exception("Missing x_goog_request_id header") + raise Exception( + f"Missing {X_GOOG_REQUEST_ID} header in {call_details.method}" + ) response_or_iterator = method(request_or_iterator, call_details) streaming = getattr(response_or_iterator, "__iter__", None) is not None diff --git a/tests/unit/test_request_id_header.py b/tests/unit/test_request_id_header.py index 12b1081960..df282f6356 100644 --- a/tests/unit/test_request_id_header.py +++ b/tests/unit/test_request_id_header.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random +import threading + from tests.mockserver_tests.mock_server_test_base import ( MockServerTestBase, add_select1_result, @@ -63,6 +66,133 @@ def test_snapshot_read(self): assert got_unary_segments == want_unary_segments assert got_stream_segments == want_stream_segments + def test_snapshot_read_concurrent(self): + def select1(): + with self.database.snapshot() as snapshot: + rows = snapshot.execute_sql("select 1") + res_list = [] + for row in rows: + self.assertEqual(1, row[0]) + res_list.append(row) + self.assertEqual(1, len(res_list)) + + n = 10 + threads = [] + for i in range(n): + th = threading.Thread(target=select1, name=f"snapshot-select1-{i}") + th.run() + threads.append(th) + + random.shuffle(threads) + + while True: + n_finished = 0 + for thread in threads: + if thread.is_alive(): + thread.join() + else: + n_finished += 1 + + if n_finished == len(threads): + break + + time.sleep(1) + + requests = self.spanner_service.requests + self.assertEqual(n * 2, len(requests), msg=requests) + + client_id = self.database._nth_client_id + channel_id = self.database._channel_id + got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() + + want_unary_segments = [ + ( + "/google.spanner.v1.Spanner/BatchCreateSessions", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 3, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 5, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 7, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 9, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 11, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 13, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 15, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 17, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 19, 1), + ), + ] + assert got_unary_segments == want_unary_segments + + want_stream_segments = [ + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 2, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 4, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 6, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 8, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 10, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 12, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 14, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 16, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 18, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 20, 1), + ), + ] + assert got_stream_segments == want_stream_segments + def canonicalize_request_id_headers(self): src = self.database._x_goog_request_id_interceptor return src._stream_req_segments, src._unary_req_segments