Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/dstack/_internal/cli/utils/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dstack._internal.core.errors import ResourceNotExistsError
from dstack._internal.core.models.common import EntityReference
from dstack._internal.core.models.gateways import Gateway
from dstack._internal.utils.common import DateFormatter, pretty_date
from dstack._internal.utils.common import DateFormatter, interpolate_gateway_domain, pretty_date
from dstack.api.server._gateways import GatewaysAPIClient


Expand Down Expand Up @@ -62,7 +62,7 @@ def get_gateways_table(
table.add_column("NAME", no_wrap=True)
table.add_column("BACKEND")
table.add_column("HOSTNAME", no_wrap=True)
table.add_column("DOMAIN")
table.add_column("DOMAIN", no_wrap=True)
table.add_column("DEFAULT")
table.add_column("STATUS")
if verbose or include_created:
Expand All @@ -78,11 +78,23 @@ def get_gateways_table(
gateway.project_name if gateway.project_name is not None else current_project,
current_project,
)
domain = gateway.wildcard_domain
if (
gateway.project_name is not None
and gateway.project_name != current_project
and domain is not None
):
domain = interpolate_gateway_domain(
domain=domain,
run_project_name=current_project,
# Ignore errors in case future server versions introduce more interpolation variables
exception_type=None,
)
row = {
"NAME": name,
"BACKEND": f"{gateway.configuration.backend.value} ({gateway.configuration.region})",
"HOSTNAME": gateway.hostname,
"DOMAIN": gateway.wildcard_domain,
"DOMAIN": domain,
"DEFAULT": "✓" if gateway.default else "",
"STATUS": gateway.status,
"CREATED": format_date(gateway.created_at),
Expand Down
14 changes: 13 additions & 1 deletion src/dstack/_internal/server/services/gateways/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@
from dstack._internal.server.services.pipelines import PipelineHinterProtocol
from dstack._internal.server.services.plugins import apply_plugin_policies
from dstack._internal.server.utils.common import gather_map_async
from dstack._internal.utils.common import get_current_datetime, run_async
from dstack._internal.utils.common import (
get_current_datetime,
interpolate_gateway_domain,
run_async,
)
from dstack._internal.utils.crypto import generate_rsa_key_pair_bytes
from dstack._internal.utils.logging import get_logger

Expand Down Expand Up @@ -816,6 +820,14 @@ def _validate_gateway_configuration(configuration: GatewayConfiguration):
if configuration.name is not None:
validate_dstack_resource_name(configuration.name)

if configuration.domain is not None:
# validate that domain can be interpolated
interpolate_gateway_domain(
domain=configuration.domain,
run_project_name="example",
exception_type=ServerClientError,
)

if (
not configuration.public_ip
and configuration.backend not in BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT
Expand Down
6 changes: 6 additions & 0 deletions src/dstack/_internal/server/services/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from dstack._internal.server.services.logging import fmt
from dstack._internal.server.services.services.options import get_service_options
from dstack._internal.utils.common import interpolate_gateway_domain
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -158,6 +159,11 @@ async def _register_service_in_gateway(
wildcard_domain = gateway.wildcard_domain.lstrip("*.") if gateway.wildcard_domain else None
if wildcard_domain is None:
raise ServerClientError("Domain is required for gateway")
wildcard_domain = interpolate_gateway_domain(
domain=wildcard_domain,
run_project_name=run_model.project.name,
exception_type=GatewayError,
)
service_url = f"{service_protocol}://{run_model.run_name}.{wildcard_domain}"
if isinstance(run_spec.configuration.model, OpenAIChatModel):
model_url = service_url + run_spec.configuration.model.prefix
Expand Down
13 changes: 13 additions & 0 deletions src/dstack/_internal/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing_extensions import ParamSpec

from dstack._internal.core.models.common import Duration
from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator


class Unset:
Expand Down Expand Up @@ -336,5 +337,17 @@ def make_proxy_url(server_url: str, proxy_url: str) -> str:
return proxy.geturl()


def interpolate_gateway_domain(
domain: str, run_project_name: str, exception_type: Optional[type[Exception]]
) -> str:
interpolator = VariablesInterpolator({"run": {"project_name": run_project_name}})
try:
return interpolator.interpolate_or_error(domain)
except InterpolatorError as e:
if exception_type is None:
return domain
raise exception_type(f"Cannot interpolate gateway domain name: {e.args[0]}") from e


def list_enum_values_for_annotation(enum_class: type[enum.Enum]) -> str:
return ", ".join(f"`{e.value}`" for e in enum_class)
52 changes: 52 additions & 0 deletions src/tests/_internal/server/routers/test_gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,58 @@ async def test_create_gateway_missing_backend(
)
assert response.status_code == 400

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_create_gateway_with_valid_domain_interpolation(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session, global_role=GlobalRole.USER)
project = await create_project(session)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.ADMIN
)
await create_backend(session, project.id, backend_type=BackendType.AWS)
response = await client.post(
f"/api/project/{project.name}/gateways/create",
json={
"configuration": {
"type": "gateway",
"name": "test",
"backend": "aws",
"region": "us",
"domain": "${{ run.project_name }}.example.com",
},
},
headers=get_auth_headers(user.token),
)
assert response.status_code == 200

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_create_gateway_with_invalid_domain_interpolation(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session, global_role=GlobalRole.USER)
project = await create_project(session)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.ADMIN
)
await create_backend(session, project.id, backend_type=BackendType.AWS)
response = await client.post(
f"/api/project/{project.name}/gateways/create",
json={
"configuration": {
"type": "gateway",
"name": "test",
"backend": "aws",
"region": "us",
"domain": "${{ run.unknown_variable }}.example.com",
},
},
headers=get_auth_headers(user.token),
)
assert response.status_code == 400


class TestDefaultGateway:
@pytest.mark.asyncio
Expand Down
122 changes: 122 additions & 0 deletions src/tests/_internal/server/routers/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3616,6 +3616,128 @@ async def test_not_submits_to_default_gateway_if_not_imported(
]
}

@pytest.mark.asyncio
async def test_interpolates_project_name_in_imported_gateway_domain(
self, test_db, session: AsyncSession, client: AsyncClient
) -> None:
exporter_user = await create_user(
session=session, global_role=GlobalRole.USER, name="exporter_user"
)
exporter_project = await create_project(
session=session, owner=exporter_user, name="exporter-project"
)
backend = await create_backend(session=session, project_id=exporter_project.id)
gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway = await create_gateway(
session=session,
project_id=exporter_project.id,
backend_id=backend.id,
gateway_compute_id=gateway_compute.id,
status=GatewayStatus.RUNNING,
name="exported-gateway",
wildcard_domain="${{ run.project_name }}.example.com",
)

importer_user = await create_user(
session=session, global_role=GlobalRole.USER, name="importer_user"
)
importer_project = await create_project(
session=session, owner=importer_user, name="importer-project"
)
await add_project_member(
session=session,
project=importer_project,
user=importer_user,
project_role=ProjectRole.USER,
)
importer_repo = await create_repo(session=session, project_id=importer_project.id)
await create_export(
session=session,
exporter_project=exporter_project,
importer_projects=[importer_project],
exported_fleets=[],
exported_gateways=[gateway],
)

run_spec = get_service_run_spec(
repo_id=importer_repo.name,
run_name="test-service",
gateway="exporter-project/exported-gateway",
)
response = await client.post(
f"/api/project/{importer_project.name}/runs/submit",
headers=get_auth_headers(importer_user.token),
json={"run_spec": run_spec},
)
assert response.status_code == 200
assert (
response.json()["service"]["url"]
== "https://test-service.importer-project.example.com"
)

@pytest.mark.asyncio
async def test_returns_error_if_imported_gateway_domain_has_unknown_variable(
self, test_db, session: AsyncSession, client: AsyncClient
) -> None:
exporter_user = await create_user(
session=session, global_role=GlobalRole.USER, name="exporter_user"
)
exporter_project = await create_project(
session=session, owner=exporter_user, name="exporter-project"
)
backend = await create_backend(session=session, project_id=exporter_project.id)
gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway = await create_gateway(
session=session,
project_id=exporter_project.id,
backend_id=backend.id,
gateway_compute_id=gateway_compute.id,
status=GatewayStatus.RUNNING,
name="exported-gateway",
wildcard_domain="${{ run.unknown_variable }}.example.com",
)

importer_user = await create_user(
session=session, global_role=GlobalRole.USER, name="importer_user"
)
importer_project = await create_project(
session=session, owner=importer_user, name="importer-project"
)
await add_project_member(
session=session,
project=importer_project,
user=importer_user,
project_role=ProjectRole.USER,
)
importer_repo = await create_repo(session=session, project_id=importer_project.id)
await create_export(
session=session,
exporter_project=exporter_project,
importer_projects=[importer_project],
exported_fleets=[],
exported_gateways=[gateway],
)

run_spec = get_service_run_spec(
repo_id=importer_repo.name,
run_name="test-service",
gateway="exporter-project/exported-gateway",
)
response = await client.post(
f"/api/project/{importer_project.name}/runs/submit",
headers=get_auth_headers(importer_user.token),
json={"run_spec": run_spec},
)
assert response.status_code == 400
assert response.json() == {
"detail": [
{
"msg": "Cannot interpolate gateway domain name: Failed to interpolate due to missing vars: ['run.unknown_variable']",
"code": "gateway_error",
}
]
}

@pytest.mark.asyncio
async def test_unregister_dangling_service(
self,
Expand Down
Loading