Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 95 additions & 62 deletions airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, Sequence
import time
import warnings
from typing import TYPE_CHECKING, Any, Callable, Sequence

from botocore.exceptions import ClientError

Expand Down Expand Up @@ -106,6 +108,41 @@ def _create_integer_fields(self) -> None:
"""
self.integer_fields = []

def _get_unique_job_name(
self, proposed_name: str, fail_if_exists: bool, describe_func: Callable[[str], Any]
) -> str:
"""
Returns the proposed name if it doesn't already exist, otherwise returns it with a timestamp suffix.

:param proposed_name: Base name.
:param fail_if_exists: Will throw an error if a job with that name already exists
instead of finding a new name.
:param describe_func: The `describe_` function for that kind of job.
We use it as an O(1) way to check if a job exists.
"""
job_name = proposed_name
while self._check_if_job_exists(job_name, describe_func):
# this while should loop only once in most cases, just setting it this way to regenerate a name
# in case there is collision.
if fail_if_exists:
raise AirflowException(f"A SageMaker job with name {job_name} already exists.")
else:
job_name = f"{proposed_name}-{time.time_ns()//1000000}"
self.log.info("Changed job name to '%s' to avoid collision.", job_name)
return job_name

def _check_if_job_exists(self, job_name, describe_func: Callable[[str], Any]) -> bool:
"""Returns True if job exists, False otherwise."""
try:
describe_func(job_name)
self.log.info("Found existing job with name '%s'.", job_name)
return True
except ClientError as e:
if e.response["Error"]["Code"] == "ValidationException":
return False # ValidationException is thrown when the job could not be found
else:
raise e

def execute(self, context: Context):
raise NotImplementedError("Please implement execute() in sub class!")

Expand Down Expand Up @@ -137,8 +174,8 @@ class SageMakerProcessingOperator(SageMakerBaseOperator):
:param max_ingestion_time: If wait is set to True, the operation fails if the processing job
doesn't finish within max_ingestion_time seconds. If you set this parameter to None,
the operation does not timeout.
:param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment"
(default) and "fail".
:param action_if_job_exists: Behaviour if the job name already exists. Possible options are "timestamp"
(default), "increment" (deprecated) and "fail".
:return Dict: Returns The ARN of the processing job created in Amazon SageMaker.
"""

Expand All @@ -151,15 +188,22 @@ def __init__(
print_log: bool = True,
check_interval: int = CHECK_INTERVAL_SECOND,
max_ingestion_time: int | None = None,
action_if_job_exists: str = "increment",
action_if_job_exists: str = "timestamp",
**kwargs,
):
super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
if action_if_job_exists not in ("increment", "fail"):
if action_if_job_exists not in ("increment", "fail", "timestamp"):
raise AirflowException(
f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \
f"Argument action_if_job_exists accepts only 'timestamp', 'increment' and 'fail'. \
Provided value: '{action_if_job_exists}'."
)
if action_if_job_exists == "increment":
warnings.warn(
"Action 'increment' on job name conflict has been deprecated for performance reasons."
"The alternative to 'fail' is now 'timestamp'.",
DeprecationWarning,
stacklevel=2,
)
self.action_if_job_exists = action_if_job_exists
self.wait_for_completion = wait_for_completion
self.print_log = print_log
Expand All @@ -183,21 +227,12 @@ def expand_role(self) -> None:

