VQ-GAN | PyTorch Implementation

แชร์
ฝัง
  • เผยแพร่เมื่อ 1 ม.ค. 2025

ความคิดเห็น • 69

  • @csoRoBeRt
    @csoRoBeRt ปีที่แล้ว +1

    marvelous implementation. it's much clearer than looking into the original code

  • @NirDodge
    @NirDodge ปีที่แล้ว +3

    @outliier 26:34 I suspect that the colors are off due to decoded_images.add(1).mul(0.5) in the visualization, which maps the colors from [-1, 1] to [0, 1], but is only applied to the decoded images and not the original images for some reason.

  • @947973
    @947973 ปีที่แล้ว +1

    Excellent Video. Thank you very much for making this video

  • @the-whisper-
    @the-whisper- 2 ปีที่แล้ว +4

    Regarding the training part for the VQGAN at 24:24, from what I understand the following is happening:
    1. VQGAN grads are zeroed, grads are then propagated over the Discriminator (because of g_loss) and over the VQGAN for the rest of the losses (and the g_loss); retain_graph = True is added in order to keep the previously computed forward pass values, otherwise calling backward again on the same losses would raise an error;
    2. Discriminator grads are zeroed to remove what was previously added by the g_loss.backward(), and another backward call is done on the gan_loss to propagate the grads for the proper loss function (d_loss_fake and d_loss_real);
    3. The optimizers are called one after another to update the weights with the accumulated values in the leaf-tensors .grad property.
    One possible error might have occured at step 2 which led to the bad reconstruction seen a few minutes later in the video. The gan_loss.backward() propagates the following:
    - d_loss_real which was computed by applying the Discriminator over the real images.
    - d_loss_fake which was computed by using the disc_fake images *generated* by VQGAN. Here is where the issue might lie. The disc_fake_images were obtained by a forward pass through the VQGAN model, as a result the computational graph will retain these forward values and when gan_loss.backward() will be called the d_loss_fake will be propagated over the Discriminator and the VQGAN. In turn, this will adjust the VQGAN's weights to also minimize Discriminator's loss which will be something along the lines of "Generate images such that the Discriminator will be able to easily tell that they are fake".
    A possible cause for which the VQGAN is still able to reconstruct the images albeit not very well, because of the perturbing loss propagation, might be due to 2 factors:
    - the reconstruction loss is still present
    - the discriminator is turned off until the treshold is hit, but after that perturbation comes into place.
    A solution would be to:
    - (not as optimum) use two tensors for the fake images: disc_fake_1 = decoded_images and disc_fake_2 = decoded_images.detach() which will not propagate grads through the VQGAN. Pass them both through the Discriminator where disc_fake_1 will be used in g_loss to update the VQGAN and disc_fake_2 will be used in gan_loss to update the Discriminator.
    - (better as 1 single pass in Discriminator is required) before doing the gan_loss.backward() call, use self.vqgan.requires_grad(False) => this will disable the accumulation of gradients in VQGAN, so only the discriminator will receives values in its .grad property. After the backward() call reactivate the grads self.vqgan.requires_grad(True).
    I am a beginner in the field so I might be wrong in both my understanding and explanation.
    Source:
    - pytorch.org/docs/stable/notes/autograd.html#setting-requires-grad

    • @csoRoBeRt
      @csoRoBeRt ปีที่แล้ว

      I was having the same confusion of the grad propagation the I saw your answer!

    • @NirDodge
      @NirDodge ปีที่แล้ว

      I think that this is not necessary, since opt_disc.step() should only modify the discriminator parameters (and opt_vq.step() should only modify the vqgan parameters).

    • @csoRoBeRt
      @csoRoBeRt ปีที่แล้ว

      But loss_fake.backward() should add grad on the layers of generator. Not sure whether the vq.step() would take two part of grad to update or not @@NirDodge

    • @NirDodge
      @NirDodge ปีที่แล้ว

      @@csoRoBeRt Right, I see... While opt_disc.step() would not affect vqgan weights, it looks like self.vqgan.requires_grad(False) is needed so that gan_loss.backward() will not accumulate gradients on vqgan, that would affect the update in opt_vq.step().

    • @JJJYmmm
      @JJJYmmm 9 หลายเดือนก่อน

      Great! But in solution 1, I think just changing line 56 'disc_fake = self.discriminator(decoded_images)' to 'disc_fake = self.discriminator(decoded_images.detach())' is ok.

  • @tiln8455
    @tiln8455 2 ปีที่แล้ว +2

    Nice one 👍🏼👍🏼

  • @hassenzaayra5419
    @hassenzaayra5419 ปีที่แล้ว

    Hello, I trained my model w I had a good result I will download an image and used the training model with the extension .pt to see the image reconstituted

    • @hassenzaayra5419
      @hassenzaayra5419 ปีที่แล้ว

      @Outlier can you help me to test a model vqgan

  • @vinc6966
    @vinc6966 4 หลายเดือนก่อน

    Dude, you are making a YT video, not a class presentation, you have all the time in the world to take your time and explain each module step by step. Especially since your implementation has quite a few bugs…
    But overall, you did a decent job.

    • @outliier
      @outliier  4 หลายเดือนก่อน

      @@vinc6966 :(

  • @decreer4567
    @decreer4567 10 หลายเดือนก่อน

    Hey I looked through your code book. VQVAEs perform a one hot encoding. Is that something from the paper or just something you personally included. Nice video.

  • @神楽坂雫月
    @神楽坂雫月 2 ปีที่แล้ว

    awesome!thanks for the video!

  • @miumiu5224
    @miumiu5224 ปีที่แล้ว

    Hi, your video is great! I can't find a second one as good as yours. I have a question I would like to ask, how do I add conditions in the form of pictures when training the second stage transformer

    • @outliier
      @outliier  ปีที่แล้ว

      Hey thank you! I answered you question on github

  • @mkamp
    @mkamp 2 ปีที่แล้ว

    Very nice. Thanks for the video. QQ: 10:30 Why do you use the expanded version and not just (a-b)**2?

    • @Kostarion1
      @Kostarion1 2 ปีที่แล้ว +1

      I wonder the same question. I can only suppose that it is needed for not losing much precision when calculating a square of the difference, since (a-b) values can be very small.

    • @MrXboy3x
      @MrXboy3x 2 ปีที่แล้ว +4

      z-flat.shape = [1024,256] , embad.shape = [1024,256] , when you do (a-b)**2 you will get shape [1024,256] which mean ==> for each feature of the 1024 we get nearest 256 code vector, however if we use long term (a**2 +b**2 - 2ab) you will get [1024,1024] ==>for each feature of the 1024 we get nearest 1024 code vector (because of dot product operation). so you will think 256 still going to be better because still going to be nearest features yet its not entirely correct because if we have only 256 as selected features the model when back-propagate will only optimize 256 feature .
      try it your self :
      add in __init__ function ==> self.l2 = nn.MSELoss(reduction='none')
      add in forward function => d=self.l2(z_flattened,self.embedding.weight)

    • @rikki146
      @rikki146 ปีที่แล้ว

      @@MrXboy3x how about (a-b)*(a-b)^T

    • @rikki146
      @rikki146 ปีที่แล้ว +1

      Never mind I was being stupid. However, there does indeed exist a way to do it more elegantly:
      embedding = torch.rand((256,512)) # embedding_size, latent_dim
      z_flattened = torch.randn((10, 512, 16, 16)).view(2560, 512) # N*h*w, channel
      diff = z_flattened.repeat(256, 1, 1) - embedding[:,None,:]
      diff_squared = torch.sum(diff**2, dim=2)
      min_index = diff_squared.argmin(dim=0)

    • @kalisticmodiani2613
      @kalisticmodiani2613 ปีที่แล้ว

      @@MrXboy3x this still does not make sense to me, because mathematically the function that gives f(a,b) the loss from the inputs a and b is the same no matter how you decompose it. So the gradients on the inputs, or the minimum index should be the same... Am I missing something ? I suppose you may make the argument one is more numerically stable than the other, but I heard the (a-b)^2 version is more numerically stable..

  • @loko818r
    @loko818r ปีที่แล้ว

    it is a great code, but do you or anybody has the link where to download the flowers dataset?

    • @outliier
      @outliier  ปีที่แล้ว

      Thanks. Just look for oxford flower dataset and you should find it

  • @erwann_millon
    @erwann_millon 2 ปีที่แล้ว

    great video

  • @xiaolongye-y4g
    @xiaolongye-y4g ปีที่แล้ว

    真的很棒 很详细

  • @fatihhaslak962
    @fatihhaslak962 2 ปีที่แล้ว

    Hello. Thank you for this tutorial. Can you add to VQGAN+CLIP please.

    • @fatihhaslak962
      @fatihhaslak962 2 ปีที่แล้ว

      How to add CLIP for this code

  • @ganeshb8683
    @ganeshb8683 ปีที่แล้ว

    Thank you for the great explanation! Out of curiosity - what is the purpose of implementing the blocks (GroupNorm) as a separate class instead of using the predefined class in the torch (torch.nn.GroupNorm) ?

  • @uladzimirtumanau6240
    @uladzimirtumanau6240 ปีที่แล้ว

    Hi! Do you have a profile on kaggle?

    • @outliier
      @outliier  ปีที่แล้ว

      No I don’t :c

  • @davita6379
    @davita6379 2 ปีที่แล้ว +4

    I should have learned pytorch

  • @hassenzaayra5419
    @hassenzaayra5419 2 ปีที่แล้ว

    Hello. thank you so much for the video and source files. can you please add the test code or can you help me to create the test code

    • @outliier
      @outliier  2 ปีที่แล้ว

      What kind of test code are you talking about? All the code is on github. Did you see that?

  • @Paul-wk7rp
    @Paul-wk7rp 2 ปีที่แล้ว +1

    Cool

  • @helplearncenter198
    @helplearncenter198 2 ปีที่แล้ว

    thank you

  • @AileneTallant-h1l
    @AileneTallant-h1l 3 หลายเดือนก่อน

    Gregoria Dam

  • @batuhanbayraktar337
    @batuhanbayraktar337 2 ปีที่แล้ว

    Hello, I did 500 epoch training. But I only want the 500th epoch to generate 5000 images. How can I do?

  • @BuddyShanno-x8b
    @BuddyShanno-x8b 3 หลายเดือนก่อน

    Boehm Alley

  • @DorothyLopez-x5u
    @DorothyLopez-x5u 2 หลายเดือนก่อน

    Abshire Cliff

  • @PhilipPlunk-n9i
    @PhilipPlunk-n9i 3 หลายเดือนก่อน

    Emanuel Underpass

  • @SuzanneFleming-nj5cc
    @SuzanneFleming-nj5cc 3 หลายเดือนก่อน

    Bashirian Turnpike

  • @BlitheXaviera-f9m
    @BlitheXaviera-f9m 3 หลายเดือนก่อน

    Alexzander Locks

  • @DavidParker-e9n
    @DavidParker-e9n 3 หลายเดือนก่อน

    Rempel Spring

  • @sefadogantr
    @sefadogantr 2 ปีที่แล้ว

    Hello. First of all, many thanks for the video and source files. I want to develop a midjourney-like system to improve myself and I would like to ask you a few questions for your guidance.
    With the process I did in the video, we redrawn an existing image. When we make a system like this midjourney, at what stage will it work for us?
    I have seen projects written with VQGAN and CLIP over colab, but I want to write a system myself. What would you recommend? Which systems do you think I should use?
    Another question is, I remember doing faster tutorials with tensorflow. Would you suggest using tensorflow instead of pytorch?
    Thank you.

    • @outliier
      @outliier  2 ปีที่แล้ว

      Hey there, first of all I would recommend using pytorch. There is a much greater community out there in the generative field that is using pytorch. Second of all a VQGAN usually represents the first stage to compress data and remove redundancies. You would now need to learn a model which learns in this compressed stage. Thats what the transformer in the second stage is doing. I dont know exactly how midjourney is doing it, but for example stable diffusion uses the same approach of first learning a VQGAN and then learning a diffusion model in the latent space. So usually text-to-image tasks are done using transformers or diffusion models. You can watch my videos on diffusion models and maybe train them and eventually combine them with VQGAN which gives you latent diffusion (the method that stable diffusion is using). Let me know if you have further questions.

  • @WillPhoebe-w2b
    @WillPhoebe-w2b 3 หลายเดือนก่อน

    Glover Squares

  • @ShirelySwanner-t9t
    @ShirelySwanner-t9t 3 หลายเดือนก่อน

    Adams Valleys

  • @WallisDarren-f4n
    @WallisDarren-f4n 3 หลายเดือนก่อน

    Keeling Cape

  • @CarlFriel-c8w
    @CarlFriel-c8w 3 หลายเดือนก่อน

    Hills Meadows

  • @PatrickBlaise-b3x
    @PatrickBlaise-b3x 3 หลายเดือนก่อน

    Schneider Bridge

  • @KurtisErnst-b1q
    @KurtisErnst-b1q 3 หลายเดือนก่อน

    Ned Well

  • @PhylissFlemm-j6k
    @PhylissFlemm-j6k 3 หลายเดือนก่อน

    Labadie Cape

  • @AntoneSoliman-b3w
    @AntoneSoliman-b3w 3 หลายเดือนก่อน

    Bailey Point

  • @CooperLynn-l5x
    @CooperLynn-l5x 3 หลายเดือนก่อน

    Odessa River

  • @CliftonBrown-t5i
    @CliftonBrown-t5i 3 หลายเดือนก่อน

    Shanel Station

  • @DanGlover-b8c
    @DanGlover-b8c 3 หลายเดือนก่อน

    Danyka Locks

  • @flieskao9161
    @flieskao9161 ปีที่แล้ว

    这老外真牛逼,b站没一个讲的有你一半好的

  • @PatriciaLopez-s2y
    @PatriciaLopez-s2y 3 หลายเดือนก่อน

    Abbott Inlet

  • @FowlerWill-v9g
    @FowlerWill-v9g 3 หลายเดือนก่อน

    Kuhlman Mills

  • @CissieHugh-m3m
    @CissieHugh-m3m 3 หลายเดือนก่อน

    Pfeffer River

  • @EvrimAydın-g8x
    @EvrimAydın-g8x 3 หลายเดือนก่อน

    Pietro Union

  • @EdithHazel-g9o
    @EdithHazel-g9o 4 หลายเดือนก่อน

    Thompson Michelle Miller Ruth Thomas Mark

  • @JeffersonHilary
    @JeffersonHilary 4 หลายเดือนก่อน

    Martin Melissa Young Joseph Young Christopher