diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index 2ba236b065b29..b862ed1469840 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -104,6 +104,7 @@ def get_task_instance( return task_instance_schema.dump(task_instance) +@mark_fastapi_migration_done @security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_mapped_task_instance( diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 4295211546f7e..b7e7f62693719 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -1491,6 +1491,69 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}: + get: + tags: + - Task Instance + summary: Get Mapped Task Instance + description: Get task instance. + operationId: get_mapped_task_instance + parameters: + - name: dag_id + in: path + required: true + schema: + type: string + title: Dag Id + - name: dag_run_id + in: path + required: true + schema: + type: string + title: Dag Run Id + - name: task_id + in: path + required: true + schema: + type: string + title: Task Id + - name: map_index + in: path + required: true + schema: + type: integer + title: Map Index + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/TaskInstanceResponse' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /public/variables/{variable_key}: delete: tags: diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow/api_fastapi/core_api/routes/public/task_instances.py index 7e0c6fa8941a3..c9458e843afee 100644 --- a/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -44,7 +44,6 @@ async def get_task_instance( .join(TI.dag_run) .options(joinedload(TI.rendered_task_instance_fields)) ) - task_instance = session.scalar(query) if task_instance is None: @@ -56,3 +55,31 @@ async def get_task_instance( raise HTTPException(404, "Task instance is mapped, add the map_index value to the URL") return TaskInstanceResponse.model_validate(task_instance, from_attributes=True) + + +@task_instances_router.get( + "/{task_id}/{map_index}", responses=create_openapi_http_exception_doc([401, 403, 404]) +) +async def get_mapped_task_instance( + dag_id: str, + dag_run_id: str, + task_id: str, + map_index: int, + session: Annotated[Session, Depends(get_session)], +) -> TaskInstanceResponse: + """Get task instance.""" + query = ( + select(TI) + .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id, TI.map_index == map_index) + .join(TI.dag_run) + .options(joinedload(TI.rendered_task_instance_fields)) + ) + task_instance = session.scalar(query) + + if task_instance is None: + raise HTTPException( + 404, + f"The Mapped Task Instance with dag_id: `{dag_id}`, run_id: `{dag_run_id}`, task_id: `{task_id}`, and map_index: `{map_index}` was not found", + ) + + return TaskInstanceResponse.model_validate(task_instance, from_attributes=True) diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 875e2b87f39c0..07edb67a99304 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -423,6 +423,32 @@ export const UseTaskInstanceServiceGetTaskInstanceKeyFn = ( useTaskInstanceServiceGetTaskInstanceKey, ...(queryKey ?? [{ dagId, dagRunId, taskId }]), ]; +export type TaskInstanceServiceGetMappedTaskInstanceDefaultResponse = Awaited< + ReturnType +>; +export type TaskInstanceServiceGetMappedTaskInstanceQueryResult< + TData = TaskInstanceServiceGetMappedTaskInstanceDefaultResponse, + TError = unknown, +> = UseQueryResult; +export const useTaskInstanceServiceGetMappedTaskInstanceKey = + "TaskInstanceServiceGetMappedTaskInstance"; +export const UseTaskInstanceServiceGetMappedTaskInstanceKeyFn = ( + { + dagId, + dagRunId, + mapIndex, + taskId, + }: { + dagId: string; + dagRunId: string; + mapIndex: number; + taskId: string; + }, + queryKey?: Array, +) => [ + useTaskInstanceServiceGetMappedTaskInstanceKey, + ...(queryKey ?? [{ dagId, dagRunId, mapIndex, taskId }]), +]; export type VariableServiceGetVariableDefaultResponse = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow/ui/openapi-gen/queries/prefetch.ts index 795c8770b85ce..db61369e19ff7 100644 --- a/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow/ui/openapi-gen/queries/prefetch.ts @@ -530,6 +530,46 @@ export const prefetchUseTaskInstanceServiceGetTaskInstance = ( queryFn: () => TaskInstanceService.getTaskInstance({ dagId, dagRunId, taskId }), }); +/** + * Get Mapped Task Instance + * Get task instance. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @param data.mapIndex + * @returns TaskInstanceResponse Successful Response + * @throws ApiError + */ +export const prefetchUseTaskInstanceServiceGetMappedTaskInstance = ( + queryClient: QueryClient, + { + dagId, + dagRunId, + mapIndex, + taskId, + }: { + dagId: string; + dagRunId: string; + mapIndex: number; + taskId: string; + }, +) => + queryClient.prefetchQuery({ + queryKey: Common.UseTaskInstanceServiceGetMappedTaskInstanceKeyFn({ + dagId, + dagRunId, + mapIndex, + taskId, + }), + queryFn: () => + TaskInstanceService.getMappedTaskInstance({ + dagId, + dagRunId, + mapIndex, + taskId, + }), + }); /** * Get Variable * Get a variable entry. diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index afb3fddaefe57..7820656799e5b 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -680,6 +680,50 @@ export const useTaskInstanceServiceGetTaskInstance = < TaskInstanceService.getTaskInstance({ dagId, dagRunId, taskId }) as TData, ...options, }); +/** + * Get Mapped Task Instance + * Get task instance. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @param data.mapIndex + * @returns TaskInstanceResponse Successful Response + * @throws ApiError + */ +export const useTaskInstanceServiceGetMappedTaskInstance = < + TData = Common.TaskInstanceServiceGetMappedTaskInstanceDefaultResponse, + TError = unknown, + TQueryKey extends Array = unknown[], +>( + { + dagId, + dagRunId, + mapIndex, + taskId, + }: { + dagId: string; + dagRunId: string; + mapIndex: number; + taskId: string; + }, + queryKey?: TQueryKey, + options?: Omit, "queryKey" | "queryFn">, +) => + useQuery({ + queryKey: Common.UseTaskInstanceServiceGetMappedTaskInstanceKeyFn( + { dagId, dagRunId, mapIndex, taskId }, + queryKey, + ), + queryFn: () => + TaskInstanceService.getMappedTaskInstance({ + dagId, + dagRunId, + mapIndex, + taskId, + }) as TData, + ...options, + }); /** * Get Variable * Get a variable entry. diff --git a/airflow/ui/openapi-gen/queries/suspense.ts b/airflow/ui/openapi-gen/queries/suspense.ts index ab8dfbabcc0e6..2cb0841d71f28 100644 --- a/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow/ui/openapi-gen/queries/suspense.ts @@ -668,6 +668,50 @@ export const useTaskInstanceServiceGetTaskInstanceSuspense = < TaskInstanceService.getTaskInstance({ dagId, dagRunId, taskId }) as TData, ...options, }); +/** + * Get Mapped Task Instance + * Get task instance. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @param data.mapIndex + * @returns TaskInstanceResponse Successful Response + * @throws ApiError + */ +export const useTaskInstanceServiceGetMappedTaskInstanceSuspense = < + TData = Common.TaskInstanceServiceGetMappedTaskInstanceDefaultResponse, + TError = unknown, + TQueryKey extends Array = unknown[], +>( + { + dagId, + dagRunId, + mapIndex, + taskId, + }: { + dagId: string; + dagRunId: string; + mapIndex: number; + taskId: string; + }, + queryKey?: TQueryKey, + options?: Omit, "queryKey" | "queryFn">, +) => + useSuspenseQuery({ + queryKey: Common.UseTaskInstanceServiceGetMappedTaskInstanceKeyFn( + { dagId, dagRunId, mapIndex, taskId }, + queryKey, + ), + queryFn: () => + TaskInstanceService.getMappedTaskInstance({ + dagId, + dagRunId, + mapIndex, + taskId, + }) as TData, + ...options, + }); /** * Get Variable * Get a variable entry. diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index fd38b2ec31cb9..486e04b056f8a 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -56,6 +56,8 @@ import type { GetProvidersResponse, GetTaskInstanceData, GetTaskInstanceResponse, + GetMappedTaskInstanceData, + GetMappedTaskInstanceResponse, DeleteVariableData, DeleteVariableResponse, GetVariableData, @@ -874,6 +876,38 @@ export class TaskInstanceService { }, }); } + + /** + * Get Mapped Task Instance + * Get task instance. + * @param data The data for the request. + * @param data.dagId + * @param data.dagRunId + * @param data.taskId + * @param data.mapIndex + * @returns TaskInstanceResponse Successful Response + * @throws ApiError + */ + public static getMappedTaskInstance( + data: GetMappedTaskInstanceData, + ): CancelablePromise { + return __request(OpenAPI, { + method: "GET", + url: "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}", + path: { + dag_id: data.dagId, + dag_run_id: data.dagRunId, + task_id: data.taskId, + map_index: data.mapIndex, + }, + errors: { + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 422: "Validation Error", + }, + }); + } } export class VariableService { diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 064163f4178fe..0580694ba78f0 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -827,6 +827,15 @@ export type GetTaskInstanceData = { export type GetTaskInstanceResponse = TaskInstanceResponse; +export type GetMappedTaskInstanceData = { + dagId: string; + dagRunId: string; + mapIndex: number; + taskId: string; +}; + +export type GetMappedTaskInstanceResponse = TaskInstanceResponse; + export type DeleteVariableData = { variableKey: string; }; @@ -1528,6 +1537,33 @@ export type $OpenApiTs = { }; }; }; + "/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}": { + get: { + req: GetMappedTaskInstanceData; + res: { + /** + * Successful Response + */ + 200: TaskInstanceResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; "/public/variables/{variable_key}": { delete: { req: DeleteVariableData; diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index 85b4639d6c0bd..fa9cc0b161d0a 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -394,3 +394,69 @@ def test_raises_404_for_mapped_task_instance_with_one_index(self, test_client, s ) assert response.status_code == 404 assert response.json() == {"detail": "Task instance is mapped, add the map_index value to the URL"} + + +class TestGetMappedTaskInstance(TestTaskInstanceEndpoint): + def test_should_respond_200_mapped_task_instance_with_rtif(self, test_client, session): + """Verify we don't duplicate rows through join to RTIF""" + tis = self.create_task_instances(session) + old_ti = tis[0] + for idx in (1, 2): + ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id, map_index=idx) + ti.rendered_task_instance_fields = RTIF(ti, render_templates=False) + for attr in ["duration", "end_date", "pid", "start_date", "state", "queue", "note"]: + setattr(ti, attr, getattr(old_ti, attr)) + session.add(ti) + session.commit() + + # in each loop, we should get the right mapped TI back + for map_index in (1, 2): + response = test_client.get( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances" + f"/print_the_context/{map_index}", + ) + assert response.status_code == 200 + + assert response.json() == { + "dag_id": "example_python_operator", + "duration": 10000.0, + "end_date": "2020-01-03T00:00:00Z", + "logical_date": "2020-01-01T00:00:00Z", + "executor": None, + "executor_config": "{}", + "hostname": "", + "map_index": map_index, + "max_tries": 0, + "note": "placeholder-note", + "operator": "PythonOperator", + "pid": 100, + "pool": "default_pool", + "pool_slots": 1, + "priority_weight": 9, + "queue": "default_queue", + "queued_when": None, + "start_date": "2020-01-02T00:00:00Z", + "state": "running", + "task_id": "print_the_context", + "task_display_name": "print_the_context", + "try_number": 0, + "unixname": getuser(), + "dag_run_id": "TEST_DAG_RUN_ID", + "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "rendered_map_index": None, + "trigger": None, + "triggerer_job": None, + } + + def test_should_respond_404_wrong_map_index(self, test_client, session): + self.create_task_instances(session) + + response = test_client.get( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances" + "/print_the_context/10", + ) + assert response.status_code == 404 + + assert response.json() == { + "detail": "The Mapped Task Instance with dag_id: `example_python_operator`, run_id: `TEST_DAG_RUN_ID`, task_id: `print_the_context`, and map_index: `10` was not found" + }