Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion flytekit/common/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,16 @@ def get_serializable_node(
elif isinstance(entity.flyte_entity, LaunchPlan):
lp_spec = get_serializable(entity_mapping, settings, entity.flyte_entity)

# Node's inputs should not contain the data which is fixed input
node_input = []
for b in entity.bindings:
if b.var not in entity.flyte_entity.fixed_inputs.literals:
node_input.append(b)

node_model = workflow_model.Node(
id=_dnsify(entity.id),
metadata=entity.metadata,
inputs=entity.bindings,
inputs=node_input,
upstream_node_ids=[n.id for n in upstream_sdk_nodes],
output_aliases=[],
workflow_node=workflow_model.WorkflowNode(launchplan_ref=lp_spec.id),
Expand Down
35 changes: 35 additions & 0 deletions tests/flytekit/unit/common_tests/test_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,38 @@ def t1(a: int) -> (int, str):
)
task_spec = get_serializable(OrderedDict(), ssettings, t2)
assert "pyflyte" not in task_spec.template.container.args


def test_launch_plan_with_fixed_input():
@task
def greet(day_of_week: str, number: int, am: bool) -> str:
greeting = "Have a great " + day_of_week + " "
greeting += "morning" if am else "evening"
return greeting + "!" * number

@workflow
def go_greet(day_of_week: str, number: int, am: bool = False) -> str:
return greet(day_of_week=day_of_week, number=number, am=am)

morning_greeting = LaunchPlan.create(
"morning_greeting",
go_greet,
fixed_inputs={"am": True},
default_inputs={"number": 1},
)

@workflow
def morning_greeter_caller(day_of_week: str) -> str:
greeting = morning_greeting(day_of_week=day_of_week)
return greeting

settings = (
serialization_settings.new_builder()
.with_fast_serialization_settings(FastSerializationSettings(enabled=True))
.build()
)
task_spec = get_serializable(OrderedDict(), settings, morning_greeter_caller)
assert len(task_spec.template.interface.inputs) == 1
assert len(task_spec.template.interface.outputs) == 1
assert len(task_spec.template.nodes) == 1
assert len(task_spec.template.nodes[0].inputs) == 2