diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/model_service.py b/airflow/providers/google/cloud/hooks/vertex_ai/model_service.py index c6e0c69678c84..02fc7378380f2 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/model_service.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/model_service.py @@ -208,6 +208,7 @@ def upload_model( project_id: str, region: str, model: Model | dict, + parent_model: str | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), @@ -218,6 +219,7 @@ def upload_model( :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param model: Required. The Model to create. + :param parent_model: The name of the parent model to create a new version under. :param retry: Designation of what errors, if any, should be retried. :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. @@ -225,11 +227,16 @@ def upload_model( client = self.get_model_service_client(region) parent = client.common_location_path(project_id, region) + request = { + "parent": parent, + "model": model, + } + + if parent_model: + request["parent_model"] = parent_model + result = client.upload_model( - request={ - "parent": parent, - "model": model, - }, + request=request, retry=retry, timeout=timeout, metadata=metadata, diff --git a/airflow/providers/google/cloud/operators/vertex_ai/model_service.py b/airflow/providers/google/cloud/operators/vertex_ai/model_service.py index c44da5ffd58c5..04d45f253f568 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/model_service.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/model_service.py @@ -362,6 +362,7 @@ class UploadModelOperator(GoogleCloudBaseOperator): :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. :param model: Required. The Model to create. + :param parent_model: The name of the parent model to create a new version under. :param retry: Designation of what errors, if any, should be retried. :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. @@ -385,6 +386,7 @@ def __init__( project_id: str, region: str, model: Model | dict, + parent_model: str | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), @@ -396,6 +398,7 @@ def __init__( self.project_id = project_id self.region = region self.model = model + self.parent_model = parent_model self.retry = retry self.timeout = timeout self.metadata = metadata @@ -412,6 +415,7 @@ def execute(self, context: Context): project_id=self.project_id, region=self.region, model=self.model, + parent_model=self.parent_model, retry=self.retry, timeout=self.timeout, metadata=self.metadata, diff --git a/tests/providers/google/cloud/hooks/vertex_ai/test_model_service.py b/tests/providers/google/cloud/hooks/vertex_ai/test_model_service.py index e4431b0cd2f51..1505e8276d077 100644 --- a/tests/providers/google/cloud/hooks/vertex_ai/test_model_service.py +++ b/tests/providers/google/cloud/hooks/vertex_ai/test_model_service.py @@ -36,6 +36,7 @@ TEST_REGION: str = "test-region" TEST_PROJECT_ID: str = "test-project-id" TEST_MODEL = None +TEST_PARENT_MODEL = "test-parent-model" TEST_MODEL_NAME: str = "test_model_name" TEST_OUTPUT_CONFIG: dict = {} @@ -136,6 +137,24 @@ def test_upload_model(self, mock_client) -> None: ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + @mock.patch(MODEL_SERVICE_STRING.format("ModelServiceHook.get_model_service_client")) + def test_upload_model_with_parent_model(self, mock_client) -> None: + self.hook.upload_model( + project_id=TEST_PROJECT_ID, region=TEST_REGION, model=TEST_MODEL, parent_model=TEST_PARENT_MODEL + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.upload_model.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + model=TEST_MODEL, + parent_model=TEST_PARENT_MODEL, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + @mock.patch(MODEL_SERVICE_STRING.format("ModelServiceHook.get_model_service_client")) def test_list_model_versions(self, mock_client) -> None: self.hook.list_model_versions( @@ -322,6 +341,24 @@ def test_upload_model(self, mock_client) -> None: ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + @mock.patch(MODEL_SERVICE_STRING.format("ModelServiceHook.get_model_service_client")) + def test_upload_model_with_parent_model(self, mock_client) -> None: + self.hook.upload_model( + project_id=TEST_PROJECT_ID, region=TEST_REGION, model=TEST_MODEL, parent_model=TEST_PARENT_MODEL + ) + mock_client.assert_called_once_with(TEST_REGION) + mock_client.return_value.upload_model.assert_called_once_with( + request=dict( + parent=mock_client.return_value.common_location_path.return_value, + model=TEST_MODEL, + parent_model=TEST_PARENT_MODEL, + ), + metadata=(), + retry=DEFAULT, + timeout=None, + ) + mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + @mock.patch(MODEL_SERVICE_STRING.format("ModelServiceHook.get_model_service_client")) def test_list_model_versions(self, mock_client) -> None: self.hook.list_model_versions( diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py index 4957d71ed5739..ef0ca5360873f 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -2849,6 +2849,34 @@ def test_execute(self, mock_hook, to_dict_mock): region=GCP_LOCATION, project_id=GCP_PROJECT, model=TEST_MODEL_OBJ, + parent_model=None, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch(VERTEX_AI_PATH.format("model_service.model_service.UploadModelResponse.to_dict")) + @mock.patch(VERTEX_AI_PATH.format("model_service.ModelServiceHook")) + def test_execute_with_parent_model(self, mock_hook, to_dict_mock): + op = UploadModelOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + model=TEST_MODEL_OBJ, + parent_model=TEST_PARENT_MODEL, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + op.execute(context={"ti": mock.MagicMock()}) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.upload_model.assert_called_once_with( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + model=TEST_MODEL_OBJ, + parent_model=TEST_PARENT_MODEL, retry=RETRY, timeout=TIMEOUT, metadata=METADATA, diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_model_service.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_model_service.py index 4560d9f54f9f4..b80eaabdd58db 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_model_service.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_model_service.py @@ -231,6 +231,13 @@ model=MODEL_OBJ, ) # [END how_to_cloud_vertex_ai_upload_model_operator] + upload_model_with_parent_model = UploadModelOperator( + task_id="upload_model_with_parent_model", + region=REGION, + project_id=PROJECT_ID, + model=MODEL_OBJ, + parent_model=MODEL_DISPLAY_NAME, + ) # [START how_to_cloud_vertex_ai_export_model_operator] export_model = ExportModelOperator( @@ -251,6 +258,13 @@ trigger_rule=TriggerRule.ALL_DONE, ) # [END how_to_cloud_vertex_ai_delete_model_operator] + delete_model_with_parent_model = DeleteModelOperator( + task_id="delete_model_with_parent_model", + project_id=PROJECT_ID, + region=REGION, + model_id=upload_model_with_parent_model.output["model_id"], + trigger_rule=TriggerRule.ALL_DONE, + ) # [START how_to_cloud_vertex_ai_list_models_operator] list_models = ListModelsOperator( @@ -317,8 +331,10 @@ >> set_default_version >> add_version_alias >> upload_model + >> upload_model_with_parent_model >> export_model >> delete_model + >> delete_model_with_parent_model >> list_models # TEST TEARDOWN >> delete_version_alias