From a4e5e9d375c38007c0317c106f478c3f15ea011e Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Mon, 11 May 2026 09:43:39 +0200 Subject: [PATCH] Add project name interpolation in gateway domains Allow to use the `${{ run.project_name }}` variable in gateway domain names. The variable is interpolated at service submission time, which allows different projects that import and use the same shared gateway to use different domain names. ```yaml type: gateway name: global-gateway backend: aws region: eu-west-1 domain: ${{ run.project_name }}.mycompany.example ``` --- src/dstack/_internal/cli/utils/gateway.py | 18 ++- .../server/services/gateways/__init__.py | 14 +- .../server/services/services/__init__.py | 6 + src/dstack/_internal/utils/common.py | 13 ++ .../_internal/server/routers/test_gateways.py | 52 ++++++++ .../_internal/server/routers/test_runs.py | 122 ++++++++++++++++++ 6 files changed, 221 insertions(+), 4 deletions(-) diff --git a/src/dstack/_internal/cli/utils/gateway.py b/src/dstack/_internal/cli/utils/gateway.py index 0458798014..897e1e7121 100644 --- a/src/dstack/_internal/cli/utils/gateway.py +++ b/src/dstack/_internal/cli/utils/gateway.py @@ -7,7 +7,7 @@ from dstack._internal.core.errors import ResourceNotExistsError from dstack._internal.core.models.common import EntityReference from dstack._internal.core.models.gateways import Gateway -from dstack._internal.utils.common import DateFormatter, pretty_date +from dstack._internal.utils.common import DateFormatter, interpolate_gateway_domain, pretty_date from dstack.api.server._gateways import GatewaysAPIClient @@ -62,7 +62,7 @@ def get_gateways_table( table.add_column("NAME", no_wrap=True) table.add_column("BACKEND") table.add_column("HOSTNAME", no_wrap=True) - table.add_column("DOMAIN") + table.add_column("DOMAIN", no_wrap=True) table.add_column("DEFAULT") table.add_column("STATUS") if verbose or include_created: @@ -78,11 +78,23 @@ def get_gateways_table( gateway.project_name if gateway.project_name is not None else current_project, current_project, ) + domain = gateway.wildcard_domain + if ( + gateway.project_name is not None + and gateway.project_name != current_project + and domain is not None + ): + domain = interpolate_gateway_domain( + domain=domain, + run_project_name=current_project, + # Ignore errors in case future server versions introduce more interpolation variables + exception_type=None, + ) row = { "NAME": name, "BACKEND": f"{gateway.configuration.backend.value} ({gateway.configuration.region})", "HOSTNAME": gateway.hostname, - "DOMAIN": gateway.wildcard_domain, + "DOMAIN": domain, "DEFAULT": "✓" if gateway.default else "", "STATUS": gateway.status, "CREATED": format_date(gateway.created_at), diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index e007b65b49..287117e2ed 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -67,7 +67,11 @@ from dstack._internal.server.services.pipelines import PipelineHinterProtocol from dstack._internal.server.services.plugins import apply_plugin_policies from dstack._internal.server.utils.common import gather_map_async -from dstack._internal.utils.common import get_current_datetime, run_async +from dstack._internal.utils.common import ( + get_current_datetime, + interpolate_gateway_domain, + run_async, +) from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes from dstack._internal.utils.logging import get_logger @@ -816,6 +820,14 @@ def _validate_gateway_configuration(configuration: GatewayConfiguration): if configuration.name is not None: validate_dstack_resource_name(configuration.name) + if configuration.domain is not None: + # validate that domain can be interpolated + interpolate_gateway_domain( + domain=configuration.domain, + run_project_name="example", + exception_type=ServerClientError, + ) + if ( not configuration.public_ip and configuration.backend not in BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 955c7a4865..273054e74f 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -39,6 +39,7 @@ ) from dstack._internal.server.services.logging import fmt from dstack._internal.server.services.services.options import get_service_options +from dstack._internal.utils.common import interpolate_gateway_domain from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -158,6 +159,11 @@ async def _register_service_in_gateway( wildcard_domain = gateway.wildcard_domain.lstrip("*.") if gateway.wildcard_domain else None if wildcard_domain is None: raise ServerClientError("Domain is required for gateway") + wildcard_domain = interpolate_gateway_domain( + domain=wildcard_domain, + run_project_name=run_model.project.name, + exception_type=GatewayError, + ) service_url = f"{service_protocol}://{run_model.run_name}.{wildcard_domain}" if isinstance(run_spec.configuration.model, OpenAIChatModel): model_url = service_url + run_spec.configuration.model.prefix diff --git a/src/dstack/_internal/utils/common.py b/src/dstack/_internal/utils/common.py index f0c30b2a2b..2ee1f337ad 100644 --- a/src/dstack/_internal/utils/common.py +++ b/src/dstack/_internal/utils/common.py @@ -15,6 +15,7 @@ from typing_extensions import ParamSpec from dstack._internal.core.models.common import Duration +from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator class Unset: @@ -336,5 +337,17 @@ def make_proxy_url(server_url: str, proxy_url: str) -> str: return proxy.geturl() +def interpolate_gateway_domain( + domain: str, run_project_name: str, exception_type: Optional[type[Exception]] +) -> str: + interpolator = VariablesInterpolator({"run": {"project_name": run_project_name}}) + try: + return interpolator.interpolate_or_error(domain) + except InterpolatorError as e: + if exception_type is None: + return domain + raise exception_type(f"Cannot interpolate gateway domain name: {e.args[0]}") from e + + def list_enum_values_for_annotation(enum_class: type[enum.Enum]) -> str: return ", ".join(f"`{e.value}`" for e in enum_class) diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index a41fc6de01..5cc6bdd715 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -531,6 +531,58 @@ async def test_create_gateway_missing_backend( ) assert response.status_code == 400 + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_create_gateway_with_valid_domain_interpolation( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + await create_backend(session, project.id, backend_type=BackendType.AWS) + response = await client.post( + f"/api/project/{project.name}/gateways/create", + json={ + "configuration": { + "type": "gateway", + "name": "test", + "backend": "aws", + "region": "us", + "domain": "${{ run.project_name }}.example.com", + }, + }, + headers=get_auth_headers(user.token), + ) + assert response.status_code == 200 + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_create_gateway_with_invalid_domain_interpolation( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + await create_backend(session, project.id, backend_type=BackendType.AWS) + response = await client.post( + f"/api/project/{project.name}/gateways/create", + json={ + "configuration": { + "type": "gateway", + "name": "test", + "backend": "aws", + "region": "us", + "domain": "${{ run.unknown_variable }}.example.com", + }, + }, + headers=get_auth_headers(user.token), + ) + assert response.status_code == 400 + class TestDefaultGateway: @pytest.mark.asyncio diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 9e74439b6c..e13e20853e 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -3616,6 +3616,128 @@ async def test_not_submits_to_default_gateway_if_not_imported( ] } + @pytest.mark.asyncio + async def test_interpolates_project_name_in_imported_gateway_domain( + self, test_db, session: AsyncSession, client: AsyncClient + ) -> None: + exporter_user = await create_user( + session=session, global_role=GlobalRole.USER, name="exporter_user" + ) + exporter_project = await create_project( + session=session, owner=exporter_user, name="exporter-project" + ) + backend = await create_backend(session=session, project_id=exporter_project.id) + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway = await create_gateway( + session=session, + project_id=exporter_project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.RUNNING, + name="exported-gateway", + wildcard_domain="${{ run.project_name }}.example.com", + ) + + importer_user = await create_user( + session=session, global_role=GlobalRole.USER, name="importer_user" + ) + importer_project = await create_project( + session=session, owner=importer_user, name="importer-project" + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.USER, + ) + importer_repo = await create_repo(session=session, project_id=importer_project.id) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[], + exported_gateways=[gateway], + ) + + run_spec = get_service_run_spec( + repo_id=importer_repo.name, + run_name="test-service", + gateway="exporter-project/exported-gateway", + ) + response = await client.post( + f"/api/project/{importer_project.name}/runs/submit", + headers=get_auth_headers(importer_user.token), + json={"run_spec": run_spec}, + ) + assert response.status_code == 200 + assert ( + response.json()["service"]["url"] + == "https://test-service.importer-project.example.com" + ) + + @pytest.mark.asyncio + async def test_returns_error_if_imported_gateway_domain_has_unknown_variable( + self, test_db, session: AsyncSession, client: AsyncClient + ) -> None: + exporter_user = await create_user( + session=session, global_role=GlobalRole.USER, name="exporter_user" + ) + exporter_project = await create_project( + session=session, owner=exporter_user, name="exporter-project" + ) + backend = await create_backend(session=session, project_id=exporter_project.id) + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway = await create_gateway( + session=session, + project_id=exporter_project.id, + backend_id=backend.id, + gateway_compute_id=gateway_compute.id, + status=GatewayStatus.RUNNING, + name="exported-gateway", + wildcard_domain="${{ run.unknown_variable }}.example.com", + ) + + importer_user = await create_user( + session=session, global_role=GlobalRole.USER, name="importer_user" + ) + importer_project = await create_project( + session=session, owner=importer_user, name="importer-project" + ) + await add_project_member( + session=session, + project=importer_project, + user=importer_user, + project_role=ProjectRole.USER, + ) + importer_repo = await create_repo(session=session, project_id=importer_project.id) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[], + exported_gateways=[gateway], + ) + + run_spec = get_service_run_spec( + repo_id=importer_repo.name, + run_name="test-service", + gateway="exporter-project/exported-gateway", + ) + response = await client.post( + f"/api/project/{importer_project.name}/runs/submit", + headers=get_auth_headers(importer_user.token), + json={"run_spec": run_spec}, + ) + assert response.status_code == 400 + assert response.json() == { + "detail": [ + { + "msg": "Cannot interpolate gateway domain name: Failed to interpolate due to missing vars: ['run.unknown_variable']", + "code": "gateway_error", + } + ] + } + @pytest.mark.asyncio async def test_unregister_dangling_service( self,