MLP-Mixer in Flax and PyTorch

แชร์
ฝัง
  • เผยแพร่เมื่อ 21 ก.ย. 2024

ความคิดเห็น • 18

  • @clee5653
    @clee5653 3 ปีที่แล้ว +2

    Conv implementation is a nice touch! Comparing different implementation definitely helps understanding.

  • @yimingqu2403
    @yimingqu2403 3 ปีที่แล้ว +3

    Great tutorial. I really learned a lot.
    Two comments:
    - It has been suggested that we should avoid using `.data`, so when doing parameter copying,
    `module_linear.weight.data[:] = module_conv.weight.data.reshape(hidden_dim, -1)`
    can be substituted with
    `module_linear.weight = nn.Parameter(module_conv.weight.clone().reshape(hidden_dim, -1))`.
    Using `.data` in such tutorial might confuse newsettlers to pytorch.
    - I found that when comparing output from `Linear` and `Conv2d`, `torch.allclose` should be more tolerant when you raise hyperparameters to larger value. For example, try to increase `patch_size`, `image_size`, and `atol=1e-6` will not be sufficient. In this case `atol` should be set to larger value, e.g. 1e-5. Took a lot time for me to debug this. And I'm not sure what's happening under the hood, but when you initialize `Conv2d` before `Linear`, `atol = 1e-6` becomes okay again😂

    • @mildlyoverfitted
      @mildlyoverfitted  3 ปีที่แล้ว +1

      Wow! What a great comment! Always glad to see that some people dig into the technical details too!
      1) Great point, you can definitely create a new Parameter from scratch and clone the underlying memory storage:) AFAIK the changes of `data` are not tracked with autograd which could be pretty bad if you want this assignment to be a part of the computational graph. However, it should be fine if you are just using it to initialize/update. However, point taken. Interesting link:
      * discuss.pytorch.org/t/how-to-assign-an-arbitrary-tensor-to-models-parameter/44082
      * discuss.pytorch.org/t/leaf-variable-was-used-in-an-inplace-operation/308/2
      2) Actually, I was wondering about the same thing and I eventually concluded that there must be 2 different algorithms in the background and one cannot get rid of this floating point difference.Some resources I found
      * discuss.pytorch.org/t/conv2d-and-linear-layer-for-1x1-image/87309/2
      * stackoverflow.com/questions/55576314/conv1d-with-kernel-size-1-vs-linear-layer

  • @alexmckinney5761
    @alexmckinney5761 ปีที่แล้ว +1

    Hi, this video is old but thought I would mention. I think to compute the total parameters you can simply do sum of a tree map of x.size. No need for the extra tree_leaves or prod. This is from memory though, so if that doesn't work I can dig up the code.

    • @mildlyoverfitted
      @mildlyoverfitted  8 หลายเดือนก่อน

      Ahh, thank you for the tip:)

  • @BigDudeSuperstar
    @BigDudeSuperstar 3 ปีที่แล้ว +2

    Love your content, love the sound of your typing on the keyboard. May I ask, do you have a script prior to writing the code? Or do you do it as you go along? Very clean and coherent videos, great work.

    • @mildlyoverfitted
      @mildlyoverfitted  3 ปีที่แล้ว +6

      Comments like this are really motivating:) So thank you! I am always spending a lot of time reading the paper + available implementations and then preparing my scripts + IPython examples. When I do the live coding it is more of a live rewriting of the prepared code. I actually tried to film some stuff without preparation, however, I quickly gave up because it is hard and the final footage gets really long.

  • @DiogoSanti
    @DiogoSanti 3 ปีที่แล้ว +1

    Great content! How do you compare Jax/Flax to Pytorch? Did you think will be a replacement in a near future or it's just only another framework in the block?

    • @mildlyoverfitted
      @mildlyoverfitted  3 ปีที่แล้ว +1

      Thank you for the comment! I actually do not have a lot of experience with Flax/Jax so I am definitely not the right person to make this kind of comparison. Obviously, when it comes to implementing things like MLP-Mixer one can do it with any framework. Anyway, I hope to learn more about Flax/Jax in the near future and maybe I will come back to this comment then:D

    • @DiogoSanti
      @DiogoSanti 3 ปีที่แล้ว +1

      @@mildlyoverfitted Thanks for this, i will start a journey in Jax aswell, let's see how it goes... This part: jax.readthedocs.io/en/latest/autodidax.html is a little overwhelming for me right know, to understand the it's intrinsics. So going back to basics for now.

    • @mildlyoverfitted
      @mildlyoverfitted  3 ปีที่แล้ว +1

      @@DiogoSanti Wow, thanks for the link actually. I did not know about it:) Anyway, good luck!

  • @michaelvonreich74
    @michaelvonreich74 3 ปีที่แล้ว +1

    Just a bit off topic, but I really like your docstring format. After a quick google, it seems to be the numpy/scipy styleguide? I've personally been writing using the google style, but I think this looks nicer.
    Looking back at the fact that the mixer layers can also be done with 1d convolution, is there any inherent advantage in using linear vs conv layer? Linear is more computationally efficient, right?

    • @mildlyoverfitted
      @mildlyoverfitted  3 ปีที่แล้ว

      Yes, it is indeed the numpydoc format - numpydoc.readthedocs.io/en/latest/format.html. I have been using it for quite a while now and I think I was introduced to it via scikit-learn. Anyway, other formats are fine too!
      Regarding the efficiency, I never really investigated which of the two approaches is "faster". Heh, now I am curious too. Anyway, in this video I talked about it purely to show that one can make an argument that MLP-Mixer is a CNN.

  • @GaneshBhat
    @GaneshBhat 2 ปีที่แล้ว

    Hi Thanks for the video but I keep getting the below error while running the code. This is for the conv1dDepthWiseShared class.
    Can you please share the reason?
    ----> 1 out_conv = module_conv(x).reshape(n_samples, hidden_dim, k)
    in forward(self, x)
    10 def forward(self, x):
    ---> 11 weight = self.weight_shared(self.hidden_dim, 1, 1)
    12 bias = self.bias_shared.repeat(self.hidden_dim)
    13 res = torch.nn.functional.conv1d(x, weight=weight, bias=bias, groups=self.hidden_dim)
    TypeError: 'Parameter' object is not callable

    • @mildlyoverfitted
      @mildlyoverfitted  2 ปีที่แล้ว +1

      Could you please create an issue on GitHub and post the example there? github.com/jankrepl/mildlyoverfitted/issues
      TH-cam is not great when it comes to formatting code!

  • @teetanrobotics5363
    @teetanrobotics5363 3 ปีที่แล้ว +2

    Amazing content. But please make it a little noob-friendly.

    • @mildlyoverfitted
      @mildlyoverfitted  3 ปีที่แล้ว +1

      Check out the IPython sections of the video:) The Flax intro should be relatively simple. Anyway, I will try to make beginner friendly content in the future:)