Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class Connection:
should end a that a new one should be started when the next statement is executed.
"""

def __init__(self, instance, database, read_only=False):
def __init__(self, instance, database=None, read_only=False):
self._instance = instance
self._database = database
self._ddl_statements = []
Expand Down Expand Up @@ -242,6 +242,8 @@ def _session_checkout(self):
:rtype: :class:`google.cloud.spanner_v1.session.Session`
:returns: Cloud Spanner session object ready to use.
"""
if self.database is None:
raise ValueError("Database needs to be passed for this operation")
if not self._session:
self._session = self.database._pool.get()

Expand All @@ -252,6 +254,8 @@ def _release_session(self):

The session will be returned into the sessions pool.
"""
if self.database is None:
raise ValueError("Database needs to be passed for this operation")
self.database._pool.put(self._session)
self._session = None

Expand Down Expand Up @@ -368,7 +372,7 @@ def close(self):
if self.inside_transaction:
self._transaction.rollback()

if self._own_pool:
if self._own_pool and self.database:
self.database._pool.clear()

self.is_closed = True
Expand All @@ -378,6 +382,8 @@ def commit(self):

This method is non-operational in autocommit mode.
"""
if self.database is None:
raise ValueError("Database needs to be passed for this operation")
self._snapshot = None

if self._autocommit:
Expand Down Expand Up @@ -420,6 +426,8 @@ def cursor(self):

