diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 3e61872368..6a9f1f48f5 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -344,7 +344,7 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals ) observability_options = getattr(database, "observability_options", None) with trace_call( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", self._session, trace_attributes, observability_options=observability_options, diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 8c28cda7ce..963debdab8 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -699,38 +699,43 @@ def execute_partitioned_dml( ) def execute_pdml(): - with SessionCheckout(self._pool) as session: - txn = api.begin_transaction( - session=session.name, options=txn_options, metadata=metadata - ) + with trace_call( + "CloudSpanner.Database.execute_partitioned_pdml", + observability_options=self.observability_options, + ) as span: + with SessionCheckout(self._pool) as session: + add_span_event(span, "Starting BeginTransaction") + txn = api.begin_transaction( + session=session.name, options=txn_options, metadata=metadata + ) - txn_selector = TransactionSelector(id=txn.id) + txn_selector = TransactionSelector(id=txn.id) - request = ExecuteSqlRequest( - session=session.name, - sql=dml, - params=params_pb, - param_types=param_types, - query_options=query_options, - request_options=request_options, - ) - method = functools.partial( - api.execute_streaming_sql, - metadata=metadata, - ) + request = ExecuteSqlRequest( + session=session.name, + sql=dml, + params=params_pb, + param_types=param_types, + query_options=query_options, + request_options=request_options, + ) + method = functools.partial( + api.execute_streaming_sql, + metadata=metadata, + ) - iterator = _restart_on_unavailable( - method=method, - trace_name="CloudSpanner.ExecuteStreamingSql", - request=request, - transaction_selector=txn_selector, - observability_options=self.observability_options, - ) + iterator = _restart_on_unavailable( + method=method, + trace_name="CloudSpanner.ExecuteStreamingSql", + request=request, + transaction_selector=txn_selector, + observability_options=self.observability_options, + ) - result_set = StreamedResultSet(iterator) - list(result_set) # consume all partials + result_set = StreamedResultSet(iterator) + list(result_set) # consume all partials - return result_set.stats.row_count_lower_bound + return result_set.stats.row_count_lower_bound return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)() @@ -1357,6 +1362,10 @@ def to_dict(self): "transaction_id": snapshot._transaction_id, } + @property + def observability_options(self): + return getattr(self._database, "observability_options", {}) + def _get_session(self): """Create session as needed. @@ -1476,27 +1485,32 @@ def generate_read_batches( mappings of information used perform actual partitioned reads via :meth:`process_read_batch`. """ - partitions = self._get_snapshot().partition_read( - table=table, - columns=columns, - keyset=keyset, - index=index, - partition_size_bytes=partition_size_bytes, - max_partitions=max_partitions, - retry=retry, - timeout=timeout, - ) + with trace_call( + f"CloudSpanner.{type(self).__name__}.generate_read_batches", + extra_attributes=dict(table=table, columns=columns), + observability_options=self.observability_options, + ): + partitions = self._get_snapshot().partition_read( + table=table, + columns=columns, + keyset=keyset, + index=index, + partition_size_bytes=partition_size_bytes, + max_partitions=max_partitions, + retry=retry, + timeout=timeout, + ) - read_info = { - "table": table, - "columns": columns, - "keyset": keyset._to_dict(), - "index": index, - "data_boost_enabled": data_boost_enabled, - "directed_read_options": directed_read_options, - } - for partition in partitions: - yield {"partition": partition, "read": read_info.copy()} + read_info = { + "table": table, + "columns": columns, + "keyset": keyset._to_dict(), + "index": index, + "data_boost_enabled": data_boost_enabled, + "directed_read_options": directed_read_options, + } + for partition in partitions: + yield {"partition": partition, "read": read_info.copy()} def process_read_batch( self, @@ -1522,12 +1536,17 @@ def process_read_batch( :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ - kwargs = copy.deepcopy(batch["read"]) - keyset_dict = kwargs.pop("keyset") - kwargs["keyset"] = KeySet._from_dict(keyset_dict) - return self._get_snapshot().read( - partition=batch["partition"], **kwargs, retry=retry, timeout=timeout - ) + observability_options = self.observability_options + with trace_call( + f"CloudSpanner.{type(self).__name__}.process_read_batch", + observability_options=observability_options, + ): + kwargs = copy.deepcopy(batch["read"]) + keyset_dict = kwargs.pop("keyset") + kwargs["keyset"] = KeySet._from_dict(keyset_dict) + return self._get_snapshot().read( + partition=batch["partition"], **kwargs, retry=retry, timeout=timeout + ) def generate_query_batches( self, @@ -1602,34 +1621,39 @@ def generate_query_batches( mappings of information used perform actual partitioned reads via :meth:`process_read_batch`. """ - partitions = self._get_snapshot().partition_query( - sql=sql, - params=params, - param_types=param_types, - partition_size_bytes=partition_size_bytes, - max_partitions=max_partitions, - retry=retry, - timeout=timeout, - ) + with trace_call( + f"CloudSpanner.{type(self).__name__}.generate_query_batches", + extra_attributes=dict(sql=sql), + observability_options=self.observability_options, + ): + partitions = self._get_snapshot().partition_query( + sql=sql, + params=params, + param_types=param_types, + partition_size_bytes=partition_size_bytes, + max_partitions=max_partitions, + retry=retry, + timeout=timeout, + ) - query_info = { - "sql": sql, - "data_boost_enabled": data_boost_enabled, - "directed_read_options": directed_read_options, - } - if params: - query_info["params"] = params - query_info["param_types"] = param_types - - # Query-level options have higher precedence than client-level and - # environment-level options - default_query_options = self._database._instance._client._query_options - query_info["query_options"] = _merge_query_options( - default_query_options, query_options - ) + query_info = { + "sql": sql, + "data_boost_enabled": data_boost_enabled, + "directed_read_options": directed_read_options, + } + if params: + query_info["params"] = params + query_info["param_types"] = param_types + + # Query-level options have higher precedence than client-level and + # environment-level options + default_query_options = self._database._instance._client._query_options + query_info["query_options"] = _merge_query_options( + default_query_options, query_options + ) - for partition in partitions: - yield {"partition": partition, "query": query_info} + for partition in partitions: + yield {"partition": partition, "query": query_info} def process_query_batch( self, @@ -1654,9 +1678,16 @@ def process_query_batch( :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ - return self._get_snapshot().execute_sql( - partition=batch["partition"], **batch["query"], retry=retry, timeout=timeout - ) + with trace_call( + f"CloudSpanner.{type(self).__name__}.process_query_batch", + observability_options=self.observability_options, + ): + return self._get_snapshot().execute_sql( + partition=batch["partition"], + **batch["query"], + retry=retry, + timeout=timeout, + ) def run_partitioned_query( self, @@ -1711,18 +1742,23 @@ def run_partitioned_query( :rtype: :class:`~google.cloud.spanner_v1.merged_result_set.MergedResultSet` :returns: a result set instance which can be used to consume rows. """ - partitions = list( - self.generate_query_batches( - sql, - params, - param_types, - partition_size_bytes, - max_partitions, - query_options, - data_boost_enabled, + with trace_call( + f"CloudSpanner.${type(self).__name__}.run_partitioned_query", + extra_attributes=dict(sql=sql), + observability_options=self.observability_options, + ): + partitions = list( + self.generate_query_batches( + sql, + params, + param_types, + partition_size_bytes, + max_partitions, + query_options, + data_boost_enabled, + ) ) - ) - return MergedResultSet(self, partitions, 0) + return MergedResultSet(self, partitions, 0) def process(self, batch): """Process a single, partitioned query or read. diff --git a/google/cloud/spanner_v1/merged_result_set.py b/google/cloud/spanner_v1/merged_result_set.py index 9165af9ee3..bfecad1e46 100644 --- a/google/cloud/spanner_v1/merged_result_set.py +++ b/google/cloud/spanner_v1/merged_result_set.py @@ -17,6 +17,8 @@ from typing import Any, TYPE_CHECKING from threading import Lock, Event +from google.cloud.spanner_v1._opentelemetry_tracing import trace_call + if TYPE_CHECKING: from google.cloud.spanner_v1.database import BatchSnapshot @@ -37,6 +39,16 @@ def __init__(self, batch_snapshot, partition_id, merged_result_set): self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue def run(self): + observability_options = getattr( + self._batch_snapshot, "observability_options", {} + ) + with trace_call( + "CloudSpanner.PartitionExecutor.run", + observability_options=observability_options, + ): + self.__run() + + def __run(self): results = None try: results = self._batch_snapshot.process_query_batch(self._partition_id) diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index 03bff81b52..596f76a1f1 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -523,12 +523,11 @@ def bind(self, database): metadata.append( _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) - created_session_count = 0 self._database_role = self._database_role or self._database.database_role request = BatchCreateSessionsRequest( database=database.name, - session_count=self.size - created_session_count, + session_count=self.size, session_template=Session(creator_role=self.database_role), ) @@ -549,38 +548,28 @@ def bind(self, database): span_event_attributes, ) - if created_session_count >= self.size: - add_span_event( - current_span, - "Created no new sessions as sessionPool is full", - span_event_attributes, - ) - return - - add_span_event( - current_span, - f"Creating {request.session_count} sessions", - span_event_attributes, - ) - observability_options = getattr(self._database, "observability_options", None) with trace_call( "CloudSpanner.PingingPool.BatchCreateSessions", observability_options=observability_options, ) as span: returned_session_count = 0 - while created_session_count < self.size: + while returned_session_count < self.size: resp = api.batch_create_sessions( request=request, metadata=metadata, ) + + add_span_event( + span, + f"Created {len(resp.session)} sessions", + ) + for session_pb in resp.session: session = self._new_session() + returned_session_count += 1 session._session_id = session_pb.name.split("/")[-1] self.put(session) - returned_session_count += 1 - - created_session_count += len(resp.session) add_span_event( span, diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index de610e1387..dc28644d6c 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -680,10 +680,14 @@ def partition_read( ) trace_attributes = {"table_id": table, "columns": columns} + can_include_index = (index != "") and (index is not None) + if can_include_index: + trace_attributes["index"] = index + with trace_call( f"CloudSpanner.{type(self).__name__}.partition_read", self._session, - trace_attributes, + extra_attributes=trace_attributes, observability_options=getattr(database, "observability_options", None), ): method = functools.partial( @@ -784,7 +788,7 @@ def partition_query( trace_attributes = {"db.statement": sql} with trace_call( - "CloudSpanner.PartitionReadWriteTransaction", + f"CloudSpanner.{type(self).__name__}.partition_query", self._session, trace_attributes, observability_options=getattr(database, "observability_options", None), diff --git a/tests/_helpers.py b/tests/_helpers.py index c7b1665e89..667f9f8be1 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -86,7 +86,7 @@ def assertSpanAttributes( ): if HAS_OPENTELEMETRY_INSTALLED: if not span: - span_list = self.ot_exporter.get_finished_spans() + span_list = self.get_finished_spans() self.assertEqual(len(span_list) > 0, True) span = span_list[0] @@ -132,3 +132,20 @@ def get_finished_spans(self): def reset(self): self.tearDown() + + def finished_spans_events_statuses(self): + span_list = self.get_finished_spans() + # Some event attributes are noisy/highly ephemeral + # and can't be directly compared against. + got_all_events = [] + imprecise_event_attributes = ["exception.stacktrace", "delay_seconds", "cause"] + for span in span_list: + for event in span.events: + evt_attributes = event.attributes.copy() + for attr_name in imprecise_event_attributes: + if attr_name in evt_attributes: + evt_attributes[attr_name] = "EPHEMERAL" + + got_all_events.append((event.name, evt_attributes)) + + return got_all_events diff --git a/tests/system/test_observability_options.py b/tests/system/test_observability_options.py index 42ce0de7fe..a91955496f 100644 --- a/tests/system/test_observability_options.py +++ b/tests/system/test_observability_options.py @@ -16,6 +16,9 @@ from . import _helpers from google.cloud.spanner_v1 import Client +from google.api_core.exceptions import Aborted +from google.auth.credentials import AnonymousCredentials +from google.rpc import code_pb2 HAS_OTEL_INSTALLED = False @@ -37,7 +40,7 @@ not HAS_OTEL_INSTALLED, reason="OpenTelemetry is necessary to test traces." ) @pytest.mark.skipif( - not _helpers.USE_EMULATOR, reason="mulator is necessary to test traces." + not _helpers.USE_EMULATOR, reason="Emulator is necessary to test traces." ) def test_observability_options_propagation(): PROJECT = _helpers.EMULATOR_PROJECT @@ -97,7 +100,8 @@ def test_propagation(enable_extended_tracing): _ = val from_global_spans = global_trace_exporter.get_finished_spans() - from_inject_spans = inject_trace_exporter.get_finished_spans() + target_spans = inject_trace_exporter.get_finished_spans() + from_inject_spans = sorted(target_spans, key=lambda v1: v1.start_time) assert ( len(from_global_spans) == 0 ) # "Expecting no spans from the global trace exporter" @@ -131,23 +135,11 @@ def test_propagation(enable_extended_tracing): test_propagation(False) -@pytest.mark.skipif( - not _helpers.USE_EMULATOR, - reason="Emulator needed to run this tests", -) -@pytest.mark.skipif( - not HAS_OTEL_INSTALLED, - reason="Tracing requires OpenTelemetry", -) -def test_transaction_abort_then_retry_spans(): - from google.auth.credentials import AnonymousCredentials - from google.api_core.exceptions import Aborted - from google.rpc import code_pb2 +def create_db_trace_exporter(): from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( InMemorySpanExporter, ) - from opentelemetry.trace.status import StatusCode from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.sampling import ALWAYS_ON @@ -159,20 +151,6 @@ def test_transaction_abort_then_retry_spans(): NODE_COUNT = 5 LABELS = {"test": "true"} - counters = dict(aborted=0) - - def select_in_txn(txn): - results = txn.execute_sql("SELECT 1") - for row in results: - _ = row - - if counters["aborted"] == 0: - counters["aborted"] = 1 - raise Aborted( - "Thrown from ClientInterceptor for testing", - errors=[_helpers.FauxCall(code_pb2.ABORTED)], - ) - tracer_provider = TracerProvider(sampler=ALWAYS_ON) trace_exporter = InMemorySpanExporter() tracer_provider.add_span_processor(SimpleSpanProcessor(trace_exporter)) @@ -206,22 +184,72 @@ def select_in_txn(txn): except Exception: pass + return db, trace_exporter + + +@pytest.mark.skipif( + not _helpers.USE_EMULATOR, + reason="Emulator needed to run this test", +) +@pytest.mark.skipif( + not HAS_OTEL_INSTALLED, + reason="Tracing requires OpenTelemetry", +) +def test_transaction_abort_then_retry_spans(): + from opentelemetry.trace.status import StatusCode + + db, trace_exporter = create_db_trace_exporter() + + counters = dict(aborted=0) + + def select_in_txn(txn): + results = txn.execute_sql("SELECT 1") + for row in results: + _ = row + + if counters["aborted"] == 0: + counters["aborted"] = 1 + raise Aborted( + "Thrown from ClientInterceptor for testing", + errors=[_helpers.FauxCall(code_pb2.ABORTED)], + ) + db.run_in_transaction(select_in_txn) + got_statuses, got_events = finished_spans_statuses(trace_exporter) + + # Check for the series of events + want_events = [ + ("Acquiring session", {"kind": "BurstyPool"}), + ("Waiting for a session to become available", {"kind": "BurstyPool"}), + ("No sessions available in pool. Creating session", {"kind": "BurstyPool"}), + ("Creating Session", {}), + ( + "Transaction was aborted in user operation, retrying", + {"delay_seconds": "EPHEMERAL", "cause": "EPHEMERAL", "attempt": 1}, + ), + ("Starting Commit", {}), + ("Commit Done", {}), + ] + assert got_events == want_events + + # Check for the statues. + codes = StatusCode + want_statuses = [ + ("CloudSpanner.Database.run_in_transaction", codes.OK, None), + ("CloudSpanner.CreateSession", codes.OK, None), + ("CloudSpanner.Session.run_in_transaction", codes.OK, None), + ("CloudSpanner.Transaction.execute_streaming_sql", codes.OK, None), + ("CloudSpanner.Transaction.execute_streaming_sql", codes.OK, None), + ("CloudSpanner.Transaction.commit", codes.OK, None), + ] + assert got_statuses == want_statuses + + +def finished_spans_statuses(trace_exporter): span_list = trace_exporter.get_finished_spans() # Sort the spans by their start time in the hierarchy. span_list = sorted(span_list, key=lambda span: span.start_time) - got_span_names = [span.name for span in span_list] - want_span_names = [ - "CloudSpanner.Database.run_in_transaction", - "CloudSpanner.CreateSession", - "CloudSpanner.Session.run_in_transaction", - "CloudSpanner.Transaction.execute_streaming_sql", - "CloudSpanner.Transaction.execute_streaming_sql", - "CloudSpanner.Transaction.commit", - ] - - assert got_span_names == want_span_names got_events = [] got_statuses = [] @@ -233,6 +261,7 @@ def select_in_txn(txn): got_statuses.append( (span.name, span.status.status_code, span.status.description) ) + for event in span.events: evt_attributes = event.attributes.copy() for attr_name in imprecise_event_attributes: @@ -241,30 +270,70 @@ def select_in_txn(txn): got_events.append((event.name, evt_attributes)) + return got_statuses, got_events + + +@pytest.mark.skipif( + not _helpers.USE_EMULATOR, + reason="Emulator needed to run this test", +) +@pytest.mark.skipif( + not HAS_OTEL_INSTALLED, + reason="Tracing requires OpenTelemetry", +) +def test_database_partitioned_error(): + from opentelemetry.trace.status import StatusCode + + db, trace_exporter = create_db_trace_exporter() + + try: + db.execute_partitioned_dml("UPDATE NonExistent SET name = 'foo' WHERE id > 1") + except Exception: + pass + + got_statuses, got_events = finished_spans_statuses(trace_exporter) # Check for the series of events want_events = [ ("Acquiring session", {"kind": "BurstyPool"}), ("Waiting for a session to become available", {"kind": "BurstyPool"}), ("No sessions available in pool. Creating session", {"kind": "BurstyPool"}), ("Creating Session", {}), + ("Starting BeginTransaction", {}), ( - "Transaction was aborted in user operation, retrying", - {"delay_seconds": "EPHEMERAL", "cause": "EPHEMERAL", "attempt": 1}, + "exception", + { + "exception.type": "google.api_core.exceptions.InvalidArgument", + "exception.message": "400 Table not found: NonExistent [at 1:8]\nUPDATE NonExistent SET name = 'foo' WHERE id > 1\n ^", + "exception.stacktrace": "EPHEMERAL", + "exception.escaped": "False", + }, + ), + ( + "exception", + { + "exception.type": "google.api_core.exceptions.InvalidArgument", + "exception.message": "400 Table not found: NonExistent [at 1:8]\nUPDATE NonExistent SET name = 'foo' WHERE id > 1\n ^", + "exception.stacktrace": "EPHEMERAL", + "exception.escaped": "False", + }, ), - ("Starting Commit", {}), - ("Commit Done", {}), ] assert got_events == want_events # Check for the statues. codes = StatusCode want_statuses = [ - ("CloudSpanner.Database.run_in_transaction", codes.OK, None), + ( + "CloudSpanner.Database.execute_partitioned_pdml", + codes.ERROR, + "InvalidArgument: 400 Table not found: NonExistent [at 1:8]\nUPDATE NonExistent SET name = 'foo' WHERE id > 1\n ^", + ), ("CloudSpanner.CreateSession", codes.OK, None), - ("CloudSpanner.Session.run_in_transaction", codes.OK, None), - ("CloudSpanner.Transaction.execute_streaming_sql", codes.OK, None), - ("CloudSpanner.Transaction.execute_streaming_sql", codes.OK, None), - ("CloudSpanner.Transaction.commit", codes.OK, None), + ( + "CloudSpanner.ExecuteStreamingSql", + codes.ERROR, + "InvalidArgument: 400 Table not found: NonExistent [at 1:8]\nUPDATE NonExistent SET name = 'foo' WHERE id > 1\n ^", + ), ] assert got_statuses == want_statuses diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 4e80657584..d2a86c8ddf 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -437,7 +437,6 @@ def test_batch_insert_then_read(sessions_database, ot_exporter): if ot_exporter is not None: span_list = ot_exporter.get_finished_spans() - assert len(span_list) == 4 assert_span_attributes( ot_exporter, @@ -464,6 +463,8 @@ def test_batch_insert_then_read(sessions_database, ot_exporter): span=span_list[3], ) + assert len(span_list) == 4 + def test_batch_insert_then_read_string_array_of_string(sessions_database, not_postgres): table = "string_plus_array_of_string" @@ -1193,30 +1194,57 @@ def unit_of_work(transaction): with tracer.start_as_current_span("Test Span"): session.run_in_transaction(unit_of_work) - span_list = ot_exporter.get_finished_spans() + span_list = [] + for span in ot_exporter.get_finished_spans(): + if span and span.name: + span_list.append(span) + + span_list = sorted(span_list, key=lambda v1: v1.start_time) got_span_names = [span.name for span in span_list] - want_span_names = [ + expected_span_names = [ "CloudSpanner.CreateSession", "CloudSpanner.Batch.commit", + "Test Span", + "CloudSpanner.Session.run_in_transaction", "CloudSpanner.DMLTransaction", "CloudSpanner.Transaction.commit", - "CloudSpanner.Session.run_in_transaction", - "Test Span", ] - assert got_span_names == want_span_names - - def assert_parent_hierarchy(parent, children): - for child in children: - assert child.context.trace_id == parent.context.trace_id - assert child.parent.span_id == parent.context.span_id - - test_span = span_list[-1] - test_span_children = [span_list[-2]] - assert_parent_hierarchy(test_span, test_span_children) - - session_run_in_txn = span_list[-2] - session_run_in_txn_children = span_list[2:-2] - assert_parent_hierarchy(session_run_in_txn, session_run_in_txn_children) + assert got_span_names == expected_span_names + + # We expect: + # |------CloudSpanner.CreateSession-------- + # + # |---Test Span----------------------------| + # |>--Session.run_in_transaction----------| + # |---------DMLTransaction-------| + # + # |>----Transaction.commit---| + + # CreateSession should have a trace of its own, with no children + # nor being a child of any other span. + session_span = span_list[0] + test_span = span_list[2] + # assert session_span.context.trace_id != test_span.context.trace_id + for span in span_list[1:]: + if span.parent: + assert span.parent.span_id != session_span.context.span_id + + def assert_parent_and_children(parent_span, children): + for span in children: + assert span.context.trace_id == parent_span.context.trace_id + assert span.parent.span_id == parent_span.context.span_id + + # [CreateSession --> Batch] should have their own trace. + session_run_in_txn_span = span_list[3] + children_of_test_span = [session_run_in_txn_span] + assert_parent_and_children(test_span, children_of_test_span) + + dml_txn_span = span_list[4] + batch_commit_txn_span = span_list[5] + children_of_session_run_in_txn_span = [dml_txn_span, batch_commit_txn_span] + assert_parent_and_children( + session_run_in_txn_span, children_of_session_run_in_txn_span + ) def test_execute_partitioned_dml( diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 738bce9529..eb5069b497 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -527,7 +527,7 @@ def test_batch_write_already_committed(self): group.delete(TABLE_NAME, keyset=keyset) groups.batch_write() self.assertSpanAttributes( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), ) @@ -553,7 +553,7 @@ def test_batch_write_grpc_error(self): groups.batch_write() self.assertSpanAttributes( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", status=StatusCode.ERROR, attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), ) @@ -615,7 +615,7 @@ def _test_batch_write_with_request_options( ) self.assertSpanAttributes( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), ) diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index 89715c741d..9b5d2c9885 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -918,7 +918,11 @@ def test_spans_put_full(self): attributes=attrs, span=span_list[-1], ) - wantEventNames = ["Requested for 4 sessions, returned 4"] + wantEventNames = [ + "Created 2 sessions", + "Created 2 sessions", + "Requested for 4 sessions, returned 4", + ] self.assertSpanEvents( "CloudSpanner.PingingPool.BatchCreateSessions", wantEventNames ) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index a4446a0d1e..099bd31bea 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -1194,12 +1194,17 @@ def _partition_read_helper( timeout=timeout, ) + want_span_attributes = dict( + BASE_ATTRIBUTES, + table_id=TABLE_NAME, + columns=tuple(COLUMNS), + ) + if index: + want_span_attributes["index"] = index self.assertSpanAttributes( "CloudSpanner._Derived.partition_read", status=StatusCode.OK, - attributes=dict( - BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) - ), + attributes=want_span_attributes, ) def test_partition_read_single_use_raises(self): @@ -1369,7 +1374,7 @@ def _partition_query_helper( ) self.assertSpanAttributes( - "CloudSpanner.PartitionReadWriteTransaction", + "CloudSpanner._Derived.partition_query", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}), ) @@ -1387,7 +1392,7 @@ def test_partition_query_other_error(self): list(derived.partition_query(SQL_QUERY)) self.assertSpanAttributes( - "CloudSpanner.PartitionReadWriteTransaction", + "CloudSpanner._Derived.partition_query", status=StatusCode.ERROR, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}), ) @@ -1696,6 +1701,14 @@ def test_begin_w_other_error(self): with self.assertRaises(RuntimeError): snapshot.begin() + if not HAS_OPENTELEMETRY_INSTALLED: + return + + span_list = self.get_finished_spans() + got_span_names = [span.name for span in span_list] + want_span_names = ["CloudSpanner.Snapshot.begin"] + assert got_span_names == want_span_names + self.assertSpanAttributes( "CloudSpanner.Snapshot.begin", status=StatusCode.ERROR, @@ -1816,6 +1829,10 @@ def __init__(self, directed_read_options=None): self._route_to_leader_enabled = True self._directed_read_options = directed_read_options + @property + def observability_options(self): + return dict(db_name=self.name) + class _Session(object): def __init__(self, database=None, name=TestSnapshot.SESSION_NAME):