Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/circle.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def plot_head_ugrid(head, cbc, workspace):
outer_maximum=500,
under_relaxation=None,
inner_dvclose=1.0e-4,
inner_rclose=0.001,
rcloserecord=flopy4.mf6.Ims.Rcloserecord(inner_rclose=0.001),
inner_maximum=100,
linear_acceleration="cg",
reordering_method=None,
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/frenchman-flat.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ def plot_head_ugrid(head, cbc, grid, workspace):
under_relaxation_gamma=0.000000,
under_relaxation_momentum=0.000000,
inner_dvclose=0.00001,
inner_rclose=0.1,
rcloserecord=flopy4.mf6.Ims.Rcloserecord(inner_rclose=0.1),
inner_maximum=100,
linear_acceleration="bicgstab",
number_orthogonalizations=0,
Expand Down
5 changes: 3 additions & 2 deletions docs/examples/twri.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def plot_head(head, workspace):
icelltype=icelltype,
k=k,
k33=k33,
cvoptions=flopy4.mf6.gwf.Npf.CvOptions(dewatered=True),
variablecv=True,
dewatered=True,
perched=True,
save_flows=True,
dims=dims,
Expand Down Expand Up @@ -199,7 +200,7 @@ def plot_head(head, workspace):
outer_maximum=500,
under_relaxation=None,
inner_dvclose=1.0e-4,
inner_rclose=0.001,
rcloserecord=flopy4.mf6.Ims.Rcloserecord(inner_rclose=0.001),
inner_maximum=100,
linear_acceleration="cg",
scaling_method=None,
Expand Down
21 changes: 19 additions & 2 deletions flopy4/mf6/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,35 @@
from tomli_w import dump as dump_toml

# Import submodules to make them accessible via flopy4.mf6.*
from flopy4.mf6 import gwf, simulation, solution, utils
from flopy4.mf6 import gwe, gwf, gwt, prt, simulation, solution, utils
from flopy4.mf6.codec import dump as dump_mf6
from flopy4.mf6.codec import load as load_mf6
from flopy4.mf6.component import Component
from flopy4.mf6.converter import structure, unstructure
from flopy4.mf6.ems import Ems
from flopy4.mf6.exchange import GwfGwe, GwfGwt
from flopy4.mf6.ims import Ims
from flopy4.mf6.netcdf import NetCDFModel
from flopy4.mf6.simulation import Simulation
from flopy4.mf6.tdis import Tdis
from flopy4.uio import DEFAULT_REGISTRY

__all__ = ["gwf", "simulation", "solution", "utils", "Ims", "NetCDFModel", "Tdis", "Simulation"]
__all__ = [
"gwf",
"gwt",
"gwe",
"prt",
"simulation",
"solution",
"utils",
"Ems",
"GwfGwe",
"GwfGwt",
"Ims",
"NetCDFModel",
"Tdis",
"Simulation",
]


class WriteError(Exception):
Expand Down
2 changes: 1 addition & 1 deletion flopy4/mf6/binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def from_component(cls, component: Component) -> "Binding":
def _get_binding_type(component: Component) -> str:
cls_name = component.__class__.__name__
if isinstance(component, Exchange):
return f"{'-'.join([cls_name[:2], cls_name[3:]]).upper()}6"
return f"{cls_name[:3].upper()}6-{cls_name[3:].upper()}6"
elif isinstance(component, Solution):
return f"{component.slntype}6"
else:
Expand Down
9 changes: 8 additions & 1 deletion flopy4/mf6/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,14 @@ def _update_maxbound_if_needed(self):

@classmethod
def __attrs_init_subclass__(cls):
COMPONENTS[cls.__name__.lower()] = cls
key = cls.__name__.lower()
COMPONENTS[key] = cls
# Also register a model-qualified key (e.g. "gwf-ic") for classes in a
# model subpackage (flopy4.mf6.<model>.<pkg>), giving deterministic
# per-model lookup when multiple models share a class name like "ic".
parts = cls.__module__.split(".")
if len(parts) >= 4 and parts[0] == "flopy4" and parts[1] == "mf6":
COMPONENTS[f"{parts[2]}-{key}"] = cls

def __getitem__(self, key):
# We use `children` from `xattree` to implement MutableMapping.
Expand Down
91 changes: 75 additions & 16 deletions flopy4/mf6/converter/egress/unstructure.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from collections.abc import Iterable, Mapping
from datetime import datetime
from pathlib import Path
from typing import Any
from typing import Any, cast

import attrs
import numpy as np
import xarray as xr
import xattree
Expand All @@ -18,7 +19,9 @@

