I want to point out that after training, the inner loop in the RNN still needs to be "optimized" since the gradient update rule is the RNN update rule itself. The normal NLL training procedure just trains the model on how to "teach" the internal RNN layers. This isn't normal optimization though. It doesn't require an optimizer since we know the functional form of the gradient wrt. to hidden state so it's very efficient.
I really like seeing videos like this that are actually willing to explore the math and break it down to make it easier to catch onto how it actually works. Too many just discuss the abstract or skip past the math because it hampers audience maximization.
I have understand that the inner loop will update W_t with the modified reconstruction loss, then how is the theta_K, theta_V, theta_Q being updated in outer loop? specifically, what loss is related to these three parameters?
The "outer loop" uses the normal training strategy where we have the negative log likelihood or cross entropy loss for next token prediction. This gradient of this loss backpropogated to all layers like normal. So the "outer loop" is just the rest of the network.
@@gabrielmongaras but the normal cross entropy loss for next token prediction should not incur any gradient flow to theta_K and theta_V? They are for reconstruction loss, so they are not involved in token prediction
The theta params are trained via the next token prediction loss. This can be thought of as the outer model querying the inner model for information by changing the loss function with these theta params. I think the outer loop actually differentiates the inner loop (including the gradient of the loss) so the K,V there params are updated by the outer loop in this way. The only param that's "trained" using the inner loop is the hidden state.
@@gabrielmongaras Maybe similar question to this and please help to clarify. Inner loop is to update W using the loss function in (4). Does it mean that theta_K and theta_V are 'not' updated based on the loss in (4) ? I guess it is not. Otherwise, how to update theta_Q that is not shown in (4). I guess theta's are updated based on the other loss function but not written precisely in the paper.
I want to point out that after training, the inner loop in the RNN still needs to be "optimized" since the gradient update rule is the RNN update rule itself. The normal NLL training procedure just trains the model on how to "teach" the internal RNN layers. This isn't normal optimization though. It doesn't require an optimizer since we know the functional form of the gradient wrt. to hidden state so it's very efficient.
I really like seeing videos like this that are actually willing to explore the math and break it down to make it easier to catch onto how it actually works. Too many just discuss the abstract or skip past the math because it hampers audience maximization.
What is the application you're using to annotate or take notes?
I have understand that the inner loop will update W_t with the modified reconstruction loss, then how is the theta_K, theta_V, theta_Q being updated in outer loop? specifically, what loss is related to these three parameters?
The "outer loop" uses the normal training strategy where we have the negative log likelihood or cross entropy loss for next token prediction. This gradient of this loss backpropogated to all layers like normal. So the "outer loop" is just the rest of the network.
@@gabrielmongaras but the normal cross entropy loss for next token prediction should not incur any gradient flow to theta_K and theta_V? They are for reconstruction loss, so they are not involved in token prediction
The theta params are trained via the next token prediction loss. This can be thought of as the outer model querying the inner model for information by changing the loss function with these theta params. I think the outer loop actually differentiates the inner loop (including the gradient of the loss) so the K,V there params are updated by the outer loop in this way. The only param that's "trained" using the inner loop is the hidden state.
@@gabrielmongaras Maybe similar question to this and please help to clarify. Inner loop is to update W using the loss function in (4). Does it mean that theta_K and theta_V are 'not' updated based on the loss in (4) ? I guess it is not. Otherwise, how to update theta_Q that is not shown in (4). I guess theta's are updated based on the other loss function but not written precisely in the paper.