When I started learning JAX, I personally think it stands for JIT (J), Autograd (A), XLA (X) which is essentially an abbreviation for a bunch of abbreviations. Given that those features are the 'highlights' of JAX, its very possible. If that's the case, pretty cool naming from DeepMind. Anyways, there aren't many comprehensive resources for JAX right now, so I'm really looking forward to this series! Cheers Aleksa.
Wow great point I totally overlooked it. 😅 That looks like a better hypothesis than the one I had. If you google "Jax meaning" it seems it's a legit name and means something like "God has been gracious; has shown favor". 😂
@@TheAIEpiphany Probably an alternative in case the name 'Jack' is too boring lmao. Had a similar experience, first time I googled "jax tutorial" it was a guide for a game character haha.
50:40 - I would argue here, that it's not necessary to pass all the parameters into the function, as long as it's not changing any of the params, it's ok to use external globals(), like for some reference tables etc. This definition (though academically thorough), make practical application a bit more cumbersome. I believe that the better way to think "2." is sufficient to make this work. No need to pass long list of params. Just make sure not to update/change anything external inside the function, and whatever is not passed in is static. Alternatively, you can have "get jit_function" every time you anticipate that your globals might've changed. So, you will be effectively re-creating your jit function with new globals(). In some cases that feels much preferable to passing everything in. For instance, you can use all sorts of globals inside it, then just re-create it just before your training loop.
Thanks for the great tutorial with pointing out the strong and weak points of the JAX framework, with caveats and salient features. What makes me somehow confused -the behavior, that overshooting the index array clips the index to maximum or does nothing. In C/C++ if one does this usually if the displacement is small - some memory part outside the given data is modified, and for strongly index mistake one would receive SEGMENTATION FAULT. Clipping the index makes the program safer, but in addition to counterintuitive behavior is adds some small additional cost for fetching the index.
Thank you for the tutorial! By the way, according to their paper (Compiling machine learning programs via high-level tracing), JAX stands for just after execution 😃
Can someone explain to me why at 20:50 jnp.sum() is required and why it returns [0, 2, 4]? I would assume it would return 0 + 2 + 4 = 6 like its described in the comment and using sum(), but it doesn't it just returns the original vector size with all the elements squared.
I made a mistake. It's going to return a vector (df/dx1, df/dx2, df/dx3) and not the sum. f = x1^2 + x2^2 + x3^2 and grad takes derivatives independently for x1, x2 and x3 since they are all bundled into the first argument of the function f. Hope that makes sense. You can always consult the docs and experiment yourself.
Good question. Took me a moment as well, but, the function gives the sum, whereas grad(function) gives the three gradients, one per parameter, since the output of grad is used for SGD parameter updates w1, w2, w3 = w1 - lr*df/x1, w2 - lr*df/x2, w3 - lr*df/x3.
Thanks alot. Keep up the good work. Am I wrong or the derivative at 20:15 should be (x1*2, x2*2, x3*2). I mean you take the gradient with respect to a vector so you should take the derivative with respect of each variable separately.
Great work. Thank you, Aleksa. I learned a lot. Coming from R, I like the functional approach here. Would be interested to hear about your current opinion about jax, after knowing it better.
At around 1:10:34, you have used static_argnums=(0,) for jit. Wouldn't this extremely slow down the program as it will have to retrace for all new values of x? Code to reproduce: def switch(x): print("traced") if x > 10.: return 1. else: return 0. jit_switch = jit(switch, static_argnums=(0,)) x = 5 jit_switch(x) x = 16 jit_switch(x) ''' Output: traced DeviceArray(0., dtype=float32, weak_type=True) traced DeviceArray(1., dtype=float32, weak_type=True) '''
Function Pureness rule #2 means one can not use closure variables (wrapping function variables)? That’s good to know since jax states that it is functional, but does not include closure use - due to jit caching only regarding explicitly passed function parameters. Closure variables are hacky, but they are valid python code. Just not in JAX.
Hi Aleksa, First of thank you very much for sharing great content. I learn a lot from you. Could you please explain some up side of JAX over other frameworks?? I really need motivation to get started with JAX. Thanking you. Cheers :)
Why would this video imply that you shouldn't give Julia a shot? I may make a video on Julia in the future. I personally wanted to learn JAX since I'll be using it in DeepMind.
I also open-sourced an accompanying repo here: github.com/gordicaleksa/get-started-with-JAX I recommend opening the notebook in parallel while watching the video so that you can play and tweak the code as well. Just open the notebook, click the Colab button on top of the file, and voila! You'll avoid having to set up the Python env and everything will just work! (you can potentially choose a GPU as an accelerator).
This channel is going on the GOAT level status
Hahaha thank you! Step by step
When I started learning JAX, I personally think it stands for JIT (J), Autograd (A), XLA (X) which is essentially an abbreviation for a bunch of abbreviations. Given that those features are the 'highlights' of JAX, its very possible. If that's the case, pretty cool naming from DeepMind. Anyways, there aren't many comprehensive resources for JAX right now, so I'm really looking forward to this series! Cheers Aleksa.
Wow great point I totally overlooked it. 😅 That looks like a better hypothesis than the one I had. If you google "Jax meaning" it seems it's a legit name and means something like "God has been gracious; has shown favor". 😂
@@TheAIEpiphany Probably an alternative in case the name 'Jack' is too boring lmao. Had a similar experience, first time I googled "jax tutorial" it was a guide for a game character haha.
Believe it stands for "Just Another XLA" compiler.
50:40 - I would argue here, that it's not necessary to pass all the parameters into the function, as long as it's not changing any of the params, it's ok to use external globals(), like for some reference tables etc. This definition (though academically thorough), make practical application a bit more cumbersome. I believe that the better way to think "2." is sufficient to make this work. No need to pass long list of params. Just make sure not to update/change anything external inside the function, and whatever is not passed in is static. Alternatively, you can have "get jit_function" every time you anticipate that your globals might've changed. So, you will be effectively re-creating your jit function with new globals(). In some cases that feels much preferable to passing everything in. For instance, you can use all sorts of globals inside it, then just re-create it just before your training loop.
Thanks for this video!! that was really interesting for a new user of JAX like me
That is a great hands-on tutorial which has a perfect mix of theory and practical usage. Thanks
Finally some jax tutorial.. Keep them coming
Yup it was about time I started learning it. Will do sir
this video is such a great service to the community. really great examples to help better understand Jax at a nuanced level.
congrats to your deepmind job man (read your post), nice channel, keep going!
Thank you!! 🙏😄
I am glad that I found this channel
Welcome 🚀
Wonderful tutorial! Really really good examples. Thank you so much!
JAX = Just After eXecution (related to the tracing behaviour)
JIT = Just In Time (related to the compilation behaviour)
Thank you! A couple more people also pointed it out. Makes sense
Awesome work!! JAX is a fantastic library. This series is the reason I finally subscribed to your channel. Thanks for your work!
Thank you! 😄
Great video!🔥 Would need Paper implementations too💯
i love jax ... thank you for your work!
You're welcome!
great video and content, this channel needs more recognition.
Thank you! 🥳
Wonderful explanation about vmap function
Thanks for the great tutorial with pointing out the strong and weak points of the JAX framework, with caveats and salient features. What makes me somehow confused -the behavior, that overshooting the index array clips the index to maximum or does nothing. In C/C++ if one does this usually if the displacement is small - some memory part outside the given data is modified, and for strongly index mistake one would receive SEGMENTATION FAULT. Clipping the index makes the program safer, but in addition to counterintuitive behavior is adds some small additional cost for fetching the index.
Thank you! It is confusing. It'd be cool to understand why exactly is it difficult to handle this "correctly" (throwing an exception).
oh, quite impressive series with the perfect explanation
Thank you man! 🙏
Great video thanks, kindly complete the tutorial series on Flax as well.
Thank you for the amazing content.
Greetings from Spain
Gracias y saludos desde Belgrado! 🙏 Tengo muchos amigos en España.
JAX is Just After eXecution, represent the paradigm of tracing and transform (grad, vmap, jit,..) after the first execution.
Hmmm, source?
@@TheAIEpiphany Sorry I can't remember it now. But it somewhere in the documentation or a Jax's github issue/discussion,
@@PhucLe-qs7nx Thanks in any case! One of the other comments mentioned it simply stands for Jit Autograd XLA. 😄 That sounds reasonable as well.
Great material and great efforts , excited to see FLAX and Haiku ,thank you
Thanks for the tutorial! Love it!
Glad to hear that ^^
Thank you for the tutorial!
By the way, according to their paper (Compiling machine learning programs via high-level tracing), JAX stands for just after execution 😃
looking forward for such grt videos
Can someone explain to me why at 20:50 jnp.sum() is required and why it returns [0, 2, 4]? I would assume it would return 0 + 2 + 4 = 6 like its described in the comment and using sum(), but it doesn't it just returns the original vector size with all the elements squared.
I made a mistake. It's going to return a vector (df/dx1, df/dx2, df/dx3) and not the sum.
f = x1^2 + x2^2 + x3^2 and grad takes derivatives independently for x1, x2 and x3 since they are all bundled into the first argument of the function f.
Hope that makes sense. You can always consult the docs and experiment yourself.
Good question. Took me a moment as well, but, the function gives the sum, whereas grad(function) gives the three gradients, one per parameter, since the output of grad is used for SGD parameter updates w1, w2, w3 = w1 - lr*df/x1, w2 - lr*df/x2, w3 - lr*df/x3.
Great stuff, looking forward to the next jax tutorials
Thanks!
Thanks alot. Keep up the good work.
Am I wrong or the derivative at 20:15 should be (x1*2, x2*2, x3*2). I mean you take the gradient with respect to a vector so you should take the derivative with respect of each variable separately.
Of course what did I do? 😂
@@TheAIEpiphany you wrote x1*2+x2*2+x3*2. I replaced + with comma :)
@@adels1388 My bad 😅 Thanks for noticing, the printed result was correct...
Great work. Thank you, Aleksa. I learned a lot. Coming from R, I like the functional approach here. Would be interested to hear about your current opinion about jax, after knowing it better.
Great job. Very nice tutorial
At around 1:10:34, you have used static_argnums=(0,) for jit. Wouldn't this extremely slow down the program as it will have to retrace for all new values of x?
Code to reproduce:
def switch(x):
print("traced")
if x > 10.:
return 1.
else:
return 0.
jit_switch = jit(switch, static_argnums=(0,))
x = 5
jit_switch(x)
x = 16
jit_switch(x)
'''
Output:
traced
DeviceArray(0., dtype=float32, weak_type=True)
traced
DeviceArray(1., dtype=float32, weak_type=True)
'''
Thank you for a great content!
Thanks!
Saw your video retweeted by someone, watched it and subbed, because your content is great :) How often will you be uploading the following Jax vids?
Thanks! Next one tomorrow or Thursday.
Would it be better to use Julia and not have to worry about the gotchas? And still get the performance.
ty ty ty ty ty for this video
Very Cool, Thanks
great video!
Finally the Jax !!!!!!!
Function Pureness rule #2 means one can not use closure variables (wrapping function variables)? That’s good to know since jax states that it is functional, but does not include closure use - due to jit caching only regarding explicitly passed function parameters.
Closure variables are hacky, but they are valid python code. Just not in JAX.
Hi ! Thank you for your video.
Is not that very similar tu Numba ?
It's really good Tutorial!! thx :)
Great job! Keep it up!😀
Hi Aleksa, First of thank you very much for sharing great content. I learn a lot from you. Could you please explain some up side of JAX over other frameworks?? I really need motivation to get started with JAX. Thanking you. Cheers :)
Thank you so much !
What font are you using for the Colab notebook?
Thank you so much ..
Thanks for the video. I have a question, how do i run tensorflow jax on browser? (Not in an online notebook)
Hi, are there any resources on how to freeze certain layers of the network for transfer learning?
jax.lax.stop_gradient
Can we say that if we made all arguments static, then it will be as good as normal code without jax?
Thank you for these videos btw
amazing content
Thank you!
Is it possible to use jax with python statsmodel?
bru youre great
Why not julia lang?
Why would this video imply that you shouldn't give Julia a shot?
I may make a video on Julia in the future.
I personally wanted to learn JAX since I'll be using it in DeepMind.
great content, horribile font ;)
Hahaha thank you and thank you!
Noice
Can you make a video "how to install jax in anaconda, python and other python Ide
at th-cam.com/video/SstuvS-tVc0/w-d-xo.html, what I noticed was that when I tried, print(grad(f_jit)(2.)), even with the static_argnums.
I also open-sourced an accompanying repo here: github.com/gordicaleksa/get-started-with-JAX
I recommend opening the notebook in parallel while watching the video so that you can play and tweak the code as well. Just open the notebook, click the Colab button on top of the file, and voila! You'll avoid having to set up the Python env and everything will just work! (you can potentially choose a GPU as an accelerator).
JAX = Just Autograd and Xla
one hour of nothing lol
49:00 jnp.reshape(x, (np.prod(x.shape),)) works.