def _path_to_tuple(name: str, value: Path, inout: FileInOut) -> tuple[str, ...]:
t = [name.upper()]
if name.endswith("_file"):
if name.endswith("_filerecord"):
t[0] = name.replace("_filerecord", "").upper()
elif name.endswith("_file"):
t[0] = name.replace("_file", "").upper()
if inout:
t.append(inout.upper())
Expand Down Expand Up @@ -184,6 +187,31 @@ def _unstructure_block_param(
if child_spec.metadata["block"] == block_name: # type: ignore
return

# xattree.asdict converts inner-class attrs instances (like Rcloserecord) to
# plain dicts before this function sees them. Check the raw component attribute
# first so the attrs match case can fire on the real object.
raw_value = getattr(value, field_name, None)
cls = type(raw_value)
if attrs.has(cls) and "_keyword" in vars(cls):
# Generated inner class record: convert to keyword-prefixed tuple.
# _keyword is "" for records with no leading trigger token (e.g. rcloserecord).
keyword: str = vars(cls)["_keyword"]
tokens: list[Any] = [keyword.upper()] if keyword else []
for a in attrs.fields(cast(type[attrs.AttrsInstance], cls)):
val = getattr(raw_value, a.name)
if val is None:
continue
if a.metadata.get("tagged", False):
tokens.append(a.name.upper())
tokens.append(val)
elif isinstance(val, bool):
if val:
tokens.append(a.name.upper())
else:
tokens.append(val)
blocks[block_name][field_name] = tuple(tokens)
return

# filter out empty values and false keywords, and convert:
# - paths to records
# - datetimes to ISO format
Expand All @@ -209,7 +237,8 @@ def _unstructure_block_param(
case t if (
field_name == "auxiliary" and hasattr(field_value, "values") and field_value is not None
):
blocks[block_name][field_name] = tuple(field_value.values.tolist())
# MF6 OPTIONS format requires the keyword "AUXILIARY" before the variable names.
blocks[block_name][field_name] = ("AUXILIARY",) + tuple(field_value.values.tolist())
case xr.DataArray():
has_spatial_dims = any(
dim in field_value.dims for dim in ["nlay", "nrow", "ncol", "ncpl", "nodes"]
Expand All @@ -220,15 +249,24 @@ def _unstructure_block_param(
structured_grid_dims=value.data.dims, # type: ignore
)
if "nper" in field_value.dims and block_name == "period":
if not np.issubdtype(field_value.dtype, np.number):
dat = _hack_period_non_numeric(field_name, field_value)
for n, v in dat.items():
period_data[n] = v
else:
is_tabular = (
np.issubdtype(field_value.dtype, np.number)
or np.issubdtype(field_value.dtype, np.str_)
or (
field_value.dtype == object
and field_value.size > 0
and isinstance(field_value.values.flat[0], str)
)
)
if is_tabular:
period_data[field_name] = {
kper: field_value.isel(nper=kper) # type: ignore
for kper in range(field_value.sizes["nper"])
}
else:
dat = _hack_period_non_numeric(field_name, field_value)
for n, v in dat.items():
period_data[n] = v
else:
blocks[block_name][field_name] = field_value
case _:
Expand Down Expand Up @@ -286,6 +324,18 @@ def _unstructure_array_component(value: Component) -> dict[str, Any]:
return {name: block for name, block in blocks.items() if name != "period"}


# Block names that MF6 rejects if present but empty.
# These blocks should only be written when they contain data.
_SKIP_IF_EMPTY = frozenset({"dimensions", "tracktimes"})

# Block names whose fields are list columns (one array per column, same dim)
# rather than independent grid arrays. Only these blocks are auto-combined
# into an xr.Dataset for row-per-record output. griddata-style blocks must
# NOT be in this set — their fields are written individually with
# INTERNAL/CONSTANT/NETCDF format.
_LIST_BLOCK_NAMES = frozenset({"packagedata", "packages", "perioddata", "table"})


def _unstructure_component(value: Component) -> dict[str, Any]:
blockspec = blocks_dict(type(value))
blocks: dict[str, dict[str, Any]] = {}
Expand Down Expand Up @@ -343,11 +393,6 @@ def _unstructure_component(value: Component) -> dict[str, Any]:
block, coords=block[arr_name].coords
)

# combine "perioddata" block arrays (tdis, ats) into datasets
# so they render as lists. temp hack TODO do this generically
if perioddata := blocks.get("perioddata", None):
blocks["perioddata"] = {"perioddata": xr.Dataset(perioddata)}

if vertices := blocks.get("vertices", None):
# TODO comes twice once with "vertices" key and once with dataarrays
if "vertices" in vertices:
Expand All @@ -356,8 +401,18 @@ def _unstructure_component(value: Component) -> dict[str, Any]:
vertices["iv"] = vertices["iv"] + 1
blocks["vertices"] = {"vertices": xr.Dataset(vertices)}

# TODO: this fixes out of order blocks (e.g. model namefile) from
# blocks.update() child binding call above
# Combine list-style blocks into a Dataset for row-per-record output.
# Only applies to known list block names — griddata-style blocks (each
# field a separate array) must NOT be combined.
if block_name in _LIST_BLOCK_NAMES:
current_block = blocks.get(block_name, {})
if current_block:
das = [v for v in current_block.values() if isinstance(v, xr.DataArray)]
if das and len(das) == len(current_block):
first_dim = das[0].dims[0] if das[0].dims else None
if first_dim and all(da.dims and da.dims[0] == first_dim for da in das):
blocks[block_name] = {block_name: xr.Dataset(current_block)}

blocks = dict(sorted(blocks.items(), key=block_sort_key))

# total temporary hack! manually set solutiongroup 1.
Expand All @@ -367,4 +422,8 @@ def _unstructure_component(value: Component) -> dict[str, Any]:
blocks["solutiongroup 1"] = sg
del blocks["solutiongroup"]

return {name: block for name, block in blocks.items() if name != "period"}
return {
name: block
for name, block in blocks.items()
if name != "period" and (block or name not in _SKIP_IF_EMPTY)
}
12 changes: 12 additions & 0 deletions flopy4/mf6/ems.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# autogenerated file, do not modify
# ruff: noqa: E501
from typing import ClassVar

from xattree import xattree

from flopy4.mf6.solution import Solution


@xattree(kw_only=True)
class Ems(Solution):
slntype: ClassVar[str] = "ems"
19 changes: 15 additions & 4 deletions flopy4/mf6/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,20 @@

@xattree
class Exchange(Package, ABC):
# mypy doesn't understand that kw_only=True on the
# Component means we can have required fields here
exgtype: type = field() # type: ignore
exgfile: Path = field() # type: ignore
exgtype: Optional[type] = field(default=None) # type: ignore
exgfile: Optional[Path] = field(default=None) # type: ignore
exgmnamea: Optional[str] = field(default=None)
exgmnameb: Optional[str] = field(default=None)

def default_filename(self) -> str:
return f"{self.name}.exg" # type: ignore


@xattree
class GwfGwt(Exchange):
"""GWF-GWT flow-transport exchange (declares coupling in mfsim.nam)."""


@xattree
class GwfGwe(Exchange):
"""GWF-GWE flow-energy exchange (declares coupling in mfsim.nam)."""
9 changes: 9 additions & 0 deletions flopy4/mf6/exg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from flopy4.mf6.exg.gwfgwe import Gwfgwe
from flopy4.mf6.exg.gwfgwt import Gwfgwt
from flopy4.mf6.exg.gwfprt import Gwfprt

__all__ = [
"Gwfgwe",
"Gwfgwt",
"Gwfprt",
]
10 changes: 10 additions & 0 deletions flopy4/mf6/exg/gwfgwe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# autogenerated file, do not modify
# ruff: noqa: E501
from xattree import xattree

from flopy4.mf6.package import Package


@xattree(kw_only=True)
class Gwfgwe(Package):
pass
10 changes: 10 additions & 0 deletions flopy4/mf6/exg/gwfgwt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# autogenerated file, do not modify
# ruff: noqa: E501
from xattree import xattree

from flopy4.mf6.package import Package


@xattree(kw_only=True)
class Gwfgwt(Package):
pass
10 changes: 10 additions & 0 deletions flopy4/mf6/exg/gwfprt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# autogenerated file, do not modify
# ruff: noqa: E501
from xattree import xattree

from flopy4.mf6.package import Package


@xattree(kw_only=True)
class Gwfprt(Package):
pass
46 changes: 46 additions & 0 deletions flopy4/mf6/gwe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Optional

from xattree import xattree

from flopy4.mf6.gwe.adv import Adv
from flopy4.mf6.gwe.cnd import Cnd
from flopy4.mf6.gwe.ctp import Ctp
from flopy4.mf6.gwe.dis import Dis
from flopy4.mf6.gwe.esl import Esl
from flopy4.mf6.gwe.est import Est
from flopy4.mf6.gwe.ic import Ic
from flopy4.mf6.gwe.mve import Mve
from flopy4.mf6.gwe.ssm import Ssm
from flopy4.mf6.model import Model
from flopy4.mf6.spec import field

__all__ = [
"Gwe",
"Dis",
"Adv",
"Cnd",
"Ctp",
"Esl",
"Est",
"Ic",
"Mve",
"Ssm",
]


@xattree
class Gwe(Model):
list_: Optional[str] = field(block="options", default=None)
print_input: bool = field(block="options", default=False)
print_flows: bool = field(block="options", default=False)
save_flows: bool = field(block="options", default=False)
dependent_variable_scaling: bool = field(block="options", default=False)
dis: Dis | None = field(block="packages", default=None)
ic: Ic | None = field(block="packages", default=None)
adv: Adv | None = field(block="packages", default=None)
cnd: Cnd | None = field(block="packages", default=None)
est: Est | None = field(block="packages", default=None)
ctp: list[Ctp] = field(block="packages")
esl: list[Esl] = field(block="packages")
ssm: Ssm | None = field(block="packages", default=None)
mve: Mve | None = field(block="packages", default=None)
Loading
Loading