Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions mkdocs/docs/concepts/backends.md
Comment thread
jvstme marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -403,16 +403,30 @@ There are two ways to configure Azure: using a client secret or using the defaul
- type: azure
creds:
type: default
regions: [westeurope]
vpc_ids:
westeurope: myNetworkResourceGroup/myNetworkName
regions: [westeurope]
vpc_ids:
westeurope: myNetworkResourceGroup/myNetworkName
```

Alternatively, specify `subnet_ids` to target specific subnets:

```yaml
projects:
- name: main
backends:
- type: azure
creds:
type: default
regions: [westeurope]
subnet_ids:
westeurope: myNetworkResourceGroup/myNetworkName/mySubnetName
```


??? info "Private subnets"
By default, `dstack` provisions instances with public IPs and permits inbound SSH traffic.
If you want `dstack` to use private subnets and provision instances without public IPs,
specify custom networks using `vpc_ids` and set `public_ips` to `false`.
specify custom networks using `vpc_ids` or `subnet_ids`, and set `public_ips` to `false`.

```yaml
projects:
Expand Down
36 changes: 36 additions & 0 deletions src/dstack/_internal/core/backends/azure/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def create_instance(
network_client=self._network_client,
resource_group=self.config.resource_group,
vpc_ids=self.config.vpc_ids,
subnet_ids=self.config.subnet_ids,
location=location,
allocate_public_ip=allocate_public_ip,
)
Expand Down Expand Up @@ -252,6 +253,7 @@ def create_gateway(
network_client=self._network_client,
resource_group=self.config.resource_group,
vpc_ids=self.config.vpc_ids,
subnet_ids=self.config.subnet_ids,
location=configuration.region,
allocate_public_ip=True,
)
Expand Down Expand Up @@ -326,9 +328,38 @@ def get_resource_group_network_subnet_or_error(
network_client: network_mgmt.NetworkManagementClient,
resource_group: Optional[str],
vpc_ids: Optional[Dict[str, str]],
subnet_ids: Optional[Dict[str, str]],
location: str,
allocate_public_ip: bool,
) -> Tuple[str, str, str]:
if subnet_ids is not None and location in subnet_ids:
subnet_id = subnet_ids[location]
try:
net_resource_group, network_name, subnet_name = _parse_config_subnet_id(subnet_id)
except Exception:
raise ComputeError(
"Subnet specified in incorrect format."
" Supported format for `subnet_ids` values: 'networkResourceGroupName/networkName/subnetName'"
)
try:
subnet = network_client.subnets.get(net_resource_group, network_name, subnet_name)
except ResourceNotFoundError:
raise ComputeError(
f"Subnet {subnet_name} not found in network {network_name}"
f" in resource group {net_resource_group}"
)
if not allocate_public_ip and not azure_resources.is_eligible_private_subnet(
network_client=network_client,
resource_group=net_resource_group,
network_name=network_name,
subnet=subnet,
):
raise ComputeError(
f"Subnet {subnet_name} in network {network_name} does not have outbound internet connectivity."
" Ensure a NAT Gateway is attached or VNet peering is configured."
)
return net_resource_group, network_name, subnet_name

if vpc_ids is not None:
vpc_id = vpc_ids.get(location)
if vpc_id is None:
Expand Down Expand Up @@ -388,6 +419,11 @@ def _parse_config_vpc_id(vpc_id: str) -> Tuple[str, str]:
return resource_group, network_name


def _parse_config_subnet_id(subnet_id: str) -> Tuple[str, str, str]:
resource_group, network_name, subnet_name = subnet_id.split("/")
return resource_group, network_name, subnet_name


