diff --git a/UPDATING.md b/UPDATING.md
index a09271f909f97..6538a71745f1b 100644
--- a/UPDATING.md
+++ b/UPDATING.md
@@ -52,6 +52,7 @@ Migrated are:
|-----------------------------------------------------------------|----------------------------------------------------------|
| airflow.contrib.hooks.aws_athena_hook.AWSAthenaHook | airflow.providers.aws.hooks.athena.AWSAthenaHook |
| airflow.contrib.operators.aws_athena_operator.AWSAthenaOperator | airflow.providers.aws.operators.athena.AWSAthenaOperator |
+| airflow.contrib.operators.awsbatch_operator.AWSBatchOperator | airflow.providers.aws.operators.batch.AWSBatchOperator |
| airflow.contrib.sensors.aws_athena_sensor.AthenaSensor | airflow.providers.aws.sensors.athena.AthenaSensor |
### Additional arguments passed to BaseOperator cause an exception
diff --git a/airflow/contrib/operators/awsbatch_operator.py b/airflow/contrib/operators/awsbatch_operator.py
index 22508d17c56a2..b0955d704d784 100644
--- a/airflow/contrib/operators/awsbatch_operator.py
+++ b/airflow/contrib/operators/awsbatch_operator.py
@@ -16,194 +16,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
-import sys
-from math import pow
-from random import randint
-from time import sleep
-from typing import Optional
-
-from airflow.contrib.hooks.aws_hook import AwsHook
-from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator
-from airflow.typing import Protocol
-from airflow.utils.decorators import apply_defaults
-
-
-class BatchProtocol(Protocol):
- def submit_job(self, jobName, jobQueue, jobDefinition, containerOverrides):
- ...
-
- def get_waiter(self, x: str):
- ...
-
- def describe_jobs(self, jobs):
- ...
-
- def terminate_job(self, jobId: str, reason: str):
- ...
-
-
-class AWSBatchOperator(BaseOperator):
- """
- Execute a job on AWS Batch Service
-
- .. warning: the queue parameter was renamed to job_queue to segregate the
- internal CeleryExecutor queue from the AWS Batch internal queue.
-
- :param job_name: the name for the job that will run on AWS Batch (templated)
- :type job_name: str
- :param job_definition: the job definition name on AWS Batch
- :type job_definition: str
- :param job_queue: the queue name on AWS Batch
- :type job_queue: str
- :param overrides: the same parameter that boto3 will receive on
- containerOverrides (templated):
- http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job
- :type overrides: dict
- :param array_properties: the same parameter that boto3 will receive on
- arrayProperties:
- http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job
- :type array_properties: dict
- :param max_retries: exponential backoff retries while waiter is not
- merged, 4200 = 48 hours
- :type max_retries: int
- :param aws_conn_id: connection id of AWS credentials / region name. If None,
- credential boto3 strategy will be used
- (http://boto3.readthedocs.io/en/latest/guide/configuration.html).
- :type aws_conn_id: str
- :param region_name: region name to use in AWS Hook.
- Override the region_name in connection (if provided)
- :type region_name: str
- """
-
- ui_color = '#c3dae0'
- client = None # type: Optional[BatchProtocol]
- arn = None # type: Optional[str]
- template_fields = ('job_name', 'overrides',)
-
- @apply_defaults
- def __init__(self, job_name, job_definition, job_queue, overrides, array_properties=None,
- max_retries=4200, aws_conn_id=None, region_name=None, **kwargs):
- super().__init__(**kwargs)
-
- self.job_name = job_name
- self.aws_conn_id = aws_conn_id
- self.region_name = region_name
- self.job_definition = job_definition
- self.job_queue = job_queue
- self.overrides = overrides
- self.array_properties = array_properties
- self.max_retries = max_retries
-
- self.jobId = None # pylint: disable=invalid-name
- self.jobName = None # pylint: disable=invalid-name
-
- self.hook = self.get_hook()
-
- def execute(self, context):
- self.log.info(
- 'Running AWS Batch Job - Job definition: %s - on queue %s',
- self.job_definition, self.job_queue
- )
- self.log.info('AWSBatchOperator overrides: %s', self.overrides)
-
- self.client = self.hook.get_client_type(
- 'batch',
- region_name=self.region_name
- )
-
- try:
- response = self.client.submit_job(
- jobName=self.job_name,
- jobQueue=self.job_queue,
- jobDefinition=self.job_definition,
- arrayProperties=self.array_properties,
- containerOverrides=self.overrides)
-
- self.log.info('AWS Batch Job started: %s', response)
-
- self.jobId = response['jobId']
- self.jobName = response['jobName']
-
- self._wait_for_task_ended()
-
- self._check_success_task()
-
- self.log.info('AWS Batch Job has been successfully executed: %s', response)
- except Exception as e:
- self.log.info('AWS Batch Job has failed executed')
- raise AirflowException(e)
-
- def _wait_for_task_ended(self):
- """
- Try to use a waiter from the below pull request
-
- * https://github.com/boto/botocore/pull/1307
-
- If the waiter is not available apply a exponential backoff
-
- * docs.aws.amazon.com/general/latest/gr/api-retries.html
- """
- try:
- waiter = self.client.get_waiter('job_execution_complete')
- waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
- waiter.wait(jobs=[self.jobId])
- except ValueError:
- # If waiter not available use expo
-
- # Allow a batch job some time to spin up. A random interval
- # decreases the chances of exceeding an AWS API throttle
- # limit when there are many concurrent tasks.
- pause = randint(5, 30)
-
- retries = 1
- while retries <= self.max_retries:
- self.log.info('AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds',
- self.jobId, retries, self.max_retries, pause)
- sleep(pause)
-
- response = self.client.describe_jobs(jobs=[self.jobId])
- status = response['jobs'][-1]['status']
- self.log.info('AWS Batch job (%s) status: %s', self.jobId, status)
- if status in ['SUCCEEDED', 'FAILED']:
- break
-
- retries += 1
- pause = 1 + pow(retries * 0.3, 2)
-
- def _check_success_task(self):
- response = self.client.describe_jobs(
- jobs=[self.jobId],
- )
-
- self.log.info('AWS Batch stopped, check status: %s', response)
- if len(response.get('jobs')) < 1:
- raise AirflowException('No job found for {}'.format(response))
- for job in response['jobs']:
- job_status = job['status']
- if job_status == 'FAILED':
- reason = job['statusReason']
- raise AirflowException('Job failed with status {}'.format(reason))
- elif job_status in [
- 'SUBMITTED',
- 'PENDING',
- 'RUNNABLE',
- 'STARTING',
- 'RUNNING'
- ]:
- raise AirflowException(
- 'This task is still pending {}'.format(job_status))
+"""This module is deprecated. Please use `airflow.providers.aws.operators.batch`."""
- def get_hook(self):
- return AwsHook(
- aws_conn_id=self.aws_conn_id
- )
+import warnings
- def on_kill(self):
- response = self.client.terminate_job(
- jobId=self.jobId,
- reason='Task killed by the user')
+# pylint: disable=unused-import
+from airflow.providers.aws.operators.batch import AWSBatchOperator, BatchProtocol # noqa
- self.log.info(response)
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.aws.operators.batch`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/providers/aws/operators/batch.py b/airflow/providers/aws/operators/batch.py
new file mode 100644
index 0000000000000..22508d17c56a2
--- /dev/null
+++ b/airflow/providers/aws/operators/batch.py
@@ -0,0 +1,209 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import sys
+from math import pow
+from random import randint
+from time import sleep
+from typing import Optional
+
+from airflow.contrib.hooks.aws_hook import AwsHook
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.typing import Protocol
+from airflow.utils.decorators import apply_defaults
+
+
+class BatchProtocol(Protocol):
+ def submit_job(self, jobName, jobQueue, jobDefinition, containerOverrides):
+ ...
+
+ def get_waiter(self, x: str):
+ ...
+
+ def describe_jobs(self, jobs):
+ ...
+
+ def terminate_job(self, jobId: str, reason: str):
+ ...
+
+
+class AWSBatchOperator(BaseOperator):
+ """
+ Execute a job on AWS Batch Service
+
+ .. warning: the queue parameter was renamed to job_queue to segregate the
+ internal CeleryExecutor queue from the AWS Batch internal queue.
+
+ :param job_name: the name for the job that will run on AWS Batch (templated)
+ :type job_name: str
+ :param job_definition: the job definition name on AWS Batch
+ :type job_definition: str
+ :param job_queue: the queue name on AWS Batch
+ :type job_queue: str
+ :param overrides: the same parameter that boto3 will receive on
+ containerOverrides (templated):
+ http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job
+ :type overrides: dict
+ :param array_properties: the same parameter that boto3 will receive on
+ arrayProperties:
+ http://boto3.readthedocs.io/en/latest/reference/services/batch.html#submit_job
+ :type array_properties: dict
+ :param max_retries: exponential backoff retries while waiter is not
+ merged, 4200 = 48 hours
+ :type max_retries: int
+ :param aws_conn_id: connection id of AWS credentials / region name. If None,
+ credential boto3 strategy will be used
+ (http://boto3.readthedocs.io/en/latest/guide/configuration.html).
+ :type aws_conn_id: str
+ :param region_name: region name to use in AWS Hook.
+ Override the region_name in connection (if provided)
+ :type region_name: str
+ """
+
+ ui_color = '#c3dae0'
+ client = None # type: Optional[BatchProtocol]
+ arn = None # type: Optional[str]
+ template_fields = ('job_name', 'overrides',)
+
+ @apply_defaults
+ def __init__(self, job_name, job_definition, job_queue, overrides, array_properties=None,
+ max_retries=4200, aws_conn_id=None, region_name=None, **kwargs):
+ super().__init__(**kwargs)
+
+ self.job_name = job_name
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+ self.job_definition = job_definition
+ self.job_queue = job_queue
+ self.overrides = overrides
+ self.array_properties = array_properties
+ self.max_retries = max_retries
+
+ self.jobId = None # pylint: disable=invalid-name
+ self.jobName = None # pylint: disable=invalid-name
+
+ self.hook = self.get_hook()
+
+ def execute(self, context):
+ self.log.info(
+ 'Running AWS Batch Job - Job definition: %s - on queue %s',
+ self.job_definition, self.job_queue
+ )
+ self.log.info('AWSBatchOperator overrides: %s', self.overrides)
+
+ self.client = self.hook.get_client_type(
+ 'batch',
+ region_name=self.region_name
+ )
+
+ try:
+ response = self.client.submit_job(
+ jobName=self.job_name,
+ jobQueue=self.job_queue,
+ jobDefinition=self.job_definition,
+ arrayProperties=self.array_properties,
+ containerOverrides=self.overrides)
+
+ self.log.info('AWS Batch Job started: %s', response)
+
+ self.jobId = response['jobId']
+ self.jobName = response['jobName']
+
+ self._wait_for_task_ended()
+
+ self._check_success_task()
+
+ self.log.info('AWS Batch Job has been successfully executed: %s', response)
+ except Exception as e:
+ self.log.info('AWS Batch Job has failed executed')
+ raise AirflowException(e)
+
+ def _wait_for_task_ended(self):
+ """
+ Try to use a waiter from the below pull request
+
+ * https://github.com/boto/botocore/pull/1307
+
+ If the waiter is not available apply a exponential backoff
+
+ * docs.aws.amazon.com/general/latest/gr/api-retries.html
+ """
+ try:
+ waiter = self.client.get_waiter('job_execution_complete')
+ waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
+ waiter.wait(jobs=[self.jobId])
+ except ValueError:
+ # If waiter not available use expo
+
+ # Allow a batch job some time to spin up. A random interval
+ # decreases the chances of exceeding an AWS API throttle
+ # limit when there are many concurrent tasks.
+ pause = randint(5, 30)
+
+ retries = 1
+ while retries <= self.max_retries:
+ self.log.info('AWS Batch job (%s) status check (%d of %d) in the next %.2f seconds',
+ self.jobId, retries, self.max_retries, pause)
+ sleep(pause)
+
+ response = self.client.describe_jobs(jobs=[self.jobId])
+ status = response['jobs'][-1]['status']
+ self.log.info('AWS Batch job (%s) status: %s', self.jobId, status)
+ if status in ['SUCCEEDED', 'FAILED']:
+ break
+
+ retries += 1
+ pause = 1 + pow(retries * 0.3, 2)
+
+ def _check_success_task(self):
+ response = self.client.describe_jobs(
+ jobs=[self.jobId],
+ )
+
+ self.log.info('AWS Batch stopped, check status: %s', response)
+ if len(response.get('jobs')) < 1:
+ raise AirflowException('No job found for {}'.format(response))
+
+ for job in response['jobs']:
+ job_status = job['status']
+ if job_status == 'FAILED':
+ reason = job['statusReason']
+ raise AirflowException('Job failed with status {}'.format(reason))
+ elif job_status in [
+ 'SUBMITTED',
+ 'PENDING',
+ 'RUNNABLE',
+ 'STARTING',
+ 'RUNNING'
+ ]:
+ raise AirflowException(
+ 'This task is still pending {}'.format(job_status))
+
+ def get_hook(self):
+ return AwsHook(
+ aws_conn_id=self.aws_conn_id
+ )
+
+ def on_kill(self):
+ response = self.client.terminate_job(
+ jobId=self.jobId,
+ reason='Task killed by the user')
+
+ self.log.info(response)
diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst
index fc516f4115895..44c0259fa9c5a 100644
--- a/docs/operators-and-hooks-ref.rst
+++ b/docs/operators-and-hooks-ref.rst
@@ -309,9 +309,9 @@ These integrations allow you to perform various operations within the Amazon Web
- :mod:`airflow.providers.aws.operators.athena`
- :mod:`airflow.providers.aws.sensors.athena`
- * - `AWS Batch `__
+ * - `AWS Batch `__
-
- - :mod:`airflow.contrib.operators.awsbatch_operator`
+ - :mod:`airflow.providers.aws.operators.batch`
-
* - `Amazon CloudWatch Logs `__
diff --git a/scripts/ci/pylint_todo.txt b/scripts/ci/pylint_todo.txt
index cdbe04939ef47..6252c0c068b94 100644
--- a/scripts/ci/pylint_todo.txt
+++ b/scripts/ci/pylint_todo.txt
@@ -35,7 +35,6 @@
./airflow/contrib/hooks/vertica_hook.py
./airflow/contrib/hooks/wasb_hook.py
./airflow/contrib/operators/adls_list_operator.py
-./airflow/contrib/operators/awsbatch_operator.py
./airflow/contrib/operators/azure_container_instances_operator.py
./airflow/contrib/operators/azure_cosmos_operator.py
./airflow/operators/cassandra_to_gcs.py
@@ -201,6 +200,7 @@
./airflow/operators/subdag_operator.py
./airflow/plugins_manager.py
./airflow/providers/aws/operators/athena.py
+./airflow/providers/aws/operators/batch.py
./airflow/providers/aws/sensors/athena.py
./airflow/sensors/__init__.py
./airflow/sensors/base_sensor_operator.py
diff --git a/tests/contrib/operators/test_awsbatch_operator.py b/tests/providers/aws/operators/test_batch.py
similarity index 99%
rename from tests/contrib/operators/test_awsbatch_operator.py
rename to tests/providers/aws/operators/test_batch.py
index 8814857c14736..5978fa3516f25 100644
--- a/tests/contrib/operators/test_awsbatch_operator.py
+++ b/tests/providers/aws/operators/test_batch.py
@@ -21,8 +21,8 @@
import sys
import unittest
-from airflow.contrib.operators.awsbatch_operator import AWSBatchOperator
from airflow.exceptions import AirflowException
+from airflow.providers.aws.operators.batch import AWSBatchOperator
from tests.compat import mock
RESPONSE_WITHOUT_FAILURES = {
diff --git a/tests/test_core_to_contrib.py b/tests/test_core_to_contrib.py
index 1ddf0dc48569b..127743c379d1e 100644
--- a/tests/test_core_to_contrib.py
+++ b/tests/test_core_to_contrib.py
@@ -741,6 +741,10 @@
"airflow.providers.aws.operators.athena.AWSAthenaOperator",
"airflow.contrib.operators.aws_athena_operator.AWSAthenaOperator",
),
+ (
+ "airflow.providers.aws.operators.batch.AWSBatchOperator",
+ "airflow.contrib.operators.awsbatch_operator.AWSBatchOperator",
+ ),
]
SENSOR = [
(