Transformer from Scratch (in PyTorch)


I implemented Transformer from scratch in PyTorch. Why would I do that in the first place?

Implementing scientific papers from scratch is something machine learning engineers rarely do these days, at least in my opinion. Most of the machine learning models are already implemented and optimized and all you have to do is tweak some code. Usually it’s the data loading code; sometimes it’s other parts of the code, but it’s almost never the case that you have to do some heavy coding if you’re using an already existing model.

Even though I was aware of the above, I still wanted to try my hand at implementing a paper from scratch. There was something alluring to me about the ability to do that: to read a paper and translate the concept(s) from that paper to (working) code.

I could have picked many different papers to implement from scratch. The reason why I chose to implement Transformer from scratch in particular was:

  • from what I read on some AI safety related posts (such as this one), it was a good test to see how good would you be as a machine learning research engineer
  • the Transformer model has been covered by a lot of blog posts and other resources, so I knew that if I got stuck I could read a lot of other material in order to try to understand the thing I’m stuck on better
  • natural language processing (NLP) is very interesting to me

The code I wrote can be found in this GitHub repository.

My goals

  • ensure that the model was training properly by getting to 0 loss (or near 0 loss) on a very small training set
  • ensure that the trained model worked better than the baseline (which was an untrained Transformer model)
  • ensure that a saved trained model worked properly after being loaded and that it worked when generating sequences token-by-token

All of my goals were achieved. The figure below confirms the achievement of the first goal:

The last two goals were confirmed by the outputs of the scripts which tested the trained model.

I picked the goals above because they ensured that my model was implemented correctly. If I can train the model and use it, I can be convinced that my implementation works. I could have put some other goals here as well, but they would be out of scope for what I had in mind – I wanted to test if I could implement concepts from scientific papers and that was it. All of the other things could be considered for future work.

Knowledge I had prior to this project

Prior to this project, I had pretty limited experience with natural language processing (NLP). One of my larger projects was fine-tuning GPT-2 so that it generates movie scripts, but that was done relying on Hugging Face. I did have 2-3 years of machine learning engineering experience (or related), but that was almost exclusively in the field of computer vision.

I’ll do my best to enumerate what I knew at the point of starting the project (which is relevant to the project):

  • I knew that in order to encode tokens (which could be words or characters) I needed to use embeddings and that the output of the model would be a probability distribution over the next token; in general, I was aware of the “pipeline” of how things are done in NLP
  • I was aware of the encoder-decoder architecture, but other than using them within a framework, I had limited experience with them (let alone coding them up from scratch)
  • I had a vague understanding of sequence-to-sequence models, but I didn’t really understand how they worked in detail
  • I didn’t have any experience with implementing models from scratch in PyTorch. I did implement a small neural network from scratch, but I didn’t use PyTorch there and that neural network was rather simple

Rules I had

I had some rules regarding how this project was to be conducted. They were as follows:

  1. I was allowed to read only the original paper and start writing code from there. If I got stuck, I was allowed to go to step number 2.
  2. I was allowed to consult blog posts and/or articles containing explanations for the things I didn’t understand. If it contained code by any chance, I wouldn’t look at it. If I got stuck here as well, I was allowed to go to step number 3.
  3. I was allowed to ask questions on particular things I didn’t understand or about a particular bug in my code, after I tried to solve the misunderstanding or the bug by myself for some reasonable amount of time. Here I want to extend a big thank you to Louis Bradshaw, who helped me clear up some conceptual misunderstandings and also gave me advice on how to best approach finding the bugs I had in my code and solving some specific ones I found.

    Finally, if none of this worked, I was allowed to go to step 4.

  4. I was allowed to look at existing Transformer implementations. Also, I was allowed to copy/paste certain parts of code. I never got to this step.

Total time spent on this project

