Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(tracing): ensure nesting of Transaction.begin under commit + fix suggestions from feature review #1287

Merged
merged 3 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion google/cloud/spanner_v1/_opentelemetry_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ def trace_call(name, session=None, extra_attributes=None, observability_options=
# invoke .record_exception on our own else we shall have 2 exceptions.
raise
else:
if (not span._status) or span._status.status_code == StatusCode.UNSET:
# All spans still have set_status available even if for example
# NonRecordingSpan doesn't have "_status".
absent_span_status = getattr(span, "_status", None) is None
if absent_span_status or span._status.status_code == StatusCode.UNSET:
# OpenTelemetry-Python only allows a status change
# if the current code is UNSET or ERROR. At the end
# of the generator's consumption, only set it to OK
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def _get_streamed_result_set(
iterator = _restart_on_unavailable(
restart,
request,
f"CloudSpanner.{type(self).__name__}.execute_streaming_sql",
f"CloudSpanner.{type(self).__name__}.execute_sql",
self._session,
trace_attributes,
transaction=self,
Expand Down
66 changes: 34 additions & 32 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,39 +242,7 @@ def commit(
:returns: timestamp of the committed changes.
:raises ValueError: if there are no mutations to commit.
"""
self._check_state()
if self._transaction_id is None and len(self._mutations) > 0:
self.begin()
elif self._transaction_id is None and len(self._mutations) == 0:
raise ValueError("Transaction is not begun")

database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)

if request_options is None:
request_options = RequestOptions()
elif type(request_options) is dict:
request_options = RequestOptions(request_options)
if self.transaction_tag is not None:
request_options.transaction_tag = self.transaction_tag

# Request tags are not supported for commit requests.
request_options.request_tag = None

request = CommitRequest(
session=self._session.name,
mutations=self._mutations,
transaction_id=self._transaction_id,
return_commit_stats=return_commit_stats,
max_commit_delay=max_commit_delay,
request_options=request_options,
)

trace_attributes = {"num_mutations": len(self._mutations)}
observability_options = getattr(database, "observability_options", None)
with trace_call(
Expand All @@ -283,6 +251,40 @@ def commit(
trace_attributes,
observability_options,
) as span:
self._check_state()
if self._transaction_id is None and len(self._mutations) > 0:
self.begin()
elif self._transaction_id is None and len(self._mutations) == 0:
raise ValueError("Transaction is not begun")

api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(
database._route_to_leader_enabled
)
)

if request_options is None:
request_options = RequestOptions()
elif type(request_options) is dict:
request_options = RequestOptions(request_options)
if self.transaction_tag is not None:
request_options.transaction_tag = self.transaction_tag

# Request tags are not supported for commit requests.
request_options.request_tag = None

request = CommitRequest(
session=self._session.name,
mutations=self._mutations,
transaction_id=self._transaction_id,
return_commit_stats=return_commit_stats,
max_commit_delay=max_commit_delay,
request_options=request_options,
)

add_span_event(span, "Starting Commit")

method = functools.partial(
Expand Down
116 changes: 113 additions & 3 deletions tests/system/test_observability_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_propagation(enable_extended_tracing):
gotNames = [span.name for span in from_inject_spans]
wantNames = [
"CloudSpanner.CreateSession",
"CloudSpanner.Snapshot.execute_streaming_sql",
"CloudSpanner.Snapshot.execute_sql",
]
assert gotNames == wantNames

Expand Down Expand Up @@ -239,8 +239,8 @@ def select_in_txn(txn):
("CloudSpanner.Database.run_in_transaction", codes.OK, None),
("CloudSpanner.CreateSession", codes.OK, None),
("CloudSpanner.Session.run_in_transaction", codes.OK, None),
("CloudSpanner.Transaction.execute_streaming_sql", codes.OK, None),
("CloudSpanner.Transaction.execute_streaming_sql", codes.OK, None),
("CloudSpanner.Transaction.execute_sql", codes.OK, None),
("CloudSpanner.Transaction.execute_sql", codes.OK, None),
("CloudSpanner.Transaction.commit", codes.OK, None),
]
assert got_statuses == want_statuses
Expand Down Expand Up @@ -273,6 +273,116 @@ def finished_spans_statuses(trace_exporter):
return got_statuses, got_events


@pytest.mark.skipif(
not _helpers.USE_EMULATOR,
reason="Emulator needed to run this tests",
)
@pytest.mark.skipif(
not HAS_OTEL_INSTALLED,
reason="Tracing requires OpenTelemetry",
)
def test_transaction_update_implicit_begin_nested_inside_commit():
# Tests to ensure that transaction.commit() without a began transaction
# has transaction.begin() inlined and nested under the commit span.
from google.auth.credentials import AnonymousCredentials
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
InMemorySpanExporter,
)
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.sampling import ALWAYS_ON

PROJECT = _helpers.EMULATOR_PROJECT
CONFIGURATION_NAME = "config-name"
INSTANCE_ID = _helpers.INSTANCE_ID
DISPLAY_NAME = "display-name"
DATABASE_ID = _helpers.unique_id("temp_db")
NODE_COUNT = 5
LABELS = {"test": "true"}

def tx_update(txn):
txn.insert(
"Singers",
columns=["SingerId", "FirstName"],
values=[["1", "Bryan"], ["2", "Slash"]],
)

tracer_provider = TracerProvider(sampler=ALWAYS_ON)
trace_exporter = InMemorySpanExporter()
tracer_provider.add_span_processor(SimpleSpanProcessor(trace_exporter))
observability_options = dict(
tracer_provider=tracer_provider,
enable_extended_tracing=True,
)

client = Client(
project=PROJECT,
observability_options=observability_options,
credentials=AnonymousCredentials(),
)

instance = client.instance(
INSTANCE_ID,
CONFIGURATION_NAME,
display_name=DISPLAY_NAME,
node_count=NODE_COUNT,
labels=LABELS,
)

try:
instance.create()
except Exception:
pass

db = instance.database(DATABASE_ID)
try:
db._ddl_statements = [
"""CREATE TABLE Singers (
SingerId INT64 NOT NULL,
FirstName STRING(1024),
LastName STRING(1024),
SingerInfo BYTES(MAX),
FullName STRING(2048) AS (
ARRAY_TO_STRING([FirstName, LastName], " ")
) STORED
) PRIMARY KEY (SingerId)""",
"""CREATE TABLE Albums (
SingerId INT64 NOT NULL,
AlbumId INT64 NOT NULL,
AlbumTitle STRING(MAX),
MarketingBudget INT64,
) PRIMARY KEY (SingerId, AlbumId),
INTERLEAVE IN PARENT Singers ON DELETE CASCADE""",
]
db.create()
except Exception:
pass

try:
db.run_in_transaction(tx_update)
except Exception:
pass

span_list = trace_exporter.get_finished_spans()
# Sort the spans by their start time in the hierarchy.
span_list = sorted(span_list, key=lambda span: span.start_time)
got_span_names = [span.name for span in span_list]
want_span_names = [
"CloudSpanner.Database.run_in_transaction",
"CloudSpanner.CreateSession",
"CloudSpanner.Session.run_in_transaction",
"CloudSpanner.Transaction.commit",
"CloudSpanner.Transaction.begin",
]

assert got_span_names == want_span_names

# Our object is to ensure that .begin() is a child of .commit()
span_tx_begin = span_list[-1]
span_tx_commit = span_list[-2]
assert span_tx_begin.parent.span_id == span_tx_commit.context.span_id


@pytest.mark.skipif(
not _helpers.USE_EMULATOR,
reason="Emulator needed to run this test",
Expand Down
31 changes: 30 additions & 1 deletion tests/unit/test__opentelemetry_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_trace_codeless_error(self):
span = span_list[0]
self.assertEqual(span.status.status_code, StatusCode.ERROR)

def test_trace_call_terminal_span_status(self):
def test_trace_call_terminal_span_status_ALWAYS_ON_sampler(self):
# Verify that we don't unconditionally set the terminal span status to
# SpanStatus.OK per https://github.com/googleapis/python-spanner/issues/1246
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
Expand Down Expand Up @@ -195,3 +195,32 @@ def test_trace_call_terminal_span_status(self):
("VerifyTerminalSpanStatus", StatusCode.ERROR, "Our error exhibit"),
]
assert got_statuses == want_statuses

def test_trace_call_terminal_span_status_ALWAYS_OFF_sampler(self):
# Verify that we get the correct status even when using the ALWAYS_OFF
# sampler which produces the NonRecordingSpan per
# https://github.com/googleapis/python-spanner/issues/1286
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import (
InMemorySpanExporter,
)
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.sampling import ALWAYS_OFF

tracer_provider = TracerProvider(sampler=ALWAYS_OFF)
trace_exporter = InMemorySpanExporter()
tracer_provider.add_span_processor(SimpleSpanProcessor(trace_exporter))
observability_options = dict(tracer_provider=tracer_provider)

session = _make_session()
used_span = None
with _opentelemetry_tracing.trace_call(
"VerifyWithNonRecordingSpan",
session,
observability_options=observability_options,
) as span:
used_span = span

assert type(used_span).__name__ == "NonRecordingSpan"
span_list = list(trace_exporter.get_finished_spans())
assert span_list == []
4 changes: 2 additions & 2 deletions tests/unit/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def test_execute_sql_other_error(self):
self.assertEqual(derived._execute_sql_count, 1)

self.assertSpanAttributes(
"CloudSpanner._Derived.execute_streaming_sql",
"CloudSpanner._Derived.execute_sql",
status=StatusCode.ERROR,
attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}),
)
Expand Down Expand Up @@ -1024,7 +1024,7 @@ def _execute_sql_helper(
self.assertEqual(derived._execute_sql_count, sql_count + 1)

self.assertSpanAttributes(
"CloudSpanner._Derived.execute_streaming_sql",
"CloudSpanner._Derived.execute_sql",
status=StatusCode.OK,
attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}),
)
Expand Down
Loading
Loading