diff --git a/flytekit/core/node.py b/flytekit/core/node.py index 4f7838d2b6..bf5c97ba60 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -128,6 +128,8 @@ def with_overrides(self, *args, **kwargs): if not isinstance(new_task_config, type(self.flyte_entity._task_config)): raise ValueError("can't change the type of the task config") self.flyte_entity._task_config = new_task_config + if "container_image" in kwargs: + self.flyte_entity._container_image = kwargs["container_image"] return self diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index da708a8571..81621ef3fc 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -452,3 +452,16 @@ def my_wf(a: str) -> str: return t1(a=a).with_overrides(task_config=None) my_wf() + + +def test_override_image(): + @task + def bar(): + print("hello") + + @workflow + def wf() -> str: + bar().with_overrides(container_image="hello/world") + return "hi" + + assert wf.nodes[0].flyte_entity.container_image == "hello/world"