Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,12 @@ def __enter__(self):

def __exit__(self, exc_type, exc_val, exc_tb):
"""End ``with`` block."""
if isinstance(exc_val, NotFound):
# If NotFound exception occurs inside the with block
# then we validate if the session still exists.
if not self._session.exists():
self._session = self._database._pool._new_session()
self._session.create()
self._database._pool.put(self._session)


Expand Down
61 changes: 60 additions & 1 deletion tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import mock
from google.api_core import gapic_v1

from google.cloud.spanner_v1.param_types import INT64
from google.api_core.retry import Retry

Expand Down Expand Up @@ -1792,6 +1791,66 @@ class Testing(Exception):

self.assertIs(pool._session, session)

def test_context_mgr_session_not_found_error(self):
from google.cloud.exceptions import NotFound

database = _Database(self.DATABASE_NAME)
session = _Session(database, name="session-1")
session.exists = mock.MagicMock(return_value=False)
pool = database._pool = _Pool()
new_session = _Session(database, name="session-2")
new_session.create = mock.MagicMock(return_value=[])
pool._new_session = mock.MagicMock(return_value=new_session)

pool.put(session)
checkout = self._make_one(database)

self.assertEqual(pool._session, session)
with self.assertRaises(NotFound):
with checkout as _:
raise NotFound("Session not found")
# Assert that session-1 was removed from pool and new session was added.
self.assertEqual(pool._session, new_session)

def test_context_mgr_table_not_found_error(self):
from google.cloud.exceptions import NotFound

database = _Database(self.DATABASE_NAME)
session = _Session(database, name="session-1")
session.exists = mock.MagicMock(return_value=True)
pool = database._pool = _Pool()
pool._new_session = mock.MagicMock(return_value=[])

pool.put(session)
checkout = self._make_one(database)

self.assertEqual(pool._session, session)
with self.assertRaises(NotFound):
with checkout as _:
raise NotFound("Table not found")
# Assert that session-1 was not removed from pool.
self.assertEqual(pool._session, session)
pool._new_session.assert_not_called()

def test_context_mgr_unknown_error(self):
database = _Database(self.DATABASE_NAME)
session = _Session(database)
pool = database._pool = _Pool()
pool._new_session = mock.MagicMock(return_value=[])
pool.put(session)
checkout = self._make_one(database)

class Testing(Exception):
pass

self.assertEqual(pool._session, session)
with self.assertRaises(Testing):
with checkout as _:
raise Testing("Unknown error.")
# Assert that session-1 was not removed from pool.
self.assertEqual(pool._session, session)
pool._new_session.assert_not_called()


class TestBatchSnapshot(_BaseTest):
TABLE = "table_name"
Expand Down