Looks great! I tend to default to numpy when I want to do something that is not fully supported in keras or pytorch and if i can get paralellization on gpu very easily from this that is perfect!
How are you going to compare torch to tf/jax when run on a different GPU? There is no way you can argue the 2 gpus are comparable, they will be faster/slower at different types of computation regardless of software used. Should have compared the 3 on a common gpu if for some reason torch couldnt be run on the tpuv3.
This typically looks like a problem that could be easily solved with a language that supports multi-stage programming; meta-programming as a first class citizen, which is not really the case with Python. Like Rust or Elixir via the Nx library which is actually directly inspired of Jax.
Ok, this is seriously cool. Is this brand new? Haven't seen it before. Also, in the first code sample did you mean to import vmap and pmap instead of map, or is that some kind of namespace black magic I don't understand?
Why do you need a new lib? Tensorflow can do 90+% of this, doesn't it? Is it a good idea to make a completely new thing instead extending the old one? One more question: do/will you have Keras support?
Numerical differentiation computes f’(x) by evaluating the function around x: (f(x+h)-f(x-h))/2h with a small h. Automatic differentiation represents the function expression or code as a computational graph. It looks at the actual code of the function. The final derivative is obtained by propagating the value of local derivatives of simple expressions through the graph via the chain rule. The simple expressions are functions like +, -, cos(x), exp(x) for which we knows the derivatives at a given x.
I mean it is kinda niche but suppose you solve a problem that heavily relies on many custom functions, e.g., a very specific algebra like quaternion-operations. Then you can write super-fast basic operations and compose them to build a complicated loss-function that as a whole you can then jit-compile and let it get optimized. Or differentiate it, or vectorize it, all with a tiny decorator.
torch and keras is "slow" and is only meant for the development phase. not sure how fast jax can outperform them. edit: "slow" as in computation/inference time
This guy is so epic. He looks like he's enjoying every second of life.
NumPy on steroids
This video maximizes dInsights/dtime, is well written and easy to understand! I want to see more videos from Jake!
JAke X
EduTube needs a like button for specifically this metric 🤜🤛
It maximizes Insights/time, not the derivative
@@nrrgrdn yeh maybe that's better, but I think he means you gain continuously more insights as you advance in the video
I burst out laughing with the ExpressoMaker that overloads the + operator.
Thank you for this good intro to JAX. Very easy to follow and understand, Jake. Definitely going to add this to my toolkit. 👍🙏
Looks great! I tend to default to numpy when I want to do something that is not fully supported in keras or pytorch and if i can get paralellization on gpu very easily from this that is perfect!
i have a question, whats the porpuse of doing so many frameworks? time? efficiency? cuz i don't see it.
Thiis sounds very good especially the grad and vmap functionality. I think more libraries would have to be released to compete with pytorch.
How are you going to compare torch to tf/jax when run on a different GPU? There is no way you can argue the 2 gpus are comparable, they will be faster/slower at different types of computation regardless of software used. Should have compared the 3 on a common gpu if for some reason torch couldnt be run on the tpuv3.
This typically looks like a problem that could be easily solved with a language that supports multi-stage programming; meta-programming as a first class citizen, which is not really the case with Python. Like Rust or Elixir via the Nx library which is actually directly inspired of Jax.
Thanks! For me it helps alot! Being a C/C++ / Python developer, somehow I left behind such an important framework / library.
JAX seems to be more similar to PyTorch i.e., dynamic graph instead of static graph as in Tensorflow.
There's something called AutoGraph in TensorFlow actually
That's Flax. Jax is more like the backbone of that
Ok, this is seriously cool. Is this brand new? Haven't seen it before.
Also, in the first code sample did you mean to import vmap and pmap instead of map, or is that some kind of namespace black magic I don't understand?
It has been around for over 2 years now I believe
ya it's a typo, there's no magic
Awesome intro!
Great content ! BRAVO and THANKS !
nice talk that will be interesting
Why do you need a new lib? Tensorflow can do 90+% of this, doesn't it? Is it a good idea to make a completely new thing instead extending the old one?
One more question: do/will you have Keras support?
What is the difference btw numerical vs automatic differentiation?
Numerical differentiation computes f’(x) by evaluating the function around x: (f(x+h)-f(x-h))/2h with a small h. Automatic differentiation represents the function expression or code as a computational graph. It looks at the actual code of the function. The final derivative is obtained by propagating the value of local derivatives of simple expressions through the graph via the chain rule. The simple expressions are functions like +, -, cos(x), exp(x) for which we knows the derivatives at a given x.
Why don't use julia lang?
I thought JAX was running as default in tensorflow, am I missing something here?
Awesome video! Thank you!
2:14 why in predict function inputs is reassigned but never used ? should be outputs = np.tanh(outputs)
Active: Jax enters Evasion, a defensive stance, for up to 2 seconds, causing all basic attacks against him to miss.
I knew this is going to come up lol
Seems tensorflow is fast enough?
Is this much better than simd?
This is wild!
Interesting Stuff
Does this support apples gpus in M1 max?
I also wonder if they utilize the neural processors, too?
Nice framework.
Really interesting!
it says "from jax import map", but it seems it should be vmap?
from jax import map as vmap
Something's wrong with the audio. His voice gets so soft it's hard to hear at the end of some sentences.
Seeing JAX on the TensorFlow channel, now I am scared they'll mess this codebase too. Please don't, k thx.
Amazing!
Imagine if it had a real weapon
Amazing👏
Googles Bard sent me here . Anyone know why ?
Top Jax OP
Is it me or does the backend technology of JAX sound very similar to the one in tensorflow.
I dont get it. Why do we need this if pytorch and keras/tf already exist?
I mean it is kinda niche but suppose you solve a problem that heavily relies on many custom functions, e.g., a very specific algebra like quaternion-operations. Then you can write super-fast basic operations and compose them to build a complicated loss-function that as a whole you can then jit-compile and let it get optimized. Or differentiate it, or vectorize it, all with a tiny decorator.
torch and keras is "slow" and is only meant for the development phase. not sure how fast jax can outperform them.
edit: "slow" as in computation/inference time
@@tclf90 So what frameworks are "fast"?
Wow! 🏋️
JAX come out because of torch
My Call Jax Son #AI ?
1:14 lol they compared TPU runtimes with GPU runtimes
😮😮😮😮 0:28
its tiresome