Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions airflow/providers/apache/livy/hooks/livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def run_method(
method: str = 'GET',
data: Optional[Any] = None,
headers: Optional[Dict[str, Any]] = None,
retry_args: Optional[Dict[str, Any]] = None,
) -> Any:
"""
Wrapper for HttpHook, allows to change method on the same HttpHook
Expand All @@ -107,6 +108,8 @@ def run_method(
:param endpoint: endpoint
:param data: request payload
:param headers: headers
:param retry_args: Arguments which define the retry behaviour.
See Tenacity documentation at https://github.com/jd/tenacity
:return: http response
:rtype: requests.Response
"""
Expand All @@ -118,7 +121,17 @@ def run_method(
back_method = self.method
self.method = method
try:
result = self.run(endpoint, data, headers, self.extra_options)
if retry_args:
result = self.run_with_advanced_retry(
endpoint=endpoint,
data=data,
headers=headers,
extra_options=self.extra_options,
_retry_args=retry_args,
)
else:
result = self.run(endpoint, data, headers, self.extra_options)

finally:
self.method = back_method
return result
Expand Down Expand Up @@ -180,18 +193,22 @@ def get_batch(self, session_id: Union[int, str]) -> Any:

return response.json()

def get_batch_state(self, session_id: Union[int, str]) -> BatchState:
def get_batch_state(
self, session_id: Union[int, str], retry_args: Optional[Dict[str, Any]] = None
) -> BatchState:
"""
Fetch the state of the specified batch

:param session_id: identifier of the batch sessions
:param retry_args: Arguments which define the retry behaviour.
See Tenacity documentation at https://github.com/jd/tenacity
:return: batch state
:rtype: BatchState
"""
self._validate_session_id(session_id)

self.log.debug("Fetching info for batch session %d", session_id)
response = self.run_method(endpoint=f'/batches/{session_id}/state')
response = self.run_method(endpoint=f'/batches/{session_id}/state', retry_args=retry_args)

try:
response.raise_for_status()
Expand Down
8 changes: 6 additions & 2 deletions airflow/providers/apache/livy/operators/livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class LivyOperator(BaseOperator):
:param extra_options: A dictionary of options, where key is string and value
depends on the option that's being modified.
:param extra_headers: A dictionary of headers passed to the HTTP request to livy.
:param retry_args: Arguments which define the retry behaviour.
See Tenacity documentation at https://github.com/jd/tenacity
"""

template_fields: Sequence[str] = ('spark_params',)
Expand Down Expand Up @@ -80,6 +82,7 @@ def __init__(
polling_interval: int = 0,
extra_options: Optional[Dict[str, Any]] = None,
extra_headers: Optional[Dict[str, Any]] = None,
retry_args: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:

Expand Down Expand Up @@ -111,6 +114,7 @@ def __init__(

self._livy_hook: Optional[LivyHook] = None
self._batch_id: Union[int, str]
self.retry_args = retry_args

def get_hook(self) -> LivyHook:
"""
Expand Down Expand Up @@ -142,11 +146,11 @@ def poll_for_termination(self, batch_id: Union[int, str]) -> None:
:param batch_id: id of the batch session to monitor.
"""
hook = self.get_hook()
state = hook.get_batch_state(batch_id)
state = hook.get_batch_state(batch_id, retry_args=self.retry_args)
while state not in hook.TERMINAL_STATES:
self.log.debug('Batch with id %s is in state: %s', batch_id, state.value)
sleep(self._polling_interval)
state = hook.get_batch_state(batch_id)
state = hook.get_batch_state(batch_id, retry_args=self.retry_args)
self.log.info("Batch with id %s terminated with state: %s", batch_id, state.value)
hook.dump_batch_logs(batch_id)
if state != BatchState.SUCCESS:
Expand Down
12 changes: 6 additions & 6 deletions tests/providers/apache/livy/operators/test_livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_poll_for_termination(self, mock_livy, mock_dump_logs):

state_list = 2 * [BatchState.RUNNING] + [BatchState.SUCCESS]

def side_effect(_):
def side_effect(_, retry_args):
if state_list:
return state_list.pop(0)
# fail if does not stop right before
Expand All @@ -67,7 +67,7 @@ def side_effect(_):
task._livy_hook = task.get_hook()
task.poll_for_termination(BATCH_ID)

mock_livy.assert_called_with(BATCH_ID)
mock_livy.assert_called_with(BATCH_ID, retry_args=None)
mock_dump_logs.assert_called_with(BATCH_ID)
assert mock_livy.call_count == 3

Expand All @@ -80,7 +80,7 @@ def test_poll_for_termination_fail(self, mock_livy, mock_dump_logs):

state_list = 2 * [BatchState.RUNNING] + [BatchState.ERROR]

def side_effect(_):
def side_effect(_, retry_args):
if state_list:
return state_list.pop(0)
# fail if does not stop right before
Expand All @@ -94,7 +94,7 @@ def side_effect(_):
with pytest.raises(AirflowException):
task.poll_for_termination(BATCH_ID)

mock_livy.assert_called_with(BATCH_ID)
mock_livy.assert_called_with(BATCH_ID, retry_args=None)
mock_dump_logs.assert_called_with(BATCH_ID)
assert mock_livy.call_count == 3

Expand All @@ -119,7 +119,7 @@ def test_execution(self, mock_post, mock_get, mock_dump_logs):

call_args = {k: v for k, v in mock_post.call_args[1].items() if v}
assert call_args == {'file': 'sparkapp'}
mock_get.assert_called_once_with(BATCH_ID)
mock_get.assert_called_once_with(BATCH_ID, retry_args=None)
mock_dump_logs.assert_called_once_with(BATCH_ID)

@patch('airflow.providers.apache.livy.operators.livy.LivyHook.post_batch')
Expand Down Expand Up @@ -171,5 +171,5 @@ def test_log_dump(self, mock_post, mock_get_logs, mock_get):
assert 'INFO:airflow.providers.apache.livy.hooks.livy.LivyHook:first_line' in cm.output
assert 'INFO:airflow.providers.apache.livy.hooks.livy.LivyHook:second_line' in cm.output
assert 'INFO:airflow.providers.apache.livy.hooks.livy.LivyHook:third_line' in cm.output
mock_get.assert_called_once_with(BATCH_ID)
mock_get.assert_called_once_with(BATCH_ID, retry_args=None)
mock_get_logs.assert_called_once_with(BATCH_ID, 0, 100)