*bug fix* in def predict: outputs = np.tanh(outputs) # not `inputs` your current _predict_ function is just a linear projection without a non-linearity.
Thank you!!! I thought I misunderstood something and just wrote a comment asking for clarification then saw yours. Also, thanks for elaborating on what that line does.
Notes/ Important Points :- 1.) Just replacing numpy with will make your code run on GPU/TPU . 2.) Just passing the function through jit as :- g=(f) will fuse together components of f and optimize it. 3.) Given the function f, to obtain f' ; use grad_f = (f) . grad_f(x) now equals f'(x) . 4.) Jax doesn't use finite difference methods etc. for computing gradients. It uses . Note that it is always possible to find analytic gradients, but not analytic integrals. 5.) takes in a function and returns the batched version of the function , which considers the first dimension as the batch dimension. 6.) How jit works ? JAX converts code to lax primitives. Then passes a tracer value to make the intermediate representation(IR) [ 21: 00 ] This IR is then used for compilation by XLA.
Can you please explain 4th point in the above? If JAX doesn't use finite difference methods etc. for computing gradients, then which methods or techniques are used for computing gradients ?
Hrrmm I’m having trouble understanding why I would want to ditch the well established frameworks that TF or PyTorch already provide for the added acceleration you get using JAX. Is the use case for using JAX over other frameworks for prototyping? Is using JAX acceleration with the frameworks of TF or PT mutually exclusive? Now I have two separate code bases, one implemented using JAX and another using TF or PT?
I don't think Jax is competing with the well established TF or Pytorch ecosystem. And it seems like you have unified TF and pytorch in your reply, but clearly one has to decide between one or the other - there is little to no interoperability between the two libraries ( unless one mix and match the dataloader and model handling). Jax operates with a completely different philosophy, and it is that you have a drop in replacement for numpy for your math operations, along with grad, which enables you to compute grandints, vmap and pmap, which allows you to completely vectorize your operations, even across multiple compute devices. It is this very simple API, and a "functional" framework ( Jax functions are ment to be without side effects and allows you to do function composition without much fear) that I love about jax. It is tailored more towards differentiable programming paradigm, and is very suitable for physics' informed networks etc. If you want to use vmap and grad, there is an implementation in torch named functorch that lets you use these features without using Jax.
In my opinion this was a very shallow overview of JAX. The speaker didn't compare this tool with industry standard machine learning frameworks like - PyTorch and TensorFlow.
Thank you! I thought I was the only one that didn’t get the full story. I’m still having a tough time understanding when one would want to use JAX outside of its obvious advantages you get in acceleration and parallelism.
This is exactly what the scientific computing community needs.
it saves time and reduces a lot errors
*bug fix* in def predict:
outputs = np.tanh(outputs) # not `inputs`
your current _predict_ function is just a linear projection without a non-linearity.
Thank you!!! I thought I misunderstood something and just wrote a comment asking for clarification then saw yours. Also, thanks for elaborating on what that line does.
Notes/ Important Points :-
1.) Just replacing numpy with will make your code run on GPU/TPU .
2.) Just passing the function through jit as :- g=(f) will fuse together components of f and optimize it.
3.) Given the function f, to obtain f' ; use grad_f = (f) . grad_f(x) now equals f'(x) .
4.) Jax doesn't use finite difference methods etc. for computing gradients. It uses . Note that it is always possible to find analytic gradients, but not analytic integrals.
5.) takes in a function and returns the batched version of the function , which considers the first dimension as the batch dimension.
6.) How jit works ? JAX converts code to lax primitives. Then passes a tracer value to make the intermediate representation(IR) [ 21: 00 ] This IR is then used for compilation by XLA.
Can you please explain 4th point in the above?
If JAX doesn't use finite difference methods etc. for computing gradients, then which methods or techniques are used for computing gradients ?
@@Karthikk-ln9ge It uses algorithmic differentation (also called automatic differentation).
@@Karthikk-ln9ge It does analytic/symbolic differentiation, instead of numerical.
Hrrmm I’m having trouble understanding why I would want to ditch the well established frameworks that TF or PyTorch already provide for the added acceleration you get using JAX. Is the use case for using JAX over other frameworks for prototyping? Is using JAX acceleration with the frameworks of TF or PT mutually exclusive? Now I have two separate code bases, one implemented using JAX and another using TF or PT?
I don't think Jax is competing with the well established TF or Pytorch ecosystem. And it seems like you have unified TF and pytorch in your reply, but clearly one has to decide between one or the other - there is little to no interoperability between the two libraries ( unless one mix and match the dataloader and model handling). Jax operates with a completely different philosophy, and it is that you have a drop in replacement for numpy for your math operations, along with grad, which enables you to compute grandints, vmap and pmap, which allows you to completely vectorize your operations, even across multiple compute devices. It is this very simple API, and a "functional" framework ( Jax functions are ment to be without side effects and allows you to do function composition without much fear) that I love about jax. It is tailored more towards differentiable programming paradigm, and is very suitable for physics' informed networks etc. If you want to use vmap and grad, there is an implementation in torch named functorch that lets you use these features without using Jax.
Very promising. Hope to see a unified deep learning library (convs, etc) for jax in the future
Hamza Flax is a higher level library for Jax that includes Conv layers: github.com/google/flax
wow, amazingly cool and really well explained.
🙏 🙏 😊
very interesting, thank you !
How can some dislike this video? What's wrong folks?
Jake does not like quiet keyboards?
Hey, lots of people love loud mechanical keyboards! :D
In my opinion this was a very shallow overview of JAX. The speaker didn't compare this tool with industry standard machine learning frameworks like - PyTorch and TensorFlow.
Thank you! I thought I was the only one that didn’t get the full story. I’m still having a tough time understanding when one would want to use JAX outside of its obvious advantages you get in acceleration and parallelism.