class VMImageVariant(enum.Enum):
GRID = enum.auto()
CUDA = enum.auto()
Expand Down
36 changes: 26 additions & 10 deletions src/dstack/_internal/core/backends/azure/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def create_backend(
subscription_id=config.subscription_id,
resource_group=config.resource_group,
locations=config.regions,
create_default_network=config.vpc_ids is None,
create_default_network=config.vpc_ids is None and config.subnet_ids is None,
)
return BackendRecord(
config=AzureStoredConfig(
Expand Down Expand Up @@ -226,23 +226,38 @@ def _check_config_vpc(
if config.subscription_id is None:
return None
allocate_public_ip = config.public_ips if config.public_ips is not None else True
if config.public_ips is False and config.vpc_ids is None:
raise ServerClientError(msg="`vpc_ids` must be specified if `public_ips: false`.")
if config.public_ips is False and config.vpc_ids is None and config.subnet_ids is None:
raise ServerClientError(
msg="`vpc_ids` or `subnet_ids` must be specified if `public_ips: false`."
)
if config.vpc_ids is not None and config.subnet_ids is not None:
overlap = sorted(set(config.vpc_ids.keys()) & set(config.subnet_ids.keys()))
if overlap:
raise ServerClientError(
f"Regions {overlap} are configured in both `vpc_ids` and `subnet_ids`."
" Each region must be specified in only one of them."
)
locations = config.regions
if locations is None:
locations = DEFAULT_LOCATIONS
if config.vpc_ids is not None:
vpc_ids_locations = list(config.vpc_ids.keys())
not_configured_locations = [loc for loc in locations if loc not in vpc_ids_locations]
if config.vpc_ids is not None or config.subnet_ids is not None:
configured_locations = set()
if config.vpc_ids is not None:
configured_locations |= set(config.vpc_ids.keys())
if config.subnet_ids is not None:
configured_locations |= set(config.subnet_ids.keys())
not_configured_locations = [
loc for loc in locations if loc not in configured_locations
]
if len(not_configured_locations) > 0:
if config.regions is None:
raise ServerClientError(
f"`vpc_ids` not configured for regions {not_configured_locations}. "
"Configure `vpc_ids` for all regions or specify `regions`."
f"Networking not configured for regions {not_configured_locations}. "
"Configure either `vpc_ids` or `subnet_ids` for all regions or specify `regions`."
)
raise ServerClientError(
f"`vpc_ids` not configured for regions {not_configured_locations}. "
"Configure `vpc_ids` for all regions specified in `regions`."
f"Networking not configured for regions {not_configured_locations}. "
"Configure either `vpc_ids` or `subnet_ids` for all regions specified in `regions`."
)
network_client = network_mgmt.NetworkManagementClient(
credential=credential,
Expand All @@ -256,6 +271,7 @@ def _check_config_vpc(
network_client=network_client,
resource_group=None,
vpc_ids=config.vpc_ids,
subnet_ids=config.subnet_ids,
location=location,
allocate_public_ip=allocate_public_ip,
)
Expand Down
14 changes: 12 additions & 2 deletions src/dstack/_internal/core/backends/azure/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,23 @@ class AzureBackendConfig(CoreModel):
)
),
] = None
subnet_ids: Annotated[
Optional[Dict[str, str]],
Field(
description=(
"The mapping from configured Azure locations to subnet IDs."
" A subnet ID must have a format `networkResourceGroup/networkName/subnetName`."
" Cannot be configured for the same region as `vpc_ids`"
)
),
] = None
public_ips: Annotated[
Optional[bool],
Field(
description=(
"A flag to enable/disable public IP assigning on instances."
" `public_ips: false` requires `vpc_ids` that specifies custom networks with outbound internet connectivity"
" provided by NAT Gateway or other mechanism."
" `public_ips: false` requires `vpc_ids` or `subnet_ids` that specifies custom networks"
" with outbound internet connectivity provided by NAT Gateway or other mechanism."
" Defaults to `true`"
)
),
Expand Down
4 changes: 2 additions & 2 deletions src/dstack/_internal/core/backends/azure/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_network_subnets(
)
for subnet in subnets:
if private:
if _is_eligible_private_subnet(
if is_eligible_private_subnet(
network_client=network_client,
resource_group=resource_group,
network_name=network_name,
Expand Down Expand Up @@ -54,7 +54,7 @@ def _is_eligible_public_subnet(
return True


def _is_eligible_private_subnet(
def is_eligible_private_subnet(
network_client: network_mgmt.NetworkManagementClient,
resource_group: str,
network_name: str,
Expand Down
70 changes: 70 additions & 0 deletions src/tests/_internal/core/backends/azure/test_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dstack._internal.core.errors import (
BackendAuthError,
BackendInvalidCredentialsError,
ServerClientError,
)


Expand Down Expand Up @@ -59,3 +60,72 @@ def test_validate_config_invalid_creds(self):
["creds", "client_id"],
["creds", "client_secret"],
]


class TestCheckConfigVpc:
def _make_config(self, **kwargs):
return AzureBackendConfigWithCreds(
creds=AzureClientCreds(tenant_id="t", client_id="c", client_secret="s"),
tenant_id="ten1",
subscription_id="sub1",
**kwargs,
)

def _check(self, config):
with (
patch("azure.mgmt.network.NetworkManagementClient"),
patch(
"dstack._internal.core.backends.azure.compute.get_resource_group_network_subnet_or_error"
),
):
AzureConfigurator()._check_config_vpc(config, Mock())

def test_public_ips_false_requires_network_config(self):
config = self._make_config(regions=["westeurope"], public_ips=False)
with pytest.raises(ServerClientError, match="`vpc_ids` or `subnet_ids` must be specified"):
AzureConfigurator()._check_config_vpc(config, Mock())

def test_public_ips_false_with_vpc_ids_ok(self):
config = self._make_config(
regions=["westeurope"], public_ips=False, vpc_ids={"westeurope": "rg/net"}
)
self._check(config)

def test_public_ips_false_with_subnet_ids_ok(self):
config = self._make_config(
regions=["westeurope"], public_ips=False, subnet_ids={"westeurope": "rg/net/subnet"}
)
self._check(config)

def test_overlap_raises(self):
config = self._make_config(
regions=["westeurope", "eastus"],
vpc_ids={"westeurope": "rg/net", "eastus": "rg/net2"},
subnet_ids={"westeurope": "rg/net/subnet"},
)
with pytest.raises(ServerClientError, match="westeurope"):
AzureConfigurator()._check_config_vpc(config, Mock())

def test_uncovered_region_raises_with_vpc_ids(self):
config = self._make_config(
regions=["westeurope", "eastus"],
vpc_ids={"westeurope": "rg/net"},
)
with pytest.raises(ServerClientError, match="eastus"):
AzureConfigurator()._check_config_vpc(config, Mock())

def test_uncovered_region_raises_with_subnet_ids(self):
config = self._make_config(
regions=["westeurope", "eastus"],
subnet_ids={"westeurope": "rg/net/subnet"},
)
with pytest.raises(ServerClientError, match="eastus"):
AzureConfigurator()._check_config_vpc(config, Mock())

def test_mixed_vpc_and_subnet_ids_covers_all_regions(self):
config = self._make_config(
regions=["westeurope", "eastus"],
vpc_ids={"westeurope": "rg/net"},
subnet_ids={"eastus": "rg/net/subnet"},
)
self._check(config)
Loading