During the project, I kept a time log. I want to note that when I worked on this project, I generally worked for 30 minutes, then took a 10 minute break, then worked for 30 minutes again, then I took another 10 minute break etc. So for example, if I say I worked for 40 minutes, 30 minutes was actually me sitting on a computer working, while 10 minutes was me walking around the room resting. This work/rest split is something I found optimal for myself and I’m noting it here so you can keep it in mind when reading the numbers below.

I spent around 3550 minutes on this project in total, which translates to around 60 hours. By skimming my time log, I’d say that around 10 to 15 hours was spent on clarifying my conceptual understanding, while the other time was spent writing code, re-writing code and debugging; around 5 hours was spent on writing dataset-related code.

If you want to look at the full time log, it can be found in the Appendix.

Useful resources if you’re implementing the Transformer from scratch or trying to understand it

Below I list all the resources I found very useful during this project. You can find more by Googling when you don’t understand something in particular; I’m listing the resources I found really useful. Although the list is unordered, I tried to put the resources I used at the beginning of the project at the beginning of the list and vice versa.

  • Attention Is All You Need – the original paper (duh)
  • The Illustrated Transformer – a great, simple explanation of how Transformers work and how they are trained
  • How Do Self-Attention Masks Work? – an amazing guide explaining how do self-attention masks work; this was very, very useful when I was implementing masking
  • How to understand masked multi-head attention in transformer – Stack Overflow answer which gives another attempt at explaining masking; after reading this and the Medium article in the previous bullet point I think it “clicked” for me
  • A Recipe for Training Neural Networks – a guide on how to implement papers by Andrej Karpathy. I wish I had known about this earlier (I found out about this blog post about ¾ through the project) as it contains some good tips. That being said, I am unsure about how I would implement some of the advice (i.e. I don’t really see a good way of testing the encoder-decoder architecture without implementing them both)

Notes on the things which were unclear to me or tripped me up

I list some things which were unclear to me or tripped me up below. This isn’t an exhaustive list; there were many small and large “aha” moments I had, but these are the ones that I deem most prominent. In the list, I will use Python-style indexing to explain certain things:

  • When reading the original paper, I thought that Multi-Head Attention “breaks up” the embedding and feeds it to the Scaled Dot-Product Attention. To clarify, if the embedding dimensions are 512 and the keys, values and queries dimensions are all 64 and the number of attention heads is 8, I thought I should take the first 64 dimensions of the embedding (dimensions with indices from 0 to 63), pass it into one Scaled Dot-Product Attention layer, then take the second 64 dimensions of the embedding (dimensions with indices from 64 to 127), pass it into another Scaled Dot-Product Attention layer and so forth. This was wrong, as I later learned in The Illustrated Transformer; the right way to implement this is to pass the entire embedding to the Scaled Dot-Product Attention layer, which then “spits out” lower-dimensional vectors, which are then all concatenated and passed through a final linear layer within the Multi-Head Attention.
  • Let’s say that you have 18 tokens in your target sequence and that the vocabulary (vocab) size is 50000. I’m ignoring batches here. During training, the outputs of your decoder (if you implemented everything correctly) will be of dimensions [18, 50000] where each of the 50000-dimension vectors represent the logits of the next token. So, if I wanted to know which token is the most probable next token after token at index 15, I would take the softmax of [15, :], then I’d take the argmax of that and I’d get the token index of the most probable next token.
  • Related to the previous bullet point, when generating target sequence token-by-token, the next token probability is always obtained from the last output which has vocab size dimensions. In other words, if I want to predict the next token in the sequence, I’d take the softmax, then argmax of the logits at index [-1, :].
  • If you will be putting layers into lists, don’t put them in regular Python lists because when you save the model those weights won’t be saved (even though they get updated during training). Use ModuleList. It took me some time to figure out that my model wasn’t saving and loading properly; it was exactly because I used regular Python lists to store multiple layers of the same type, instead of ModuleList.
  • You should use masking both for training and for inference.
  • I found debugging quite hard, as debugging this model had a new layer of being wrong, which was “I implemented what I understood correctly, but what I understood was wrong”. I found it useful to print out all the shapes (and sometimes the values) of all of the things passing through the model – all the way from the embeddings to the final result. I also tested things one-by-one. For example, if my trained model didn’t work when I generated the target sequence token-by-token, I tested it to see if it worked on the entire target sequence. If it did, I knew that there was a problem in the token generation part of the code. This way, I narrowed down the problem until I eventually fixed it.
  • Be careful when stacking tensors; make sure they are of appropriate dimensions when they come out. I ended up having errors because I exchanged and torch.stack.


