diff --git a/airflow/providers/mongo/hooks/mongo.py b/airflow/providers/mongo/hooks/mongo.py index 90798d71ced06..63b981df9ef08 100644 --- a/airflow/providers/mongo/hooks/mongo.py +++ b/airflow/providers/mongo/hooks/mongo.py @@ -32,6 +32,8 @@ if TYPE_CHECKING: from types import TracebackType + from pymongo.collection import Collection as MongoCollection + from pymongo.command_cursor import CommandCursor from typing_extensions import Literal from airflow.models import Connection @@ -218,9 +220,7 @@ def _create_uri(self) -> str: path = f"/{self.connection.schema}" return urlunsplit((scheme, netloc, path, "", "")) - def get_collection( - self, mongo_collection: str, mongo_db: str | None = None - ) -> pymongo.collection.Collection: + def get_collection(self, mongo_collection: str, mongo_db: str | None = None) -> MongoCollection: """ Fetch a mongo collection object for querying. @@ -233,7 +233,7 @@ def get_collection( def aggregate( self, mongo_collection: str, aggregate_query: list, mongo_db: str | None = None, **kwargs - ) -> pymongo.command_cursor.CommandCursor: + ) -> CommandCursor: """ Run an aggregation pipeline and returns the results.