Efficient and Modular Implicit Differentiation (Machine Learning Research Paper Explained)

แชร์
ฝัง

ความคิดเห็น • 47

  • @YannicKilcher
    @YannicKilcher  3 ปีที่แล้ว +5

    OUTLINE:
    0:00 - Intro & Overview
    2:05 - Automatic Differentiation of Inner Optimizations
    4:30 - Example: Meta-Learning
    7:45 - Unrolling Optimization
    13:00 - Unified Framework Overview & Pseudocode
    21:10 - Implicit Function Theorem
    25:45 - More Technicalities
    28:45 - Experiments
    ERRATA:
    - Dataset Distillation is done with respect to the training set, not the validation or test set.

  • @ChaiTimeDataScience
    @ChaiTimeDataScience 3 ปีที่แล้ว +43

    I've been loving the new speed at which Dr. Kilcher is putting all of these videos out! So much to learn!

  • @paxdriver
    @paxdriver 3 ปีที่แล้ว +15

    Dr lightspeed, you are a hero

  • @kimchi_taco
    @kimchi_taco 3 ปีที่แล้ว +1

    Thank you for introducing interesting paper!
    18:00 the equation of ridge_solver is the same to the optimality condition F=0. I'm very little bit confused...
    23:45 instead of optimality procedure has to be differentiable, only the optimality condition now needs to be differentiable. best summary ever.

  • @tchlux
    @tchlux 3 ปีที่แล้ว

    In your expression at 8:22 I think you've got which terms are transposed reversed. Notice in the code, the terms all have `_tr` at the end. It should be $(X X^t + \lambda I) w = X y^t$. The way it's written in the video, it looks like you're solving for what weights should be applied to every *point,* but instead we want weights applies to each *component* (i.e., a linear function).

  • @herp_derpingson
    @herp_derpingson 3 ปีที่แล้ว +4

    20:00 What does the custom_root decorator do? What does it return?
    .
    20:15 I dont understand, are we using the ridge solver as an approximation for the inner SGD? Why?
    .
    I dont understand what is going on in this paper, but if it works, its gonna be big.

    • @mgostIH
      @mgostIH 3 ปีที่แล้ว +2

      In general decorators in python are a way to apply properties to a function by redefining it internally, for example you could have a decorator that add a timer to the function you apply it to so you can benchmark things easily.
      In this case, the custom_root decorator takes in the differentiable optimality conditions this paper talks about, so a function F that is differentiable and is zero when we found the right solution. When applied to a function like the ridge_solver, it redefines its gradient in terms of lambda by instead using this paper method, rather than JAX builtin autograd.
      The ridge solver is just an example that doesn't have anything to do with SGD, this method however provides you with the gradient to find the optimal value for lambda.
      Essentially, the ridge regression function solves the problem of finding w minimizing this loss: ||w^T * X - y|| + λ ||w||, but we'd also like to find the optimal value of λ that minimizes this loss even further! This entire paper solves the problem **WITHOUT** us having to look inside the details of the ridge regression solver, internally we could've used whatever crazy method as long as it satisfies that minimization task, and by specifying a differentiable F we get for free the gradient with respect to λ, which allows us to minimize the loss even further.
      A way to imagine this, a Yannic mentions, is to think of λ as a hyperparameter we have to search, but instead of just doing a black box search we actually get gradients efficiently. Imagine if instead of minimizing ||w^T * X - y|| + λ ||w||, you were to minimize ||NN(x) - y|| + λ ||w||, where NN is a neural network (Again, regression but with a regularization term dependant on lambda). Then the solver would be all the gradients step we take for training our model for a fixed λ, **which might internally be so complicated that backpropagating through it wrt. λ is simply impossible**. But with this method you only care that you have a routine, which internally may use SGD, that optimizes the weights of the network you defined, and you still get the gradient with respect to the hyperparameters by defining the optimization as a root of another differentiable function F!
      Even nicer, this method allows to directly solve optimization problems by just stating the optimal conditions, which before required entire papers (like OptNet) to derive.

  • @theoboyer3812
    @theoboyer3812 3 ปีที่แล้ว +17

    Now I want to learn jax

  • @victorrielly4588
    @victorrielly4588 3 ปีที่แล้ว +5

    Recurrent neural networks also struggled with the problem of computing gradients by unrolling the recurrence. I wonder if this technique could be applied in that instance as well.

    • @chuby35
      @chuby35 3 ปีที่แล้ว +2

      hmm. Interesting idea, but what would be the optimality condition for the inner loop of an RNN layer?

    • @victorrielly4588
      @victorrielly4588 3 ปีที่แล้ว +1

      @@chuby35 I’d have to think about it

  • @MrMIB983
    @MrMIB983 3 ปีที่แล้ว +3

    Nice to give credit to an underrated area

  • @theoboyer3812
    @theoboyer3812 3 ปีที่แล้ว +1

    This is really cool, but if I understood it correctly, the applications are still quite limited. The problem is that you need to compute a d*d jacobian matrix (and then solve a linear system involving this matrix), d being the dimension of the output of your inner optimization algorithm and input of the optimality condition function.
    So, for any application involving neural networks for example, unless I'm wrong your little d would be the number of parameters of your neural network. Before even talking about solving the linear system, you would need to store a matrix of size "the number of parameters of the neural network SQUARED"

  • @joedalton77
    @joedalton77 3 ปีที่แล้ว +4

    I have the feeling this one is going to be big

  • @paulcurry8383
    @paulcurry8383 3 ปีที่แล้ว +1

    Im a bit confused by the toy example, couldn’t you differentiate ||wTX -y|| +theta||w|| as your loss function, treating theta as another weight?
    Does this not work because the norm is nonlinear or something?

    • @YannicKilcher
      @YannicKilcher  3 ปีที่แล้ว +1

      The two losses are with respect to different datasets. The outer optimization is over the validation set and conditional on having solved the inner problem to completion

    • @paulcurry8383
      @paulcurry8383 3 ปีที่แล้ว

      @@YannicKilcher ah so now you’d want 3 datasets so you can validate your hyper parameter training

  • @piotr780
    @piotr780 3 ปีที่แล้ว

    is it now implmented in JAX library ?

  • @JTMoustache
    @JTMoustache 3 ปีที่แล้ว

    Powerful stuff 💪🏼

  • @MaeglinLiao
    @MaeglinLiao 3 ปีที่แล้ว +6

    Does this framework also apply to training GANs ? Or it is a tri-level optimization problem if hyperparameter-optimization is involved🤣.

    • @YannicKilcher
      @YannicKilcher  3 ปีที่แล้ว +9

      Yes it does. And yes, this actually supports any depth of inner loops 😁

    • @MaeglinLiao
      @MaeglinLiao 3 ปีที่แล้ว +4

      @@YannicKilcher it’s so cool to actually optimize the theoretical max-min problem

  • @nx6803
    @nx6803 3 ปีที่แล้ว +1

    GradSlam uses unrolling, perhaps it can utilize this!

  • @Phenix66
    @Phenix66 3 ปีที่แล้ว +12

    Dude, I wanna go to sleep... Damn it :D

  • @drdca8263
    @drdca8263 3 ปีที่แล้ว +4

    This sounds like it could be a big deal. Does this primarily make multi-level optimization things easier to code, or does it also make these things notably faster when running?
    (I guess rather than "notably faster", what I really mean is like, a better big O time, compared to have they would usually have been implemented previously)
    It still has to compute the inner optimizations I suppose..
    Does this make some computations tasks that were previously infeasible, now feasible?
    (feasible in the sense of "we know how to write a program within a reasonable amount of time and effort, which will run within a reasonable amount of time, and produce the answer")
    Not that if the answer is no that this wouldn't still be important,
    just, trying to tell if, if I understood it better, whether it would seem *very* important, or just, not quite that important but still quite cool.

    • @aspergale9836
      @aspergale9836 3 ปีที่แล้ว +1

      Doing 2-level optim through, say, SGD with enough steps has been practically impossible for methods that do naive auto-diff then GD. This makes it possible.

    • @drdca8263
      @drdca8263 3 ปีที่แล้ว

      @@aspergale9836 Thanks!

  • @chuby35
    @chuby35 3 ปีที่แล้ว +1

    I'd love to see the optimal dataset for ffhq with some classifier, but I don't want to learn jax just for that. :) I hope someone will create that just for the laughs. :)

  • @XOPOIIIO
    @XOPOIIIO 3 ปีที่แล้ว +1

    But why autodiff wasn't used before?

  • @mahirahimi97
    @mahirahimi97 3 ปีที่แล้ว +2

    Now we can have meta-meta-...-meta learning

  • @MadlipzMarathi
    @MadlipzMarathi 3 ปีที่แล้ว +1

    4:52 ok so, no one noticed it.

  • @scottmiller2591
    @scottmiller2591 3 ปีที่แล้ว +3

    I need to define an alias for jax.jacobian so I can just write jax.ocbian.

  • @scottmiller2591
    @scottmiller2591 3 ปีที่แล้ว +1

    If X is N x p, where N is the number of data points and p is the dimension of the features in X, and similarly y is NX1, then you want X w = y, not w' X = y, so rows have to equal N rows on both sides. Also, it's the L2 norms SQUARED, not just the L2 norms, at least for Tikhonov ridge regression.
    This method seems interesting, I need to look at the proximal gradient stuff.

  • @al8-.W
    @al8-.W 3 ปีที่แล้ว

    Are we getting closer to the free lunch theorem ?

  • @barberb
    @barberb 3 ปีที่แล้ว +19

    This is just another example of plagarizing Schmidhuber

    • @nilsrethmeier8280
      @nilsrethmeier8280 3 ปีที่แล้ว +1

      Could you point put the JS paper(s) pls? Much appreciated.

    • @barberb
      @barberb 3 ปีที่แล้ว

      For those who don't understand, this is a joke.

    • @scottmiller2591
      @scottmiller2591 3 ปีที่แล้ว +2

      frey_squinting.jpg - Not sure if sarcasm or historically accurate, but odds are historically accurate.

    • @G12GilbertProduction
      @G12GilbertProduction 3 ปีที่แล้ว

      It hurts my grad(init) headbutt too.

  • @sieyk
    @sieyk 3 ปีที่แล้ว

    Doesn't this mean that you can now feasibly backpropagate spiking neural networks?

    • @YannicKilcher
      @YannicKilcher  3 ปีที่แล้ว

      I'm not sure, is the inner loop run to optimum?

    • @sieyk
      @sieyk 3 ปีที่แล้ว

      @@YannicKilcher Ah, I did not realise this attempted to optimise two separate problems simultaneously. I was thinking for the spiking network you could just solve the dictionary example where the dictionary elements are the connected node pairs. Perhaps this has some good applications in reinforcement learning!

  • @ankitaharwal5886
    @ankitaharwal5886 3 ปีที่แล้ว

    In implementation we passed theta=10.0, so it like we pass Weight Initialization in normal deep learning. And we would get new optimized thetha at the end.

  • @scottmiller2591
    @scottmiller2591 3 ปีที่แล้ว

    "I'm not saying inner and outer loop, but 'Inner and Outer Loop'"

  • @Pmaisterify
    @Pmaisterify 3 ปีที่แล้ว +1

    First

  • @dhruvpatel4948
    @dhruvpatel4948 3 ปีที่แล้ว +1

    Can we officially name this a Bayesian optimization killer?