From e0bba1f9bdde57dbb56ebb81346146376b076b03 Mon Sep 17 00:00:00 2001 From: Fabio Graetz Date: Mon, 19 Jun 2023 09:00:26 +0000 Subject: [PATCH] Warn when training locally with nnodes > 1 Signed-off-by: Fabio Graetz --- .../flytekitplugins/kfpytorch/task.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 3e5178d532..86f70dad4b 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -17,6 +17,7 @@ from flytekit.configuration import SerializationSettings from flytekit.core.resources import convert_resources_to_resource_model from flytekit.extend import IgnoreOutputs, TaskPlugins +from flytekit.loggers import logger TORCH_IMPORT_ERROR_MESSAGE = "PyTorch is not installed. Please install `flytekitplugins-kfpytorch['elastic']`." @@ -274,6 +275,15 @@ def _execute(self, **kwargs) -> Any: else: nproc = self.task_config.nproc_per_node + dist_env_vars_set = os.environ.get("PET_NNODES") is not None + if not dist_env_vars_set and self.min_nodes > 1: + logger.warning( + ( + f"`nnodes` is set to {self.task_config.nnodes} in elastic task but execution appears " + "to not run in a `PyTorchJob`. Rendezvous might timeout." + ) + ) + config = LaunchConfig( run_id=flytekit.current_context().execution_id.name, min_nodes=self.min_nodes,