This project was, as far as I can recall, the most intellectually challenging project I have done so far. This was a project where I had to understand things first, implement them and then repeat the process as I discovered that something doesn’t work as intended. This is different from almost all the other projects I worked on in the sense that a lot of my mental bandwidth went on understanding the concepts; I usually understand (almost) all the concepts when programming something other than a scientific paper.

This project also opened my mind in the sense that I now know that I can implement and, more importantly, understand almost everything (if not everything). I just need enough time and patience and if I have that, who knows where the limits of my abilities lie? I hope to find this out over the course of my career.

Appendix: Time log

Here is the time log in its “raw” format, if anyone is curious:

  • ~2-4 h reading papers
  • 40 min setting up virtual environment
  • ~10 min setting up virtual environment; 30 min re-reading the paper and writing code for scaled dot-product attention
  • 40 min writing code for scaled dot-product attention and the scaffolding code (creating a module out of the layers folder and writing the scaffolding for the test)
  • 20 min debugging and writing code for the Scaled Dot-Product Attention layer
  • ~30 min on writing a test for the Scaled Dot-Product Attention layer and 10 min for Multi-Head Attention
  • 40 min writing code for Multi-Head Attention
  • 40 min writing and debugging code for Multi-Head Attention
  • 40 min debugging code for Multi-Head Attention
  • 40 min reading the paper again (to catch some technical details), some Medium articles and writing scaffolding code for the Encoder block
  • 40 min writing code for the Encoder block
  • ~10 minutes reading some Medium articles
  • 40 min writing code for the Encoder block and the Feed Forward layer and the Feed Forward layer test
  • 20 min for writing the Encoder block test and revisiting the Decoder block architecture from the paper
  • 20 min of reading
  • 40 min reading Visualizing A Neural Machine Translation Model (Mechanics of Seq2seq Models With Attention), then The Illustrated Transformer and changing up the Encoder block code so that the feedforward network layer is shared between all blocks
  • 40 min reading The Illustrated Transformer
  • 5-10 min reading The Illustrated Transformer
  • 20 min reading The Illustrated Transformer
  • 40 min re-implementing the Scaled Dot-Product Attention layer and the Multi-Head Attention layer (haven’t finished the Multi-Head Attention layer)
  • ~10 min re-implementing the Multi-Head Attention layer and moving some code to deprecated folder
  • ~15 min re-implementing the Scaled Dot-Product Attention layer test and Multi-Head Attention layer test
  • ~5 min double-checking the FeedForward layer implementation and its test
  • 20 min moved the Encoder_first_try to deprecated folder and explained why it’s deprecated
  • 40 min re-implementing the Encoder Block, double-checking its test and reading about the Decoder in The Illustrated Transformer
  • ~15 min reading about the decoder in The Illustrated Transformer and The Attention Is All You Need paper
  • 40 min implementing masking in the Scaled Dot-Product Attention layer (based on
  • 40 min reading about the Decoder
  • 40 min implementing Encoder-Decoder Attention and Scaled Dot-Product Attention for the Decoder
  • 40 min implementing the Decoder block and its test
  • 20 min reading up on Encoder-Decoder training procedure
  • 40 min implementing the Transformer Encoder model
  • 40 min implementing the Transformer model
  • ~15 min implementing the Transformer model
  • 40 min implementing the Transformer Encoder model and the Transformer Decoder model and their associated tests
  • ~20 min reseraching into embeddings
  • 40 min looking into embeddings
  • ~40 min reading about the embedding and the dataset used in the paper and BPEmb
  • 40 min writing the Encoder Train Dataset PyTorch class using BPEmb
  • ~50 min looking into BPEmb
  • ~20 min looking into Encoder-Decoder training
  • ~30 min looking into Encoder-Decoder training
  • 40 min looking into Encoder-Decoder training
  • 40 min looking into Encoder-Decoder training and writing the Encoder Dataset class
  • 40 min writing the Encoder Dataset class (I was distracted here by listening to the Lex Fridman podcast)
  • 40 min implementing the positional encoding and testing it
  • 40 min implementing the Decoder Dataset class and the forward pass code (
  • 40 min implementing the forward pass code ( and reviewing The Illustrated Transformer
  • 40 min debugging the forward pass code (
  • 40 min looking into masking and debugging the forward pass code (
  • ~50 min debugging the forward pass code (
  • 10 min implementing the training loop, then realizing I needed to rewrite it
  • 40 min implementing the TransformerDataset class
  • ~15 min writing
  • 40 min looking into dataloader error when batch_size > 1 and writing
  • ~40 min writing
  • 40 min looking into BPEmb
  • ~30 min re-writing
  • 10 min debugging
  • ~40 min looking into training
  • ~40 min looking into training
  • ~10 min looking into training
  • ~40 min looking into training and writing and
  • 40 min writing
  • ~20 min writing the Transformer class and re-writing and
  • ~5 min looking into Cross-Entropy loss
  • ~45 min reading about the training parameters in the paper and writing
  • 40 min implementing
  • ~25 min reading about training in the original paper and implementing that in my model layers and/or code
  • ~15 min looking at cross entropy and learning rates
  • ~25 min reading Andrej Karpathy’s A Recipe for Training Neural Networks (
  • ~40 min debugging the model training
  • ~5 min setting up model hyperparameters (number of epochs etc.) and editing code
  • 40 min loading up the Transformer trained weights and see how they perform against the baseline (randomly initialized weights)
  • ~35 min re-reading the testing code, looking for errors
  • ~5-10 min looking into the testing code to find bugs
  • 20 min looking into and seeing if everything is OK
  • ~20 min looking at debug output log and testing code
  • ~40 min debugging testing code
  • 40 min implementing masking for inference
  • ~20 min debugging inference
  • 40 min reading the masking article again and looking for bugs in my inference code
  • 5-10 min looking into masking
  • ~20 min debugging inference
  • ~15 min debugging inference
  • 40 min debugging inference – found one of the bugs; I was training without masking and running inference with masking
  • ~50 min debugging inference
  • ~1 h debugging inference – found one of the bugs; it was the fact that PyTorch didn’t save all the model weights; it saved only the immediate layers in the TransformerEncoder and TransformerDecoder instances in the Transformer class, but it didn’t save the weights of the other layers TransformerEncoder and TransformerDecoder were composed of
  • ~30 min looking at training log output and reading about saving all of the sub-layer weights of a model
  • ~5 min testing inference
  • ~10 min debugging inference – my training loss wasn’t 0, so that’s why some of the predictions bad
  • ~40 min checking out how the training is going and trying out different learning rates
  • ~40 min debugging
  • 40 min re-writing some code (re-naming variables etc.)
  • 40 min re-writing some code (re-naming variables etc.) and starting the training again due to renaming variables
  • ~40 min adding positional encoding to, installing packages and writing loss visualization code (Jupyter notebook)
  • ~40 min writing README and testing
  • ~1 h 50 min writing anew
  • ~40 min debugging inference
  • ~1 h 30 min debugging inference and re-writing a small part of – the bug in inference was related to the fact that positional encoding got passed a matrix of shape [1, 100] and it iterated over the dimension of 1, not 100 as was expected
  • 1 h tidying up code and wrapping things up
  • 40 min tidying up the repository and starting to write the writeup
NewsletterUpdates on interesting things I am doing

Subscribe to my newsletter to keep abreast of the interesting things I'm doing. I will send you the newsletter only when there is something interesting. This means 0% spam, 100% interesting content.