Machine Learning with JAX - From Zero to Hero | Tutorial #1

แชร์
ฝัง
  • เผยแพร่เมื่อ 10 พ.ย. 2024

ความคิดเห็น • 91

  • @billykotsos4642
    @billykotsos4642 3 ปีที่แล้ว +20

    This channel is going on the GOAT level status

    • @TheAIEpiphany
      @TheAIEpiphany  3 ปีที่แล้ว

      Hahaha thank you! Step by step

  • @matthewkhoo1153
    @matthewkhoo1153 3 ปีที่แล้ว +32

    When I started learning JAX, I personally think it stands for JIT (J), Autograd (A), XLA (X) which is essentially an abbreviation for a bunch of abbreviations. Given that those features are the 'highlights' of JAX, its very possible. If that's the case, pretty cool naming from DeepMind. Anyways, there aren't many comprehensive resources for JAX right now, so I'm really looking forward to this series! Cheers Aleksa.

    • @TheAIEpiphany
      @TheAIEpiphany  3 ปีที่แล้ว +7

      Wow great point I totally overlooked it. 😅 That looks like a better hypothesis than the one I had. If you google "Jax meaning" it seems it's a legit name and means something like "God has been gracious; has shown favor". 😂

    • @matthewkhoo1153
      @matthewkhoo1153 3 ปีที่แล้ว +2

      @@TheAIEpiphany Probably an alternative in case the name 'Jack' is too boring lmao. Had a similar experience, first time I googled "jax tutorial" it was a guide for a game character haha.

    • @DanielSuo
      @DanielSuo ปีที่แล้ว

      Believe it stands for "Just Another XLA" compiler.

  • @not_a_human_being
    @not_a_human_being ปีที่แล้ว +1

    50:40 - I would argue here, that it's not necessary to pass all the parameters into the function, as long as it's not changing any of the params, it's ok to use external globals(), like for some reference tables etc. This definition (though academically thorough), make practical application a bit more cumbersome. I believe that the better way to think "2." is sufficient to make this work. No need to pass long list of params. Just make sure not to update/change anything external inside the function, and whatever is not passed in is static. Alternatively, you can have "get jit_function" every time you anticipate that your globals might've changed. So, you will be effectively re-creating your jit function with new globals(). In some cases that feels much preferable to passing everything in. For instance, you can use all sorts of globals inside it, then just re-create it just before your training loop.

  • @mariuskombou6729
    @mariuskombou6729 7 หลายเดือนก่อน +3

    Thanks for this video!! that was really interesting for a new user of JAX like me

  • @abbashaider7165
    @abbashaider7165 3 หลายเดือนก่อน +1

    That is a great hands-on tutorial which has a perfect mix of theory and practical usage. Thanks

  • @akashraut3581
    @akashraut3581 3 ปีที่แล้ว +13

    Finally some jax tutorial.. Keep them coming

    • @TheAIEpiphany
      @TheAIEpiphany  3 ปีที่แล้ว +2

      Yup it was about time I started learning it. Will do sir

  • @user-wr4yl7tx3w
    @user-wr4yl7tx3w 2 ปีที่แล้ว +5

    this video is such a great service to the community. really great examples to help better understand Jax at a nuanced level.

  • @johanngerberding5956
    @johanngerberding5956 3 ปีที่แล้ว +4

    congrats to your deepmind job man (read your post), nice channel, keep going!

  • @mohammedelfatihsalahmohame7288
    @mohammedelfatihsalahmohame7288 3 ปีที่แล้ว +1

    I am glad that I found this channel

  • @carlosrondonmoreno9796
    @carlosrondonmoreno9796 3 หลายเดือนก่อน

    Wonderful tutorial! Really really good examples. Thank you so much!

  • @yulanliu3839
    @yulanliu3839 2 ปีที่แล้ว +2

    JAX = Just After eXecution (related to the tracing behaviour)
    JIT = Just In Time (related to the compilation behaviour)

    • @TheAIEpiphany
      @TheAIEpiphany  2 ปีที่แล้ว

      Thank you! A couple more people also pointed it out. Makes sense

  • @mikesmith853
    @mikesmith853 3 ปีที่แล้ว +3

    Awesome work!! JAX is a fantastic library. This series is the reason I finally subscribed to your channel. Thanks for your work!

  • @vijayanand7270
    @vijayanand7270 3 ปีที่แล้ว +5

    Great video!🔥 Would need Paper implementations too💯

  • @bionhoward3159
    @bionhoward3159 3 ปีที่แล้ว +2

    i love jax ... thank you for your work!

  • @TheParkitny
    @TheParkitny 3 ปีที่แล้ว +2

    great video and content, this channel needs more recognition.

  • @DullPigeon2750
    @DullPigeon2750 ปีที่แล้ว

    Wonderful explanation about vmap function

  • @sacramentofwilderness6656
    @sacramentofwilderness6656 3 ปีที่แล้ว +5

    Thanks for the great tutorial with pointing out the strong and weak points of the JAX framework, with caveats and salient features. What makes me somehow confused -the behavior, that overshooting the index array clips the index to maximum or does nothing. In C/C++ if one does this usually if the displacement is small - some memory part outside the given data is modified, and for strongly index mistake one would receive SEGMENTATION FAULT. Clipping the index makes the program safer, but in addition to counterintuitive behavior is adds some small additional cost for fetching the index.

    • @TheAIEpiphany
      @TheAIEpiphany  3 ปีที่แล้ว

      Thank you! It is confusing. It'd be cool to understand why exactly is it difficult to handle this "correctly" (throwing an exception).

  • @alexanderchernyavskiy9538
    @alexanderchernyavskiy9538 3 ปีที่แล้ว +1

    oh, quite impressive series with the perfect explanation

  • @jawadmansoor6064
    @jawadmansoor6064 3 ปีที่แล้ว +1

    Great video thanks, kindly complete the tutorial series on Flax as well.

  • @sallanmega1
    @sallanmega1 3 ปีที่แล้ว +2

    Thank you for the amazing content.
    Greetings from Spain

    • @TheAIEpiphany
      @TheAIEpiphany  3 ปีที่แล้ว

      Gracias y saludos desde Belgrado! 🙏 Tengo muchos amigos en España.

  • @PhucLe-qs7nx
    @PhucLe-qs7nx 3 ปีที่แล้ว +2

    JAX is Just After eXecution, represent the paradigm of tracing and transform (grad, vmap, jit,..) after the first execution.

    • @TheAIEpiphany
      @TheAIEpiphany  3 ปีที่แล้ว

      Hmmm, source?

    • @PhucLe-qs7nx
      @PhucLe-qs7nx 3 ปีที่แล้ว +1

      @@TheAIEpiphany Sorry I can't remember it now. But it somewhere in the documentation or a Jax's github issue/discussion,

    • @TheAIEpiphany
      @TheAIEpiphany  3 ปีที่แล้ว

      @@PhucLe-qs7nx Thanks in any case! One of the other comments mentioned it simply stands for Jit Autograd XLA. 😄 That sounds reasonable as well.

  • @sarahel-sherif3318
    @sarahel-sherif3318 3 ปีที่แล้ว +2

    Great material and great efforts , excited to see FLAX and Haiku ,thank you

  • @kaneyxx
    @kaneyxx 3 ปีที่แล้ว +2

    Thanks for the tutorial! Love it!

  • @RamithHettiarachchi
    @RamithHettiarachchi 2 ปีที่แล้ว +1

    Thank you for the tutorial!
    By the way, according to their paper (Compiling machine learning programs via high-level tracing), JAX stands for just after execution 😃

  • @deoabhijit5935
    @deoabhijit5935 3 ปีที่แล้ว

    looking forward for such grt videos

  • @kenbobcorn
    @kenbobcorn 3 ปีที่แล้ว +6

    Can someone explain to me why at 20:50 jnp.sum() is required and why it returns [0, 2, 4]? I would assume it would return 0 + 2 + 4 = 6 like its described in the comment and using sum(), but it doesn't it just returns the original vector size with all the elements squared.

    • @TheAIEpiphany
      @TheAIEpiphany  3 ปีที่แล้ว +3

      I made a mistake. It's going to return a vector (df/dx1, df/dx2, df/dx3) and not the sum.
      f = x1^2 + x2^2 + x3^2 and grad takes derivatives independently for x1, x2 and x3 since they are all bundled into the first argument of the function f.
      Hope that makes sense. You can always consult the docs and experiment yourself.

    • @Gannicus99
      @Gannicus99 3 ปีที่แล้ว +1

      Good question. Took me a moment as well, but, the function gives the sum, whereas grad(function) gives the three gradients, one per parameter, since the output of grad is used for SGD parameter updates w1, w2, w3 = w1 - lr*df/x1, w2 - lr*df/x2, w3 - lr*df/x3.

  • @Khushpich
    @Khushpich 3 ปีที่แล้ว +1

    Great stuff, looking forward to the next jax tutorials

  • @adels1388
    @adels1388 3 ปีที่แล้ว +6

    Thanks alot. Keep up the good work.
    Am I wrong or the derivative at 20:15 should be (x1*2, x2*2, x3*2). I mean you take the gradient with respect to a vector so you should take the derivative with respect of each variable separately.

    • @TheAIEpiphany
      @TheAIEpiphany  3 ปีที่แล้ว +1

      Of course what did I do? 😂

    • @adels1388
      @adels1388 3 ปีที่แล้ว +1

      @@TheAIEpiphany you wrote x1*2+x2*2+x3*2. I replaced + with comma :)

    • @TheAIEpiphany
      @TheAIEpiphany  3 ปีที่แล้ว +1

      @@adels1388 My bad 😅 Thanks for noticing, the printed result was correct...

  • @maikkschischo1790
    @maikkschischo1790 ปีที่แล้ว

    Great work. Thank you, Aleksa. I learned a lot. Coming from R, I like the functional approach here. Would be interested to hear about your current opinion about jax, after knowing it better.

  • @broccoli322
    @broccoli322 ปีที่แล้ว

    Great job. Very nice tutorial

  • @adityakane5669
    @adityakane5669 2 ปีที่แล้ว +1

    At around 1:10:34, you have used static_argnums=(0,) for jit. Wouldn't this extremely slow down the program as it will have to retrace for all new values of x?
    Code to reproduce:
    def switch(x):
    print("traced")
    if x > 10.:
    return 1.
    else:
    return 0.
    jit_switch = jit(switch, static_argnums=(0,))
    x = 5
    jit_switch(x)
    x = 16
    jit_switch(x)
    '''
    Output:
    traced
    DeviceArray(0., dtype=float32, weak_type=True)
    traced
    DeviceArray(1., dtype=float32, weak_type=True)
    '''

  • @yagneshm.bhadiyadra4359
    @yagneshm.bhadiyadra4359 2 ปีที่แล้ว +1

    Thank you for a great content!

  • @MikeOxmol_
    @MikeOxmol_ 3 ปีที่แล้ว +2

    Saw your video retweeted by someone, watched it and subbed, because your content is great :) How often will you be uploading the following Jax vids?

    • @TheAIEpiphany
      @TheAIEpiphany  3 ปีที่แล้ว +1

      Thanks! Next one tomorrow or Thursday.

  • @user-wr4yl7tx3w
    @user-wr4yl7tx3w 2 ปีที่แล้ว +1

    Would it be better to use Julia and not have to worry about the gotchas? And still get the performance.

  • @promethesured
    @promethesured 6 หลายเดือนก่อน +1

    ty ty ty ty ty for this video

  • @alinajafistudent6906
    @alinajafistudent6906 ปีที่แล้ว

    Very Cool, Thanks

  • @michaellaskin3407
    @michaellaskin3407 2 ปีที่แล้ว +1

    great video!

  • @mrlovalovaa
    @mrlovalovaa 3 ปีที่แล้ว +1

    Finally the Jax !!!!!!!

  • @Gannicus99
    @Gannicus99 3 ปีที่แล้ว

    Function Pureness rule #2 means one can not use closure variables (wrapping function variables)? That’s good to know since jax states that it is functional, but does not include closure use - due to jit caching only regarding explicitly passed function parameters.
    Closure variables are hacky, but they are valid python code. Just not in JAX.

  • @zumpitu
    @zumpitu ปีที่แล้ว

    Hi ! Thank you for your video.
    Is not that very similar tu Numba ?

  • @YoungbinLee-z4p
    @YoungbinLee-z4p 2 ปีที่แล้ว

    It's really good Tutorial!! thx :)

  • @shyamalchandraofficial
    @shyamalchandraofficial 2 ปีที่แล้ว

    Great job! Keep it up!😀

  • @1potdish271
    @1potdish271 2 ปีที่แล้ว +1

    Hi Aleksa, First of thank you very much for sharing great content. I learn a lot from you. Could you please explain some up side of JAX over other frameworks?? I really need motivation to get started with JAX. Thanking you. Cheers :)

  • @gim8377
    @gim8377 8 หลายเดือนก่อน

    Thank you so much !

  • @peterkanini867
    @peterkanini867 ปีที่แล้ว

    What font are you using for the Colab notebook?

  • @vaishnav4035
    @vaishnav4035 ปีที่แล้ว

    Thank you so much ..

  • @nicecomment5411
    @nicecomment5411 2 ปีที่แล้ว

    Thanks for the video. I have a question, how do i run tensorflow jax on browser? (Not in an online notebook)

  • @alvinhew1872
    @alvinhew1872 3 ปีที่แล้ว +1

    Hi, are there any resources on how to freeze certain layers of the network for transfer learning?

  • @yagneshm.bhadiyadra4359
    @yagneshm.bhadiyadra4359 2 ปีที่แล้ว

    Can we say that if we made all arguments static, then it will be as good as normal code without jax?
    Thank you for these videos btw

  • @teetanrobotics5363
    @teetanrobotics5363 3 ปีที่แล้ว +1

    amazing content

  • @tshegofatsotshego375
    @tshegofatsotshego375 2 ปีที่แล้ว

    Is it possible to use jax with python statsmodel?

  • @arshsharma8627
    @arshsharma8627 5 หลายเดือนก่อน

    bru youre great

  • @L4rsTrysToMakeTut
    @L4rsTrysToMakeTut 3 ปีที่แล้ว +1

    Why not julia lang?

    • @TheAIEpiphany
      @TheAIEpiphany  3 ปีที่แล้ว +2

      Why would this video imply that you shouldn't give Julia a shot?
      I may make a video on Julia in the future.
      I personally wanted to learn JAX since I'll be using it in DeepMind.

  • @adrianstaniec
    @adrianstaniec 3 ปีที่แล้ว +4

    great content, horribile font ;)

    • @TheAIEpiphany
      @TheAIEpiphany  3 ปีที่แล้ว

      Hahaha thank you and thank you!

  • @TheMazyProduction
    @TheMazyProduction 3 ปีที่แล้ว +1

    Noice

  • @alokraj7120
    @alokraj7120 2 ปีที่แล้ว

    Can you make a video "how to install jax in anaconda, python and other python Ide

  • @user-wr4yl7tx3w
    @user-wr4yl7tx3w 2 ปีที่แล้ว

    at th-cam.com/video/SstuvS-tVc0/w-d-xo.html, what I noticed was that when I tried, print(grad(f_jit)(2.)), even with the static_argnums.

  • @TheAIEpiphany
    @TheAIEpiphany  3 ปีที่แล้ว +7

    I also open-sourced an accompanying repo here: github.com/gordicaleksa/get-started-with-JAX
    I recommend opening the notebook in parallel while watching the video so that you can play and tweak the code as well. Just open the notebook, click the Colab button on top of the file, and voila! You'll avoid having to set up the Python env and everything will just work! (you can potentially choose a GPU as an accelerator).

  • @kirtipandya4618
    @kirtipandya4618 3 ปีที่แล้ว

    JAX = Just Autograd and Xla

  • @samueltrif5472
    @samueltrif5472 ปีที่แล้ว

    one hour of nothing lol

  • @heyman620
    @heyman620 2 ปีที่แล้ว

    49:00 jnp.reshape(x, (np.prod(x.shape),)) works.