Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions airflow/providers/amazon/aws/operators/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,141 @@ def execute(self, context: Context):
target_state="stopped",
check_interval=self.check_interval,
)


class EC2CreateInstanceOperator(BaseOperator):
"""
Create and start a specified number of EC2 Instances using boto3

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:EC2CreateInstanceOperator`

:param image_id: ID of the AMI used to create the instance.
:param max_count: Maximum number of instances to launch. Defaults to 1.
:param min_count: Minimum number of instances to launch. Defaults to 1.
:param aws_conn_id: AWS connection to use
:param region_name: AWS region name associated with the client.
:param poll_interval: Number of seconds to wait before attempting to
check state of instance. Only used if wait_for_completion is True. Default is 20.
:param max_attempts: Maximum number of attempts when checking state of instance.
Only used if wait_for_completion is True. Default is 20.
:param config: Dictionary for arbitrary parameters to the boto3 run_instances call.
:param wait_for_completion: If True, the operator will wait for the instance to be
in the `running` state before returning.
"""

template_fields: Sequence[str] = (
"image_id",
"max_count",
"min_count",
"aws_conn_id",
"region_name",
"config",
"wait_for_completion",
Comment thread
josh-fell marked this conversation as resolved.
Outdated
)

def __init__(
self,
image_id: str,
max_count: int = 1,
min_count: int = 1,
aws_conn_id: str = "aws_default",
region_name: str | None = None,
poll_interval: int = 20,
max_attempts: int = 20,
config: dict | None = None,
wait_for_completion: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.image_id = image_id
self.max_count = max_count
self.min_count = min_count
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.poll_interval = poll_interval
self.max_attempts = max_attempts
self.config = config or {}
self.wait_for_completion = wait_for_completion

def execute(self, context: Context):
ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type")
instances = ec2_hook.conn.run_instances(
ImageId=self.image_id,
MinCount=self.min_count,
MaxCount=self.max_count,
**self.config,
)["Instances"]
instance_ids = []
for instance in instances:
instance_ids.append(instance["InstanceId"])
self.log.info("Created EC2 instance %s", instance["InstanceId"])

if self.wait_for_completion:
ec2_hook.get_waiter("instance_running").wait(
InstanceIds=[instance["InstanceId"]],
WaiterConfig={
"Delay": self.poll_interval,
"MaxAttempts": self.max_attempts,
},
)

return instance_ids


class EC2TerminateInstanceOperator(BaseOperator):
"""
Terminate EC2 Instances using boto3

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:EC2TerminateInstanceOperator`

:param instance_id: ID of the instance to be terminated.
:param aws_conn_id: AWS connection to use
:param region_name: AWS region name associated with the client.
:param poll_interval: Number of seconds to wait before attempting to
check state of instance. Only used if wait_for_completion is True. Default is 20.
:param max_attempts: Maximum number of attempts when checking state of instance.
Only used if wait_for_completion is True. Default is 20.
:param wait_for_completion: If True, the operator will wait for the instance to be
in the `terminated` state before returning.
"""

template_fields: Sequence[str] = ("instance_ids", "region_name", "aws_conn_id", "wait_for_completion")

def __init__(
self,
instance_ids: str | list[str],
aws_conn_id: str = "aws_default",
region_name: str | None = None,
poll_interval: int = 20,
max_attempts: int = 20,
wait_for_completion: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.instance_ids = instance_ids
self.aws_conn_id = aws_conn_id
self.region_name = region_name
self.poll_interval = poll_interval
self.max_attempts = max_attempts
self.wait_for_completion = wait_for_completion

def execute(self, context: Context):
if isinstance(self.instance_ids, str):
self.instance_ids = [self.instance_ids]
ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name, api_type="client_type")
ec2_hook.conn.terminate_instances(InstanceIds=self.instance_ids)

for instance_id in self.instance_ids:
self.log.info("Terminating EC2 instance %s", instance_id)
if self.wait_for_completion:
ec2_hook.get_waiter("instance_terminated").wait(
InstanceIds=[instance_id],
WaiterConfig={
"Delay": self.poll_interval,
"MaxAttempts": self.max_attempts,
},
)
28 changes: 28 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/ec2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,34 @@ To stop an Amazon EC2 instance you can use
:start-after: [START howto_operator_ec2_stop_instance]
:end-before: [END howto_operator_ec2_stop_instance]

.. _howto/operator:EC2CreateInstanceOperator:

Create and start an Amazon EC2 instance
=======================================

To create and start an Amazon EC2 instance you can use
:class:`~airflow.providers.amazon.aws.operators.ec2.EC2CreateInstanceOperator`.

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ec2.py
:language: python
:dedent: 4
:start-after: [START howto_operator_ec2_create_instance]
:end-before: [END howto_operator_ec2_create_instance]

.. _howto/operator:EC2TerminateInstanceOperator:

Terminate an Amazon EC2 instance
================================

To terminate an Amazon EC2 instance you can use
:class:`~airflow.providers.amazon.aws.operators.ec2.EC2TerminateInstanceOperator`.

.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ec2.py
:language: python
:dedent: 4
:start-after: [START howto_operator_ec2_terminate_instance]
:end-before: [END howto_operator_ec2_terminate_instance]

Sensors
-------

Expand Down
130 changes: 118 additions & 12 deletions tests/providers/amazon/aws/operators/test_ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,121 @@
from moto import mock_ec2

from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
from airflow.providers.amazon.aws.operators.ec2 import EC2StartInstanceOperator, EC2StopInstanceOperator
from airflow.providers.amazon.aws.operators.ec2 import (
EC2CreateInstanceOperator,
EC2StartInstanceOperator,
EC2StopInstanceOperator,
EC2TerminateInstanceOperator,
)


class BaseEc2TestClass:
@classmethod
def _create_instance(cls, hook: EC2Hook):
"""Create Instance and return instance id."""
def _get_image_id(cls, hook):
"""Get a valid image id to create an instance."""
conn = hook.get_conn()
try:
ec2_client = conn.meta.client
except AttributeError:
ec2_client = conn

# We need existed AMI Image ID otherwise `moto` will raise DeprecationWarning.
# We need an existing AMI Image ID otherwise `moto` will raise DeprecationWarning.
images = ec2_client.describe_images()["Images"]
response = ec2_client.run_instances(MaxCount=1, MinCount=1, ImageId=images[0]["ImageId"])
return response["Instances"][0]["InstanceId"]
return images[0]["ImageId"]


class TestEC2CreateInstanceOperator(BaseEc2TestClass):
Comment thread
josh-fell marked this conversation as resolved.
Outdated
def test_init(self):
ec2_operator = EC2CreateInstanceOperator(
task_id="test_create_instance",
image_id="test_image_id",
)

assert ec2_operator.task_id == "test_create_instance"
assert ec2_operator.image_id == "test_image_id"
assert ec2_operator.max_count == 1
assert ec2_operator.min_count == 1
assert ec2_operator.max_attempts == 20
assert ec2_operator.poll_interval == 20

@mock_ec2
def test_create_instance(self):
ec2_hook = EC2Hook()
create_instance = EC2CreateInstanceOperator(
image_id=self._get_image_id(ec2_hook),
task_id="test_create_instance",
)
instance_id = create_instance.execute(None)

assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running"

@mock_ec2
def test_create_multiple_instances(self):
ec2_hook = EC2Hook()
create_instances = EC2CreateInstanceOperator(
task_id="test_create_multiple_instances",
image_id=self._get_image_id(hook=ec2_hook),
min_count=5,
max_count=5,
)
instance_ids = create_instances.execute(None)
assert len(instance_ids) == 5

for id in instance_ids:
assert ec2_hook.get_instance_state(instance_id=id) == "running"


class TestEC2TerminateInstanceOperator(BaseEc2TestClass):
Comment thread
josh-fell marked this conversation as resolved.
Outdated
def test_init(self):
ec2_operator = EC2TerminateInstanceOperator(
task_id="test_terminate_instance",
instance_ids="test_image_id",
)

assert ec2_operator.task_id == "test_terminate_instance"
assert ec2_operator.max_attempts == 20
assert ec2_operator.poll_interval == 20

@mock_ec2
def test_terminate_instance(self):
ec2_hook = EC2Hook()

create_instance = EC2CreateInstanceOperator(
image_id=self._get_image_id(ec2_hook),
task_id="test_create_instance",
)
instance_id = create_instance.execute(None)

assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running"

terminate_instance = EC2TerminateInstanceOperator(
task_id="test_terminate_instance", instance_ids=instance_id
)
terminate_instance.execute(None)

assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "terminated"

@mock_ec2
def test_terminate_multiple_instances(self):
ec2_hook = EC2Hook()
create_instances = EC2CreateInstanceOperator(
task_id="test_create_multiple_instances",
image_id=self._get_image_id(hook=ec2_hook),
min_count=5,
max_count=5,
)
instance_ids = create_instances.execute(None)
assert len(instance_ids) == 5

for id in instance_ids:
assert ec2_hook.get_instance_state(instance_id=id) == "running"

terminate_instance = EC2TerminateInstanceOperator(
task_id="test_terminate_instance", instance_ids=instance_ids
)
terminate_instance.execute(None)
for id in instance_ids:
assert ec2_hook.get_instance_state(instance_id=id) == "terminated"


class TestEC2StartInstanceOperator(BaseEc2TestClass):
Expand All @@ -58,16 +156,20 @@ def test_init(self):
def test_start_instance(self):
# create instance
ec2_hook = EC2Hook()
instance_id = self._create_instance(ec2_hook)
create_instance = EC2CreateInstanceOperator(
image_id=self._get_image_id(ec2_hook),
task_id="test_create_instance",
)
instance_id = create_instance.execute(None)

# start instance
start_test = EC2StartInstanceOperator(
task_id="start_test",
instance_id=instance_id,
instance_id=instance_id[0],
)
start_test.execute(None)
# assert instance state is running
assert ec2_hook.get_instance_state(instance_id=instance_id) == "running"
assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "running"


class TestEC2StopInstanceOperator(BaseEc2TestClass):
Expand All @@ -89,13 +191,17 @@ def test_init(self):
def test_stop_instance(self):
# create instance
ec2_hook = EC2Hook()
instance_id = self._create_instance(ec2_hook)
create_instance = EC2CreateInstanceOperator(
image_id=self._get_image_id(ec2_hook),
task_id="test_create_instance",
)
instance_id = create_instance.execute(None)

# stop instance
stop_test = EC2StopInstanceOperator(
task_id="stop_test",
instance_id=instance_id,
instance_id=instance_id[0],
)
stop_test.execute(None)
# assert instance state is running
assert ec2_hook.get_instance_state(instance_id=instance_id) == "stopped"
assert ec2_hook.get_instance_state(instance_id=instance_id[0]) == "stopped"
Loading