@check_not_closed
def run_prior_DDL_statements(self):
if self.database is None:
raise ValueError("Database needs to be passed for this operation")
if self._ddl_statements:
ddl_statements = self._ddl_statements
self._ddl_statements = []
Expand Down Expand Up @@ -474,6 +482,8 @@ def validate(self):
:raises: :class:`google.cloud.exceptions.NotFound`: if the linked instance
or database doesn't exist.
"""
if self.database is None:
raise ValueError("Database needs to be passed for this operation")
with self.database.snapshot() as snapshot:
result = list(snapshot.execute_sql("SELECT 1"))
if result != [[1]]:
Expand All @@ -492,7 +502,7 @@ def __exit__(self, etype, value, traceback):

def connect(
instance_id,
database_id,
database_id=None,
project=None,
credentials=None,
pool=None,
Expand All @@ -505,7 +515,7 @@ def connect(
:param instance_id: The ID of the instance to connect to.

:type database_id: str
:param database_id: The ID of the database to connect to.
:param database_id: (Optional) The ID of the database to connect to.

:type project: str
:param project: (Optional) The ID of the project which owns the
Expand Down Expand Up @@ -557,7 +567,9 @@ def connect(
raise ValueError("project in url does not match client object project")

instance = client.instance(instance_id)
conn = Connection(instance, instance.database(database_id, pool=pool))
conn = Connection(
instance, instance.database(database_id, pool=pool) if database_id else None
)
if pool is not None:
conn._own_pool = False

Expand Down
8 changes: 8 additions & 0 deletions google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def execute(self, sql, args=None):
:type args: list
:param args: Additional parameters to supplement the SQL query.
"""
if self.connection.database is None:
raise ValueError("Database needs to be passed for this operation")
self._itr = None
self._result_set = None
self._row_count = _UNSET_COUNT
Expand Down Expand Up @@ -301,6 +303,8 @@ def executemany(self, operation, seq_of_params):
:param seq_of_params: Sequence of additional parameters to run
the query with.
"""
if self.connection.database is None:
raise ValueError("Database needs to be passed for this operation")
self._itr = None
self._result_set = None
self._row_count = _UNSET_COUNT
Expand Down Expand Up @@ -444,6 +448,8 @@ def _handle_DQL_with_snapshot(self, snapshot, sql, params):
self._row_count = _UNSET_COUNT

def _handle_DQL(self, sql, params):
if self.connection.database is None:
raise ValueError("Database needs to be passed for this operation")
sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params)
if self.connection.read_only and not self.connection.autocommit:
# initiate or use the existing multi-use snapshot
Expand Down Expand Up @@ -484,6 +490,8 @@ def list_tables(self):
def run_sql_in_snapshot(self, sql, params=None, param_types=None):
# Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions
# hence this method exists to circumvent that limit.
if self.connection.database is None:
raise ValueError("Database needs to be passed for this operation")
self.connection.run_prior_DDL_statements()

with self.connection.database.snapshot() as snapshot:
Expand Down
50 changes: 46 additions & 4 deletions tests/unit/spanner_dbapi/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ def test__session_checkout(self, mock_database):
connection._session_checkout()
self.assertEqual(connection._session, "db_session")

def test__session_checkout_database_error(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a stupid question, but why is there inconsistency between method names here? Some have single underscore, some have two.

from google.cloud.spanner_dbapi import Connection

connection = Connection(INSTANCE)

with pytest.raises(ValueError):
connection._session_checkout()

@mock.patch("google.cloud.spanner_v1.database.Database")
def test__release_session(self, mock_database):
from google.cloud.spanner_dbapi import Connection
Expand All @@ -182,6 +190,13 @@ def test__release_session(self, mock_database):
pool.put.assert_called_once_with("session")
self.assertIsNone(connection._session)

def test__release_session_database_error(self):
from google.cloud.spanner_dbapi import Connection

connection = Connection(INSTANCE)
with pytest.raises(ValueError):
connection._release_session()

def test_transaction_checkout(self):
from google.cloud.spanner_dbapi import Connection

Expand Down Expand Up @@ -294,6 +309,14 @@ def test_commit(self, mock_warn):
AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2
)

def test_commit_database_error(self):
from google.cloud.spanner_dbapi import Connection

connection = Connection(INSTANCE)

with pytest.raises(ValueError):
connection.commit()

@mock.patch.object(warnings, "warn")
def test_rollback(self, mock_warn):
from google.cloud.spanner_dbapi import Connection
Expand Down Expand Up @@ -347,6 +370,13 @@ def test_run_prior_DDL_statements(self, mock_database):
with self.assertRaises(InterfaceError):
connection.run_prior_DDL_statements()

def test_run_prior_DDL_statements_database_error(self):
from google.cloud.spanner_dbapi import Connection

connection = Connection(INSTANCE)
with pytest.raises(ValueError):
connection.run_prior_DDL_statements()

def test_as_context_manager(self):
connection = self._make_connection()
with connection as conn:
Expand Down Expand Up @@ -766,6 +796,14 @@ def test_validate_error(self):

snapshot_obj.execute_sql.assert_called_once_with("SELECT 1")

def test_validate_database_error(self):
from google.cloud.spanner_dbapi import Connection

connection = Connection(INSTANCE)

with pytest.raises(ValueError):
connection.validate()

def test_validate_closed(self):
from google.cloud.spanner_dbapi.exceptions import InterfaceError

Expand Down Expand Up @@ -916,16 +954,14 @@ def test_request_priority(self):
sql, params, param_types=param_types, request_options=None
)

@mock.patch("google.cloud.spanner_v1.Client")
def test_custom_client_connection(self, mock_client):
def test_custom_client_connection(self):
from google.cloud.spanner_dbapi import connect

client = _Client()
connection = connect("test-instance", "test-database", client=client)
self.assertTrue(connection.instance._client == client)

@mock.patch("google.cloud.spanner_v1.Client")
def test_invalid_custom_client_connection(self, mock_client):
def test_invalid_custom_client_connection(self):
from google.cloud.spanner_dbapi import connect

client = _Client()
Expand All @@ -937,6 +973,12 @@ def test_invalid_custom_client_connection(self, mock_client):
client=client,
)

def test_connection_wo_database(self):
from google.cloud.spanner_dbapi import connect

connection = connect("test-instance")
self.assertTrue(connection.database is None)


def exit_ctx_func(self, exc_type, exc_value, traceback):
"""Context __exit__ method mock."""
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/spanner_dbapi/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ def test_execute_attribute_error(self):
with self.assertRaises(AttributeError):
cursor.execute(sql="SELECT 1")

def test_execute_database_error(self):
connection = self._make_connection(self.INSTANCE)
cursor = self._make_one(connection)

with self.assertRaises(ValueError):
cursor.execute(sql="SELECT 1")

def test_execute_autocommit_off(self):
from google.cloud.spanner_dbapi.utils import PeekIterator

Expand Down Expand Up @@ -607,6 +614,16 @@ def test_executemany_insert_batch_aborted(self):
)
self.assertIsInstance(connection._statements[0][1], ResultsChecksum)

@mock.patch("google.cloud.spanner_v1.Client")
def test_executemany_database_error(self, mock_client):
from google.cloud.spanner_dbapi import connect

connection = connect("test-instance")
cursor = connection.cursor()

with self.assertRaises(ValueError):
cursor.executemany("""SELECT * FROM table1 WHERE "col1" = @a1""", ())

@unittest.skipIf(
sys.version_info[0] < 3, "Python 2 has an outdated iterator definition"
)
Expand Down Expand Up @@ -754,6 +771,13 @@ def test_handle_dql_priority(self):
sql, None, None, request_options=RequestOptions(priority=1)
)

def test_handle_dql_database_error(self):
connection = self._make_connection(self.INSTANCE)
cursor = self._make_one(connection)

with self.assertRaises(ValueError):
cursor._handle_DQL("sql", params=None)

def test_context(self):
connection = self._make_connection(self.INSTANCE, self.DATABASE)
cursor = self._make_one(connection)
Expand Down Expand Up @@ -814,6 +838,13 @@ def test_run_sql_in_snapshot(self):
mock_snapshot.execute_sql.return_value = results
self.assertEqual(cursor.run_sql_in_snapshot("sql"), list(results))

def test_run_sql_in_snapshot_database_error(self):
connection = self._make_connection(self.INSTANCE)
cursor = self._make_one(connection)

with self.assertRaises(ValueError):
cursor.run_sql_in_snapshot("sql")

def test_get_table_column_schema(self):
from google.cloud.spanner_dbapi.cursor import ColumnDetails
from google.cloud.spanner_dbapi import _helpers
Expand Down