diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index daf8c3299f7..a2ae5f58b1a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,14 @@ repos: hooks: - id: isort language_version: python3 + - repo: https://github.com/asottile/pyupgrade + # Do not upgrade: there's a bug in Cython that causes sum(... for ...) to fail; + # it needs sum([... for ...]) + rev: v2.13.0 + hooks: + - id: pyupgrade + args: + - --py37-plus - repo: https://github.com/psf/black rev: 21.9b0 hooks: @@ -17,11 +25,22 @@ repos: hooks: - id: flake8 language_version: python3 - - repo: https://github.com/asottile/pyupgrade - # Do not upgrade: there's a bug in Cython that causes sum(... for ...) to fail; - # it needs sum([... for ...]) - rev: v2.13.0 + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.910 hooks: - - id: pyupgrade - args: - - --py37-plus + - id: mypy + additional_dependencies: + # Type stubs + - types-docutils + - types-requests + - types-paramiko + - types-pkg_resources + - types-PyYAML + - types-setuptools + - types-psutil + # Libraries exclusively imported under `if TYPE_CHECKING:` + - typing_extensions # To be reviewed after dropping Python 3.7 + # Typed libraries + - numpy + - dask + - tornado diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index cc6ad8d3166..0bcd83522ca 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -127,12 +127,12 @@ def run_once(self, comm=None) -> None: # populate self.pending self._run_policies() - drop_by_worker: defaultdict[WorkerState, set[TaskState]] = defaultdict( - set - ) - repl_by_worker: defaultdict[ - WorkerState, dict[TaskState, set[WorkerState]] - ] = defaultdict(dict) + drop_by_worker: ( + defaultdict[WorkerState, set[TaskState]] + ) = defaultdict(set) + repl_by_worker: ( + defaultdict[WorkerState, dict[TaskState, set[str]]] + ) = defaultdict(dict) for ts, (pending_repl, pending_drop) in self.pending.items(): if not ts.who_has: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 3e755a580bb..12add27f2d8 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -13,14 +13,24 @@ import warnings import weakref from collections import defaultdict, deque -from collections.abc import Hashable, Iterable, Iterator, Mapping, Set +from collections.abc import ( + Callable, + Collection, + Hashable, + Iterable, + Iterator, + Mapping, + Set, +) from contextlib import suppress from datetime import timedelta from functools import partial from numbers import Number +from typing import ClassVar +from typing import cast as pep484_cast import psutil -import sortedcontainers +from sortedcontainers import SortedDict, SortedSet from tlz import ( compose, first, @@ -524,16 +534,17 @@ class WorkerState: def __init__( self, - address: str = None, - pid: Py_ssize_t = 0, - name: object = None, + *, + address: str, + pid: Py_ssize_t, + name: object, nthreads: Py_ssize_t = 0, - memory_limit: Py_ssize_t = 0, - local_directory: str = None, - services: dict = None, - versions: dict = None, - nanny: str = None, - extra: dict = None, + memory_limit: Py_ssize_t, + local_directory: str, + nanny: str, + services: "dict | None" = None, + versions: "dict | None" = None, + extra: "dict | None" = None, ): self._address = address self._pid = pid @@ -584,7 +595,7 @@ def actors(self): return self._actors @property - def address(self): + def address(self) -> str: return self._address @property @@ -781,7 +792,7 @@ class Computation: def __init__(self): self._start = time() self._groups = set() - self._code = sortedcontainers.SortedSet() + self._code = SortedSet() self._id = uuid.uuid4() @property @@ -862,7 +873,7 @@ class TaskPrefix: """ _name: str - _all_durations: object + _all_durations: "defaultdict[str, float]" _duration_average: double _suspicious: Py_ssize_t _groups: list @@ -882,19 +893,19 @@ def __init__(self, name: str): self._suspicious = 0 @property - def name(self): + def name(self) -> str: return self._name @property - def all_durations(self): + def all_durations(self) -> "defaultdict[str, float]": return self._all_durations @property - def duration_average(self): + def duration_average(self) -> double: return self._duration_average @property - def suspicious(self): + def suspicious(self) -> Py_ssize_t: return self._suspicious @property @@ -907,7 +918,7 @@ def states(self): return merge_with(sum, [tg._states for tg in self._groups]) @property - def active(self): + def active(self) -> "list[TaskGroup]": tg: TaskGroup return [ tg @@ -1000,7 +1011,7 @@ class TaskGroup: """ _name: str - _prefix: TaskPrefix + _prefix: TaskPrefix # TaskPrefix | None _states: dict _dependencies: set _nbytes_total: Py_ssize_t @@ -1008,13 +1019,13 @@ class TaskGroup: _types: set _start: double _stop: double - _all_durations: object - _last_worker: WorkerState + _all_durations: "defaultdict[str, float]" + _last_worker: WorkerState # WorkerState | None _last_worker_tasks_left: Py_ssize_t def __init__(self, name: str): self._name = name - self._prefix = None + self._prefix = None # type: ignore self._states = {state: 0 for state in ALL_TASK_STATES} self._states["forgotten"] = 0 self._dependencies = set() @@ -1024,23 +1035,23 @@ def __init__(self, name: str): self._start = 0.0 self._stop = 0.0 self._all_durations = defaultdict(float) - self._last_worker = None + self._last_worker = None # type: ignore self._last_worker_tasks_left = 0 @property - def name(self): + def name(self) -> str: return self._name @property - def prefix(self): + def prefix(self) -> "TaskPrefix | None": return self._prefix @property - def states(self): + def states(self) -> dict: return self._states @property - def dependencies(self): + def dependencies(self) -> set: return self._dependencies @property @@ -1048,38 +1059,37 @@ def nbytes_total(self): return self._nbytes_total @property - def duration(self): + def duration(self) -> double: return self._duration @property - def types(self): + def types(self) -> set: return self._types @property - def all_durations(self): + def all_durations(self) -> "defaultdict[str, float]": return self._all_durations @property - def start(self): + def start(self) -> double: return self._start @property - def stop(self): + def stop(self) -> double: return self._stop @property - def last_worker(self): + def last_worker(self) -> "WorkerState | None": return self._last_worker @property - def last_worker_tasks_left(self): + def last_worker_tasks_left(self) -> int: return self._last_worker_tasks_left @ccall - def add(self, o): - ts: TaskState = o - self._states[ts._state] += 1 - ts._group = self + def add(self, other: "TaskState"): + self._states[other._state] += 1 + other._group = self def __repr__(self): return ( @@ -1347,34 +1357,34 @@ class TaskState: _hash: Py_hash_t _prefix: TaskPrefix _run_spec: object - _priority: tuple - _state: str - _dependencies: set - _dependents: set + _priority: tuple # tuple | None + _state: str # str | None + _dependencies: set # set[TaskState] + _dependents: set # set[TaskState] _has_lost_dependencies: bint - _waiting_on: set - _waiters: set - _who_wants: set - _who_has: set - _processing_on: WorkerState + _waiting_on: set # set[TaskState] + _waiters: set # set[TaskState] + _who_wants: set # set[ClientState] + _who_has: set # set[WorkerState] + _processing_on: WorkerState # WorkerState | None _retries: Py_ssize_t _nbytes: Py_ssize_t - _type: str + _type: str # str | None _exception: object _exception_text: str _traceback: object _traceback_text: str - _exception_blame: object + _exception_blame: "TaskState" # TaskState | None" _erred_on: set _suspicious: Py_ssize_t - _host_restrictions: set - _worker_restrictions: set - _resource_restrictions: dict + _host_restrictions: set # set[str] | None + _worker_restrictions: set # set[str] | None + _resource_restrictions: dict # dict | None _loose_restrictions: bint _metadata: dict _annotations: dict _actor: bint - _group: TaskGroup + _group: TaskGroup # TaskGroup | None _group_key: str __slots__ = ( @@ -1434,28 +1444,32 @@ def __init__(self, key: str, run_spec: object): self._key = key self._hash = hash(key) self._run_spec = run_spec - self._state = None - self._exception = self._traceback = self._exception_blame = None - self._exception_text = self._traceback_text = "" - self._suspicious = self._retries = 0 + self._state = None # type: ignore + self._exception = None + self._exception_blame = None # type: ignore + self._traceback = None + self._exception_text = "" + self._traceback_text = "" + self._suspicious = 0 + self._retries = 0 self._nbytes = -1 - self._priority = None + self._priority = None # type: ignore self._who_wants = set() self._dependencies = set() self._dependents = set() self._waiting_on = set() self._waiters = set() self._who_has = set() - self._processing_on = None + self._processing_on = None # type: ignore self._has_lost_dependencies = False - self._host_restrictions = None - self._worker_restrictions = None - self._resource_restrictions = None + self._host_restrictions = None # type: ignore + self._worker_restrictions = None # type: ignore + self._resource_restrictions = None # type: ignore self._loose_restrictions = False self._actor = False - self._type = None + self._type = None # type: ignore self._group_key = key_split_group(key) - self._group = None + self._group = None # type: ignore self._metadata = {} self._annotations = {} self._erred_on = set() @@ -1485,11 +1499,11 @@ def run_spec(self): return self._run_spec @property - def priority(self): + def priority(self) -> "tuple | None": return self._priority @property - def state(self) -> str: + def state(self) -> "str | None": return self._state @state.setter @@ -1499,11 +1513,11 @@ def state(self, value: str): self._state = value @property - def dependencies(self): + def dependencies(self) -> "set[TaskState]": return self._dependencies @property - def dependents(self): + def dependents(self) -> "set[TaskState]": return self._dependents @property @@ -1511,27 +1525,27 @@ def has_lost_dependencies(self): return self._has_lost_dependencies @property - def waiting_on(self): + def waiting_on(self) -> "set[TaskState]": return self._waiting_on @property - def waiters(self): + def waiters(self) -> "set[TaskState]": return self._waiters @property - def who_wants(self): + def who_wants(self) -> "set[ClientState]": return self._who_wants @property - def who_has(self): + def who_has(self) -> "set[WorkerState]": return self._who_has @property - def processing_on(self): + def processing_on(self) -> "WorkerState | None": return self._processing_on @processing_on.setter - def processing_on(self, v: WorkerState): + def processing_on(self, v: WorkerState) -> None: self._processing_on = v @property @@ -1547,7 +1561,7 @@ def nbytes(self, v: Py_ssize_t): self._nbytes = v @property - def type(self): + def type(self) -> "str | None": return self._type @property @@ -1567,7 +1581,7 @@ def traceback_text(self): return self._traceback_text @property - def exception_blame(self): + def exception_blame(self) -> "TaskState | None": return self._exception_blame @property @@ -1575,15 +1589,15 @@ def suspicious(self): return self._suspicious @property - def host_restrictions(self): + def host_restrictions(self) -> "set[str] | None": return self._host_restrictions @property - def worker_restrictions(self): + def worker_restrictions(self) -> "set[str] | None": return self._worker_restrictions @property - def resource_restrictions(self): + def resource_restrictions(self) -> "dict | None": return self._resource_restrictions @property @@ -1603,11 +1617,11 @@ def actor(self): return self._actor @property - def group(self): + def group(self) -> "TaskGroup | None": return self._group @property - def group_key(self): + def group_key(self) -> str: return self._group_key @property @@ -1839,12 +1853,12 @@ class SchedulerState: _aliases: dict _bandwidth: double - _clients: dict + _clients: dict # dict[str, ClientState] _computations: object _extensions: dict _host_info: dict - _idle: object - _idle_dv: dict + _idle: "SortedDict[str, WorkerState]" + _idle_dv: dict # dict[str, WorkerState] _n_tasks: Py_ssize_t _resources: dict _saturated: set @@ -1859,9 +1873,10 @@ class SchedulerState: _unknown_durations: dict _unrunnable: set _validate: bint - _workers: object - _workers_dv: dict + _workers: "SortedDict[str, WorkerState]" + _workers_dv: dict # dict[str, WorkerState] _transition_counter: Py_ssize_t + _plugins: dict # dict[str, SchedulerPlugin] # Variables from dask.config, cached by __init__ for performance UNKNOWN_TASK_DURATION: double @@ -1873,54 +1888,41 @@ class SchedulerState: def __init__( self, - aliases: dict = None, - clients: dict = None, - workers=None, - host_info=None, - resources=None, - tasks: dict = None, - unrunnable: set = None, - validate: bint = False, - **kwargs, + aliases: dict, + clients: "dict[str, ClientState]", + workers: "SortedDict[str, WorkerState]", + host_info: dict, + resources: dict, + tasks: dict, + unrunnable: set, + validate: bint, + plugins: "Iterable[SchedulerPlugin]" = (), + **kwargs, # Passed verbatim to Server.__init__() ): - if aliases is not None: - self._aliases = aliases - else: - self._aliases = dict() + self._aliases = aliases self._bandwidth = parse_bytes( dask.config.get("distributed.scheduler.bandwidth") ) - if clients is not None: - self._clients = clients - else: - self._clients = dict() + self._clients = clients self._clients["fire-and-forget"] = ClientState("fire-and-forget") - self._extensions = dict() - if host_info is not None: - self._host_info = host_info - else: - self._host_info = dict() - self._idle = sortedcontainers.SortedDict() - self._idle_dv: dict = cast(dict, self._idle) + self._extensions = {} + self._host_info = host_info + self._idle = SortedDict() + # Note: cython.cast, not typing.cast! + self._idle_dv = cast(dict, self._idle) self._n_tasks = 0 - if resources is not None: - self._resources = resources - else: - self._resources = dict() + self._resources = resources self._saturated = set() - if tasks is not None: - self._tasks = tasks - else: - self._tasks = dict() + self._tasks = tasks self._replicated_tasks = { ts for ts in self._tasks.values() if len(ts._who_has) > 1 } self._computations = deque( maxlen=dask.config.get("distributed.diagnostics.computations.max-history") ) - self._task_groups = dict() - self._task_prefixes = dict() - self._task_metadata = dict() + self._task_groups = {} + self._task_prefixes = {} + self._task_metadata = {} self._total_nthreads = 0 self._total_occupancy = 0 self._transitions_table = { @@ -1940,17 +1942,13 @@ def __init__( ("memory", "released"): self.transition_memory_released, ("released", "erred"): self.transition_released_erred, } - self._unknown_durations = dict() - if unrunnable is not None: - self._unrunnable = unrunnable - else: - self._unrunnable = set() + self._unknown_durations = {} + self._unrunnable = unrunnable self._validate = validate - if workers is not None: - self._workers = workers - else: - self._workers = sortedcontainers.SortedDict() - self._workers_dv: dict = cast(dict, self._workers) + self._workers = workers + # Note: cython.cast, not typing.cast! + self._workers_dv = cast(dict, self._workers) + self._plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins} # Variables from dask.config, cached by __init__ for performance self.UNKNOWN_TASK_DURATION = parse_timedelta( @@ -1974,7 +1972,8 @@ def __init__( ) self._transition_counter = 0 - super().__init__(**kwargs) + # Call Server.__init__() + super().__init__(**kwargs) # type: ignore @property def aliases(self): @@ -2072,6 +2071,10 @@ def validate(self, v: bint): def workers(self): return self._workers + @property + def plugins(self) -> "dict[str, SchedulerPlugin]": + return self._plugins + @property def memory(self) -> MemoryState: return MemoryState.sum(*(w.memory for w in self.workers.values())) @@ -2109,14 +2112,13 @@ def new_task( tp: TaskPrefix prefix_key = key_split(key) - tp = self._task_prefixes.get(prefix_key) + tp = self._task_prefixes.get(prefix_key) # type: ignore if tp is None: self._task_prefixes[prefix_key] = tp = TaskPrefix(prefix_key) ts._prefix = tp - tg: TaskGroup group_key = ts._group_key - tg = self._task_groups.get(group_key) + tg: TaskGroup = self._task_groups.get(group_key) # type: ignore if tg is None: self._task_groups[group_key] = tg = TaskGroup(group_key) if computation: @@ -2166,7 +2168,7 @@ def _transition(self, key, finish: str, *args, **kwargs): worker_msgs = {} client_msgs = {} - ts = parent._tasks.get(key) + ts = parent._tasks.get(key) # type: ignore if ts is None: return recommendations, client_msgs, worker_msgs start = ts._state @@ -2180,9 +2182,8 @@ def _transition(self, key, finish: str, *args, **kwargs): start_finish = (start, finish) func = self._transitions_table.get(start_finish) if func is not None: - a: tuple = func(key, *args, **kwargs) + recommendations, client_msgs, worker_msgs = func(key, *args, **kwargs) self._transition_counter += 1 - recommendations, client_msgs, worker_msgs = a elif "released" not in start_finish: assert not args and not kwargs, (args, kwargs, start_finish) a_recs: dict @@ -2201,13 +2202,13 @@ def _transition(self, key, finish: str, *args, **kwargs): recommendations.update(a_recs) for c, new_msgs in a_cmsgs.items(): - msgs = client_msgs.get(c) + msgs = client_msgs.get(c) # type: ignore if msgs is not None: msgs.extend(new_msgs) else: client_msgs[c] = new_msgs for w, new_msgs in a_wmsgs.items(): - msgs = worker_msgs.get(w) + msgs = worker_msgs.get(w) # type: ignore if msgs is not None: msgs.extend(new_msgs) else: @@ -2215,13 +2216,13 @@ def _transition(self, key, finish: str, *args, **kwargs): recommendations.update(b_recs) for c, new_msgs in b_cmsgs.items(): - msgs = client_msgs.get(c) + msgs = client_msgs.get(c) # type: ignore if msgs is not None: msgs.extend(new_msgs) else: client_msgs[c] = new_msgs for w, new_msgs in b_wmsgs.items(): - msgs = worker_msgs.get(w) + msgs = worker_msgs.get(w) # type: ignore if msgs is not None: msgs.extend(new_msgs) else: @@ -2232,7 +2233,11 @@ def _transition(self, key, finish: str, *args, **kwargs): raise RuntimeError("Impossible transition from %r to %r" % start_finish) finish2 = ts._state - self.transition_log.append((key, start, finish2, recommendations, time())) + # FIXME downcast antipattern + scheduler = pep484_cast(Scheduler, self) + scheduler.transition_log.append( + (key, start, finish2, recommendations, time()) + ) if parent._validate: logger.debug( "Transitioned %r %s->%s (actual: %s). Consequence: %s", @@ -2283,7 +2288,6 @@ def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: di This includes feedback from previous transitions and continues until we reach a steady state """ - parent: SchedulerState = cast(SchedulerState, self) keys: set = set() recommendations = recommendations.copy() msgs: list @@ -2301,21 +2305,23 @@ def _transitions(self, recommendations: dict, client_msgs: dict, worker_msgs: di recommendations.update(new_recs) for c, new_msgs in new_cmsgs.items(): - msgs = client_msgs.get(c) + msgs = client_msgs.get(c) # type: ignore if msgs is not None: msgs.extend(new_msgs) else: client_msgs[c] = new_msgs for w, new_msgs in new_wmsgs.items(): - msgs = worker_msgs.get(w) + msgs = worker_msgs.get(w) # type: ignore if msgs is not None: msgs.extend(new_msgs) else: worker_msgs[w] = new_msgs - if parent._validate: + if self._validate: + # FIXME downcast antipattern + scheduler = pep484_cast(Scheduler, self) for key in keys: - self.validate_key(key) + scheduler.validate_key(key) def transition_released_waiting(self, key): try: @@ -2457,7 +2463,7 @@ def transition_no_worker_memory( @ccall @exceptval(check=False) - def decide_worker(self, ts: TaskState) -> WorkerState: + def decide_worker(self, ts: TaskState) -> WorkerState: # -> WorkerState | None """ Decide on a worker for task *ts*. Return a WorkerState. @@ -2471,10 +2477,10 @@ def decide_worker(self, ts: TaskState) -> WorkerState: in a round-robin fashion. """ if not self._workers_dv: - return None + return None # type: ignore - ws: WorkerState = None - group: TaskGroup = ts._group + ws: WorkerState + tg: TaskGroup = ts._group valid_workers: set = self.valid_workers(ts) if ( @@ -2484,34 +2490,35 @@ def decide_worker(self, ts: TaskState) -> WorkerState: ): self._unrunnable.add(ts) ts.state = "no-worker" - return ws + return None # type: ignore - # Group is larger than cluster with few dependencies? Minimize future data transfers. + # Group is larger than cluster with few dependencies? + # Minimize future data transfers. if ( valid_workers is None - and len(group) > self._total_nthreads * 2 - and len(group._dependencies) < 5 - and sum(map(len, group._dependencies)) < 5 + and len(tg) > self._total_nthreads * 2 + and len(tg._dependencies) < 5 + and sum(map(len, tg._dependencies)) < 5 ): - ws: WorkerState = group._last_worker + ws = tg._last_worker if not ( - ws and group._last_worker_tasks_left and ws._address in self._workers_dv + ws and tg._last_worker_tasks_left and ws._address in self._workers_dv ): # Last-used worker is full or unknown; pick a new worker for the next few tasks ws = min( (self._idle_dv or self._workers_dv).values(), key=partial(self.worker_objective, ts), ) - group._last_worker_tasks_left = math.floor( - (len(group) / self._total_nthreads) * ws._nthreads + tg._last_worker_tasks_left = math.floor( + (len(tg) / self._total_nthreads) * ws._nthreads ) # Record `last_worker`, or clear it on the final task - group._last_worker = ( - ws if group.states["released"] + group.states["waiting"] > 1 else None + tg._last_worker = ( + ws if tg.states["released"] + tg.states["waiting"] > 1 else None ) - group._last_worker_tasks_left -= 1 + tg._last_worker_tasks_left -= 1 return ws if ts._dependencies or valid_workers is not None: @@ -2524,6 +2531,7 @@ def decide_worker(self, ts: TaskState) -> WorkerState: else: # Fastpath when there are no related tasks or restrictions worker_pool = self._idle or self._workers + # Note: cython.cast, not typing.cast! worker_pool_dv = cast(dict, worker_pool) wp_vals = worker_pool.values() n_workers: Py_ssize_t = len(worker_pool_dv) @@ -2676,6 +2684,8 @@ def transition_processing_memory( worker_msgs: dict = {} try: ts: TaskState = self._tasks[key] + tg: TaskGroup = ts._group + assert worker assert isinstance(worker, str) @@ -2688,15 +2698,15 @@ def transition_processing_memory( assert not ts._exception_blame assert ts._state == "processing" - ws = self._workers_dv.get(worker) + ws = self._workers_dv.get(worker) # type: ignore if ws is None: recommendations[key] = "released" return recommendations, client_msgs, worker_msgs if ws != ts._processing_on: # someone else has this task logger.info( - "Unexpected worker completed task, likely due to" - " work stealing. Expected: %s, Got: %s, Key: %s", + "Unexpected worker completed task, likely due to " + "work stealing. Expected: %s, Got: %s, Key: %s", ts._processing_on, ws, key, @@ -2726,7 +2736,7 @@ def transition_processing_memory( # record timings of all actions -- a cheaper way of # getting timing info compared with get_task_stream() ts._prefix._all_durations[action] += stop - start - ts._group._all_durations[action] += stop - start + tg._all_durations[action] += stop - start ############################# # Update Timing Information # @@ -2742,10 +2752,10 @@ def transition_processing_memory( avg_duration = 0.5 * old_duration + 0.5 * new_duration ts._prefix._duration_average = avg_duration - ts._group._duration += new_duration - ts._group._start = ts._group._start or compute_start - if ts._group._stop < compute_stop: - ts._group._stop = compute_stop + tg._duration += new_duration + tg._start = tg._start or compute_start + if tg._stop < compute_stop: + tg._stop = compute_stop s: set = self._unknown_durations.pop(ts._prefix._name, None) tts: TaskState @@ -3068,15 +3078,15 @@ def transition_processing_erred( ts._erred_on.add(w or worker) if exception is not None: ts._exception = exception - ts._exception_text = exception_text + ts._exception_text = exception_text # type: ignore if traceback is not None: ts._traceback = traceback - ts._traceback_text = traceback_text + ts._traceback_text = traceback_text # type: ignore if cause is not None: failing_ts = self._tasks[cause] ts._exception_blame = failing_ts else: - failing_ts = ts._exception_blame + failing_ts = ts._exception_blame # type: ignore for dts in ts._dependents: dts._exception_blame = failing_ts @@ -3313,7 +3323,7 @@ def get_task_duration(self, ts: TaskState, default: double = -1) -> double: if duration >= 0: return duration - s: set = self._unknown_durations.get(ts._prefix._name) + s: set = self._unknown_durations.get(ts._prefix._name) # type: ignore if s is None: self._unknown_durations[ts._prefix._name] = s = set() s.add(ts) @@ -3331,7 +3341,7 @@ def valid_workers(self, ts: TaskState) -> set: * host_restrictions * resource_restrictions """ - s: set = None + s: set = None # type: ignore if ts._worker_restrictions: s = {w for w in ts._worker_restrictions if w in self._workers_dv} @@ -3343,7 +3353,7 @@ def valid_workers(self, ts: TaskState) -> set: # XXX need HostState? sl: list = [] for h in hr: - dh: dict = self._host_info.get(h) + dh: dict = self._host_info.get(h) # type: ignore if dh is not None: sl.append(dh["addresses"]) @@ -3356,9 +3366,9 @@ def valid_workers(self, ts: TaskState) -> set: if ts._resource_restrictions: dw: dict = {} for resource, required in ts._resource_restrictions.items(): - dr: dict = self._resources.get(resource) + dr: dict = self._resources.get(resource) # type: ignore if dr is None: - self._resources[resource] = dr = dict() + self._resources[resource] = dr = {} sw: set = set() for w, supplied in dr.items(): @@ -3532,7 +3542,7 @@ class Scheduler(SchedulerState, ServerNode): """ default_port = 8786 - _instances = weakref.WeakSet() + _instances: "ClassVar[weakref.WeakSet[Scheduler]]" = weakref.WeakSet() def __init__( self, @@ -3637,13 +3647,13 @@ def __init__( # Communication state self.loop = loop or IOLoop.current() - self.client_comms = dict() - self.stream_comms = dict() + self.client_comms = {} + self.stream_comms = {} self._worker_coroutines = [] self._ipython_kernel = None # Task state - tasks = dict() + tasks = {} for old_attr, new_attr, wrap in [ ("priority", "priority", None), ("dependencies", "dependencies", _legacy_task_key_set), @@ -3688,12 +3698,12 @@ def __init__( self._last_time = 0 unrunnable = set() - self.datasets = dict() + self.datasets = {} # Prefix-keyed containers # Client state - clients = dict() + clients = {} for old_attr, new_attr, wrap in [ ("wants_what", "wants_what", _legacy_task_key_set) ]: @@ -3703,7 +3713,7 @@ def __init__( setattr(self, old_attr, _StateLegacyMapping(clients, func)) # Worker state - workers = sortedcontainers.SortedDict() + workers = SortedDict() for old_attr, new_attr, wrap in [ ("nthreads", "nthreads", None), ("worker_bytes", "nbytes", None), @@ -3719,9 +3729,9 @@ def __init__( func = compose(wrap, func) setattr(self, old_attr, _StateLegacyMapping(workers, func)) - host_info = dict() - resources = dict() - aliases = dict() + host_info = {} + resources = {} + aliases = {} self._task_state_collections = [unrunnable] @@ -3732,7 +3742,6 @@ def __init__( aliases, ] - self.plugins = {} if not plugins else {_get_plugin_name(p): p for p in plugins} self.transition_log = deque( maxlen=dask.config.get("distributed.scheduler.transition-log-length") ) @@ -3746,8 +3755,8 @@ def __init__( ) self.event_counts = defaultdict(int) self.event_subscriber = defaultdict(set) - self.worker_plugins = dict() - self.nanny_plugins = dict() + self.worker_plugins = {} + self.nanny_plugins = {} worker_handlers = { "task-finished": self.handle_task_finished, @@ -3832,13 +3841,8 @@ def __init__( connection_limit = get_fileno_limit() / 2 super().__init__( + # Arguments to SchedulerState aliases=aliases, - handlers=self.handlers, - stream_handlers=merge(worker_handlers, client_handlers), - io_loop=self.loop, - connection_limit=connection_limit, - deserialize=False, - connection_args=self.connection_args, clients=clients, workers=workers, host_info=host_info, @@ -3846,6 +3850,14 @@ def __init__( tasks=tasks, unrunnable=unrunnable, validate=validate, + plugins=plugins, + # Arguments to ServerNode + handlers=self.handlers, + stream_handlers=merge(worker_handlers, client_handlers), + io_loop=self.loop, + connection_limit=connection_limit, + deserialize=False, + connection_args=self.connection_args, **kwargs, ) @@ -4101,7 +4113,7 @@ def heartbeat_worker( parent: SchedulerState = cast(SchedulerState, self) address = self.coerce_address(address, resolve_address) address = normalize_address(address) - ws: WorkerState = parent._workers_dv.get(address) + ws: WorkerState = parent._workers_dv.get(address) # type: ignore if ws is None: return {"status": "missing"} @@ -4172,7 +4184,7 @@ def heartbeat_worker( ws._memory_unmanaged_old = size if host_info: - dh: dict = parent._host_info.setdefault(host, {}) + dh = parent._host_info.setdefault(host, {}) dh.update(host_info) if now: @@ -4250,7 +4262,7 @@ async def add_worker( dh: dict = parent._host_info.get(host) if dh is None: - parent._host_info[host] = dh = dict() + parent._host_info[host] = dh = {} dh_addresses: set = dh.get("addresses") if dh_addresses is None: @@ -4296,7 +4308,7 @@ async def add_worker( worker_msgs: dict = {} if nbytes: assert isinstance(nbytes, dict) - already_released_keys = list() + already_released_keys = [] for key in nbytes: ts: TaskState = parent._tasks.get(key) if ts is not None and ts.state != "released": @@ -4319,7 +4331,7 @@ async def add_worker( already_released_keys.append(key) if already_released_keys: if address not in worker_msgs: - worker_msgs[address] = list() + worker_msgs[address] = [] worker_msgs[address].append( { "op": "free-keys", @@ -4857,7 +4869,7 @@ async def remove_worker(self, comm=None, address=None, safe=False, close=True): dh: dict = parent._host_info.get(host) if dh is None: - parent._host_info[host] = dh = dict() + parent._host_info[host] = dh = {} dh_addresses: set = dh["addresses"] dh_addresses.remove(address) @@ -5405,7 +5417,7 @@ def handle_long_running(self, key=None, worker=None, compute_duration=None): def handle_worker_status_change(self, status: str, worker: str): parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv.get(worker) + ws: WorkerState = parent._workers_dv.get(worker) # type: ignore if not ws: return self.log_event( @@ -5416,7 +5428,7 @@ def handle_worker_status_change(self, status: str, worker: str): "status": status, }, ) - ws._status = Status.lookup[status] + ws._status = Status.lookup[status] # type: ignore async def handle_worker(self, comm=None, worker=None): """ @@ -5518,7 +5530,7 @@ def remove_plugin( category=FutureWarning, ) if hasattr(plugin, "name"): - name = plugin.name + name = plugin.name # type: ignore else: names = [k for k, v in self.plugins.items() if v is plugin] if not names: @@ -5557,7 +5569,7 @@ async def register_scheduler_plugin(self, comm=None, plugin=None, name=None): if inspect.isawaitable(result): await result - self.add_plugin(plugin=plugin, name=name) + self.add_plugin(plugin, name=name) def worker_send(self, worker, msg): """Send message to worker @@ -5903,14 +5915,14 @@ async def gather_on_worker( return set(who_has) parent: SchedulerState = cast(SchedulerState, self) - ws: WorkerState = parent._workers_dv.get(worker_address) + ws: WorkerState = parent._workers_dv.get(worker_address) # type: ignore if ws is None: logger.warning(f"Worker {worker_address} lost during replication") return set(who_has) elif result["status"] == "OK": keys_failed = set() - keys_ok = who_has.keys() + keys_ok: Set = who_has.keys() elif result["status"] == "partial-fail": keys_failed = set(result["keys"]) keys_ok = who_has.keys() - keys_failed @@ -5921,7 +5933,7 @@ async def gather_on_worker( raise ValueError(f"Unexpected message from {worker_address}: {result}") for key in keys_ok: - ts: TaskState = parent._tasks.get(key) + ts: TaskState = parent._tasks.get(key) # type: ignore if ts is None or ts._state != "memory": logger.warning(f"Key lost during replication: {key}") continue @@ -5930,7 +5942,9 @@ async def gather_on_worker( return keys_failed - async def delete_worker_data(self, worker_address: str, keys: "list[str]") -> None: + async def delete_worker_data( + self, worker_address: str, keys: "Collection[str]" + ) -> None: """Delete data from a worker and update the corresponding worker/task states Parameters @@ -5957,12 +5971,12 @@ async def delete_worker_data(self, worker_address: str, keys: "list[str]") -> No ) return - ws: WorkerState = parent._workers_dv.get(worker_address) + ws: WorkerState = parent._workers_dv.get(worker_address) # type: ignore if ws is None: return for key in keys: - ts: TaskState = parent._tasks.get(key) + ts: TaskState = parent._tasks.get(key) # type: ignore if ts is not None and ws in ts._who_has: assert ts._state == "memory" parent.remove_replica(ts, ws) @@ -6043,14 +6057,15 @@ async def rebalance( All other workers will be ignored. The mean cluster occupancy will be calculated only using the whitelisted workers. """ - parent: SchedulerState = self + parent: SchedulerState = cast(SchedulerState, self) with log_errors(): + wss: "Collection[WorkerState]" if workers is not None: - workers = [parent._workers_dv[w] for w in workers] + wss = [parent._workers_dv[w] for w in workers] else: - workers = parent._workers_dv.values() - if not workers: + wss = parent._workers_dv.values() + if not wss: return {"status": "OK"} if keys is not None: @@ -6066,7 +6081,7 @@ async def rebalance( if missing_data: return {"status": "partial-fail", "keys": missing_data} - msgs = self._rebalance_find_msgs(keys, workers) + msgs = self._rebalance_find_msgs(keys, wss) if not msgs: return {"status": "OK"} @@ -6078,7 +6093,7 @@ async def rebalance( return result def _rebalance_find_msgs( - self: SchedulerState, + self, keys: "Set[Hashable] | None", workers: "Iterable[WorkerState]", ) -> "list[tuple[WorkerState, WorkerState, TaskState]]": @@ -6108,7 +6123,7 @@ def _rebalance_find_msgs( - recipient worker - task to be transferred """ - parent: SchedulerState = self + parent: SchedulerState = cast(SchedulerState, self) ts: TaskState ws: WorkerState @@ -6290,7 +6305,9 @@ async def _rebalance_move_data( rec_ws: WorkerState ts: TaskState - to_recipients = defaultdict(lambda: defaultdict(list)) + to_recipients: "defaultdict[str, defaultdict[str, list[str]]]" = defaultdict( + lambda: defaultdict(list) + ) for snd_ws, rec_ws, ts in msgs: to_recipients[rec_ws.address][ts._key].append(snd_ws.address) failed_keys_by_recipient = dict( @@ -6449,13 +6466,13 @@ async def replicate( def workers_to_close( self, comm=None, - memory_ratio=None, - n=None, - key=None, - minimum=None, - target=None, - attribute="address", - ): + memory_ratio: "int | float | None" = None, + n: "int | None" = None, + key: "Callable[[WorkerState], Hashable] | None" = None, + minimum: "int | None" = None, + target: "int | None" = None, + attribute: str = "address", + ) -> "list[str]": """ Find workers that we can close with low cost @@ -6567,9 +6584,9 @@ def _key(group): limit -= limit_bytes[group] - if (n is not None and n_remain - len(groups[group]) >= target) or ( - memory_ratio is not None and limit >= memory_ratio * total - ): + if ( + n is not None and n_remain - len(groups[group]) >= cast(int, target) + ) or (memory_ratio is not None and limit >= memory_ratio * total): to_close.append(group) n_remain -= len(groups[group]) @@ -6735,8 +6752,9 @@ def add_keys(self, comm=None, worker=None, keys=(), stimulus_id=None): def update_data( self, comm=None, - who_has=None, - nbytes: dict = None, + *, + who_has: dict, + nbytes: dict, client=None, serializers=None, ): @@ -6755,9 +6773,9 @@ def update_data( logger.debug("Update data %s", who_has) for key, workers in who_has.items(): - ts: TaskState = parent._tasks.get(key) + ts: TaskState = parent._tasks.get(key) # type: ignore if ts is None: - ts: TaskState = parent.new_task(key, None, "memory") + ts = parent.new_task(key, None, "memory") ts.state = "memory" ts_nbytes = nbytes.get(key, -1) if ts_nbytes >= 0: @@ -6984,7 +7002,7 @@ def set_metadata(self, comm=None, keys=None, value=None): metadata = parent._task_metadata for key in keys[:-1]: if key not in metadata or not isinstance(metadata[key], (dict, list)): - metadata[key] = dict() + metadata[key] = {} metadata = metadata[key] metadata[keys[-1]] = value except Exception: @@ -7169,7 +7187,7 @@ def add_resources(self, comm=None, worker=None, resources=None): ws._used_resources[resource] = 0 dr: dict = parent._resources.get(resource, None) if dr is None: - parent._resources[resource] = dr = dict() + parent._resources[resource] = dr = {} dr[worker] = quantity return "OK" @@ -7179,7 +7197,7 @@ def remove_resources(self, worker): for resource, quantity in ws._resources.items(): dr: dict = parent._resources.get(resource, None) if dr is None: - parent._resources[resource] = dr = dict() + parent._resources[resource] = dr = {} del dr[worker] def coerce_address(self, addr, resolve=True): @@ -7690,16 +7708,18 @@ def adaptive_target(self, comm=None, target_duration=None): @cfunc @exceptval(check=False) -def _remove_from_processing(state: SchedulerState, ts: TaskState) -> str: +def _remove_from_processing( + state: SchedulerState, ts: TaskState +) -> str: # -> str | None """ Remove *ts* from the set of processing tasks. """ ws: WorkerState = ts._processing_on - ts._processing_on = None + ts._processing_on = None # type: ignore w: str = ws._address if w not in state._workers_dv: # may have been removed - return None + return None # type: ignore duration: double = ws._processing.pop(ts) if not ws._processing: @@ -7767,7 +7787,7 @@ def _add_to_memory( client_msgs[cs._client_key] = [report_msg] ts.state = "memory" - ts._type = typename + ts._type = typename # type: ignore ts._group._types.add(typename) cs = state._clients["fire-and-forget"] @@ -7831,7 +7851,7 @@ def _client_releases_keys( logger.debug("Client %s releases keys: %s", cs._client_key, keys) ts: TaskState for key in keys: - ts = state._tasks.get(key) + ts = state._tasks.get(key) # type: ignore if ts is not None and ts in cs._wants_what: cs._wants_what.remove(ts) ts._who_wants.remove(cs) @@ -7890,7 +7910,7 @@ def _task_to_msg(state: SchedulerState, ts: TaskState, duration: double = -1) -> @cfunc @exceptval(check=False) -def _task_to_report_msg(state: SchedulerState, ts: TaskState) -> dict: +def _task_to_report_msg(state: SchedulerState, ts: TaskState) -> dict: # -> dict | None if ts._state == "forgotten": return {"op": "cancelled-key", "key": ts._key} elif ts._state == "memory": @@ -7904,7 +7924,7 @@ def _task_to_report_msg(state: SchedulerState, ts: TaskState) -> dict: "traceback": failing_ts._traceback, } else: - return None + return None # type: ignore @cfunc @@ -7949,7 +7969,7 @@ def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState): @exceptval(check=False) def decide_worker( ts: TaskState, all_workers, valid_workers: set, objective -) -> WorkerState: +) -> WorkerState: # -> WorkerState | None """ Decide which worker should take task *ts*. @@ -7965,7 +7985,7 @@ def decide_worker( of bytes sent between workers. This is determined by calling the *objective* function. """ - ws: WorkerState = None + ws: WorkerState = None # type: ignore wws: WorkerState dts: TaskState deps: set = ts._dependencies diff --git a/distributed/worker.py b/distributed/worker.py index 031c58a48f5..ee17e115d4e 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2489,7 +2489,7 @@ async def gather_dep( total_nbytes : int Total number of bytes for all the dependencies in to_gather combined """ - cause = None + cause: TaskState | None = None if self.status not in (Status.running, Status.paused): return @@ -2517,6 +2517,8 @@ async def gather_dep( if not to_gather_keys: return + assert cause + # Keep namespace clean since this func is long and has many # dep*, *ts* variables del to_gather, dependency_key, dependency_ts diff --git a/docs/source/develop.rst b/docs/source/develop.rst index 6ea8c750329..b4956d62195 100644 --- a/docs/source/develop.rst +++ b/docs/source/develop.rst @@ -169,9 +169,19 @@ fixture tests test basic interface and resilience. You should avoid ``popen`` style tests unless absolutely necessary, such as if you need to test the command line interface. -Linting -------- -distributed uses several code linters (flake8, black, isort, pyupgrade, mypy), which are -enforced by CI. Developers should run them locally before they submit a PR, through the -single command ``pre-commit run --all-files``. This makes sure that linter versions and -options are aligned for all developers. +Code Formatting +--------------- + +Dask.distributed uses several code linters (flake8, black, isort, pyupgrade, mypy), +which are enforced by CI. Developers should run them locally before they submit a PR, +through the single command ``pre-commit run --all-files``. This makes sure that linter +versions and options are aligned for all developers. + +Optionally, you may wish to setup the `pre-commit hooks `_ to +run automatically when you make a git commit. This can be done by running:: + + pre-commit install + +from the root of the distributed repository. Now the code linters will be run each time +you commit changes. You can skip these checks with ``git commit --no-verify`` or with +the short version ``git commit -n``.