Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 66 additions & 1 deletion spanner/google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,32 @@

"""User friendly container for Cloud Spanner Database."""

import copy
import functools
import re
import threading
import copy

from google.api_core.gapic_v1 import client_info
import google.auth.credentials
from google.protobuf.struct_pb2 import Struct
from google.cloud.exceptions import NotFound
import six

# pylint: disable=ungrouped-imports
from google.cloud.spanner_v1 import __version__
from google.cloud.spanner_v1._helpers import _make_value_pb
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1.batch import Batch
from google.cloud.spanner_v1.gapic.spanner_client import SpannerClient
from google.cloud.spanner_v1.keyset import KeySet
from google.cloud.spanner_v1.pool import BurstyPool
from google.cloud.spanner_v1.pool import SessionCheckout
from google.cloud.spanner_v1.session import Session
from google.cloud.spanner_v1.snapshot import _restart_on_unavailable
from google.cloud.spanner_v1.snapshot import Snapshot
from google.cloud.spanner_v1.streamed import StreamedResultSet
from google.cloud.spanner_v1.proto.transaction_pb2 import (
TransactionSelector, TransactionOptions)
# pylint: enable=ungrouped-imports


Expand Down Expand Up @@ -272,6 +279,64 @@ def drop(self):
metadata = _metadata_with_prefix(self.name)
api.drop_database(self.name, metadata=metadata)

def execute_partitioned_dml(
self, dml, params=None, param_types=None):
"""Execute a partitionable DML statement.

:type dml: str
:param dml: DML statement

:type params: dict, {str -> column value}
:param params: values for parameter replacement. Keys must match
the names used in ``dml``.

:type param_types: dict[str -> Union[dict, .types.Type]]
:param param_types:
(Optional) maps explicit types for one or more param values;
required if parameters are passed.

:rtype: int
:returns: Count of rows affected by the DML statement.
"""
if params is not None:
if param_types is None:
raise ValueError(
"Specify 'param_types' when passing 'params'.")
params_pb = Struct(fields={
key: _make_value_pb(value) for key, value in params.items()})
else:
params_pb = None

api = self.spanner_api

txn_options = TransactionOptions(
partitioned_dml=TransactionOptions.PartitionedDml())

metadata = _metadata_with_prefix(self.name)

with SessionCheckout(self._pool) as session:

txn = api.begin_transaction(
session.name, txn_options, metadata=metadata)

txn_selector = TransactionSelector(id=txn.id)

restart = functools.partial(
api.execute_streaming_sql,
session.name,
dml,
transaction=txn_selector,
params=params_pb,
param_types=param_types,
metadata=metadata)

iterator = _restart_on_unavailable(restart)

result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials

return result_set.stats.row_count_lower_bound

