Support vector lengthscales for RBF and Matern kernels#1819
Support vector lengthscales for RBF and Matern kernels#1819fehiepsi merged 6 commits intopyro-ppl:masterfrom
Conversation
juanitorduz
left a comment
There was a problem hiding this comment.
Thanks! At first glance, this looks good! But let's see that all tests pass :) @brendancooley do you want to take a look :) ?
| if isinstance(length, float | int): | ||
| exact = _exact_matern(length) | ||
| elif length.ndim == 1: | ||
| exact = _exact_matern(length) |
There was a problem hiding this comment.
why do you need two conditions for exact = _exact_matern(length)?
There was a problem hiding this comment.
(this is just a question, no need to use the OR statement in view of readability)
There was a problem hiding this comment.
I was copying the code from test_kernel_approx_squared_exponential. Yes, I believe this is in service to readability. If we use an 'or' statement here, we should probably use it there as well.
There was a problem hiding this comment.
either way fine by me
|
|
||
| import jax | ||
|
|
||
| ARRAY_TYPE = Union[jax.Array, np.ndarray] # jax.Array covers tracers |
There was a problem hiding this comment.
I wonder if this type can be used in other NumPyro modules.
There was a problem hiding this comment.
Using ArrayImpl is only an issue when a model gets compiled and the arrays turn into tracers. isinstance(X, jax.Array) will work for both jax arrays and tracers.
There was a problem hiding this comment.
Some details here https://jax.readthedocs.io/en/latest/jax_array_migration.html
I believe this is best practice for typing jax arrays (as of last year), but I am not sure
There was a problem hiding this comment.
I am definitively not an expert in type hints, so following the recommendation from the docs seems the safest path :)
There was a problem hiding this comment.
Should I mark this thread as resolved, as this seems to be in line with the recommendation?
brendancooley
left a comment
There was a problem hiding this comment.
Looks good and works on my offline example. We can see what CI says. Thanks @samanklesaria! Great that we have a group working on this and checking one another's work.
I think next step (maybe another PR) would be to support vector-valued alpha for the batch dims. I think it may "just work" with the code as is but it would be helpful to update type annotations, and maybe generalize @juanitorduz's code in contrib.hsgp.approximation to sample matrix/tensor-valued betas in linear_approximation when batch dimensions are detected.
|
@samanklesaria can you please rebase or sync with the master branch? Today we merged some fixes on the CI, see #1817 |
|
Yes! The alpha vectorization + docs we can do in another PR :) |
Done! |
|
Should sampling matrix/tensor-valued betas in |
Personally, I think the scope of this PR is fine. I like working on small iterations so any additional feature can be done in a different PR (at least from my side) |
|
@samanklesaria it seems there are other more places where you need to change the syntax (similarly as last commit) 😄 |
|
The current version might have fixed things, but I should probably install a copy of python3.9 locally to test it for sure. |
|
There is one test failing because the last change if isinstance(length, Union[float, int])Union is a type hint so this won't work. I suggest you make the change as I suggested above 😄 |
Co-authored-by: Juan Orduz <juanitorduz@gmail.com>
|
Thanks for contributing @samanklesaria ! Looking hsgp is having a great momentum. |
Resolves #1805
This allows vector lengthscales in the HSGP approximations to RBF and Matern kernels. Extends brendancooley/numpyro@ef4a24b