-
Notifications
You must be signed in to change notification settings - Fork 17.3k
PSRP improvements #19806
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PSRP improvements #19806
Changes from all commits
536c8b8
43d39d6
8d0288d
d1b62b4
13f03aa
7f1a894
88aa850
cc65c62
745b4ff
f5a846e
40988c2
a3cf344
e1b182c
f59f112
e26044d
c3e1570
051d85f
8ded311
7049284
e04b6b5
23f1e65
35a6ba7
abde233
d1d769c
f2ec2a5
2e5daa4
63390c3
df55b5c
22351f3
045f000
fa7b0cf
17918e7
2403c6d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,103 +16,245 @@ | |
| # specific language governing permissions and limitations | ||
| # under the License. | ||
|
|
||
| from time import sleep | ||
| from contextlib import contextmanager | ||
| from copy import copy | ||
| from logging import DEBUG, ERROR, INFO, WARNING | ||
| from typing import Any, Callable, Dict, Iterator, Optional | ||
| from weakref import WeakKeyDictionary | ||
|
|
||
| from pypsrp.messages import ErrorRecord, InformationRecord, ProgressRecord | ||
| from pypsrp.messages import MessageType | ||
| from pypsrp.powershell import PowerShell, PSInvocationState, RunspacePool | ||
| from pypsrp.wsman import WSMan | ||
|
|
||
| from airflow.exceptions import AirflowException | ||
| from airflow.hooks.base import BaseHook | ||
|
|
||
| INFORMATIONAL_RECORD_LEVEL_MAP = { | ||
| MessageType.DEBUG_RECORD: DEBUG, | ||
| MessageType.ERROR_RECORD: ERROR, | ||
| MessageType.VERBOSE_RECORD: INFO, | ||
| MessageType.WARNING_RECORD: WARNING, | ||
| } | ||
|
|
||
| class PSRPHook(BaseHook): | ||
| OutputCallback = Callable[[str], None] | ||
|
|
||
|
|
||
| class PsrpHook(BaseHook): | ||
| """ | ||
| Hook for PowerShell Remoting Protocol execution. | ||
|
|
||
| The hook must be used as a context manager. | ||
| When used as a context manager, the runspace pool is reused between shell | ||
| sessions. | ||
|
|
||
| :param psrp_conn_id: Required. The name of the PSRP connection. | ||
| :type psrp_conn_id: str | ||
| :param logging_level: | ||
| Logging level for message streams which are received during remote execution. | ||
| The default is to include all messages in the task log. | ||
| :type logging_level: int | ||
| :param operation_timeout: Override the default WSMan timeout when polling the pipeline. | ||
| :type operation_timeout: float | ||
| :param runspace_options: | ||
| Optional dictionary which is passed when creating the runspace pool. See | ||
| :py:class:`~pypsrp.powershell.RunspacePool` for a description of the | ||
| available options. | ||
| :type runspace_options: dict | ||
| :param wsman_options: | ||
| Optional dictionary which is passed when creating the `WSMan` client. See | ||
| :py:class:`~pypsrp.wsman.WSMan` for a description of the available options. | ||
| :type wsman_options: dict | ||
| :param on_output_callback: | ||
| Optional callback function to be called whenever an output response item is | ||
| received during job status polling. | ||
| :type on_output_callback: OutputCallback | ||
| :param exchange_keys: | ||
| If true (default), automatically initiate a session key exchange when the | ||
| hook is used as a context manager. | ||
| :type exchange_keys: bool | ||
|
|
||
| You can provide an alternative `configuration_name` using either `runspace_options` | ||
| or by setting this key as the extra fields of your connection. | ||
| """ | ||
|
|
||
| _client = None | ||
| _poll_interval = 1 | ||
| _conn = None | ||
| _configuration_name = None | ||
| _wsman_ref: "WeakKeyDictionary[RunspacePool, WSMan]" = WeakKeyDictionary() | ||
|
|
||
| def __init__(self, psrp_conn_id: str): | ||
| def __init__( | ||
| self, | ||
| psrp_conn_id: str, | ||
| logging_level: int = DEBUG, | ||
| operation_timeout: Optional[int] = None, | ||
| runspace_options: Optional[Dict[str, Any]] = None, | ||
| wsman_options: Optional[Dict[str, Any]] = None, | ||
| on_output_callback: Optional[OutputCallback] = None, | ||
| exchange_keys: bool = True, | ||
| ): | ||
| self.conn_id = psrp_conn_id | ||
| self._logging_level = logging_level | ||
| self._operation_timeout = operation_timeout | ||
| self._runspace_options = runspace_options or {} | ||
| self._wsman_options = wsman_options or {} | ||
| self._on_output_callback = on_output_callback | ||
| self._exchange_keys = exchange_keys | ||
|
|
||
| def __enter__(self): | ||
| conn = self.get_connection(self.conn_id) | ||
|
|
||
| self.log.info("Establishing WinRM connection %s to host: %s", self.conn_id, conn.host) | ||
| self._client = WSMan( | ||
| conn.host, | ||
| ssl=True, | ||
| auth="ntlm", | ||
| encryption="never", | ||
| username=conn.login, | ||
| password=conn.password, | ||
| cert_validation=False, | ||
| ) | ||
| self._client.__enter__() | ||
| conn = self.get_conn() | ||
| self._wsman_ref[conn].__enter__() | ||
| conn.__enter__() | ||
| if self._exchange_keys: | ||
| conn.exchange_keys() | ||
| self._conn = conn | ||
| return self | ||
|
|
||
| def __exit__(self, exc_type, exc_value, traceback): | ||
| try: | ||
| self._client.__exit__(exc_type, exc_value, traceback) | ||
| self._conn.__exit__(exc_type, exc_value, traceback) | ||
| self._wsman_ref[self._conn].__exit__(exc_type, exc_value, traceback) | ||
| finally: | ||
| self._client = None | ||
| del self._conn | ||
|
potiuk marked this conversation as resolved.
Outdated
|
||
|
|
||
| def invoke_powershell(self, script: str) -> PowerShell: | ||
| with RunspacePool(self._client) as pool: | ||
| ps = PowerShell(pool) | ||
| ps.add_script(script) | ||
| def get_conn(self) -> RunspacePool: | ||
| """ | ||
| Returns a runspace pool. | ||
|
|
||
| The returned object must be used as a context manager. | ||
| """ | ||
| conn = self.get_connection(self.conn_id) | ||
| self.log.info("Establishing WinRM connection %s to host: %s", self.conn_id, conn.host) | ||
|
|
||
| extra = conn.extra_dejson.copy() | ||
|
|
||
| def apply_extra(d, keys): | ||
| d = d.copy() | ||
| for key in keys: | ||
| value = extra.pop(key, None) | ||
| if value is not None: | ||
| d[key] = value | ||
| return d | ||
|
|
||
| wsman_options = apply_extra( | ||
| self._wsman_options, | ||
| ( | ||
| "auth", | ||
| "cert_validation", | ||
| "connection_timeout", | ||
| "locale", | ||
| "read_timeout", | ||
| "reconnection_retries", | ||
| "reconnection_backoff", | ||
| "ssl", | ||
| ), | ||
| ) | ||
| wsman = WSMan(conn.host, username=conn.login, password=conn.password, **wsman_options) | ||
| runspace_options = apply_extra(self._runspace_options, ("configuration_name",)) | ||
|
|
||
| if extra: | ||
| raise AirflowException(f"Unexpected extra configuration keys: {', '.join(sorted(extra))}") | ||
| pool = RunspacePool(wsman, **runspace_options) | ||
| self._wsman_ref[pool] = wsman | ||
| return pool | ||
|
|
||
| @contextmanager | ||
| def invoke(self) -> Iterator[PowerShell]: | ||
| """ | ||
| Context manager that yields a PowerShell object to which commands can be | ||
| added. Upon exit, the commands will be invoked. | ||
| """ | ||
| logger = copy(self.log) | ||
| logger.setLevel(self._logging_level) | ||
| local_context = self._conn is None | ||
| if local_context: | ||
| self.__enter__() | ||
| try: | ||
| assert self._conn is not None | ||
| ps = PowerShell(self._conn) | ||
| yield ps | ||
| ps.begin_invoke() | ||
|
|
||
| streams = [ | ||
| (ps.output, self._log_output), | ||
| (ps.streams.debug, self._log_record), | ||
| (ps.streams.information, self._log_record), | ||
| (ps.streams.error, self._log_record), | ||
| ps.output, | ||
| ps.streams.debug, | ||
| ps.streams.error, | ||
| ps.streams.information, | ||
| ps.streams.progress, | ||
| ps.streams.verbose, | ||
| ps.streams.warning, | ||
| ] | ||
| offsets = [0 for _ in streams] | ||
|
|
||
| # We're using polling to make sure output and streams are | ||
| # handled while the process is running. | ||
| while ps.state == PSInvocationState.RUNNING: | ||
| sleep(self._poll_interval) | ||
| ps.poll_invoke() | ||
| ps.poll_invoke(timeout=self._operation_timeout) | ||
|
|
||
| for (i, (stream, handler)) in enumerate(streams): | ||
| for i, stream in enumerate(streams): | ||
| offset = offsets[i] | ||
| while len(stream) > offset: | ||
| handler(stream[offset]) | ||
| record = stream[offset] | ||
|
|
||
| # Records received on the output stream during job | ||
| # status polling are handled via an optional callback, | ||
| # while the other streams are simply logged. | ||
| if stream is ps.output: | ||
| if self._on_output_callback is not None: | ||
| self._on_output_callback(record) | ||
| else: | ||
| self._log_record(logger.log, record) | ||
| offset += 1 | ||
| offsets[i] = offset | ||
|
|
||
| # For good measure, we'll make sure the process has | ||
| # stopped running. | ||
| # stopped running in any case. | ||
| ps.end_invoke() | ||
|
|
||
| self.log.info("Invocation state: %s", str(PSInvocationState(ps.state))) | ||
| if ps.streams.error: | ||
| raise AirflowException("Process had one or more errors") | ||
| finally: | ||
| if local_context: | ||
| self.__exit__(None, None, None) | ||
|
|
||
| self.log.info("Invocation state: %s", str(PSInvocationState(ps.state))) | ||
| return ps | ||
| def invoke_cmdlet(self, name: str, use_local_scope=None, **parameters: Dict[str, str]) -> PowerShell: | ||
| """Invoke a PowerShell cmdlet and return session.""" | ||
| with self.invoke() as ps: | ||
| ps.add_cmdlet(name, use_local_scope=use_local_scope) | ||
| ps.add_parameters(parameters) | ||
| return ps | ||
|
|
||
| def _log_output(self, message: str): | ||
| self.log.info("%s", message) | ||
| def invoke_powershell(self, script: str) -> PowerShell: | ||
| """Invoke a PowerShell script and return session.""" | ||
| with self.invoke() as ps: | ||
| ps.add_script(script) | ||
| return ps | ||
|
|
||
| def _log_record(self, record): | ||
| # TODO: Consider translating some or all of these records into | ||
| # normal logging levels, using `log(level, msg, *args)`. | ||
| if isinstance(record, ErrorRecord): | ||
| self.log.info("Error: %s", record) | ||
| return | ||
| def _log_record(self, log, record): | ||
| message_type = record.MESSAGE_TYPE | ||
| if message_type == MessageType.ERROR_RECORD: | ||
| log(INFO, "%s: %s", record.reason, record) | ||
| if record.script_stacktrace: | ||
| for trace in record.script_stacktrace.split('\r\n'): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, I think "\r\n" is the correct thing here – it is Windows after all, so we don't really need the bigger arsenal of line-endings in
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think using The PSRP <> Windows association is not at all obvious for someone who just reviews the code)?. |
||
| log(INFO, trace) | ||
|
|
||
| if isinstance(record, InformationRecord): | ||
| self.log.info("Information: %s", record.message_data) | ||
| return | ||
| level = INFORMATIONAL_RECORD_LEVEL_MAP.get(message_type) | ||
| if level is not None: | ||
| try: | ||
| message = str(record.message) | ||
| except BaseException as exc: | ||
| # See https://github.com/jborean93/pypsrp/pull/130 | ||
| message = str(exc) | ||
|
|
||
| if isinstance(record, ProgressRecord): | ||
| self.log.info("Progress: %s (%s)", record.activity, record.description) | ||
| return | ||
| # Sometimes a message will have a trailing \r\n sequence such as | ||
| # the tracing output of the Set-PSDebug cmdlet. | ||
| message = message.rstrip() | ||
|
|
||
| self.log.info("Unsupported record type: %s", type(record).__name__) | ||
| if record.command_name is None: | ||
| log(level, "%s", message) | ||
| else: | ||
| log(level, "%s: %s", record.command_name, message) | ||
| elif message_type == MessageType.INFORMATION_RECORD: | ||
| log(INFO, "%s (%s): %s", record.computer, record.user, record.message_data) | ||
| elif message_type == MessageType.PROGRESS_RECORD: | ||
| log(INFO, "Progress: %s (%s)", record.activity, record.description) | ||
| else: | ||
| log(WARNING, "Unsupported message type: %s", message_type) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How important is it to you to support the non-context manager use case? I feel it’s kind of unnecessarily complicating the implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Someone less comfortable with context managers will presumably have less of a challenge using the hook, which is one argument for having it. And I suppose it allows for a one-liner in simple cases.
But I don't feel strongly that we need to support the non-context manager use case. It does complicate the code a little bit – but not a lot.