def session(self, labels=None):
"""Factory to create a session for this database.

Expand Down
13 changes: 10 additions & 3 deletions spanner/google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class _SnapshotBase(_SessionWrapper):
_multi_use = False
_transaction_id = None
_read_request_count = 0
_execute_sql_count = 0

def _make_txn_selector(self): # pylint: disable=redundant-returns-doc
"""Helper for :meth:`read` / :meth:`execute_sql`.
Expand Down Expand Up @@ -195,14 +196,20 @@ def execute_sql(self, sql, params=None, param_types=None,

restart = functools.partial(
api.execute_streaming_sql,
self._session.name, sql,
transaction=transaction, params=params_pb, param_types=param_types,
query_mode=query_mode, partition_token=partition,
self._session.name,
sql,
transaction=transaction,
params=params_pb,
param_types=param_types,
query_mode=query_mode,
partition_token=partition,
seqno=self._execute_sql_count,
metadata=metadata)

iterator = _restart_on_unavailable(restart)

self._read_request_count += 1
self._execute_sql_count += 1

if self._multi_use:
return StreamedResultSet(iterator, source=self)
Expand Down
62 changes: 57 additions & 5 deletions spanner/google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

"""Spanner read-write transaction support."""

from google.cloud.spanner_v1.proto.transaction_pb2 import TransactionSelector
from google.cloud.spanner_v1.proto.transaction_pb2 import TransactionOptions
from google.protobuf.struct_pb2 import Struct

from google.cloud._helpers import _pb_timestamp_to_datetime
from google.cloud.spanner_v1._helpers import _make_value_pb
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1.proto.transaction_pb2 import TransactionSelector
from google.cloud.spanner_v1.proto.transaction_pb2 import TransactionOptions
from google.cloud.spanner_v1.snapshot import _SnapshotBase
from google.cloud.spanner_v1.batch import _BatchBase

Expand All @@ -35,6 +37,7 @@ class Transaction(_SnapshotBase, _BatchBase):
"""Timestamp at which the transaction was successfully committed."""
_rolled_back = False
_multi_use = True
_execute_sql_count = 0

def __init__(self, session):
if session._transaction is not None:
Expand Down Expand Up @@ -114,9 +117,6 @@ def commit(self):
"""
self._check_state()

if not self._mutations:
raise ValueError("No mutations to commit")

database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
Expand All @@ -128,6 +128,58 @@ def commit(self):
del self._session._transaction
return self.committed

def execute_update(self, dml, params=None, param_types=None,
query_mode=None):
"""Perform an ``ExecuteSql`` API request with DML.

:type dml: str
:param dml: SQL DML statement

:type params: dict, {str -> column value}
:param params: values for parameter replacement. Keys must match
the names used in ``dml``.

:type param_types: dict[str -> Union[dict, .types.Type]]
:param param_types:
(Optional) maps explicit types for one or more param values;
required if parameters are passed.

:type query_mode:
:class:`google.cloud.spanner_v1.proto.ExecuteSqlRequest.QueryMode`
:param query_mode: Mode governing return of results / query plan. See
https://cloud.google.com/spanner/reference/rpc/google.spanner.v1#google.spanner.v1.ExecuteSqlRequest.QueryMode1

:rtype: int
:returns: Count of rows affected by the DML statement.
"""
if params is not None:
if param_types is None:
raise ValueError(
"Specify 'param_types' when passing 'params'.")
params_pb = Struct(fields={
key: _make_value_pb(value) for key, value in params.items()})
else:
params_pb = None

database = self._session._database
metadata = _metadata_with_prefix(database.name)
transaction = self._make_txn_selector()
api = database.spanner_api

response = api.execute_sql(
self._session.name,
dml,
transaction=transaction,
params=params_pb,
param_types=param_types,
query_mode=query_mode,
seqno=self._execute_sql_count,
metadata=metadata,
)

self._execute_sql_count += 1
return response.stats.row_count_exact

def __enter__(self):
"""Begin ``with`` block."""
self.begin()
Expand Down
159 changes: 159 additions & 0 deletions spanner/tests/system/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,165 @@ def test_transaction_read_and_insert_or_update_then_commit(self):
rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL))
self._check_rows_data(rows)

def _generate_insert_statements(self):
insert_template = (
'INSERT INTO {table} ({column_list}) '
'VALUES ({row_data})'
)
for row in self.ROW_DATA:
yield insert_template.format(
table=self.TABLE,
column_list=', '.join(self.COLUMNS),
row_data='{}, "{}", "{}", "{}"'.format(*row)
)

@RetryErrors(exception=exceptions.ServerError)
@RetryErrors(exception=exceptions.Conflict)
def test_transaction_execute_sql_w_dml_read_rollback(self):
retry = RetryInstanceState(_has_all_ddl)
retry(self._db.reload)()

session = self._db.session()
session.create()
self.to_delete.append(session)

with session.batch() as batch:
batch.delete(self.TABLE, self.ALL)

transaction = session.transaction()
transaction.begin()

rows = list(
transaction.read(self.TABLE, self.COLUMNS, self.ALL))
self.assertEqual(rows, [])

for insert_statement in self._generate_insert_statements():
result = transaction.execute_sql(insert_statement)
list(result) # iterate to get stats
self.assertEqual(result.stats.row_count_exact, 1)

# Rows inserted via DML *can* be read before commit.
during_rows = list(
transaction.read(self.TABLE, self.COLUMNS, self.ALL))
self._check_rows_data(during_rows)

transaction.rollback()

rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL))
self._check_rows_data(rows, [])

@RetryErrors(exception=exceptions.ServerError)
@RetryErrors(exception=exceptions.Conflict)
def test_transaction_execute_update_read_commit(self):
retry = RetryInstanceState(_has_all_ddl)
retry(self._db.reload)()

session = self._db.session()
session.create()
self.to_delete.append(session)

with session.batch() as batch:
batch.delete(self.TABLE, self.ALL)

with session.transaction() as transaction:
rows = list(transaction.read(self.TABLE, self.COLUMNS, self.ALL))
self.assertEqual(rows, [])

for insert_statement in self._generate_insert_statements():
row_count = transaction.execute_update(insert_statement)
self.assertEqual(row_count, 1)

# Rows inserted via DML *can* be read before commit.
during_rows = list(
transaction.read(self.TABLE, self.COLUMNS, self.ALL))
self._check_rows_data(during_rows)

rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL))
self._check_rows_data(rows)

@RetryErrors(exception=exceptions.ServerError)
@RetryErrors(exception=exceptions.Conflict)
def test_transaction_execute_update_then_insert_commit(self):
retry = RetryInstanceState(_has_all_ddl)
retry(self._db.reload)()

session = self._db.session()
session.create()
self.to_delete.append(session)

with session.batch() as batch:
batch.delete(self.TABLE, self.ALL)

insert_statement = list(self._generate_insert_statements())[0]

with session.transaction() as transaction:
rows = list(transaction.read(self.TABLE, self.COLUMNS, self.ALL))
self.assertEqual(rows, [])

row_count = transaction.execute_update(insert_statement)
self.assertEqual(row_count, 1)

transaction.insert(self.TABLE, self.COLUMNS, self.ROW_DATA[1:])

rows = list(session.read(self.TABLE, self.COLUMNS, self.ALL))
self._check_rows_data(rows)

def test_execute_partitioned_dml(self):
retry = RetryInstanceState(_has_all_ddl)
retry(self._db.reload)()

delete_statement = 'DELETE FROM {} WHERE true'.format(self.TABLE)

def _setup_table(txn):
txn.execute_update(delete_statement)
for insert_statement in self._generate_insert_statements():
txn.execute_update(insert_statement)

committed = self._db.run_in_transaction(_setup_table)

with self._db.snapshot(read_timestamp=committed) as snapshot:
before_pdml = list(snapshot.read(
self.TABLE, self.COLUMNS, self.ALL))

self._check_rows_data(before_pdml)

nonesuch = 'nonesuch@example.com'
target = 'phred@example.com'
update_statement = (
'UPDATE {table} SET {table}.email = @email '
'WHERE {table}.email = @target').format(
table=self.TABLE)

row_count = self._db.execute_partitioned_dml(
update_statement,
params={
'email': nonesuch,
'target': target,
},
param_types={
'email': Type(code=STRING),
'target': Type(code=STRING),
},
)
self.assertEqual(row_count, 1)

row = self.ROW_DATA[0]
updated = [row[:3] + (nonesuch,)] + list(self.ROW_DATA[1:])

with self._db.snapshot(read_timestamp=committed) as snapshot:
after_update = list(snapshot.read(
self.TABLE, self.COLUMNS, self.ALL))
self._check_rows_data(after_update, updated)

row_count = self._db.execute_partitioned_dml(delete_statement)
self.assertEqual(row_count, len(self.ROW_DATA))

with self._db.snapshot(read_timestamp=committed) as snapshot:
after_delete = list(snapshot.read(
self.TABLE, self.COLUMNS, self.ALL))

self._check_rows_data(after_delete, [])

def _transaction_concurrency_helper(self, unit_of_work, pkey):
INITIAL_VALUE = 123
NUM_THREADS = 3 # conforms to equivalent Java systest.
Expand Down
Loading