Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions airflow/providers/amazon/aws/hooks/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,13 @@ def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = self.client_type
super().__init__(*args, **kwargs)

def create_cluster(self, name: str, roleArn: str, resourcesVpcConfig: Dict, **kwargs) -> Dict:
def create_cluster(
self,
name: str,
roleArn: str,
resourcesVpcConfig: Dict,
**kwargs,
) -> Dict:
"""
Creates an Amazon EKS control plane.

Expand All @@ -117,7 +123,14 @@ def create_cluster(self, name: str, roleArn: str, resourcesVpcConfig: Dict, **kw
return response

def create_nodegroup(
self, clusterName: str, nodegroupName: str, subnets: List[str], nodeRole: str, **kwargs
self,
clusterName: str,
nodegroupName: str,
subnets: List[str],
nodeRole: str,
*,
tags: Optional[Dict] = None,
**kwargs,
) -> Dict:
"""
Creates an Amazon EKS managed node group for an Amazon EKS Cluster.
Expand All @@ -130,25 +143,28 @@ def create_nodegroup(
:type subnets: List[str]
:param nodeRole: The Amazon Resource Name (ARN) of the IAM role to associate with your nodegroup.
:type nodeRole: str
:param tags: Optional tags to apply to your nodegroup.
:type tags: Dict

:return: Returns descriptive information about the created EKS Managed Nodegroup.
:rtype: Dict
"""
eks_client = self.conn

# The below tag is mandatory and must have a value of either 'owned' or 'shared'
# A value of 'owned' denotes that the subnets are exclusive to the nodegroup.
# The 'shared' value allows more than one resource to use the subnet.
tags = {'kubernetes.io/cluster/' + clusterName: 'owned'}
if "tags" in kwargs:
tags = {**tags, **kwargs["tags"]}
kwargs.pop("tags")
cluster_tag_key = f'kubernetes.io/cluster/{clusterName}'
resolved_tags = tags or {}
if cluster_tag_key not in resolved_tags:
resolved_tags[cluster_tag_key] = 'owned'

response = eks_client.create_nodegroup(
clusterName=clusterName,
nodegroupName=nodegroupName,
subnets=subnets,
nodeRole=nodeRole,
tags=tags,
tags=resolved_tags,
**kwargs,
)

Expand All @@ -160,7 +176,12 @@ def create_nodegroup(
return response

def create_fargate_profile(
self, clusterName: str, fargateProfileName: str, podExecutionRoleArn: str, selectors: List, **kwargs
self,
clusterName: str,
fargateProfileName: str,
podExecutionRoleArn: str,
selectors: List,
**kwargs,
) -> Dict:
"""
Creates an AWS Fargate profile for an Amazon EKS cluster.
Expand Down
53 changes: 43 additions & 10 deletions airflow/providers/amazon/aws/operators/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class EksCreateClusterOperator(BaseOperator):
:param compute: The type of compute architecture to generate along with the cluster. (templated)
Defaults to 'nodegroup' to generate an EKS Managed Nodegroup.
:type compute: str
:param create_cluster_kwargs: Optional parameters to pass to the CreateCluster API (templated)
:type: Dict
:param aws_conn_id: The Airflow connection used for AWS credentials. (templated)
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
Expand All @@ -89,36 +91,45 @@ class EksCreateClusterOperator(BaseOperator):
If this is None or empty then the default boto3 behaviour is used.
:type region: str

If compute is assigned the value of 'nodegroup', the following are required:
If compute is assigned the value of 'nodegroup':
Comment thread
ferruzzi marked this conversation as resolved.

:param nodegroup_name: The unique name to give your Amazon EKS managed node group. (templated)
:param nodegroup_name: *REQUIRED* The unique name to give your Amazon EKS managed node group. (templated)
:type nodegroup_name: str
:param nodegroup_role_arn: The Amazon Resource Name (ARN) of the IAM role to associate with the
Amazon EKS managed node group. (templated)
:param nodegroup_role_arn: *REQUIRED* The Amazon Resource Name (ARN) of the IAM role to associate with
the Amazon EKS managed node group. (templated)
:type nodegroup_role_arn: str
:param create_nodegroup_kwargs: Optional parameters to pass to the CreateNodegroup API (templated)
:type: Dict

If compute is assigned the value of 'fargate', the following are required:

:param fargate_profile_name: The unique name to give your AWS Fargate profile. (templated)
If compute is assigned the value of 'fargate':

:param fargate_profile_name: *REQUIRED* The unique name to give your AWS Fargate profile. (templated)
:type fargate_profile_name: str
:param fargate_pod_execution_role_arn: The Amazon Resource Name (ARN) of the pod execution role to
use for pods that match the selectors in the AWS Fargate profile. (templated)
:param fargate_pod_execution_role_arn: *REQUIRED* The Amazon Resource Name (ARN) of the pod execution
role to use for pods that match the selectors in the AWS Fargate profile. (templated)
:type podExecutionRoleArn: str
:param selectors: The selectors to match for pods to use this AWS Fargate profile. (templated)
:type selectors: List
:param fargate_selectors: The selectors to match for pods to use this AWS Fargate profile. (templated)
:type fargate_selectors: List
:param create_fargate_profile_kwargs: Optional parameters to pass to the CreateFargateProfile API
(templated)
:type: Dict

"""

template_fields: Sequence[str] = (
"cluster_name",
"cluster_role_arn",
"resources_vpc_config",
"create_cluster_kwargs",
"compute",
"nodegroup_name",
"nodegroup_role_arn",
"create_nodegroup_kwargs",
"fargate_profile_name",
"fargate_pod_execution_role_arn",
"fargate_selectors",
"create_fargate_profile_kwargs",
Comment thread
ferruzzi marked this conversation as resolved.
"aws_conn_id",
"region",
)
Expand All @@ -129,11 +140,14 @@ def __init__(
cluster_role_arn: str,
resources_vpc_config: Dict,
compute: Optional[str] = DEFAULT_COMPUTE_TYPE,
create_cluster_kwargs: Optional[Dict] = None,
nodegroup_name: Optional[str] = DEFAULT_NODEGROUP_NAME,
nodegroup_role_arn: Optional[str] = None,
create_nodegroup_kwargs: Optional[Dict] = None,
fargate_profile_name: Optional[str] = DEFAULT_FARGATE_PROFILE_NAME,
fargate_pod_execution_role_arn: Optional[str] = None,
fargate_selectors: Optional[List] = None,
create_fargate_profile_kwargs: Optional[Dict] = None,
aws_conn_id: str = DEFAULT_CONN_ID,
region: Optional[str] = None,
**kwargs,
Expand All @@ -156,11 +170,14 @@ def __init__(
self.cluster_name = cluster_name
self.cluster_role_arn = cluster_role_arn
self.resources_vpc_config = resources_vpc_config
self.create_cluster_kwargs = create_cluster_kwargs or {}
self.nodegroup_name = nodegroup_name
self.nodegroup_role_arn = nodegroup_role_arn
self.create_nodegroup_kwargs = create_nodegroup_kwargs or {}
self.fargate_profile_name = fargate_profile_name
self.fargate_pod_execution_role_arn = fargate_pod_execution_role_arn
self.fargate_selectors = fargate_selectors or [{"namespace": DEFAULT_NAMESPACE_NAME}]
self.create_fargate_profile_kwargs = create_fargate_profile_kwargs or {}
self.aws_conn_id = aws_conn_id
self.region = region
super().__init__(**kwargs)
Expand All @@ -175,6 +192,7 @@ def execute(self, context: 'Context'):
name=self.cluster_name,
roleArn=self.cluster_role_arn,
resourcesVpcConfig=self.resources_vpc_config,
**self.create_cluster_kwargs,
)

if not self.compute:
Expand Down Expand Up @@ -206,13 +224,15 @@ def execute(self, context: 'Context'):
nodegroupName=self.nodegroup_name,
subnets=self.resources_vpc_config.get('subnetIds'),
nodeRole=self.nodegroup_role_arn,
**self.create_nodegroup_kwargs,
)
elif self.compute == 'fargate':
eks_hook.create_fargate_profile(
clusterName=self.cluster_name,
fargateProfileName=self.fargate_profile_name,
podExecutionRoleArn=self.fargate_pod_execution_role_arn,
selectors=self.fargate_selectors,
**self.create_fargate_profile_kwargs,
)


Expand All @@ -234,6 +254,8 @@ class EksCreateNodegroupOperator(BaseOperator):
:param nodegroup_role_arn:
The Amazon Resource Name (ARN) of the IAM role to associate with the managed nodegroup. (templated)
:type nodegroup_role_arn: str
:param create_nodegroup_kwargs: Optional parameters to pass to the Create Nodegroup API (templated)
:type: Dict
:param aws_conn_id: The Airflow connection used for AWS credentials. (templated)
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
Expand All @@ -251,6 +273,7 @@ class EksCreateNodegroupOperator(BaseOperator):
"nodegroup_subnets",
"nodegroup_role_arn",
"nodegroup_name",
"create_nodegroup_kwargs",
"aws_conn_id",
"region",
)
Expand All @@ -261,6 +284,7 @@ def __init__(
nodegroup_subnets: List[str],
nodegroup_role_arn: str,
nodegroup_name: Optional[str] = DEFAULT_NODEGROUP_NAME,
create_nodegroup_kwargs: Optional[Dict] = None,
aws_conn_id: str = DEFAULT_CONN_ID,
region: Optional[str] = None,
**kwargs,
Expand All @@ -269,6 +293,7 @@ def __init__(
self.nodegroup_subnets = nodegroup_subnets
self.nodegroup_role_arn = nodegroup_role_arn
self.nodegroup_name = nodegroup_name
self.create_nodegroup_kwargs = create_nodegroup_kwargs or {}
self.aws_conn_id = aws_conn_id
self.region = region
super().__init__(**kwargs)
Expand All @@ -284,6 +309,7 @@ def execute(self, context: 'Context'):
nodegroupName=self.nodegroup_name,
subnets=self.nodegroup_subnets,
nodeRole=self.nodegroup_role_arn,
**self.create_nodegroup_kwargs,
)


Expand All @@ -304,6 +330,9 @@ class EksCreateFargateProfileOperator(BaseOperator):
:type selectors: List
:param fargate_profile_name: The unique name to give your AWS Fargate profile. (templated)
:type fargate_profile_name: str
:param create_fargate_profile_kwargs: Optional parameters to pass to the CreateFargate Profile API
(templated)
:type: Dict

:param aws_conn_id: The Airflow connection used for AWS credentials. (templated)
If this is None or empty then the default boto3 behaviour is used. If
Expand All @@ -321,6 +350,7 @@ class EksCreateFargateProfileOperator(BaseOperator):
"pod_execution_role_arn",
"selectors",
"fargate_profile_name",
"create_fargate_profile_kwargs",
"aws_conn_id",
"region",
)
Expand All @@ -331,6 +361,7 @@ def __init__(
pod_execution_role_arn: str,
selectors: List,
fargate_profile_name: Optional[str] = DEFAULT_FARGATE_PROFILE_NAME,
create_fargate_profile_kwargs: Optional[Dict] = None,
aws_conn_id: str = DEFAULT_CONN_ID,
region: Optional[str] = None,
**kwargs,
Expand All @@ -339,6 +370,7 @@ def __init__(
self.pod_execution_role_arn = pod_execution_role_arn
self.selectors = selectors
self.fargate_profile_name = fargate_profile_name
self.create_fargate_profile_kwargs = create_fargate_profile_kwargs or {}
self.aws_conn_id = aws_conn_id
self.region = region
super().__init__(**kwargs)
Expand All @@ -354,6 +386,7 @@ def execute(self, context: 'Context'):
fargateProfileName=self.fargate_profile_name,
podExecutionRoleArn=self.pod_execution_role_arn,
selectors=self.selectors,
**self.create_fargate_profile_kwargs,
)


Expand Down
35 changes: 34 additions & 1 deletion tests/providers/amazon/aws/hooks/test_eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
INSTANCE_TYPES,
LAUNCH_TEMPLATE,
MAX_FARGATE_LABELS,
NODEGROUP_OWNERSHIP_TAG_DEFAULT_VALUE,
NODEGROUP_OWNERSHIP_TAG_KEY,
NON_EXISTING_CLUSTER_NAME,
NON_EXISTING_FARGATE_PROFILE_NAME,
NON_EXISTING_NODEGROUP_NAME,
Expand Down Expand Up @@ -559,6 +561,30 @@ def test_create_nodegroup_saves_provided_parameters(self, nodegroup_builder) ->
for key, expected_value in generated_test_data.attributes_to_test:
assert generated_test_data.nodegroup_describe_output[key] == expected_value

def test_create_nodegroup_without_tags_uses_default(self, nodegroup_builder) -> None:
_, generated_test_data = nodegroup_builder()
tag_list: Dict = generated_test_data.nodegroup_describe_output[NodegroupAttributes.TAGS]
ownership_tag_key: str = NODEGROUP_OWNERSHIP_TAG_KEY.format(
cluster_name=generated_test_data.cluster_name
)

assert tag_list.get(ownership_tag_key) == NODEGROUP_OWNERSHIP_TAG_DEFAULT_VALUE

def test_create_nodegroup_with_ownership_tag_uses_provided_value(self, cluster_builder) -> None:
eks_hook, generated_test_data = cluster_builder()
cluster_name: str = generated_test_data.existing_cluster_name
ownership_tag_key: str = NODEGROUP_OWNERSHIP_TAG_KEY.format(cluster_name=cluster_name)
provided_tag_value: str = "shared"

created_nodegroup: Dict = eks_hook.create_nodegroup(
clusterName=cluster_name,
nodegroupName="nodegroup",
tags={ownership_tag_key: provided_tag_value},
**dict(deepcopy(NodegroupInputs.REQUIRED)),
)[ResponseAttributes.NODEGROUP]

assert created_nodegroup.get(NodegroupAttributes.TAGS).get(ownership_tag_key) == provided_tag_value

def test_describe_nodegroup_throws_exception_when_cluster_not_found(self, nodegroup_builder) -> None:
eks_hook, generated_test_data = nodegroup_builder()
expected_exception: Type[AWSError] = ResourceNotFoundException
Expand Down Expand Up @@ -746,7 +772,14 @@ def test_create_nodegroup_handles_launch_template_combinations(
if expected_result == PossibleTestResults.SUCCESS:
result: Dict = eks_hook.create_nodegroup(**test_inputs)[ResponseAttributes.NODEGROUP]

for key, expected_value in test_inputs.items():
expected_output = deepcopy(test_inputs)
# The Create Nodegroup hook magically adds the required
# cluster/owned tag, so add that to the expected outputs.
expected_output['tags'] = {
f'kubernetes.io/cluster/{generated_test_data.existing_cluster_name}': 'owned'
}

for key, expected_value in expected_output.items():
assert result[key] == expected_value
else:
if launch_template and disk_size:
Expand Down
Loading