Skip to content

Commit

Permalink
fix(tracing): ensure nesting of Transaction.begin under commit + fix …
Browse files Browse the repository at this point in the history
…suggestions from feature review

This change ensures that:
* If a transaction was not yet begin, that if .commit() is invoked
the resulting span hierarchy has .begin nested under .commit
* We use "CloudSpanner.Transaction.execute_sql" instead of
  "CloudSpanner.Transaction.execute_streaming_sql"
* If we have a tracer_provider that produces non-recordings spans,
that it won't crash due to lacking `span._status`

Fixes #1286
  • Loading branch information
odeke-em committed Jan 10, 2025
1 parent 592047f commit 6fba449
Show file tree
Hide file tree
Showing 8 changed files with 308 additions and 53 deletions.
5 changes: 4 additions & 1 deletion google/cloud/spanner_v1/_opentelemetry_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ def trace_call(name, session=None, extra_attributes=None, observability_options=
# invoke .record_exception on our own else we shall have 2 exceptions.
raise
else:
if (not span._status) or span._status.status_code == StatusCode.UNSET:
# All spans still have set_status available even if for example
# NonRecordingSpan doesn't have "_status".
absent_span_status = getattr(span, "_status", None) is None
if absent_span_status or span._status.status_code == StatusCode.UNSET:
# OpenTelemetry-Python only allows a status change
# if the current code is UNSET or ERROR. At the end
# of the generator's consumption, only set it to OK
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def _get_streamed_result_set(
iterator = _restart_on_unavailable(
restart,
request,
f"CloudSpanner.{type(self).__name__}.execute_streaming_sql",
f"CloudSpanner.{type(self).__name__}.execute_sql",
self._session,
trace_attributes,
transaction=self,
Expand Down
66 changes: 34 additions & 32 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,39 +242,7 @@ def commit(
:returns: timestamp of the committed changes.
:raises ValueError: if there are no mutations to commit.
"""
self._check_state()
if self._transaction_id is None and len(self._mutations) > 0:
self.begin()
elif self._transaction_id is None and len(self._mutations) == 0:
raise ValueError("Transaction is not begun")

database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)

if request_options is None:
request_options = RequestOptions()
elif type(request_options) is dict:
request_options = RequestOptions(request_options)
if self.transaction_tag is not None:
request_options.transaction_tag = self.transaction_tag

# Request tags are not supported for commit requests.
request_options.request_tag = None

request = CommitRequest(
session=self._session.name,
mutations=self._mutations,
transaction_id=self._transaction_id,
return_commit_stats=return_commit_stats,
max_commit_delay=max_commit_delay,
request_options=request_options,
)

trace_attributes = {"num_mutations": len(self._mutations)}
observability_options = getattr(database, "observability_options", None)
with trace_call(
Expand All @@ -283,6 +251,40 @@ def commit(
trace_attributes,
observability_options,
) as span:
self._check_state()
if self._transaction_id is None and len(self._mutations) > 0:
self.begin()
elif self._transaction_id is None and len(self._mutations) == 0:
raise ValueError("Transaction is not begun")

api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(
database._route_to_leader_enabled
)
)

if request_options is None:
request_options = RequestOptions()
elif type(request_options) is dict:
request_options = RequestOptions(request_options)
if self.transaction_tag is not None:
request_options.transaction_tag = self.transaction_tag

# Request tags are not supported for commit requests.
request_options.request_tag = None

request = CommitRequest(
session=self._session.name,
mutations=self._mutations,
transaction_id=self._transaction_id,
return_commit_stats=return_commit_stats,
max_commit_delay=max_commit_delay,
request_options=request_options,
)

add_span_event(span, "Starting Commit")

method = functools.partial(
Expand Down
124 changes: 120 additions & 4 deletions tests/system/test_observability_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_propagation(enable_extended_tracing):
gotNames = [span.name for span in from_inject_spans]
wantNames = [
"CloudSpanner.CreateSession",
"CloudSpanner.Snapshot.execute_streaming_sql",
"CloudSpanner.Snapshot.execute_sql",
]
assert gotNames == wantNames

Expand Down Expand Up @@ -239,8 +239,8 @@ def select_in_txn(txn):
("CloudSpanner.Database.run_in_transaction", codes.OK, None),
("CloudSpanner.CreateSession", codes.OK, None),
("CloudSpanner.Session.run_in_transaction", codes.OK, None),
("CloudSpanner.Transaction.execute_streaming_sql", codes.OK, None),
("CloudSpanner.Transaction.execute_streaming_sql", codes.OK, None),
("CloudSpanner.Transaction.execute_sql", codes.OK, None),
("CloudSpanner.Transaction.execute_sql", codes.OK, None),
("CloudSpanner.Transaction.commit", codes.OK, None),
]
assert got_statuses == want_statuses
Expand Down Expand Up @@ -273,6 +273,117 @@ def finished_spans_statuses(trace_exporter):
return got_statuses, got_events


@pytest.mark.skipif(
not _helpers.USE_EMULATOR,
reason="Emulator needed to run this tests",
)
@pytest.mark.skipif(
not HAS_OTEL_INSTALLED,
reason="Tracing requires OpenTelemetry",
)
def test_transaction_update_implicit_begin_nested_inside_commit():
# Tests to ensure that transaction.commit() without a began transaction
# has transaction.begin() inlined and nested under the commit span.
from google.auth.credentials import AnonymousCredentials
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
InMemorySpanExporter,
)
from opentelemetry.trace.status import StatusCode
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.sampling import ALWAYS_ON

PROJECT = _helpers.EMULATOR_PROJECT
CONFIGURATION_NAME = "config-name"
INSTANCE_ID = _helpers.INSTANCE_ID
DISPLAY_NAME = "display-name"
DATABASE_ID = _helpers.unique_id("temp_db")
NODE_COUNT = 5
LABELS = {"test": "true"}

def tx_update(txn):
txn.update(
"Singers",
columns=["SingerId", "FirstName"],
values=[["1", "Bryan"], ["2", "Slash"]],
)

tracer_provider = TracerProvider(sampler=ALWAYS_ON)
trace_exporter = InMemorySpanExporter()
tracer_provider.add_span_processor(SimpleSpanProcessor(trace_exporter))
observability_options = dict(
tracer_provider=tracer_provider,
enable_extended_tracing=True,
)

client = Client(
project=PROJECT,
observability_options=observability_options,
credentials=AnonymousCredentials(),
)

instance = client.instance(
INSTANCE_ID,
CONFIGURATION_NAME,
display_name=DISPLAY_NAME,
node_count=NODE_COUNT,
labels=LABELS,
)

try:
instance.create()
except Exception:
pass

db = instance.database(DATABASE_ID)
try:
db._ddl_statements = [
"""CREATE TABLE Singers (
SingerId INT64 NOT NULL,
FirstName STRING(1024),
LastName STRING(1024),
SingerInfo BYTES(MAX),
FullName STRING(2048) AS (
ARRAY_TO_STRING([FirstName, LastName], " ")
) STORED
) PRIMARY KEY (SingerId)""",
"""CREATE TABLE Albums (
SingerId INT64 NOT NULL,
AlbumId INT64 NOT NULL,
AlbumTitle STRING(MAX),
MarketingBudget INT64,
) PRIMARY KEY (SingerId, AlbumId),
INTERLEAVE IN PARENT Singers ON DELETE CASCADE""",
]
db.create()
except Exception:
pass

try:
db.run_in_transaction(tx_update)
except Exception:
pass

span_list = trace_exporter.get_finished_spans()
# Sort the spans by their start time in the hierarchy.
span_list = sorted(span_list, key=lambda span: span.start_time)
got_span_names = [span.name for span in span_list]
want_span_names = [
"CloudSpanner.Database.run_in_transaction",
"CloudSpanner.CreateSession",
"CloudSpanner.Session.run_in_transaction",
"CloudSpanner.Transaction.commit",
"CloudSpanner.Transaction.begin",
]

assert got_span_names == want_span_names

# Our object is to ensure that .begin() is a child of .commit()
span_tx_begin = span_list[-1]
span_tx_commit = span_list[-2]
assert span_tx_begin.parent.span_id == span_tx_commit.context.span_id


@pytest.mark.skipif(
not _helpers.USE_EMULATOR,
reason="Emulator needed to run this test",
Expand Down Expand Up @@ -328,13 +439,18 @@ def test_database_partitioned_error():
codes.ERROR,
"InvalidArgument: 400 Table not found: NonExistent [at 1:8]\nUPDATE NonExistent SET name = 'foo' WHERE id > 1\n ^",
),
("CloudSpanner.CreateSession", codes.OK, None),
(
"CloudSpanner.CreateSession",
codes.OK,
None,
),
(
"CloudSpanner.ExecuteStreamingSql",
codes.ERROR,
"InvalidArgument: 400 Table not found: NonExistent [at 1:8]\nUPDATE NonExistent SET name = 'foo' WHERE id > 1\n ^",
),
]
print("got_statuses", got_statuses)
assert got_statuses == want_statuses


Expand Down
31 changes: 30 additions & 1 deletion tests/unit/test__opentelemetry_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_trace_codeless_error(self):
span = span_list[0]
self.assertEqual(span.status.status_code, StatusCode.ERROR)

def test_trace_call_terminal_span_status(self):
def test_trace_call_terminal_span_status_ALWAYS_ON_sampler(self):
# Verify that we don't unconditionally set the terminal span status to
# SpanStatus.OK per https://github.com/googleapis/python-spanner/issues/1246
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
Expand Down Expand Up @@ -195,3 +195,32 @@ def test_trace_call_terminal_span_status(self):
("VerifyTerminalSpanStatus", StatusCode.ERROR, "Our error exhibit"),
]
assert got_statuses == want_statuses

def test_trace_call_terminal_span_status_ALWAYS_OFF_sampler(self):
# Verify that we get the correct status even when using the ALWAYS_OFF
# sampler which produces the NonRecordingSpan per
# https://github.com/googleapis/python-spanner/issues/1286
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
InMemorySpanExporter,
)
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.sampling import ALWAYS_OFF

tracer_provider = TracerProvider(sampler=ALWAYS_OFF)
trace_exporter = InMemorySpanExporter()
tracer_provider.add_span_processor(SimpleSpanProcessor(trace_exporter))
observability_options = dict(tracer_provider=tracer_provider)

session = _make_session()
used_span = None
with _opentelemetry_tracing.trace_call(
"VerifyWithNonRecordingSpan",
session,
observability_options=observability_options,
) as span:
used_span = span

assert type(used_span).__name__ == "NonRecordingSpan"
span_list = list(trace_exporter.get_finished_spans())
assert span_list == []
20 changes: 17 additions & 3 deletions tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# limitations under the License.


import unittest

import mock
from google.api_core import gapic_v1
from google.cloud.spanner_admin_database_v1 import (
Expand All @@ -24,6 +22,10 @@
from google.cloud.spanner_v1.param_types import INT64
from google.api_core.retry import Retry
from google.protobuf.field_mask_pb2 import FieldMask
from tests._helpers import (
HAS_OPENTELEMETRY_INSTALLED,
OpenTelemetryBase,
)

from google.cloud.spanner_v1 import RequestOptions, DirectedReadOptions

Expand Down Expand Up @@ -62,7 +64,7 @@ class _CredentialsWithScopes(
return mock.Mock(spec=_CredentialsWithScopes)


class _BaseTest(unittest.TestCase):
class _BaseTest(OpenTelemetryBase):
PROJECT_ID = "project-id"
PARENT = "projects/" + PROJECT_ID
INSTANCE_ID = "instance-id"
Expand Down Expand Up @@ -1433,6 +1435,18 @@ def test_run_in_transaction_w_args(self):
self.assertEqual(committed, NOW)
self.assertEqual(session._retried, (_unit_of_work, (SINCE,), {"until": UNTIL}))

if not HAS_OPENTELEMETRY_INSTALLED:
pass

span_list = self.get_finished_spans()
got_span_names = [span.name for span in span_list]
want_span_names = ["CloudSpanner.Database.run_in_transaction"]
assert got_span_names == want_span_names

got_span_events_statuses = self.finished_spans_events_statuses()
want_span_events_statuses = []
assert got_span_events_statuses == want_span_events_statuses

def test_run_in_transaction_nested(self):
from datetime import datetime

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def test_execute_sql_other_error(self):
self.assertEqual(derived._execute_sql_count, 1)

self.assertSpanAttributes(
"CloudSpanner._Derived.execute_streaming_sql",
"CloudSpanner._Derived.execute_sql",
status=StatusCode.ERROR,
attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}),
)
Expand Down Expand Up @@ -1024,7 +1024,7 @@ def _execute_sql_helper(
self.assertEqual(derived._execute_sql_count, sql_count + 1)

self.assertSpanAttributes(
"CloudSpanner._Derived.execute_streaming_sql",
"CloudSpanner._Derived.execute_sql",
status=StatusCode.OK,
attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}),
)
Expand Down
Loading

0 comments on commit 6fba449

Please sign in to comment.