Skip to content

feat(gh-299): Type hints in distributions modules#2032

Merged
fehiepsi merged 15 commits intopyro-ppl:masterfrom
Qazalbash:type-hint-distribution
Jun 3, 2025
Merged

feat(gh-299): Type hints in distributions modules#2032
fehiepsi merged 15 commits intopyro-ppl:masterfrom
Qazalbash:type-hint-distribution

Conversation

@Qazalbash
Copy link
Collaborator

@Qazalbash Qazalbash commented May 24, 2025

Hi,

I have added type hints using jaxtyping in numpyro.distributions.*.py modules. I have accordingly updated the setup.py too.

All types of protocols have been transferred to numpyro._typing. I have modified the DistributionLike type along with two new types, TransformLike and ConstraintLike.


This PR is related to #299.

Qazalbash added 4 commits May 23, 2025 22:33
…tion

- Added type annotations to various methods and properties in the Distribution class and its subclasses for better type checking and clarity.
- Updated the `enable_validation` and `validation_enabled` functions to use type hints.
- Enhanced the `arg_constraints`, `support`, and other attributes with appropriate types.
- Modified the `__init__` methods of several distribution classes to include type hints for parameters.
- Improved the `log_prob`, `sample`, and other methods to specify return types.
- Refactored the `clamp_probs` function in the util module to include type annotations.
- Updated the `LeftTruncatedDistribution`, `RightTruncatedDistribution`, and `TwoSidedTruncatedDistribution` classes to include type hints for parameters and return types.
- Introduced a new `ConstraintLike` protocol in `numpyro/typing.py` to standardize constraint types across distributions.
- Added type hints for the `TruncatedDistribution`, `TruncatedCauchy`, `TruncatedNormal`, `TruncatedPolyaGamma`, `DoublyTruncatedPowerLaw`, and `LowerTruncatedPowerLaw` classes.
- Improved type safety and clarity in the distribution methods by specifying expected types for inputs and outputs.
@fehiepsi
Copy link
Member

Just curious what jaxtyping offers?

@Qazalbash
Copy link
Collaborator Author

Qazalbash commented May 26, 2025

I prefer Jaxtyping over native types provided by JAX because of the array annotations. Jaxtyping documentation describes it as,

The shape and dtypes of arrays can be annotated in the form dtype[array, shape], such as Float[Array, "batch channels"].

It also has a good systematic way of typing PyTrees, along with variety of annotated types.


I have not utilized array annotations in this PR.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Qazalbash Could we not use jaxtyping? I feel that it is unnecessary to depend on it. Also we might want to allow sample key to be None for some deterministic distributions like Delta or the default Distribution, TransformedDistribution ones.

@Qazalbash
Copy link
Collaborator Author

Sure, we can avoid it!

@Qazalbash
Copy link
Collaborator Author

All test cases are passing except for those failing on the master branch. I have figured out the problem and reported it.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful! Thanks for putting lots of efforts on this, @Qazalbash!

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending the usage of Optional[jax.dtypes.prng_key] at some places.

@fehiepsi fehiepsi merged commit 4c505c1 into pyro-ppl:master Jun 3, 2025
9 checks passed
@Qazalbash Qazalbash deleted the type-hint-distribution branch June 3, 2025 15:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants