From b31e2e7e9f732b52ed69383dabb99b37631066c9 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Thu, 1 Jun 2023 12:03:59 +0200 Subject: [PATCH 1/2] update _parse_from_uri and get_uri methods, and add tests for connection model --- airflow/models/connection.py | 35 +++++- tests/models/test_connection.py | 188 ++++++++++++++++++++++++++++++++ 2 files changed, 220 insertions(+), 3 deletions(-) create mode 100644 tests/models/test_connection.py diff --git a/airflow/models/connection.py b/airflow/models/connection.py index a5653412093e9..9966d8e02d488 100644 --- a/airflow/models/connection.py +++ b/airflow/models/connection.py @@ -187,10 +187,22 @@ def _normalize_conn_type(conn_type): return conn_type def _parse_from_uri(self, uri: str): + schemes_count_in_uri = uri.count("://") + if schemes_count_in_uri > 2: + raise AirflowException(f"Invalid connection string: {uri}.") + scheme_in_uri = schemes_count_in_uri == 2 uri_parts = urlsplit(uri) conn_type = uri_parts.scheme self.conn_type = self._normalize_conn_type(conn_type) - self.host = _parse_netloc_to_hostname(uri_parts) + reset_of_the_url = uri.replace(f"{conn_type}://", ("" if scheme_in_uri else "//")) + if scheme_in_uri: + uri_splits = reset_of_the_url.split("://", 1) + if "@" in uri_splits[0] or ":" in uri_splits[0]: + raise AirflowException(f"Invalid connection string: {uri}.") + uri_parts = urlsplit(reset_of_the_url) + protocol = uri_parts.scheme if scheme_in_uri else None + host = _parse_netloc_to_hostname(uri_parts) + self.host = self._create_host(protocol, host) quoted_schema = uri_parts.path[1:] self.schema = unquote(quoted_schema) if quoted_schema else quoted_schema self.login = unquote(uri_parts.username) if uri_parts.username else uri_parts.username @@ -203,6 +215,15 @@ def _parse_from_uri(self, uri: str): else: self.extra = json.dumps(query) + @staticmethod + def _create_host(protocol, host) -> str | None: + """Returns the connection host with the protocol.""" + if not host: + return host + if protocol: + return f"{protocol}://{host}" + return host + def get_uri(self) -> str: """Return connection in URI format.""" if self.conn_type and "_" in self.conn_type: @@ -216,6 +237,14 @@ def get_uri(self) -> str: else: uri = "//" + if self.host and "://" in self.host: + protocol, host = self.host.split("://", 1) + else: + protocol, host = None, self.host + + if protocol: + uri += f"{protocol}://" + authority_block = "" if self.login is not None: authority_block += quote(self.login, safe="") @@ -229,8 +258,8 @@ def get_uri(self) -> str: uri += authority_block host_block = "" - if self.host: - host_block += quote(self.host, safe="") + if host: + host_block += quote(host, safe="") if self.port: if host_block == "" and authority_block == "": diff --git a/tests/models/test_connection.py b/tests/models/test_connection.py new file mode 100644 index 0000000000000..0223cffb8db6e --- /dev/null +++ b/tests/models/test_connection.py @@ -0,0 +1,188 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import re + +import pytest + +from airflow import AirflowException +from airflow.models import Connection + + +class TestConnection: + @pytest.mark.parametrize( + "uri, expected_conn_type, expected_host, expected_login, expected_password," + " expected_port, expected_schema, expected_extra_dict, expected_exception_message", + [ + ( + "type://user:pass@host:100/schema", + "type", + "host", + "user", + "pass", + 100, + "schema", + {}, + None, + ), + ( + "type://user:pass@host/schema", + "type", + "host", + "user", + "pass", + None, + "schema", + {}, + None, + ), + ( + "type://user:pass@host/schema?param1=val1¶m2=val2", + "type", + "host", + "user", + "pass", + None, + "schema", + {"param1": "val1", "param2": "val2"}, + None, + ), + ( + "type://host", + "type", + "host", + None, + None, + None, + "", + {}, + None, + ), + ( + "spark://mysparkcluster.com:80?deploy-mode=cluster&spark_binary=command&namespace=kube+namespace", + "spark", + "mysparkcluster.com", + None, + None, + 80, + "", + {"deploy-mode": "cluster", "spark_binary": "command", "namespace": "kube namespace"}, + None, + ), + ( + "spark://k8s://100.68.0.1:443?deploy-mode=cluster", + "spark", + "k8s://100.68.0.1", + None, + None, + 443, + "", + {"deploy-mode": "cluster"}, + None, + ), + ( + "type://protocol://user:pass@host:123?param=value", + "type", + "protocol://host", + "user", + "pass", + 123, + "", + {"param": "value"}, + None, + ), + ( + "type://user:pass@protocol://host:port?param=value", + None, + None, + None, + None, + None, + None, + None, + r"Invalid connection string: type://user:pass@protocol://host:port?param=value.", + ), + ], + ) + def test_parse_from_uri( + self, + uri, + expected_conn_type, + expected_host, + expected_login, + expected_password, + expected_port, + expected_schema, + expected_extra_dict, + expected_exception_message, + ): + if expected_exception_message is not None: + with pytest.raises(AirflowException, match=re.escape(expected_exception_message)): + Connection(uri=uri) + else: + conn = Connection(uri=uri) + assert conn.conn_type == expected_conn_type + assert conn.login == expected_login + assert conn.password == expected_password + assert conn.host == expected_host + assert conn.port == expected_port + assert conn.schema == expected_schema + assert conn.extra_dejson == expected_extra_dict + + @pytest.mark.parametrize( + "connection, expected_uri", + [ + ( + Connection( + conn_type="type", + login="user", + password="pass", + host="host", + port=100, + schema="schema", + extra={"param1": "val1", "param2": "val2"}, + ), + "type://user:pass@host:100/schema?param1=val1¶m2=val2", + ), + ( + Connection( + conn_type="type", + host="protocol://host", + port=100, + schema="schema", + extra={"param1": "val1", "param2": "val2"}, + ), + "type://protocol://host:100/schema?param1=val1¶m2=val2", + ), + ( + Connection( + conn_type="type", + login="user", + password="pass", + host="protocol://host", + port=100, + schema="schema", + extra={"param1": "val1", "param2": "val2"}, + ), + "type://protocol://user:pass@host:100/schema?param1=val1¶m2=val2", + ), + ], + ) + def test_get_uri(self, connection, expected_uri): + assert connection.get_uri() == expected_uri From 6f165fb2af83fbd04fc643f3247c83c662589ccd Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Wed, 7 Jun 2023 18:50:23 +0200 Subject: [PATCH 2/2] some fixes from review --- airflow/models/connection.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/airflow/models/connection.py b/airflow/models/connection.py index 9966d8e02d488..0bc3ca38d48fb 100644 --- a/airflow/models/connection.py +++ b/airflow/models/connection.py @@ -190,17 +190,17 @@ def _parse_from_uri(self, uri: str): schemes_count_in_uri = uri.count("://") if schemes_count_in_uri > 2: raise AirflowException(f"Invalid connection string: {uri}.") - scheme_in_uri = schemes_count_in_uri == 2 + host_with_protocol = schemes_count_in_uri == 2 uri_parts = urlsplit(uri) conn_type = uri_parts.scheme self.conn_type = self._normalize_conn_type(conn_type) - reset_of_the_url = uri.replace(f"{conn_type}://", ("" if scheme_in_uri else "//")) - if scheme_in_uri: - uri_splits = reset_of_the_url.split("://", 1) + rest_of_the_url = uri.replace(f"{conn_type}://", ("" if host_with_protocol else "//")) + if host_with_protocol: + uri_splits = rest_of_the_url.split("://", 1) if "@" in uri_splits[0] or ":" in uri_splits[0]: raise AirflowException(f"Invalid connection string: {uri}.") - uri_parts = urlsplit(reset_of_the_url) - protocol = uri_parts.scheme if scheme_in_uri else None + uri_parts = urlsplit(rest_of_the_url) + protocol = uri_parts.scheme if host_with_protocol else None host = _parse_netloc_to_hostname(uri_parts) self.host = self._create_host(protocol, host) quoted_schema = uri_parts.path[1:]