Skip to content
Merged
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- `ot.gaussian.bures_wasserstein_distance` can be batched (PR #680)
- Backend implementation of `ot.dist` for (PR #701)
- Updated documentation Quickstart guide and User guide with new API (PR #726)
- Fix jax version for auto-grad (PR #732)

#### Closed issues
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)
Expand Down
2 changes: 1 addition & 1 deletion ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,7 +1509,7 @@ def set_gradients(self, val, inputs, grads):
aux = jnp.sum(ravelled_inputs * ravelled_grads) / 2
aux = aux - jax.lax.stop_gradient(aux)

(val,) = jax.tree_map(lambda z: z + aux, (val,))
(val,) = jax.tree_util.tree_map(lambda z: z + aux, (val,))
return val

def _detach(self, a):
Expand Down
Loading