def execute(self, context: Context) -> dict:
self.preprocess_config()
processing_job_name = self.config["ProcessingJobName"]
processing_job_dedupe_pattern = "-[0-9]+$"
existing_jobs_found = self.hook.count_processing_jobs_by_name(
processing_job_name, processing_job_dedupe_pattern

self.config["ProcessingJobName"] = self._get_unique_job_name(
self.config["ProcessingJobName"],
self.action_if_job_exists == "fail",
self.hook.describe_processing_job,
)
if existing_jobs_found:
if self.action_if_job_exists == "fail":
raise AirflowException(
f"A SageMaker processing job with name {processing_job_name} already exists."
)
elif self.action_if_job_exists == "increment":
self.log.info("Found existing processing job with name '%s'.", processing_job_name)
new_processing_job_name = f"{processing_job_name}-{existing_jobs_found + 1}"
self.config["ProcessingJobName"] = new_processing_job_name
self.log.info("Incremented processing job name to '%s'.", new_processing_job_name)

response = self.hook.create_processing_job(
self.config,
Expand Down Expand Up @@ -423,8 +458,8 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
set this parameter to None, the operation does not timeout.
:param check_if_job_exists: If set to true, then the operator will check whether a transform job
already exists for the name in the config.
:param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment"
(default) and "fail".
:param action_if_job_exists: Behaviour if the job name already exists. Possible options are "timestamp"
(default), "increment" (deprecated) and "fail".
This is only relevant if check_if_job_exists is True.
:return Dict: Returns The ARN of the model created in Amazon SageMaker.
"""
Expand All @@ -438,19 +473,26 @@ def __init__(
check_interval: int = CHECK_INTERVAL_SECOND,
max_ingestion_time: int | None = None,
check_if_job_exists: bool = True,
action_if_job_exists: str = "increment",
action_if_job_exists: str = "timestamp",
**kwargs,
):
super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
self.check_if_job_exists = check_if_job_exists
if action_if_job_exists in ("increment", "fail"):
if action_if_job_exists in ("increment", "fail", "timestamp"):
if action_if_job_exists == "increment":
warnings.warn(
"Action 'increment' on job name conflict has been deprecated for performance reasons."
"The alternative to 'fail' is now 'timestamp'.",
DeprecationWarning,
stacklevel=2,
)
self.action_if_job_exists = action_if_job_exists
else:
raise AirflowException(
f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \
f"Argument action_if_job_exists accepts only 'timestamp', 'increment' and 'fail'. \
Provided value: '{action_if_job_exists}'."
)

Expand All @@ -476,13 +518,20 @@ def expand_role(self) -> None:

def execute(self, context: Context) -> dict:
self.preprocess_config()
model_config = self.config.get("Model")

transform_config = self.config.get("Transform", self.config)
if self.check_if_job_exists:
self._check_if_transform_job_exists()
transform_config["TransformJobName"] = self._get_unique_job_name(
transform_config["TransformJobName"],
self.action_if_job_exists == "fail",
self.hook.describe_transform_job,
)

model_config = self.config.get("Model")
if model_config:
self.log.info("Creating SageMaker Model %s for transform job", model_config["ModelName"])
self.hook.create_model(model_config)

self.log.info("Creating SageMaker transform Job %s.", transform_config["TransformJobName"])
response = self.hook.create_transform_job(
transform_config,
Expand All @@ -500,21 +549,6 @@ def execute(self, context: Context) -> dict:
),
}

def _check_if_transform_job_exists(self) -> None:
transform_config = self.config.get("Transform", self.config)
transform_job_name = transform_config["TransformJobName"]
transform_jobs = self.hook.list_transform_jobs(name_contains=transform_job_name)
if transform_job_name in [tj["TransformJobName"] for tj in transform_jobs]:
if self.action_if_job_exists == "increment":
self.log.info("Found existing transform job with name '%s'.", transform_job_name)
new_transform_job_name = f"{transform_job_name}-{(len(transform_jobs) + 1)}"
transform_config["TransformJobName"] = new_transform_job_name
self.log.info("Incremented transform job name to '%s'.", new_transform_job_name)
elif self.action_if_job_exists == "fail":
raise AirflowException(
f"A SageMaker transform job with name {transform_job_name} already exists."
)


class SageMakerTuningOperator(SageMakerBaseOperator):
"""
Expand Down Expand Up @@ -654,8 +688,8 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
the operation does not timeout.
:param check_if_job_exists: If set to true, then the operator will check whether a training job
already exists for the name in the config.
:param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment"
(default) and "fail".
:param action_if_job_exists: Behaviour if the job name already exists. Possible options are "timestamp"
(default), "increment" (deprecated) and "fail".
Comment thread
vandonr-amz marked this conversation as resolved.
This is only relevant if check_if_job_exists is True.
:return Dict: Returns The ARN of the training job created in Amazon SageMaker.
"""
Expand All @@ -670,7 +704,7 @@ def __init__(
check_interval: int = CHECK_INTERVAL_SECOND,
max_ingestion_time: int | None = None,
check_if_job_exists: bool = True,
action_if_job_exists: str = "increment",
action_if_job_exists: str = "timestamp",
**kwargs,
):
super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
Expand All @@ -679,11 +713,18 @@ def __init__(
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
self.check_if_job_exists = check_if_job_exists
if action_if_job_exists in ("increment", "fail"):
if action_if_job_exists in {"timestamp", "increment", "fail"}:
if action_if_job_exists == "increment":
warnings.warn(
"Action 'increment' on job name conflict has been deprecated for performance reasons."
"The alternative to 'fail' is now 'timestamp'.",
DeprecationWarning,
stacklevel=2,
)
self.action_if_job_exists = action_if_job_exists
else:
raise AirflowException(
f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \
f"Argument action_if_job_exists accepts only 'timestamp', 'increment' and 'fail'. \
Provided value: '{action_if_job_exists}'."
)

Expand All @@ -703,8 +744,14 @@ def _create_integer_fields(self) -> None:

def execute(self, context: Context) -> dict:
self.preprocess_config()

if self.check_if_job_exists:
self._check_if_job_exists()
self.config["TrainingJobName"] = self._get_unique_job_name(
self.config["TrainingJobName"],
self.action_if_job_exists == "fail",
self.hook.describe_training_job,
)

self.log.info("Creating SageMaker training job %s.", self.config["TrainingJobName"])
response = self.hook.create_training_job(
self.config,
Expand All @@ -718,20 +765,6 @@ def execute(self, context: Context) -> dict:
else:
return {"Training": serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))}

def _check_if_job_exists(self) -> None:
training_job_name = self.config["TrainingJobName"]
training_jobs = self.hook.list_training_jobs(name_contains=training_job_name)
if training_job_name in [tj["TrainingJobName"] for tj in training_jobs]:
if self.action_if_job_exists == "increment":
self.log.info("Found existing training job with name '%s'.", training_job_name)
new_training_job_name = f"{training_job_name}-{(len(training_jobs) + 1)}"
self.config["TrainingJobName"] = new_training_job_name
self.log.info("Incremented training job name to '%s'.", new_training_job_name)
elif self.action_if_job_exists == "fail":
raise AirflowException(
f"A SageMaker training job with name {training_job_name} already exists."
)


class SageMakerDeleteModelOperator(SageMakerBaseOperator):
"""
Expand Down
36 changes: 33 additions & 3 deletions tests/providers/amazon/aws/operators/test_sagemaker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
# under the License.
from __future__ import annotations

import re
from typing import Any
from unittest import mock
from unittest.mock import patch
from unittest.mock import MagicMock, patch

from airflow import DAG
import pytest
from botocore.exceptions import ClientError

from airflow import DAG, AirflowException
from airflow.models import DagRun, TaskInstance
from airflow.providers.amazon.aws.operators.sagemaker import (
SageMakerBaseOperator,
Expand All @@ -35,6 +39,8 @@


class TestSageMakerBaseOperator:
ERROR_WHEN_RESOURCE_NOT_FOUND = ClientError({"Error": {"Code": "ValidationException"}}, "op")

def setup_method(self):
self.sagemaker = SageMakerBaseOperator(task_id="test_sagemaker_operator", config=CONFIG)
self.sagemaker.aws_conn_id = "aws_default"
Expand All @@ -46,9 +52,33 @@ def test_parse_integer(self):

def test_default_integer_fields(self):
self.sagemaker.preprocess_config()

assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS

def test_job_exists(self):
exists = self.sagemaker._check_if_job_exists("the name", lambda _: {})
assert exists

def test_job_does_not_exists(self):
def raiser(_):
raise self.ERROR_WHEN_RESOURCE_NOT_FOUND

exists = self.sagemaker._check_if_job_exists("the name", raiser)
assert not exists

def test_job_renamed(self):
describe_mock = MagicMock()
# scenario : name exists, new proposed name exists as well, second proposal is ok
describe_mock.side_effect = [None, None, self.ERROR_WHEN_RESOURCE_NOT_FOUND]

name = self.sagemaker._get_unique_job_name("test", False, describe_mock)

assert describe_mock.call_count == 3
assert re.match("test-[0-9]+$", name)

def test_job_not_unique_with_fail(self):
with pytest.raises(AirflowException):
self.sagemaker._get_unique_job_name("test", True, lambda _: None)


class TestSageMakerExperimentOperator:
@patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.conn", new_callable=mock.PropertyMock)
Expand Down
Loading