refactor(infer.elbo): add type hints to elbo module#2028
refactor(infer.elbo): add type hints to elbo module#2028fehiepsi merged 7 commits intopyro-ppl:masterfrom
Conversation
|
This looks great! (I approved, but we need a true core dev to approve it). I think the error FAILED test/test_example_utils.py::test_mnist_data_load - urllib.error.HTTPError: HTTP Error 429: Too Many RequestsIs not related with these changes. |
|
This is awesome! I'll need to look into details and learn from your designs. :) Please ignore the failing tests. I guess it's a github issue and will be resolved by github devs soon. |
|
Thanks @fehiepsi and @juanitorduz for the reviews here. Still getting the hang of these typing concepts myself as you can see! Hopefully it will get easier with time and practice. |
fehiepsi
left a comment
There was a problem hiding this comment.
Looks awesome, thanks @brendancooley!
| from numpyro.util import find_stack_level, identity | ||
|
|
||
| # Type aliases | ||
| Message = dict[str, Any] |
There was a problem hiding this comment.
could you keep Message = MessageT in case users already use Message in their libraries?
There was a problem hiding this comment.
oops yes, forgot this was pre-existing. reverted.
pyproject.toml
Outdated
| ] | ||
|
|
||
| [tool.mypy] | ||
| python_version = 3.12 |
There was a problem hiding this comment.
does this mean that we only allow mypy to work with python 3.12?
There was a problem hiding this comment.
for python 3.9, mypy complains about modern (3.10+) type annotations on our type aliases
LossT: TypeAlias = jax.Array | dict[str, jax.Array]numpyro/infer/elbo.py:44: error: Invalid type alias: expression is not a valid type [valid-type]
numpyro/infer/elbo.py:44: error: Unsupported left operand type for | ("type[Array]") [operator]
Found 2 errors in 1 file (checked 88 source files)
Since python 3.9 is approaching end of life, I went ahead removed this target python_version but excluded the lint workflow from CI for the 3.9 case, which will let us use X | Y in lieu of typing.Union[X, Y] across the codebase without having to add inline ignores for the typechecker.
Alternatively, I could add an inline ignore or adopt use the legacy Union[...] type hint. Whatever you prefer!
Picking up on the work of @juanitorduz, this MR adds type hints to the
numpyro.infer.elbomodule. Along the way, I've made a few housekeeping changes to try and make it easier to maintain and extend type hints going forward.The most important change warranting discussion is that I've introduced a
ParamSpecfor the parameters of a model function.ELBOs are now generics over the signature of the model/guide functions they operate on. I think that this is a nice way to (softly) enforce consistency in the*argsand**kwargspassed toELBO.loss, but there may be some subtleties that I'm missing here.Summary of Changes
_compute_downstream_costscode and associated tests (replaced in 28e38d8)numpyro.primitiveswithnoqanumpyro._typingmodule for cross-module shared type aliasesjax.Arrayto address typing errors and following guidance in JEP 9263numpyro.handlers, with requisite importfrom __future__ import annotationsto support python 3.9.Previous Typing MRs (for Reference)