diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index aa6130e3f8016..afb3ac74b6d29 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -17,7 +17,9 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any, Sequence +import time +import warnings +from typing import TYPE_CHECKING, Any, Callable, Sequence from botocore.exceptions import ClientError @@ -106,6 +108,41 @@ def _create_integer_fields(self) -> None: """ self.integer_fields = [] + def _get_unique_job_name( + self, proposed_name: str, fail_if_exists: bool, describe_func: Callable[[str], Any] + ) -> str: + """ + Returns the proposed name if it doesn't already exist, otherwise returns it with a timestamp suffix. + + :param proposed_name: Base name. + :param fail_if_exists: Will throw an error if a job with that name already exists + instead of finding a new name. + :param describe_func: The `describe_` function for that kind of job. + We use it as an O(1) way to check if a job exists. + """ + job_name = proposed_name + while self._check_if_job_exists(job_name, describe_func): + # this while should loop only once in most cases, just setting it this way to regenerate a name + # in case there is collision. + if fail_if_exists: + raise AirflowException(f"A SageMaker job with name {job_name} already exists.") + else: + job_name = f"{proposed_name}-{time.time_ns()//1000000}" + self.log.info("Changed job name to '%s' to avoid collision.", job_name) + return job_name + + def _check_if_job_exists(self, job_name, describe_func: Callable[[str], Any]) -> bool: + """Returns True if job exists, False otherwise.""" + try: + describe_func(job_name) + self.log.info("Found existing job with name '%s'.", job_name) + return True + except ClientError as e: + if e.response["Error"]["Code"] == "ValidationException": + return False # ValidationException is thrown when the job could not be found + else: + raise e + def execute(self, context: Context): raise NotImplementedError("Please implement execute() in sub class!") @@ -137,8 +174,8 @@ class SageMakerProcessingOperator(SageMakerBaseOperator): :param max_ingestion_time: If wait is set to True, the operation fails if the processing job doesn't finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. - :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" - (default) and "fail". + :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "timestamp" + (default), "increment" (deprecated) and "fail". :return Dict: Returns The ARN of the processing job created in Amazon SageMaker. """ @@ -151,15 +188,22 @@ def __init__( print_log: bool = True, check_interval: int = CHECK_INTERVAL_SECOND, max_ingestion_time: int | None = None, - action_if_job_exists: str = "increment", + action_if_job_exists: str = "timestamp", **kwargs, ): super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) - if action_if_job_exists not in ("increment", "fail"): + if action_if_job_exists not in ("increment", "fail", "timestamp"): raise AirflowException( - f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \ + f"Argument action_if_job_exists accepts only 'timestamp', 'increment' and 'fail'. \ Provided value: '{action_if_job_exists}'." ) + if action_if_job_exists == "increment": + warnings.warn( + "Action 'increment' on job name conflict has been deprecated for performance reasons." + "The alternative to 'fail' is now 'timestamp'.", + DeprecationWarning, + stacklevel=2, + ) self.action_if_job_exists = action_if_job_exists self.wait_for_completion = wait_for_completion self.print_log = print_log @@ -183,21 +227,12 @@ def expand_role(self) -> None: def execute(self, context: Context) -> dict: self.preprocess_config() - processing_job_name = self.config["ProcessingJobName"] - processing_job_dedupe_pattern = "-[0-9]+$" - existing_jobs_found = self.hook.count_processing_jobs_by_name( - processing_job_name, processing_job_dedupe_pattern + + self.config["ProcessingJobName"] = self._get_unique_job_name( + self.config["ProcessingJobName"], + self.action_if_job_exists == "fail", + self.hook.describe_processing_job, ) - if existing_jobs_found: - if self.action_if_job_exists == "fail": - raise AirflowException( - f"A SageMaker processing job with name {processing_job_name} already exists." - ) - elif self.action_if_job_exists == "increment": - self.log.info("Found existing processing job with name '%s'.", processing_job_name) - new_processing_job_name = f"{processing_job_name}-{existing_jobs_found + 1}" - self.config["ProcessingJobName"] = new_processing_job_name - self.log.info("Incremented processing job name to '%s'.", new_processing_job_name) response = self.hook.create_processing_job( self.config, @@ -423,8 +458,8 @@ class SageMakerTransformOperator(SageMakerBaseOperator): set this parameter to None, the operation does not timeout. :param check_if_job_exists: If set to true, then the operator will check whether a transform job already exists for the name in the config. - :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" - (default) and "fail". + :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "timestamp" + (default), "increment" (deprecated) and "fail". This is only relevant if check_if_job_exists is True. :return Dict: Returns The ARN of the model created in Amazon SageMaker. """ @@ -438,7 +473,7 @@ def __init__( check_interval: int = CHECK_INTERVAL_SECOND, max_ingestion_time: int | None = None, check_if_job_exists: bool = True, - action_if_job_exists: str = "increment", + action_if_job_exists: str = "timestamp", **kwargs, ): super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) @@ -446,11 +481,18 @@ def __init__( self.check_interval = check_interval self.max_ingestion_time = max_ingestion_time self.check_if_job_exists = check_if_job_exists - if action_if_job_exists in ("increment", "fail"): + if action_if_job_exists in ("increment", "fail", "timestamp"): + if action_if_job_exists == "increment": + warnings.warn( + "Action 'increment' on job name conflict has been deprecated for performance reasons." + "The alternative to 'fail' is now 'timestamp'.", + DeprecationWarning, + stacklevel=2, + ) self.action_if_job_exists = action_if_job_exists else: raise AirflowException( - f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \ + f"Argument action_if_job_exists accepts only 'timestamp', 'increment' and 'fail'. \ Provided value: '{action_if_job_exists}'." ) @@ -476,13 +518,20 @@ def expand_role(self) -> None: def execute(self, context: Context) -> dict: self.preprocess_config() - model_config = self.config.get("Model") + transform_config = self.config.get("Transform", self.config) if self.check_if_job_exists: - self._check_if_transform_job_exists() + transform_config["TransformJobName"] = self._get_unique_job_name( + transform_config["TransformJobName"], + self.action_if_job_exists == "fail", + self.hook.describe_transform_job, + ) + + model_config = self.config.get("Model") if model_config: self.log.info("Creating SageMaker Model %s for transform job", model_config["ModelName"]) self.hook.create_model(model_config) + self.log.info("Creating SageMaker transform Job %s.", transform_config["TransformJobName"]) response = self.hook.create_transform_job( transform_config, @@ -500,21 +549,6 @@ def execute(self, context: Context) -> dict: ), } - def _check_if_transform_job_exists(self) -> None: - transform_config = self.config.get("Transform", self.config) - transform_job_name = transform_config["TransformJobName"] - transform_jobs = self.hook.list_transform_jobs(name_contains=transform_job_name) - if transform_job_name in [tj["TransformJobName"] for tj in transform_jobs]: - if self.action_if_job_exists == "increment": - self.log.info("Found existing transform job with name '%s'.", transform_job_name) - new_transform_job_name = f"{transform_job_name}-{(len(transform_jobs) + 1)}" - transform_config["TransformJobName"] = new_transform_job_name - self.log.info("Incremented transform job name to '%s'.", new_transform_job_name) - elif self.action_if_job_exists == "fail": - raise AirflowException( - f"A SageMaker transform job with name {transform_job_name} already exists." - ) - class SageMakerTuningOperator(SageMakerBaseOperator): """ @@ -654,8 +688,8 @@ class SageMakerTrainingOperator(SageMakerBaseOperator): the operation does not timeout. :param check_if_job_exists: If set to true, then the operator will check whether a training job already exists for the name in the config. - :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment" - (default) and "fail". + :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "timestamp" + (default), "increment" (deprecated) and "fail". This is only relevant if check_if_job_exists is True. :return Dict: Returns The ARN of the training job created in Amazon SageMaker. """ @@ -670,7 +704,7 @@ def __init__( check_interval: int = CHECK_INTERVAL_SECOND, max_ingestion_time: int | None = None, check_if_job_exists: bool = True, - action_if_job_exists: str = "increment", + action_if_job_exists: str = "timestamp", **kwargs, ): super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) @@ -679,11 +713,18 @@ def __init__( self.check_interval = check_interval self.max_ingestion_time = max_ingestion_time self.check_if_job_exists = check_if_job_exists - if action_if_job_exists in ("increment", "fail"): + if action_if_job_exists in {"timestamp", "increment", "fail"}: + if action_if_job_exists == "increment": + warnings.warn( + "Action 'increment' on job name conflict has been deprecated for performance reasons." + "The alternative to 'fail' is now 'timestamp'.", + DeprecationWarning, + stacklevel=2, + ) self.action_if_job_exists = action_if_job_exists else: raise AirflowException( - f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \ + f"Argument action_if_job_exists accepts only 'timestamp', 'increment' and 'fail'. \ Provided value: '{action_if_job_exists}'." ) @@ -703,8 +744,14 @@ def _create_integer_fields(self) -> None: def execute(self, context: Context) -> dict: self.preprocess_config() + if self.check_if_job_exists: - self._check_if_job_exists() + self.config["TrainingJobName"] = self._get_unique_job_name( + self.config["TrainingJobName"], + self.action_if_job_exists == "fail", + self.hook.describe_training_job, + ) + self.log.info("Creating SageMaker training job %s.", self.config["TrainingJobName"]) response = self.hook.create_training_job( self.config, @@ -718,20 +765,6 @@ def execute(self, context: Context) -> dict: else: return {"Training": serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))} - def _check_if_job_exists(self) -> None: - training_job_name = self.config["TrainingJobName"] - training_jobs = self.hook.list_training_jobs(name_contains=training_job_name) - if training_job_name in [tj["TrainingJobName"] for tj in training_jobs]: - if self.action_if_job_exists == "increment": - self.log.info("Found existing training job with name '%s'.", training_job_name) - new_training_job_name = f"{training_job_name}-{(len(training_jobs) + 1)}" - self.config["TrainingJobName"] = new_training_job_name - self.log.info("Incremented training job name to '%s'.", new_training_job_name) - elif self.action_if_job_exists == "fail": - raise AirflowException( - f"A SageMaker training job with name {training_job_name} already exists." - ) - class SageMakerDeleteModelOperator(SageMakerBaseOperator): """ diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_base.py b/tests/providers/amazon/aws/operators/test_sagemaker_base.py index 0ecab180ca099..7e5bb06ffe0ab 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_base.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_base.py @@ -16,11 +16,15 @@ # under the License. from __future__ import annotations +import re from typing import Any from unittest import mock -from unittest.mock import patch +from unittest.mock import MagicMock, patch -from airflow import DAG +import pytest +from botocore.exceptions import ClientError + +from airflow import DAG, AirflowException from airflow.models import DagRun, TaskInstance from airflow.providers.amazon.aws.operators.sagemaker import ( SageMakerBaseOperator, @@ -35,6 +39,8 @@ class TestSageMakerBaseOperator: + ERROR_WHEN_RESOURCE_NOT_FOUND = ClientError({"Error": {"Code": "ValidationException"}}, "op") + def setup_method(self): self.sagemaker = SageMakerBaseOperator(task_id="test_sagemaker_operator", config=CONFIG) self.sagemaker.aws_conn_id = "aws_default" @@ -46,9 +52,33 @@ def test_parse_integer(self): def test_default_integer_fields(self): self.sagemaker.preprocess_config() - assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS + def test_job_exists(self): + exists = self.sagemaker._check_if_job_exists("the name", lambda _: {}) + assert exists + + def test_job_does_not_exists(self): + def raiser(_): + raise self.ERROR_WHEN_RESOURCE_NOT_FOUND + + exists = self.sagemaker._check_if_job_exists("the name", raiser) + assert not exists + + def test_job_renamed(self): + describe_mock = MagicMock() + # scenario : name exists, new proposed name exists as well, second proposal is ok + describe_mock.side_effect = [None, None, self.ERROR_WHEN_RESOURCE_NOT_FOUND] + + name = self.sagemaker._get_unique_job_name("test", False, describe_mock) + + assert describe_mock.call_count == 3 + assert re.match("test-[0-9]+$", name) + + def test_job_not_unique_with_fail(self): + with pytest.raises(AirflowException): + self.sagemaker._get_unique_job_name("test", True, lambda _: None) + class TestSageMakerExperimentOperator: @patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.conn", new_callable=mock.PropertyMock) diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py index 431a91c4cfe31..19be1eabb150e 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_processing.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_processing.py @@ -19,6 +19,7 @@ from unittest import mock import pytest +from botocore.exceptions import ClientError from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook @@ -89,10 +90,12 @@ class TestSageMakerProcessingOperator: def setup_method(self): self.processing_config_kwargs = dict( - task_id="test_sagemaker_operator", wait_for_completion=False, check_interval=5 + task_id="test_sagemaker_operator", + wait_for_completion=False, + check_interval=5, ) - @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "describe_processing_job") @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=0) @mock.patch.object( SageMakerHook, @@ -100,18 +103,19 @@ def setup_method(self): return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}, ) @mock.patch.object(sagemaker, "serialize", return_value="") - def test_integer_fields_without_stopping_condition( - self, serialize, mock_processing, mock_hook, mock_client - ): + def test_integer_fields_without_stopping_condition(self, _, __, ___, mock_desc): + mock_desc.side_effect = [ClientError({"Error": {"Code": "ValidationException"}}, "op"), None] sagemaker = SageMakerProcessingOperator( **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS ) + sagemaker.execute(None) + assert sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS for (key1, key2, key3) in EXPECTED_INTEGER_FIELDS: assert sagemaker.config[key1][key2][key3] == int(sagemaker.config[key1][key2][key3]) - @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "describe_processing_job") @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=0) @mock.patch.object( SageMakerHook, @@ -119,7 +123,8 @@ def test_integer_fields_without_stopping_condition( return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}, ) @mock.patch.object(sagemaker, "serialize", return_value="") - def test_integer_fields_with_stopping_condition(self, serialize, mock_processing, mock_hook, mock_client): + def test_integer_fields_with_stopping_condition(self, _, __, ___, mock_desc): + mock_desc.side_effect = [ClientError({"Error": {"Code": "ValidationException"}}, "op"), None] sagemaker = SageMakerProcessingOperator( **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS_WITH_STOPPING_CONDITION ) @@ -134,7 +139,7 @@ def test_integer_fields_with_stopping_condition(self, serialize, mock_processing else: sagemaker.config[key1][key2] == int(sagemaker.config[key1][key2]) - @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "describe_processing_job") @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=0) @mock.patch.object( SageMakerHook, @@ -142,7 +147,8 @@ def test_integer_fields_with_stopping_condition(self, serialize, mock_processing return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}, ) @mock.patch.object(sagemaker, "serialize", return_value="") - def test_execute(self, serialize, mock_processing, mock_hook, mock_client): + def test_execute(self, _, mock_processing, __, mock_desc): + mock_desc.side_effect = [ClientError({"Error": {"Code": "ValidationException"}}, "op"), None] sagemaker = SageMakerProcessingOperator( **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS ) @@ -151,7 +157,7 @@ def test_execute(self, serialize, mock_processing, mock_hook, mock_client): CREATE_PROCESSING_PARAMS, wait_for_completion=False, check_interval=5, max_ingestion_time=None ) - @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "describe_processing_job") @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=0) @mock.patch.object( SageMakerHook, @@ -159,7 +165,8 @@ def test_execute(self, serialize, mock_processing, mock_hook, mock_client): return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}, ) @mock.patch.object(sagemaker, "serialize", return_value="") - def test_execute_with_stopping_condition(self, serialize, mock_processing, mock_hook, mock_client): + def test_execute_with_stopping_condition(self, _, mock_processing, __, mock_desc): + mock_desc.side_effect = [ClientError({"Error": {"Code": "ValidationException"}}, "op"), None] sagemaker = SageMakerProcessingOperator( **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS_WITH_STOPPING_CONDITION ) @@ -171,37 +178,36 @@ def test_execute_with_stopping_condition(self, serialize, mock_processing, mock_ max_ingestion_time=None, ) - @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "describe_processing_job") @mock.patch.object( SageMakerHook, "create_processing_job", return_value={"ProcessingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 404}}, ) - def test_execute_with_failure(self, mock_processing, mock_client): + def test_execute_with_failure(self, _, mock_desc): + mock_desc.side_effect = [ClientError({"Error": {"Code": "ValidationException"}}, "op"), None] sagemaker = SageMakerProcessingOperator( **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS ) with pytest.raises(AirflowException): sagemaker.execute(None) - @pytest.mark.skip("Currently, the auto-increment jobname functionality is not missing.") - @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "describe_processing_job") @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=1) @mock.patch.object( SageMakerHook, "create_processing_job", return_value={"ResponseMetadata": {"HTTPStatusCode": 200}} ) - def test_execute_with_existing_job_increment( - self, mock_create_processing_job, count_processing_jobs_by_name, mock_client - ): + def test_execute_with_existing_job_timestamp(self, mock_create_processing_job, _, mock_desc): + mock_desc.side_effect = [None, ClientError({"Error": {"Code": "ValidationException"}}, "op"), None] sagemaker = SageMakerProcessingOperator( **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS ) - sagemaker.action_if_job_exists = "increment" + sagemaker.action_if_job_exists = "timestamp" sagemaker.execute(None) expected_config = CREATE_PROCESSING_PARAMS.copy() - # Expect to see ProcessingJobName suffixed with "-2" because we return one existing job - expected_config["ProcessingJobName"] = "job_name-2" + # Expect to see ProcessingJobName suffixed because we return one existing job + expected_config["ProcessingJobName"].startswith("job_name-") mock_create_processing_job.assert_called_once_with( expected_config, wait_for_completion=False, @@ -209,14 +215,12 @@ def test_execute_with_existing_job_increment( max_ingestion_time=None, ) - @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "describe_processing_job") @mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=1) @mock.patch.object( SageMakerHook, "create_processing_job", return_value={"ResponseMetadata": {"HTTPStatusCode": 200}} ) - def test_execute_with_existing_job_fail( - self, mock_create_processing_job, mock_list_processing_jobs, mock_client - ): + def test_execute_with_existing_job_fail(self, _, __, ___): sagemaker = SageMakerProcessingOperator( **self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS ) @@ -224,7 +228,7 @@ def test_execute_with_existing_job_fail( with pytest.raises(AirflowException): sagemaker.execute(None) - @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "describe_processing_job") def test_action_if_job_exists_validation(self, mock_client): with pytest.raises(AirflowException): SageMakerProcessingOperator( diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_training.py b/tests/providers/amazon/aws/operators/test_sagemaker_training.py index 34398c9a78ab7..82b76fdee0422 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_training.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_training.py @@ -19,6 +19,7 @@ from unittest import mock import pytest +from botocore.exceptions import ClientError from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook @@ -65,43 +66,24 @@ def setup_method(self): check_interval=5, ) - @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "describe_training_job") @mock.patch.object(SageMakerHook, "create_training_job") @mock.patch.object(sagemaker, "serialize", return_value="") - def test_integer_fields(self, serialize, mock_training, mock_client): + def test_integer_fields(self, _, mock_training, mock_desc): + mock_desc.side_effect = [ClientError({"Error": {"Code": "ValidationException"}}, "op"), None] mock_training.return_value = { "TrainingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}, } - self.sagemaker._check_if_job_exists = mock.MagicMock() self.sagemaker.execute(None) assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS for (key1, key2) in EXPECTED_INTEGER_FIELDS: assert self.sagemaker.config[key1][key2] == int(self.sagemaker.config[key1][key2]) - @mock.patch.object(SageMakerHook, "get_conn") - @mock.patch.object(SageMakerHook, "create_training_job") - @mock.patch.object(sagemaker, "serialize", return_value="") - def test_execute_with_check_if_job_exists(self, serialize, mock_training, mock_client): - mock_training.return_value = { - "TrainingJobArn": "test_arn", - "ResponseMetadata": {"HTTPStatusCode": 200}, - } - self.sagemaker._check_if_job_exists = mock.MagicMock() - self.sagemaker.execute(None) - self.sagemaker._check_if_job_exists.assert_called_once() - mock_training.assert_called_once_with( - CREATE_TRAINING_PARAMS, - wait_for_completion=False, - print_log=True, - check_interval=5, - max_ingestion_time=None, - ) - - @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "describe_training_job") @mock.patch.object(SageMakerHook, "create_training_job") @mock.patch.object(sagemaker, "serialize", return_value="") - def test_execute_without_check_if_job_exists(self, serialize, mock_training, mock_client): + def test_execute_without_check_if_job_exists(self, _, mock_training, mock_desc): mock_training.return_value = { "TrainingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}, @@ -118,34 +100,16 @@ def test_execute_without_check_if_job_exists(self, serialize, mock_training, moc max_ingestion_time=None, ) - @mock.patch.object(SageMakerHook, "get_conn") + a = [] + a.sort() + + @mock.patch.object(SageMakerHook, "describe_training_job") @mock.patch.object(SageMakerHook, "create_training_job") - def test_execute_with_failure(self, mock_training, mock_client): + def test_execute_with_failure(self, mock_training, mock_desc): + mock_desc.side_effect = [ClientError({"Error": {"Code": "ValidationException"}}, "op")] mock_training.return_value = { "TrainingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 404}, } with pytest.raises(AirflowException): self.sagemaker.execute(None) - - @mock.patch.object(SageMakerHook, "get_conn") - @mock.patch.object(SageMakerHook, "list_training_jobs") - def test_check_if_job_exists_increment(self, mock_list_training_jobs, mock_client): - self.sagemaker.check_if_job_exists = True - self.sagemaker.action_if_job_exists = "increment" - mock_list_training_jobs.return_value = [{"TrainingJobName": "job_name"}] - self.sagemaker._check_if_job_exists() - - expected_config = CREATE_TRAINING_PARAMS.copy() - # Expect to see TrainingJobName suffixed with "-2" because we return one existing job - expected_config["TrainingJobName"] = "job_name-2" - assert self.sagemaker.config == expected_config - - @mock.patch.object(SageMakerHook, "get_conn") - @mock.patch.object(SageMakerHook, "list_training_jobs") - def test_check_if_job_exists_fail(self, mock_list_training_jobs, mock_client): - self.sagemaker.check_if_job_exists = True - self.sagemaker.action_if_job_exists = "fail" - mock_list_training_jobs.return_value = [{"TrainingJobName": "job_name"}] - with pytest.raises(AirflowException): - self.sagemaker._check_if_job_exists() diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py index d7bbc5b4b928d..482c7201ccc41 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py @@ -21,6 +21,7 @@ from unittest import mock import pytest +from botocore.exceptions import ClientError from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook @@ -73,12 +74,14 @@ def setup_method(self): check_interval=5, ) - @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "describe_transform_job") @mock.patch.object(SageMakerHook, "create_model") + @mock.patch.object(SageMakerHook, "describe_model") @mock.patch.object(SageMakerHook, "create_transform_job") @mock.patch.object(sagemaker, "serialize", return_value="") - def test_integer_fields(self, serialize, mock_transform, mock_model, mock_client): - mock_transform.return_value = { + def test_integer_fields(self, _, mock_create_transform, __, ___, mock_desc): + mock_desc.side_effect = [ClientError({"Error": {"Code": "ValidationException"}}, "op"), None] + mock_create_transform.return_value = { "TransformJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}, } @@ -91,11 +94,13 @@ def test_integer_fields(self, serialize, mock_transform, mock_model, mock_client else: self.sagemaker.config[key1][key2] == int(self.sagemaker.config[key1][key2]) - @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "describe_transform_job") @mock.patch.object(SageMakerHook, "create_model") + @mock.patch.object(SageMakerHook, "describe_model") @mock.patch.object(SageMakerHook, "create_transform_job") @mock.patch.object(sagemaker, "serialize", return_value="") - def test_execute(self, serialize, mock_transform, mock_model, mock_client): + def test_execute(self, _, mock_transform, __, mock_model, mock_desc): + mock_desc.side_effect = [ClientError({"Error": {"Code": "ValidationException"}}, "op"), None] mock_transform.return_value = { "TransformJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}, @@ -109,10 +114,11 @@ def test_execute(self, serialize, mock_transform, mock_model, mock_client): max_ingestion_time=None, ) - @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "describe_transform_job") @mock.patch.object(SageMakerHook, "create_model") @mock.patch.object(SageMakerHook, "create_transform_job") - def test_execute_with_failure(self, mock_transform, mock_model, mock_client): + def test_execute_with_failure(self, mock_transform, _, mock_desc): + mock_desc.side_effect = [ClientError({"Error": {"Code": "ValidationException"}}, "op"), None] mock_transform.return_value = { "TransformJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 404}, @@ -120,17 +126,18 @@ def test_execute_with_failure(self, mock_transform, mock_model, mock_client): with pytest.raises(AirflowException): self.sagemaker.execute(None) - @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "describe_transform_job") @mock.patch.object(SageMakerHook, "create_transform_job") + @mock.patch.object(SageMakerHook, "create_model") + @mock.patch.object(SageMakerHook, "describe_model") @mock.patch.object(sagemaker, "serialize", return_value="") - def test_execute_with_check_if_job_exists(self, serialize, mock_transform, mock_client): + def test_execute_with_check_if_job_exists(self, _, __, ___, mock_transform, mock_desc): + mock_desc.side_effect = [ClientError({"Error": {"Code": "ValidationException"}}, "op"), None] mock_transform.return_value = { "TransformJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}, } - self.sagemaker._check_if_transform_job_exists = mock.MagicMock() self.sagemaker.execute(None) - self.sagemaker._check_if_transform_job_exists.assert_called_once() mock_transform.assert_called_once_with( CREATE_TRANSFORM_PARAMS_INTEGER_FIELDS, wait_for_completion=False, @@ -138,43 +145,21 @@ def test_execute_with_check_if_job_exists(self, serialize, mock_transform, mock_ max_ingestion_time=None, ) - @mock.patch.object(SageMakerHook, "get_conn") + @mock.patch.object(SageMakerHook, "describe_transform_job") @mock.patch.object(SageMakerHook, "create_transform_job") + @mock.patch.object(SageMakerHook, "create_model") + @mock.patch.object(SageMakerHook, "describe_model") @mock.patch.object(sagemaker, "serialize", return_value="") - def test_execute_without_check_if_job_exists(self, serialize, mock_transform, mock_client): + def test_execute_without_check_if_job_exists(self, _, __, ___, mock_transform, ____): mock_transform.return_value = { "TransformJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}, } self.sagemaker.check_if_job_exists = False - self.sagemaker._check_if_transform_job_exists = mock.MagicMock() self.sagemaker.execute(None) - self.sagemaker._check_if_transform_job_exists.assert_not_called() mock_transform.assert_called_once_with( CREATE_TRANSFORM_PARAMS_INTEGER_FIELDS, wait_for_completion=False, check_interval=5, max_ingestion_time=None, ) - - @mock.patch.object(SageMakerHook, "get_conn") - @mock.patch.object(SageMakerHook, "list_transform_jobs") - def test_check_if_job_exists_increment(self, mock_list_transform_jobs, mock_client): - self.sagemaker.check_if_job_exists = True - self.sagemaker.action_if_job_exists = "increment" - mock_list_transform_jobs.return_value = [{"TransformJobName": "job_name"}] - self.sagemaker._check_if_transform_job_exists() - - expected_config = copy.deepcopy(CONFIG) - # Expect to see TransformJobName suffixed with "-2" because we return one existing job - expected_config["Transform"]["TransformJobName"] = "job_name-2" - assert self.sagemaker.config == expected_config - - @mock.patch.object(SageMakerHook, "get_conn") - @mock.patch.object(SageMakerHook, "list_transform_jobs") - def test_check_if_job_exists_fail(self, mock_list_transform_jobs, mock_client): - self.sagemaker.check_if_job_exists = True - self.sagemaker.action_if_job_exists = "fail" - mock_list_transform_jobs.return_value = [{"TransformJobName": "job_name"}] - with pytest.raises(AirflowException): - self.sagemaker._check_if_transform_job_exists()