After playing with jax I don’t feel comfortable linking the notebook I showed in the video. Most of the video content is still valid because it shows the difference between pyg and jraph. However, in the notebook I used haiku which is not recommended by Google DeepMind. They recommend using flax instead. So, I linked a new notebook showing GCN code in JAX/Flax: github.com/mashaan14/TH-cam-channel/blob/main/notebooks/2024_03_21_jraph_GCN.ipynb Here’s another video where I explained graph attention code in JAX/Flax: th-cam.com/video/O1zGWMEgW7A/w-d-xo.html
Hi man, beautiful video explaining both libraries! Loved your explanation; clear and on point. About the issue that the test results of PyG and Jraph being different, I think it is because even though both obtained 100% training accuracy (which also means they overfitted the data), the decision boundary they draw for the training set is not necessarily the same. One reason that might lead them to be different is that PyG's and Jraph's GNN weights are probably initialized randomly. Therefore, their different decision boundaries can easily result in 2 different results on the test set.
I loved your explanation, yeah it totally makes sense. But it was far simpler than that, I was training on two different feature matrices. If you notice in the jraph part, I passed this command: nodes=jnp.eye(data_Cora.x.shape[0]) I was training jraph on the identity matrix while training pyg on the feature matrix. I know it’s crazy how jraph got so close with only the identity matrix. Anyways, I couldn’t fix the notebook in the video because it was written in haiku. So I took it down and write a new one with JAX/Flax: github.com/mashaan14/TH-cam-channel/blob/main/notebooks/2024_03_21_jraph_GCN.ipynb I’d love if you can take a look at the new code.
After playing with jax I don’t feel comfortable linking the notebook I showed in the video. Most of the video content is still valid because it shows the difference between pyg and jraph. However, in the notebook I used haiku which is not recommended by Google DeepMind. They recommend using flax instead. So, I linked a new notebook showing GCN code in JAX/Flax:
github.com/mashaan14/TH-cam-channel/blob/main/notebooks/2024_03_21_jraph_GCN.ipynb
Here’s another video where I explained graph attention code in JAX/Flax:
th-cam.com/video/O1zGWMEgW7A/w-d-xo.html
Hi man, beautiful video explaining both libraries! Loved your explanation; clear and on point.
About the issue that the test results of PyG and Jraph being different, I think it is because even though both obtained 100% training accuracy (which also means they overfitted the data), the decision boundary they draw for the training set is not necessarily the same. One reason that might lead them to be different is that PyG's and Jraph's GNN weights are probably initialized randomly.
Therefore, their different decision boundaries can easily result in 2 different results on the test set.
I loved your explanation, yeah it totally makes sense. But it was far simpler than that, I was training on two different feature matrices. If you notice in the jraph part, I passed this command:
nodes=jnp.eye(data_Cora.x.shape[0])
I was training jraph on the identity matrix while training pyg on the feature matrix. I know it’s crazy how jraph got so close with only the identity matrix.
Anyways, I couldn’t fix the notebook in the video because it was written in haiku. So I took it down and write a new one with JAX/Flax:
github.com/mashaan14/TH-cam-channel/blob/main/notebooks/2024_03_21_jraph_GCN.ipynb
I’d love if you can take a look at the new code.