feat(gh-299): Type hints in distributions modules#2032
feat(gh-299): Type hints in distributions modules#2032fehiepsi merged 15 commits intopyro-ppl:masterfrom
Conversation
…tion - Added type annotations to various methods and properties in the Distribution class and its subclasses for better type checking and clarity. - Updated the `enable_validation` and `validation_enabled` functions to use type hints. - Enhanced the `arg_constraints`, `support`, and other attributes with appropriate types. - Modified the `__init__` methods of several distribution classes to include type hints for parameters. - Improved the `log_prob`, `sample`, and other methods to specify return types. - Refactored the `clamp_probs` function in the util module to include type annotations.
- Updated the `LeftTruncatedDistribution`, `RightTruncatedDistribution`, and `TwoSidedTruncatedDistribution` classes to include type hints for parameters and return types. - Introduced a new `ConstraintLike` protocol in `numpyro/typing.py` to standardize constraint types across distributions. - Added type hints for the `TruncatedDistribution`, `TruncatedCauchy`, `TruncatedNormal`, `TruncatedPolyaGamma`, `DoublyTruncatedPowerLaw`, and `LowerTruncatedPowerLaw` classes. - Improved type safety and clarity in the distribution methods by specifying expected types for inputs and outputs.
|
Just curious what jaxtyping offers? |
|
I prefer Jaxtyping over native types provided by JAX because of the array annotations. Jaxtyping documentation describes it as,
It also has a good systematic way of typing PyTrees, along with variety of annotated types. I have not utilized array annotations in this PR. |
fehiepsi
left a comment
There was a problem hiding this comment.
@Qazalbash Could we not use jaxtyping? I feel that it is unnecessary to depend on it. Also we might want to allow sample key to be None for some deterministic distributions like Delta or the default Distribution, TransformedDistribution ones.
|
Sure, we can avoid it! |
|
All test cases are passing except for those failing on the master branch. I have figured out the problem and reported it. |
fehiepsi
left a comment
There was a problem hiding this comment.
Beautiful! Thanks for putting lots of efforts on this, @Qazalbash!
fehiepsi
left a comment
There was a problem hiding this comment.
LGTM pending the usage of Optional[jax.dtypes.prng_key] at some places.
…ition" This reverts commit 51e2569.
Hi,
I have added type hints using
jaxtypinginnumpyro.distributions.*.pymodules. I have accordingly updated thesetup.pytoo.All types of protocols have been transferred to
numpyro._typing. I have modified theDistributionLiketype along with two new types,TransformLikeandConstraintLike.This PR is related to #299.