Skip to content
This repository was archived by the owner on Mar 6, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 50 additions & 19 deletions google/cloud/bigquery/_pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from itertools import islice
import logging
import queue
import threading
import warnings
from typing import Any, Union, Optional, Callable, Generator, List

Expand Down Expand Up @@ -119,6 +120,21 @@ def __init__(self):
# be an atomic operation in the Python language definition (enforced by
# the global interpreter lock).
self.done = False
# To assist with testing and understanding the behavior of the
# download, use this object as shared state to track how many worker
# threads have started and have gracefully shutdown.
self._started_workers_lock = threading.Lock()
self.started_workers = 0
self._finished_workers_lock = threading.Lock()
self.finished_workers = 0

def start(self):
with self._started_workers_lock:
self.started_workers += 1

def finish(self):
with self._finished_workers_lock:
self.finished_workers += 1


BQ_FIELD_TYPE_TO_ARROW_FIELD_METADATA = {
Expand Down Expand Up @@ -786,25 +802,35 @@ def _bqstorage_page_to_dataframe(column_names, dtypes, page):
def _download_table_bqstorage_stream(
download_state, bqstorage_client, session, stream, worker_queue, page_to_item
):
reader = bqstorage_client.read_rows(stream.name)
download_state.start()
try:
reader = bqstorage_client.read_rows(stream.name)

# Avoid deprecation warnings for passing in unnecessary read session.
# https://github.com/googleapis/python-bigquery-storage/issues/229
if _versions_helpers.BQ_STORAGE_VERSIONS.is_read_session_optional:
rowstream = reader.rows()
else:
rowstream = reader.rows(session)

for page in rowstream.pages:
item = page_to_item(page)
while True:
if download_state.done:
return
try:
worker_queue.put(item, timeout=_PROGRESS_INTERVAL)
break
except queue.Full: # pragma: NO COVER
continue
# Avoid deprecation warnings for passing in unnecessary read session.
# https://github.com/googleapis/python-bigquery-storage/issues/229
if _versions_helpers.BQ_STORAGE_VERSIONS.is_read_session_optional:
rowstream = reader.rows()
else:
rowstream = reader.rows(session)

for page in rowstream.pages:
item = page_to_item(page)

# Make sure we set a timeout on put() so that we give the worker
# thread opportunities to shutdown gracefully, for example if the
# parent thread shuts down or the parent generator object which
# collects rows from all workers goes out of scope. See:
# https://github.com/googleapis/python-bigquery/issues/2032
while True:
if download_state.done:
return
try:
worker_queue.put(item, timeout=_PROGRESS_INTERVAL)
break
except queue.Full:
continue
finally:
download_state.finish()


def _nowait(futures):
Expand All @@ -830,6 +856,7 @@ def _download_table_bqstorage(
page_to_item: Optional[Callable] = None,
max_queue_size: Any = _MAX_QUEUE_SIZE_DEFAULT,
max_stream_count: Optional[int] = None,
download_state: Optional[_DownloadState] = None,
) -> Generator[Any, None, None]:
"""Downloads a BigQuery table using the BigQuery Storage API.

Expand Down Expand Up @@ -857,6 +884,9 @@ def _download_table_bqstorage(
is True, the requested streams are limited to 1 regardless of the
`max_stream_count` value. If 0 or None, then the number of
requested streams will be unbounded. Defaults to None.
download_state (Optional[_DownloadState]):
A threadsafe state object which can be used to observe the
behavior of the worker threads created by this method.

Yields:
pandas.DataFrame: Pandas DataFrames, one for each chunk of data
Expand Down Expand Up @@ -915,7 +945,8 @@ def _download_table_bqstorage(

# Use _DownloadState to notify worker threads when to quit.
# See: https://stackoverflow.com/a/29237343/101923
download_state = _DownloadState()
if download_state is None:
download_state = _DownloadState()

# Create a queue to collect frames as they are created in each thread.
#
Expand Down
15 changes: 1 addition & 14 deletions samples/tests/test_download_public_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import pytest

from .. import download_public_data

pytest.importorskip("google.cloud.bigquery_storage_v1")


def test_download_public_data(
caplog: pytest.LogCaptureFixture, capsys: pytest.CaptureFixture[str]
) -> None:
# Enable debug-level logging to verify the BigQuery Storage API is used.
caplog.set_level(logging.DEBUG)

def test_download_public_data(capsys: pytest.CaptureFixture[str]) -> None:
download_public_data.download_public_data()
out, _ = capsys.readouterr()
assert "year" in out
assert "gender" in out
assert "name" in out

assert any(
"Started reading table 'bigquery-public-data.usa_names.usa_1910_current' with BQ Storage API session"
in message
for message in caplog.messages
)
17 changes: 2 additions & 15 deletions samples/tests/test_download_public_data_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

import pytest

from .. import download_public_data_sandbox

pytest.importorskip("google.cloud.bigquery_storage_v1")


def test_download_public_data_sandbox(
caplog: pytest.LogCaptureFixture, capsys: pytest.CaptureFixture[str]
) -> None:
# Enable debug-level logging to verify the BigQuery Storage API is used.
caplog.set_level(logging.DEBUG)

def test_download_public_data_sandbox(capsys: pytest.CaptureFixture[str]) -> None:
download_public_data_sandbox.download_public_data_sandbox()
out, err = capsys.readouterr()
out, _ = capsys.readouterr()
assert "year" in out
assert "gender" in out
assert "name" in out

assert any(
# An anonymous table is used because this sample reads from query results.
("Started reading table" in message and "BQ Storage API session" in message)
for message in caplog.messages
)
93 changes: 93 additions & 0 deletions tests/unit/test__pandas_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import datetime
import decimal
import functools
import gc
import operator
import queue
from typing import Union
Expand Down Expand Up @@ -1846,6 +1847,98 @@ def fake_download_stream(
assert queue_used.maxsize == expected_maxsize


@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`")
def test__download_table_bqstorage_shuts_down_workers(
monkeypatch,
module_under_test,
):
"""Regression test for https://github.com/googleapis/python-bigquery/issues/2032

Make sure that when the top-level iterator goes out of scope (is deleted),
the child threads are also stopped.
"""
from google.cloud.bigquery import dataset
from google.cloud.bigquery import table
import google.cloud.bigquery_storage_v1.reader
import google.cloud.bigquery_storage_v1.types

monkeypatch.setattr(
_versions_helpers.BQ_STORAGE_VERSIONS, "_installed_version", None
)
monkeypatch.setattr(bigquery_storage, "__version__", "2.5.0")

# Create a fake stream with a decent number of rows.
arrow_schema = pyarrow.schema(
[
("int_col", pyarrow.int64()),
("str_col", pyarrow.string()),
]
)
arrow_rows = pyarrow.record_batch(
[
pyarrow.array([0, 1, 2], type=pyarrow.int64()),
pyarrow.array(["a", "b", "c"], type=pyarrow.string()),
],
schema=arrow_schema,
)
session = google.cloud.bigquery_storage_v1.types.ReadSession()
session.data_format = "ARROW"
session.arrow_schema = {"serialized_schema": arrow_schema.serialize().to_pybytes()}
session.streams = [
google.cloud.bigquery_storage_v1.types.ReadStream(name=name)
for name in ("stream/s0", "stream/s1", "stream/s2")
]
bqstorage_client = mock.create_autospec(
bigquery_storage.BigQueryReadClient, instance=True
)
reader = mock.create_autospec(
google.cloud.bigquery_storage_v1.reader.ReadRowsStream, instance=True
)
reader.__iter__.return_value = [
google.cloud.bigquery_storage_v1.types.ReadRowsResponse(
arrow_schema={"serialized_schema": arrow_schema.serialize().to_pybytes()},
arrow_record_batch={
"serialized_record_batch": arrow_rows.serialize().to_pybytes()
},
)
for _ in range(100)
]
reader.rows.return_value = google.cloud.bigquery_storage_v1.reader.ReadRowsIterable(
reader, read_session=session
)
bqstorage_client.read_rows.return_value = reader
bqstorage_client.create_read_session.return_value = session
table_ref = table.TableReference(
dataset.DatasetReference("project-x", "dataset-y"),
"table-z",
)
download_state = module_under_test._DownloadState()
assert download_state.started_workers == 0
assert download_state.finished_workers == 0

result_gen = module_under_test._download_table_bqstorage(
"some-project",
table_ref,
bqstorage_client,
max_queue_size=1,
page_to_item=module_under_test._bqstorage_page_to_arrow,
download_state=download_state,
)

result_gen_iter = iter(result_gen)
next(result_gen_iter)
assert download_state.started_workers == 3
assert download_state.finished_workers == 0

# Stop iteration early and simulate the variables going out of scope
# to be doubly sure that the worker threads are supposed to be cleaned up.
del result_gen, result_gen_iter
gc.collect()

assert download_state.started_workers == 3
assert download_state.finished_workers == 3


@pytest.mark.skipif(isinstance(pyarrow, mock.Mock), reason="Requires `pyarrow`")
def test_download_arrow_row_iterator_unknown_field_type(module_under_test):
fake_page = api_core.page_iterator.Page(
Expand Down