00:00:00.000 Hi everyone. So by now you have probably heard of chat GPT. It has taken the world and AI community

00:00:05.680 by storm and it is a system that allows you to interact with an AI and give it text-based tasks.

00:00:12.560 So for example, we can ask chat GPT to write us a small haiku about how important it is that people

00:00:17.200 understand AI and then they can use it to improve the world and make it more prosperous. So when we

00:00:21.920 run this, AI knowledge brings prosperity for all to see, embrace its power. Okay, not bad. And so

00:00:29.360 you could see that chat GPT went from left to right and generated all these words sequentially.

00:00:35.200 Now I asked it already the exact same prompt a little bit earlier and it generated a slightly

00:00:40.640 different outcome. AI's power to grow, ignorance holds us back, learn, prosperity, weights.

00:00:46.400 So pretty good in both cases and slightly different. So you can see that chat GPT is a probabilistic

00:00:52.080 system and for any one prompt it can give us multiple answers sort of replying to it.

00:00:58.160 Now this is just one example of a prompt. People have come up with many, many examples and there

00:01:02.640 are entire websites that index interactions with chat GPT. And so many of them are quite humorous,

00:01:09.360 explain HTML to me like I'm a dog, write release notes for chess too, write a note about Elon Musk

00:01:15.760 buying a Twitter and so on. So as an example, please write a breaking news article about a leaf

00:01:22.400 falling from a tree and a shocking turn of events. A leaf has fallen from a tree in the local park.

00:01:28.480 Witnesses report that the leaf which was previously attached to a branch of a tree

00:01:32.320 detached itself and fell to the ground. Very dramatic. So you can see that this is a pretty

00:01:37.360 remarkable system and it is what we call a language model because it models the sequence of words

00:01:45.040 or characters or tokens more generally and it knows how sort of words follow each other in English

00:01:50.960 language. And so from its perspective, what it is doing is it is completing the sequence. So I give

00:01:57.520 it the start of a sequence and it completes the sequence with the outcome. And so it's a language

00:02:02.800 model in that sense. Now I would like to focus on the under the hood of under the hood components

00:02:09.360 of what makes chat GPT work. So what is the neural network under the hood that models the

00:02:14.240 sequence of these words. And that comes from this paper called Attention is All You Need.

00:02:20.160 In 2017, a landmark paper, a landmark paper and AI that proposed the transformer architecture.

00:02:27.600 So GPT is short for generally, generatively pre-trained transformer. So transformer is the

00:02:35.200 neural net that actually does all the heavy lifting under the hood. It comes from this paper in 2017.

00:02:41.040 Now if you read this paper, this reads like a pretty random machine translation paper. And that's

00:02:46.000 because I think the authors didn't fully anticipate the impact that the transformer would have on the

00:02:50.000 field. And this architecture that they produced in the context of machine translation in their case

00:02:55.600 actually ended up taking over the rest of AI in the next five years after. And so this architecture

00:03:02.320 with minor changes was copy-pasted into a huge amount of applications in AI in more recent years.

00:03:09.040 And that includes at the core of chat GPT. Now we are not going to what I'd like to do now is I'd

00:03:15.360 like to build out something like chat GPT. But we're not going to be able to of course reproduce

00:03:20.560 chat GPT. This is a very serious production grade system. It is trained on a good chunk of

00:03:26.880 internet. And then there's a lot of pre-training and fine-tuning stages to it. And so it's very

00:03:32.080 complicated. What I'd like to focus on is just to train a transformer based language model.

00:03:38.480 And in our case, it's going to be a character level language model. I still think that is a

00:03:43.680 very educational with respect to how these systems work. So I don't want to train on the chunk of

00:03:48.080 internet. We need a smaller data set. In this case, I propose that we work with my favorite toy data

00:03:53.680 set. It's called tiny Shakespeare. And what it is is basically it's a concatenation of all of the

00:03:59.280 works of Shakespeare in my understanding. And so this is all of Shakespeare in a single file.

00:04:05.040 This file is about one megabyte. And it's just all of Shakespeare. And what we are going to do now

00:04:10.960 is we're going to basically model how these characters follow each other. So for example,

00:04:15.840 given a chunk of these characters like this, given some context of characters in the past,

00:04:22.320 the transformer neural network will look at the characters that I've highlighted. And it's going

00:04:26.320 to predict that G is likely to come next in the sequence. And it's going to do that because we're

00:04:31.040 going to train that transformer on Shakespeare. And it's just going to try to produce character

00:04:36.960 sequences that look like this. And in that process, it's going to model all the patterns inside this

00:04:41.920 data. So once we've trained the system, I just like to give you a preview, we can generate infinite

00:04:48.720 Shakespeare. And of course, it's a fake thing that looks kind of like Shakespeare. Apologies for

00:04:57.920 there's some jank that I'm not able to resolve in here. But you can see how this is going character

00:05:05.440 by character. And it's kind of like predicting Shakespeare-like language. So, "Verily, my lord,

00:05:11.520 the sights have left the king coming with my curses with precious pale." And then Tranios

00:05:19.520 says something else, etc. And this is just coming out of the transformer in a very similar manner,

00:05:24.400 as it would come out in Cha-Cha-GPT. In our case, character by character, in Cha-Cha-GPT,

00:05:29.440 it's coming out on a token by token level. And tokens are these sort of like little

00:05:34.720 sub-word pieces. So they're not word level, they're kind of like word chunk level.

00:05:38.480 And now I've already written this entire code to train these transformers. And it is in a

00:05:48.320 GitHub repository that you can find. And it's called Nano-GPT. So Nano-GPT is a repository that

00:05:54.240 you can find on my GitHub. And it's a repository for training transformers on any given text.

00:06:00.400 And what I think is interesting about it, because there's many ways to train transformers,

00:06:04.640 but this is a very simple implementation. So it's just two files of 300 lines of code, each.

00:06:09.920 One file defiles the Cha-Cha-PT model, the transformer, and one file trains it on some given text dataset.

00:06:16.320 And here I'm showing that if you train it on the open WebText dataset, which is a fairly

00:06:20.160 large dataset of Web pages, then I reproduce the performance of GPT-2. So GPT-2 is an early version

00:06:28.320 of OpenAI's GPT from 2017, if I recall correctly. And I've only so far reproduced the smallest

00:06:35.760 124 million parameter model. But basically, this is just proving that the code base is correctly

00:06:40.240 arranged. And I'm able to load the neural network weights that OpenAI has released later.

00:06:46.960 So you can take a look at the finished code here in Nano-GPT, but what I would like to do in this

00:06:51.840 lecture is I would like to basically write this repository from scratch. So we're going to begin

00:06:57.840 with an empty file, and we're going to define a transformer piece by piece. We're going to train

00:07:04.000 it on the Titan Shakespeare dataset. And we'll see how we can then generate infinite Shakespeare.

00:07:09.520 And of course, this can copy paste to any arbitrary text dataset that you like.

00:07:14.800 But my goal really here is to just make you understand and appreciate how under the

00:07:19.440 hood chat GPT works. And really, all that's required is a proficiency in Python, and some basic

00:07:27.040 understanding of copulus statistics. And it would help if you also see my previous videos on the

00:07:33.440 same YouTube channel, in particular, my Make More series, where I define smaller and simpler

00:07:40.960 neural network language models. So multi-layer perceptrons and so on. It really introduces

00:07:46.000 the language modeling framework. And then here in this video, we're going to focus on the transformer

00:07:50.720 neural network itself. Okay, so I created a new Google collab Jupyter notebook here.

00:07:56.240 And this will allow me to later easily share this code that we're going to develop together

00:08:01.040 with you so you can follow along. So this will be in a video description later.

00:08:04.720 Now here, I've just done some preliminarys. I downloaded the dataset, the Titan Shakespeare

00:08:10.320 dataset at this URL. And you can see that it's about a one megabyte file. Then here, I open the

00:08:15.840 input.txt file and just read in all the text of the string. And we see that we are working with

00:08:20.800 1 million characters roughly. And the first 1000 characters, if we just print them out,

00:08:26.000 are basically what you would expect. This is the first 1000 characters of the Titan Shakespeare

00:08:30.640 dataset, roughly up to here. So so far, so good. Next, we're going to take this text.

00:08:37.280 And the text is a sequence of characters in Python. So when I call the set constructor on it,

00:08:42.080 I'm just going to get the set of all the characters that occur in this text. And then I call list on

00:08:49.600 that to create a list of those characters instead of just a set so that I have an ordering, an arbitrary

00:08:54.800 ordering. And then I sort that. So basically, we get just all the characters that occur in the

00:09:00.560 entire data set and they're sorted. Now the number of them is going to be our vocabulary size. These

00:09:05.680 are the possible elements of our sequences. And we see that when I print here, the characters,

00:09:10.800 there's 65 of them in total. There's a space character, and then all kinds of special characters.

00:09:17.200 And then capitals and lowercase letters. So that's our vocabulary. And that's the sort of like

00:09:23.360 possible characters that the model can see or emit. Okay, so next, we would like to develop

00:09:30.240 some strategy to tokenize the input text. Now, when people say tokenize, they mean convert the raw

00:09:37.360 text as a string to some sequence of integers, according to some note, according to some vocabulary

00:09:43.520 of possible elements. So as an example, here, we are going to be building a character level

00:09:48.640 language model. So we're simply going to be translating individual characters into integers.

00:09:52.640 So let me show you a chunk of code that sort of does that for us. So we're building both the

00:09:58.080 encoder and the decoder. And let me just talk through what's happening here. When we encode an

00:10:04.240 arbitrary text, like hi there, we're going to receive a list of integers that represents that

00:10:10.640 string. So for example, 46 47, etc. And then we also have the reverse mapping. So we can take

00:10:17.680 this list and decode it to get back the exact same string. So it's really just like a translation

00:10:23.760 to integers and back for arbitrary string. And for us, it is done on a character level.

00:10:28.800 Now the way this was achieved is we just iterate over all the characters here and create a lookup

00:10:34.720 table from the character to the integer and vice versa. And then to encode some string,

00:10:40.000 we simply translate all the characters individually and to decode it back, we use the reverse mapping

00:10:45.600 and concatenate all of it. Now this is only one of many possible encodings or many possible sort

00:10:51.360 of tokenizers. And it's a very simple one. But there's many other schemas that people have come up with

00:10:56.560 in practice. So for example, Google uses a sentence piece. So sentence piece will also encode text into

00:11:03.280 integers, but in a different schema, and using a different vocabulary. And sentence piece is a

00:11:10.720 sub word sort of tokenizer. And what that means is that you're not encoding entire words, but you're

00:11:17.040 not also encoding individual characters. It's it's a sub word unit level. And that's usually what's

00:11:23.040 adopted in practice. For example, also OpenAI has this library called tick token that uses a

00:11:28.320 byte pair encoding tokenizer. And that's what GPT uses. And you can also just encode words into

00:11:35.680 like hell world into lists of integers. So as an example, I'm using the tick token library here.

00:11:42.080 I'm getting the encoding from GPT two, or that was used for GPT two, instead of just having 65

00:11:47.600 possible characters or tokens, they have 50,000 tokens. And so when they encode the exact same

00:11:54.720 string high there, we only get a list of three integers. But those integers are not between zero

00:12:00.480 and 64, they are between zero and 5000 50,000 256. So basically, you can trade off the codebook size

00:12:10.320 and the sequence lengths. So you can have very long sequences of integers with very small

00:12:15.040 vocabularies, or we can have short sequences of integers with very large vocabularies. And so

00:12:23.520 typically people use in practice these sub word encodings. But I'd like to keep our tokenizer very

00:12:29.840 simple. So we're using character level tokenizer. And that means that we have very small codebooks.

00:12:34.800 We have very simple encode and decode functions. But we do get very long sequences as a result.

00:12:41.760 But that's the level at which we're going to stick with this lecture, because it's the simplest

00:12:45.120 thing. Okay, so now that we have an encoder and a decoder, effectively tokenizer, we can

00:12:50.240 tokenize the entire training set of Shakespeare. So here's a chunk of code that does that. And I'm

00:12:55.680 going to start to use the pytorch library, and specifically the tork.tensor from the pytorch library.

00:13:00.960 So we're going to take all of the text in tiny Shakespeare, encode it, and then wrap it into a

00:13:06.080 torch.tensor to get the data tensor. So here's what the data tensor looks like when I look at just

00:13:11.440 the first 1000 characters or the 1000 elements of it. So we see that we have a massive sequence

00:13:16.960 of integers. And this sequence of integers here is basically an identical translation of the first

00:13:22.400 1000 characters here. So I believe, for example, that zero is a new line character, and maybe one

00:13:29.280 is a space, not 100% sure. But from now on, the entire data set of text is rerepresented as just

00:13:35.360 it just stretched out is a single very large sequence of integers. Let me do one more thing

00:13:41.200 before we move on here. I'd like to separate out our data set into a train and a validation split.

00:13:46.400 So in particular, we're going to take the first 90% of the data set, and consider that to be the

00:13:52.400 training data for the transformer. And we're going to withhold the last 10% at the end of it

00:13:57.440 to be the validation data. And this will help us understand to what extent our model is overfitting.

00:14:02.480 So we're going to basically hide and keep the validation data on the side, because we don't

00:14:06.800 want just a perfect memorization of this exact Shakespeare, we want a neural network that sort of

00:14:12.080 creates Shakespeare like text. And so it should be fairly likely for it to produce the actual like

00:14:19.040 stowed away, true Shakespeare text. And so we're going to use this to get a sense of the overfitting.

00:14:27.280 Okay, so now we would like to start plugging these text sequences or integer sequences into

00:14:32.000 the transformer so that it can train and learn those patterns. Now, the important thing to realize

00:14:37.680 is we're never going to actually feed entire text into your transformer all at once. That

00:14:41.840 would be computationally very expensive and prohibitive. So when we actually train a transformer on a

00:14:46.880 lot of these data sets, we only work with chunks of the data set. And when we train the transformer,

00:14:51.760 we basically sample random little chunks out of the training set and train on just chunks at a time.

00:14:57.520 And these chunks have basically some kind of a length and some maximum length. Now, the maximum

00:15:04.240 length typically, at least in the code I usually write is called block size. You can you can

00:15:09.200 find it on the different names, like context length or something like that. Let's start with

00:15:13.680 the box size of just eight. And let me look at the first train data characters, the first block size

00:15:19.360 plus one characters. I'll explain why plus one in a second. So this is the first nine characters

00:15:25.600 in the sequence in the training set. Now, what I'd like to point out is that when you sample a

00:15:31.040 chunk of data like this, so say if these nine characters out of the training set, this actually

00:15:36.720 has multiple examples packed into it. And that's because all of these characters follow each other.

00:15:42.640 And so what this thing is going to say when we plug it into a transformer is we're going to

00:15:49.440 actually simultaneously train it to make prediction at every one of these positions. Now, in the

00:15:55.680 in a chunk of nine characters, there's actually eight individual examples packed in there. So

00:16:01.280 there's the example that when 18 when in the context of 18, 47 luckily comes next, in a context

00:16:08.480 of 18 and 47 56 comes next, in a context of 18, 47, 56, 57 can come next and so on. So that's the

00:16:17.680 eight individual examples. Let me actually spell it out with code. So here's a chunk of code to

00:16:23.680 illustrate. X are the inputs to the transformer. It will just be the first block size characters.

00:16:29.280 Y will be the next block size characters. So it's offset by one. And that's because Y are the

00:16:37.440 targets for each position in the input. And then here I'm iterating over all the block size of eight.

00:16:44.080 And the context is always all the characters in X up to T and including T. And the target is

00:16:51.840 always the teeth character, but in the targets array Y. So let me just run this. And basically,

00:16:58.800 it spells out what I said in words. These are the eight examples hidden in a chunk of nine characters

00:17:05.280 that we sampled from the training set. I want to mention one more thing. We train on all the

00:17:13.120 eight examples here with context between one all the way up to context of block size. And we train

00:17:19.440 on that not just for computational reasons, because we happen to have the sequence already or something

00:17:23.200 like that. It's not just just for efficiency. It's also done to make the transformer network

00:17:29.440 be used to seeing contexts all the way from as little as one all the way to block size. And we'd

00:17:36.240 like the transformer to be used to seeing everything in between. And that's going to be useful later

00:17:40.560 during inference, because while we're sampling, we can start to set sampling generation with

00:17:45.280 as little as one character of context. And then transformer knows how to predict next character

00:17:50.000 with all the way up to just context of one. And so then it can predict everything up to block size.

00:17:55.280 And after block size, we have to start truncating, because the transformer will never

00:17:59.120 receive more than block size inputs when it's predicting the next character. Okay, so we've looked

00:18:05.280 at the time dimension of the tensors that are going to be feeding into the transformer. There's

00:18:09.680 one more dimension to care about and that is the batch dimension. And so as we're sampling these

00:18:14.160 chunks of text, we're going to be actually every time we're going to feed them into a transformer,

00:18:19.120 we're going to have many batches of multiple chunks of text that are always like stacked up

00:18:23.120 in a single tensor. And that's just done for efficiency, just so that we can keep the GPUs busy,

00:18:27.680 because they are very good at parallel processing of data. And so we just want to process multiple

00:18:35.120 chunks all at the same time. But those chunks are processed completely independently. They don't

00:18:39.120 talk to each other and so on. So let me basically just generalize this and introduce a batch

00:18:44.080 dimension. Here's a chunk of code. Let me just run it. And then I'm going to explain what it does.

00:18:49.440 So here, because we're going to start sampling random locations in the data sets to pull chunks

00:18:55.760 from, I am setting the seed so that in the random number generator, so that the numbers I see here

00:19:02.240 are going to be the same numbers you see later, if you try to reproduce this. Now the batch size

00:19:07.040 here is how many independent sequences we are processing every forward backward pass of the

00:19:11.680 transformer. The block size, as I explained, is the maximum context length to make those predictions.

00:19:17.600 So let's say by size four, block size eight. And then here's how we get batch

00:19:22.160 for any arbitrary split. If the split is a training split, then we're going to look at train data,

00:19:27.520 otherwise, and Valdea. That gets us the data array. And then when I generate random positions

00:19:35.280 to grab a chunk out of, I actually grab, I actually generate batch size number of random offsets.

00:19:42.400 So because this is for, we are, I X is going to be a four numbers that are randomly generated

00:19:49.040 between zero and len of data minus block size. So it's just random offsets into the training set.

00:19:54.400 And then X is, as I explained, are the first block size characters starting at I. The Ys

00:20:03.600 are the offset by one of that. So just add plus one. And then we're going to get those chunks

00:20:10.160 for every one of integers, I in I X, and use a torch dot stack to take all those one dimensional

00:20:18.240 tensors as we saw here. And we're going to stack them up at rows. And so they all become a row

00:20:26.720 in a four by eight tensor. So here's where I'm printing them. When I sample a batch XB and YB,

00:20:33.680 the inputs the transformer now are the input X is the four by eight tensor, four rows of eight

00:20:43.600 columns. And each one of these is a chunk of the training set. And then the targets here are

00:20:51.840 in the associated array Y, and they will come in to the transformer all the way at the end to

00:20:56.800 create the loss function. So they will give us the correct answer for every single position

00:21:02.880 inside X. And then these are the four independent rows. So spelled out as we did before,

00:21:10.880 this four by eight array contains a total of 32 examples. And they're completely independent

00:21:18.560 as far as the transformer is concerned. So when the input is 24, the target is 43,

00:21:26.000 or rather 43 here in the Y array, when the input is 2443, the target is 58. When the input is 2443,

00:21:34.320 58, the target is five, etc. Or like when it is 52, 58, one, the target is 58.

00:21:40.400 Right. So you can sort of see this spelled out. These are the 32 independent examples

00:21:46.560 packed in to a single batch of the input X. And then the desired targets are in Y. And so now this

00:21:54.640 integer tensor of X is going to feed into the transformer. And that transformer is going to

00:22:03.040 simultaneously process all these examples, and then look up the correct integers to predict in

00:22:08.640 every one of these positions in the tensor Y. Okay, so now that we have our batch of input that

00:22:14.160 we'd like to feed into a transformer, let's start basically feeding this into neural networks.

00:22:18.240 Now we're going to start off with the simplest possible neural network, which in the case of

00:22:22.480 language modeling, in my opinion, is the by Graham language model. And we've covered the

00:22:25.920 by Graham language model in my make more series in a lot of depth. And so here I'm going to

00:22:30.880 sort of go faster and let's just implement PyTorch module directly that implements the

00:22:35.520 by Graham language model. So I'm importing the pytorch and in module for reproducibility.

00:22:43.120 And then here I'm constructing a by Graham language model, which is a subclass of an

00:22:47.200 end module. And then I'm calling it and I'm passing in the inputs and the targets.

00:22:52.640 And I'm just printing. Now when the inputs and targets come here, you see that I'm just taking

00:22:58.240 the index, the inputs X here, which I renamed to ID X. And I'm just passing them into this token

00:23:04.400 embedding table. So what's going on here is that here in the constructor, we are creating a token

00:23:10.560 embedding table. And it is of size vocab size by vocab size. And we're using an end dot embedding,

00:23:17.920 which is a very thin wrapper around basically a tensor of shape vocab size by vocab size.

00:23:23.280 And what's happening here is that when we pass ID X here, every single integer in our input

00:23:29.520 is going to refer to this embedding table, and it's going to block out a row of that embedding table

00:23:34.880 corresponding to its index. So 24 here, we'll go to the embedding table, and we'll pluck out the 24

00:23:40.800 throw. And then 43 will go here and block out the 43rd row, etc. And then pytorch is going to arrange

00:23:47.520 all of this into a batch by time by channel tensor. In this case, batch is four time is eight,

00:23:55.840 and C, which is the channels is vocab size or 65. And so we're just going to block out all those

00:24:02.320 rows, arrange them in a B by T by C. And now we're going to interpret this as the logits,

00:24:08.160 which are basically the scores for the next character in a sequence. And so what's happening

00:24:13.280 here is we are predicting what comes next based on just the individual identity of a single token.

00:24:19.280 And you can do that because, I mean, currently, the tokens are not talking to each other, and

00:24:24.720 they're not seeing any context except for they're just seeing themselves. So I'm a, I'm a token

00:24:29.680 number five. And then I can actually make pretty decent predictions about what comes next just by

00:24:34.960 knowing that I'm token five, because some characters know, sir, follow other characters in technical

00:24:41.520 scenarios. So we saw a lot of this in a lot more depth in the make more series. And here if I just

00:24:46.800 run this, then we currently get the predictions, the scores, the logits for every one of the four by

00:24:54.000 eight positions. Now that we've made predictions about what comes next, we'd like to evaluate the

00:24:58.400 loss function. And so in make more series, we saw that a good way to measure a loss or like a

00:25:03.600 quality of the predictions is to use the negative log likelihood loss, which is also implemented in

00:25:08.880 PyTorch under the name cross entropy. So what we'd like to do here is loss is the cross entropy

00:25:16.160 on the predictions and the targets. And so this measures the quality of the logits with respect

00:25:21.360 to the targets. In other words, we have the identity of the next character. So how well are we predicting

00:25:27.680 the next character based on the logits? And intuitively, the correct, the correct dimension of logits,

00:25:34.640 depending on whatever the target is, should have a very high number. And all the other dimensions

00:25:39.840 should be very low number. Right. Now the issue is that this won't actually, this is what we want.

00:25:45.600 We want to basically output the logits and the loss. This is what we want. But unfortunately,

00:25:52.640 this won't actually run. We get an error message. But intuitively, we want to measure this. Now,

00:26:00.720 when we go to the PyTorch cross entropy documentation here, we're trying to call the cross entropy in

00:26:09.440 its functional form. So that means we don't have to create like a module for it. But here we go to

00:26:15.200 the documentation, you have to look into the details of how PyTorch expects these inputs.

00:26:20.160 And basically the issue here is PyTorch expects if you have multi dimensional input, which we do,

00:26:25.760 because we have a B by T by C tensor, then it actually really wants the channels to be the second

00:26:32.960 dimension here. So if you, so basically, it wants a B by C by T, instead of a B by T by C.

00:26:42.000 And so it's just the details of how PyTorch treats these kinds of inputs. And so we don't

00:26:49.440 actually want to deal with that. So what we're going to do instead is we need to basically reshape

00:26:53.120 our logits. So here's what I like to do. I like to take basically give names to the dimensions.

00:26:58.240 So logits dot shape is B by T by C and unpack those numbers. And then let's say that logits equals

00:27:05.200 logits dot view. And we want it to be a B times C, B times T by C. So just a two dimensional array.

00:27:12.800 Right, so we're going to take all the, we're going to take all of these positions here.

00:27:19.600 And we're going to stretch them out in one dimensional sequence and preserve the channel

00:27:24.800 dimension as the second dimension. So we're just kind of like stretching out the array.

00:27:29.280 So it's two dimensional. And in that case, it's going to better conform to what PyTorch

00:27:33.840 sort of expects in its dimensions. Now we have to do the same two targets, because currently targets

00:27:40.320 are of shape B by T. And we want it to be just B times D. So one dimensional. Now,

00:27:48.880 alternatively, you could always still just do minus one, because PyTorch will guess what this

00:27:53.520 should be if you want to lay it out. But let me just be explicit and say, you times D. Once we

00:27:58.400 reshape this, it will match the cross entropy case. And then we should be able to evaluate our loss.

00:28:06.880 Okay, so at that right now, and we can do loss. And so currently we see that the loss is 4.87.

00:28:13.520 Now, because our, we have 65 possible vocabulary elements, we can actually guess at what the loss

00:28:20.480 should be. And in particular, we covered negative log likelihood in a lot of detail. We are expecting

00:28:26.640 log or long of one over 65, and negative of that. So we're expecting the loss to be about

00:28:35.360 4.17. But we're getting 4.87. And so that's telling us that the initial predictions are not

00:28:41.920 super diffuse. They've got a little bit of entropy. And so we're guessing wrong. So, yes. But actually,

00:28:49.680 we're able to evaluate the loss. Okay, so now that we can evaluate the quality of the model on some

00:28:56.080 data, we'd like to also be able to generate from the model. So let's do the generation. Now,

00:29:01.200 I'm going to go again a little bit faster here, because I covered all this already in previous

00:29:04.880 videos. So here's a generate function for the model. So we take some, we take the same kind of

00:29:15.440 input IDX here. And basically, this is the current context of some characters in a batch,

00:29:24.080 in some batch. So it's also B by T. And the job of generate is to basically take this B by T and

00:29:30.800 extend it to be B by T plus one plus two plus three. And so it's just basically, it continues the

00:29:35.200 generation in all the batch dimensions in the time dimension. So that's its job. And we'll do

00:29:40.800 that for maximum tokens. So you can see here on the bottom, there's going to be some stuff here.

00:29:45.760 But on the bottom, whatever is predicted is concatenated on top of the previous IDX,

00:29:51.200 along the first dimension, which is the time dimension, to create a B by T plus one. So that

00:29:56.880 becomes a new IDX. So the job of generate is to take a B by T and make it a B by T plus one plus

00:30:02.000 two plus three, as many as we want maximum tokens. So this is the generation from the model.

00:30:07.280 Now inside the generational, what are we doing? We're taking the current indices,

00:30:11.920 we're getting the predictions. So we get those are in the logits. And then the loss here is going

00:30:19.280 to be ignored, because we're not, we're not using that. And we have no targets that are sort of

00:30:24.080 ground truth targets that we're going to be comparing with. Then once we get the logits,

00:30:29.840 we are only focusing on the last step. So instead of a B by T by C, we're going to pluck out the

00:30:36.720 negative one, the last element in the time dimension, because those are the predictions for what comes

00:30:41.600 next. So that gives us the logits, which we then cover to probabilities via softmax. And then we

00:30:47.760 use torch up multi-nomials to sample from those probabilities. And we ask PyTorch to give us one

00:30:52.640 sample. And so IDX next will become a B by one, because in each one of the batch dimensions,

00:31:00.080 we're going to have a single prediction for what comes next. So this num samples equals one,

00:31:04.560 we'll make this be a one. And then we're going to take those integers that come from the sampling

00:31:10.400 process, according to the probability distribution given here. And those integers got just concatenated

00:31:15.440 on top of the current sort of like running stream of integers. And this gives us a B by T plus one.

00:31:21.600 And then we can return that. Now, one thing here is you see how I'm calling self of IDX,

00:31:28.160 which will end up going to the forward function. I'm not providing any targets. So currently,

00:31:33.280 this would give an error because targets is sort of like not given. So targets has to be

00:31:39.200 optional. So targets is none by default. And then if targets is none, then there's no loss to create.

00:31:46.880 So just losses, none. But else, all of this happens, and we can create a loss. So this will make it so

00:31:53.920 if we have the targets, we provide them and get a loss, if we have no targets, we'll just get the

00:31:59.760 logits. So this here will generate from the model. And let's take that for a ride now.

00:32:07.200 Oops. So I have another coach on here, which will generate for the model from the model.

00:32:14.720 And okay, this is kind of crazy. So maybe let me let you break this down. So these are the IDX,

00:32:20.720 right? I'm creating a batch will be just one time will be just one. So I'm creating a little one by

00:32:31.040 one tensor, and it's holding a zero. And the D type, the data type is integer. So zero is going

00:32:37.920 to be how we kick off the generation. And remember that zero is, is the element standing for a

00:32:44.560 new line character. So it's kind of like a reasonable thing to to feed in as the very first

00:32:48.640 character in a sequence to be the new line. So it's going to be IDX, which we're going to feed in

00:32:55.520 here. Then we're going to ask for 100 tokens. And then end that generate will continue that.

00:33:01.680 Now, because generate works on the level of batches, we then have to index into the zero throw to

00:33:09.280 basically unplug the the single batch dimension that exists. And then that gives us a time steps,

00:33:19.760 is just a one dimensional array of all the indices, which we will convert to simple Python list

00:33:25.360 from PyTorch tensor, so that that can feed into our decode function and convert those integers

00:33:32.800 into text. So let me bring this back. And we're generating 100 tokens. Let's run. And here's the

00:33:40.160 generation that we achieved. So obviously, as garbage, and the reason it's garbage is because

00:33:44.640 this is totally random model. So next up, we're going to want to train this model. Now, one more

00:33:49.520 thing I wanted to point out here is this function is written to be general. But it's kind of like

00:33:55.040 ridiculous right now, because we're feeding in all this, we're building out this context,

00:34:00.800 and we're concatenating it all. And we're always feeding it all into the model. But that's kind of

00:34:07.440 ridiculous, because this is just a simple background model. So to make, for example, this prediction

00:34:11.360 about K, we only needed this w. But actually, what we fed into the model is we fed the entire

00:34:16.560 sequence. And then we only looked at the very last piece and predicted K. So the only reason I'm

00:34:23.120 writing it in this way is because right now this is a background model. But I'd like to keep this

00:34:27.840 function fixed. And I'd like it to work later, when our characters actually basically look

00:34:35.520 further in the history. And so right now the history is not used. So this looks silly. But

00:34:40.640 eventually the history will be used. And so that's why we want to do it this way. So just a quick

00:34:46.080 comment on that. So now we see that this is random. So let's train the model. So it becomes a bit

00:34:52.320 less random. Okay, let's now train the model. So first what I'm going to do is I'm going to create

00:34:57.040 a PyTorch optimization object. So here we are using the optimizer, Adam W. Now in the Makemore

00:35:05.120 series, we've only ever used the cast gradient descent, the simplest possible optimizer, which

00:35:09.040 you can get using the SGD instead. But I want to use Adam, which is a much more advanced and

00:35:13.680 popular optimizer. And it works extremely well for typical good setting for the learning rate is

00:35:19.840 roughly three in negative four. But for very, very small networks, like is the case here,

00:35:24.560 you can get away with much, much higher learning rates, running negative three or even higher

00:35:28.480 probably. But let me create the optimizer object, which will basically take the gradients and update

00:35:34.960 the parameters using the gradients. And then here, our batch size up above was only four. So let me

00:35:41.680 actually use something bigger, let's say 32. And then for some number of steps, we are sampling

00:35:47.200 a new batch of data, we're evaluating the loss, we're zeroing out all the gradients from the previous

00:35:52.960 step, getting the gradients for all the parameters, and then using those gradients to update our

00:35:57.920 parameters. So typical training loop, as we saw in the Makemore series. So let me now run this

00:36:04.320 for say 100 iterations, and let's see what kind of losses we're going to get.

00:36:10.640 So we started around 4.7. And now we're going to down to like 4.6, 4.5, etc. So the optimization

00:36:18.320 is definitely happening. But let's sort of try to increase number of iterations and only print at

00:36:25.120 the end, because we probably will not train for longer. Okay, so we're down to 3.6 roughly,

00:36:32.240 roughly down to three.

00:36:36.480 This is the most janky optimization.

00:36:43.040 Okay, it's working. Let's just do 10,000.

00:36:49.040 And then from here, we want to copy this. And hopefully, we're going to get something reasonable.

00:36:57.760 And of course, it's not going to be Shakespeare from a background model. But at least we see that

00:37:01.680 the loss is improving. And hopefully, we're expecting something a bit more reasonable.

00:37:06.480 Okay, so we're down at about 2.5 ish. Let's see what we get. Okay, dramatic improvement

00:37:13.520 certainly on what we had here. So let me just increase the number of tokens. Okay, so we see

00:37:19.680 that we're starting to get something at least like reasonable ish. Certainly not Shakespeare,

00:37:28.800 but the model is making progress. So that is the simplest possible model.

00:37:32.640 So now what I'd like to do is,

00:37:35.360 obviously, this is a very simple model because the tokens are not talking to each other. So

00:37:41.760 given the previous context of whatever was generated, we're only looking at the very last

00:37:46.160 character to make the predictions about what comes next. So now these, now these tokens have

00:37:50.880 to start talking to each other and figuring out what is in the context so that they can make

00:37:55.760 better predictions for what comes next. And this is how we're going to kick off the transformer.

00:38:00.320 Okay, so next, I take the code that we developed in this Jupyter notebook, and I converted it to

00:38:04.240 be a script. And I'm doing this because I just want to simplify our intermediate work, which is

00:38:09.760 just the final project that we have at this point. So in the top here, I put all the

00:38:15.200 parameters that we defined. I introduced a few and I'm going to speak to that in a little bit.

00:38:20.000 Otherwise, a lot of this should be recognizable. Reproducibility, read data, get the encoder and

00:38:25.920 the decoder, create the train and test splits, use the kind of like data loader that gets a batch

00:38:33.520 of the inputs and targets. This is new, and I'll talk about it in a second. Now this is the

00:38:39.360 background language model that we developed, and it can forward and give us a logit and loss,

00:38:44.000 and it can generate. And then here, we are creating the optimizer and this is the training loop.

00:38:50.320 So everything here should look pretty familiar. Now, some of the small things that I added,

00:38:56.000 number one, I added the ability to run on a GPU if you have it. So if you have a GPU, then you can,

00:39:02.320 this will use CUDA instead of just CPU, and everything will be a lot more faster. Now,

00:39:07.280 when device becomes CUDA, then we need to make sure that when we load the data, we move it to device.

00:39:13.920 When we create the model, we want to move the model parameters to device. So as an example,

00:39:20.000 here we have the NN embedding table, and it's got a dot weight inside it, which stores the

00:39:25.200 sort of lookup table. So that would be moved to the GPU, so that all the calculations here

00:39:30.800 happen on the GPU, and they can be a lot faster. And then finally, here, when I'm creating the

00:39:35.360 context that feeds it to generate, I have to make sure that I create on the device.

00:39:40.320 Number two, what I introduced is the fact that here in the training loop,

00:39:45.680 here I was just printing the loss dot item inside the training loop. But this is a very

00:39:53.600 noisy measurement of the current loss, because every batch will be more or less lucky. And so

00:39:58.880 what I want to do usually is I have an estimate loss function, and the estimate loss basically,

00:40:06.000 then goes up here, and it averages up the loss over multiple batches. So in particular,

00:40:14.640 we're going to iterate eval iter times, and we're going to basically get our loss,

00:40:19.200 and then we're going to get the average loss for both splits. And so this will be a lot less noisy.

00:40:24.000 So here, what we call the estimate loss, we're going to report the pretty accurate train and

00:40:29.760 validation loss. Now, when we come back up, you'll notice a few things here. I'm setting the model

00:40:35.520 to a valuation phase. And down here, I'm resetting it back to training phase. Now, right now for our

00:40:41.760 model, as is this doesn't actually do anything. Because the only thing inside this model is this,

00:40:46.800 and then embedding. And this, this network would behave both would behave the same in both

00:40:54.400 evaluation mode and training mode. We have no dropout layers, we have no bathroom layers,

00:40:58.640 etc. But it is a good practice to think through what mode your neural network is in, because some

00:41:04.560 layers will have different behavior at inference time or training time. And there's also this

00:41:11.760 context manager torch.no grad. And this is just telling PyTorch that everything that happens

00:41:16.000 inside this function, we will not call dot backward on. And so PyTorch can be a lot more

00:41:21.680 efficient with its memory use, because it doesn't have to store all the intermediate variables,

00:41:26.640 because we're never going to call backward. And so it can, it can be a lot more memory

00:41:30.480 efficient in that way. So also a good practice to tell PyTorch when we don't intend to do back

00:41:35.840 propagation. So right now, this script is about 120 lines of code of, and that's kind of our starter

00:41:44.000 code. I'm calling it by gram.py, and I'm going to release it later. Now running this script,

00:41:49.760 gives us output in the terminal, and it looks something like this. It basically, as I ran this

00:41:56.400 code, it was giving me the train loss and val loss. And we see that we convert to somewhere around

00:42:01.280 2.5 with the by-gram model. And then here's the sample that we produced at the end.

00:42:06.800 And so we have everything packaged up in the script, and we're in a good position now to iterate on

00:42:12.640 this. Okay, so we are almost ready to start writing our very first self attention block for

00:42:18.160 processing these tokens. Now, before we actually get there, I want to get you used to a mathematical

00:42:25.520 trick that is used in the self attention inside a transformer. And it's really just like at the

00:42:30.160 heart of an inefficient implementation of self attention. And so I want to work with this toy

00:42:35.760 example to just get used to this operation. And then it's going to make it much more clear once we

00:42:40.320 actually get to it in the script again. So let's create a B by T by C, where B, T and C are just

00:42:48.560 4, 8 and 2 in this toy example. And these are basically channels. And we have batches, and we have

00:42:55.520 the time component, and we have some information at each point in the sequence. So C. Now, what we

00:43:02.480 would like to do is we would like these tokens. So we have up to eight tokens here in a batch.

00:43:08.480 And these eight tokens are currently not talking to each other. And we would like them to talk to

00:43:12.080 each other. We'd like to couple them. And in particular, we don't we want to couple them in as

00:43:18.400 very specific way. So the token, for example, at the fifth location, it should not communicate with

00:43:24.400 tokens in the sixth, seventh, and eighth location, because those are future tokens in the sequence.

00:43:30.160 The token on the fifth location should only talk to the one in the fourth, third, second, and first.

00:43:35.200 So it's only so information only flows from previous context to the current time step. And we

00:43:41.440 cannot get any information from the future because we are about to try to predict the future.

00:43:46.400 So what is the easiest way for tokens to communicate? Okay, the easiest way I would say is, okay, if

00:43:53.840 we're up to if we're a fifth token, and I'd like to communicate with my past, the simplest way we can

00:43:58.720 do that is to just do a way is to just do an average of all the of all the preceding elements.

00:44:06.000 So for example, if I'm the fifth token, I would like to take the channels that make up that our

00:44:12.080 information at my step, but then also the channels from the fourth step, third step, second step,

00:44:16.880 and the first step, I'd like to average those up. And then that would become sort of like a feature

00:44:21.520 vector that summarizes me in the context of my history. Now, of course, just doing a sum, or

00:44:27.520 like an average is an extremely weak form of interaction, like this communication is extremely

00:44:32.160 lossy. We've lost a ton of information about spatial arrangements of all those tokens. But

00:44:37.200 that's okay for now. We'll see how we can bring that information back later. For now, what we would

00:44:41.600 like to do is for every single batch element independently, for every teeth token in that sequence,

00:44:49.040 we'd like to now calculate the average of all the vectors in all the previous tokens,

00:44:55.200 and also at this token. So let's write that out. I have a small snippet here. And instead of just

00:45:02.000 fumbling around, let me just copy paste it and talk to it. So in other words, we're going to create

00:45:08.400 x and the b o w is short for bag of words, because bag of words is, is kind of like a term that people

00:45:16.960 use when you are just averaging up things. So this is just a bag of words. Basically, there's a word

00:45:21.760 stored on every one of these eight locations, and we're doing a bag of words, just averaging.

00:45:26.080 So in the beginning, we're going to say that it's just initialized at zero. And then I'm doing a

00:45:31.200 for loop here. So we're not being efficient yet. That's coming. But for now, we're just iterating

00:45:35.760 over all the batch dimensions independently, iterating over time. And then the previous

00:45:41.680 tokens are at this batch dimension, and then everything up to and including the teeth token.

00:45:49.120 So when we slice out x in this way, x prep becomes of shape, how many elements there were in the past,

00:45:59.520 and then of course, see, so all the two dimensional information from these little tokens.

00:46:05.040 So that's the previous sort of chunk of tokens from my current sequence. And then I'm just doing

00:46:12.560 the average or the mean over the zero dimension. So I'm averaging out the time here. And I'm just

00:46:18.960 going to get a little C one dimensional vector, which I'm going to store in x bag of words.

00:46:24.160 So I can run this. And this is not going to be very informative, because

00:46:31.040 let's see. So this is x of zero. So this is the zero batch element. And then Expo at zero.

00:46:36.320 Now, you see how the at the first location here, you see that the two are equal. And that's because

00:46:44.480 it's we're just doing an average of this one token. But here, this one is now an average of these two.

00:46:50.560 And now this one is an average of these three. And so on. So, and this last one is the average

00:47:01.040 of all of these elements. So vertical average, just averaging up all the tokens, now gives this

00:47:06.320 outcome here. So this is all well and good. But this is very inefficient. Now, the trick is that we

00:47:13.200 can be very, very efficient about doing this using matrix multiplication. So that's the mathematical

00:47:18.560 trick. And let me show you what I mean. Let's work with the toy example here. Let me run it and I'll

00:47:23.760 explain. I have a simple matrix here that is three by three of all the ones, a matrix B of just random

00:47:31.200 numbers. And it's a three by two, and a matrix C, which will be three by three multiply three by two,

00:47:36.640 which will give out a three by two. So here we're just using matrix multiplication. So A multiply B

00:47:44.480 gives us C. Okay. So how are these numbers in C achieved? Right. So this number in a

00:47:53.600 top left is the first row of a dot product with the first column of B. And since all the row of A

00:48:01.760 right now is all just once, then the dot product here with with this column of B is just going to do

00:48:07.600 a sum of these of this column. So two plus six plus six is 14. The element here in the output of C

00:48:15.360 is also the first column here, the first row of A multiplied now with the second column of B. So

00:48:21.600 seven plus four plus plus five is 16. Now you see that there's repeating elements here. So this 14

00:48:26.960 again is because this row is again all once, and it's multiplying the first column of B. So you get

00:48:31.920 14. And this one is, and so on. So this last number here is the last row dot product last column.

00:48:39.120 Now the trick here is the following. This is just a boring number of, it's just a boring array of all

00:48:48.080 ones. But torch has this function called the trail, which is short for a triangular

00:48:53.520 something like that. And you can wrap it in torch that once, and it will just return the

00:48:59.360 lower triangular portion of this. Okay. So now it will basically zero out these guys here. So we

00:49:07.680 just get the lower triangular part. Well, what happens if we do that? So now we'll have A like this

00:49:17.200 and B like this. And now what are we getting here and C? Well, what is this number? Well, this is the

00:49:22.400 first row times the first column. And because this is zeros, these elements here are now ignored. So

00:49:30.720 we just get it to. And then this number here is the first row times the second column. And because

00:49:36.880 these are zeros, they get ignored. And it's just seven, the seven multiplies this one. But look what

00:49:42.720 happened here, because this is one and then zeros, we what ended up happening is we're just plucking

00:49:47.760 out the row of this row of B. And that's what we got. Now here, we have 110. So here, 110 dot product

00:49:57.760 with these two columns will now give us two plus six, which is eight, and seven plus four,

00:50:01.840 which is 11. And because this is one one one, we ended up with the addition of all of them.

00:50:08.640 And so basically, depending on how many ones and zeros we have here, we are basically doing a sum

00:50:14.800 currently of the variable number of these rows. And that gets deposited into C. So currently,

00:50:22.080 we're doing sums, because these are ones, but we can also do average, right? And you can start

00:50:27.120 to see how we could do average of the rows of B, sort of incremental fashion. Because we don't

00:50:34.320 have to, we can basically normalize these rows. So they sum to one, and then we're going to get

00:50:39.040 an average. So if we took a, and then we did a equals a divide, torch dot sum, in the

00:50:47.040 of a, in the one dimension. And then let's keep them is true. So therefore, the broadcasting will

00:50:57.600 work out. So if I rerun this, you see now that these rows now sum to one. So this row is one,

00:51:04.720 this row is 0.5 0.5 0. And here we get one thirds. And now when we do a multiply B,

00:51:10.880 what are we getting? Here we are just getting the first row, first row. Here now we are getting

00:51:16.560 the average of the first two rows. Okay, so two and six averages, four and four and seven

00:51:23.920 averages, five and five. And on the bottom here, we're now getting the average of these three rows.

00:51:30.480 So the average of all of elements of B are now deposited here. And so you can see that by manipulating

00:51:38.400 these elements of this multiplying matrix, and then multiplying it with any given matrix,

00:51:45.280 we can do these averages in this incremental fashion, because we just get,

00:51:51.280 and we can manipulate that based on the elements of A. Okay, so that's very convenient. So let's

00:51:56.640 swing back up here and see how we can vectorize this and make it much more efficient using what

00:52:00.880 we've learned. So in particular, we are going to produce an array a, but here I'm going to call

00:52:07.680 it way short for weights. But this is our A. And this is how much of every row we want to

00:52:15.440 average up. And it's going to be an average because you can see that these rows sum to one.

00:52:20.880 So this is our A. And then our B in this example, of course, is X. So what's going to happen here now,

00:52:29.040 is that we are going to have an expo two. And this expo two is going to be way multiplying

00:52:36.560 our X. So let's think this through way is T by T. And this is matrix multiplying in PyTorch,

00:52:45.200 a B by T by C. And it's giving us what shape. So PyTorch will come here and it will see that

00:52:53.360 these shapes are not the same. So it will create a batch dimension here. And this is a batch

00:52:58.880 matrix multiply. And so it will apply this matrix multiplication in all the batch elements in parallel

00:53:05.520 and individually. And then for each batch element, there will be a T by T multiplying T by C,

00:53:12.560 exactly as we had below. So this will now create B by T by C. And X bow two will now become identical

00:53:24.480 to expo. So we can see that Torch.All close of expo and expo two should be true. Now.

00:53:37.280 So this kind of like commisses us that these are in fact the same. So X bow and X bow two,

00:53:44.880 if I just print them. Okay, we're not going to be able to just stare it down. But

00:53:52.720 well, let me try expo basically just at the zero element and expo two at the zero element.

00:53:58.880 So just the first batch. And we should see that this and that should be identical, which they are.

00:54:05.280 Right. So what happened here, the trick is we were able to use batch matrix multiply

00:54:10.240 to do this aggregation really. And it's a weighted aggregation. And the weights are specified in this

00:54:18.640 T by T array. And we're basically doing weighted sums. And these weighted sums are according to

00:54:27.360 the weights inside here, the take on sort of this triangular form. And so that means that a token

00:54:33.440 at the T dimension will only get sort of information from the tokens perceiving it. So that's exactly

00:54:41.120 what we want. And finally, I would like to rewrite it in one more way. And we're going to see why

00:54:46.400 that's useful. So this is the third version. And it's also identical to the first and second.

00:54:52.240 But let me talk through it. It uses softmax. So trill here is this matrix, lower triangular ones.

00:55:02.640 Way begins as all zero. Okay, so if I just print way in the beginning, it's all zero.

00:55:09.360 Then I used masked fill. So what this is doing is way that masked fill, it's all zeros. And I'm saying

00:55:18.720 for all the elements where trill is equal to equal zero, make them be negative infinity.

00:55:24.560 So all the elements where trill is zero will become negative infinity. Now. So this is what we get.

00:55:32.080 And then the final line here is softmax. So if I take a softmax along every single,

00:55:39.600 so dim is negative one, so long every single row, if I do softmax, what is that going to do?

00:55:45.120 Well, softmax is, is also like a normalization operation, right? And so spoiler alert, you get

00:55:56.160 the exact same matrix. Let me bring back the softmax. And recall that in softmax, we're going to

00:56:02.960 exponentiate every single one of these. And then we're going to divide by the sum. And so if we

00:56:09.280 exponentiate every single element here, we're going to get a one. And here we're going to get

00:56:13.280 basically zero, zero, zero, zero, everywhere else. And then when we normalize, we just get one.

00:56:19.440 Here we're going to get one one, and then zeros. And then softmax will again divide,

00:56:25.200 and this will give us point five point five, and so on. And so this is also the same way to produce

00:56:31.760 this mask. Now the reason that this is a bit more interesting, and the reason we're going to end

00:56:36.320 up using it in self attention, is that these weights here begin with zero. And you can think of this

00:56:44.720 as like an interaction strength, or like an affinity. So basically it's telling us how much

00:56:50.640 of each token from the past, do we want to aggregate an average up. And then this line is saying,

00:56:58.560 tokens from the past cannot communicate by setting them to negative infinity, we're saying that we

00:57:04.560 will not aggregate anything from those tokens. And so basically this then goes through softmax,

00:57:10.320 and through the weighted, and this is the aggregation through matrix multiplication.

00:57:14.880 And so what this is now is you can think of these as these zeros are currently just set by us to be

00:57:21.440 zero. But a quick preview is that these affinities between the tokens are not going to be just constant

00:57:28.320 at zero. They're going to be data dependent. These tokens are going to start looking at each other,

00:57:33.520 and some tokens will find other tokens more or less interesting. And depending on what their

00:57:38.640 values are, they're going to find each other interesting to different amounts. And I'm going

00:57:42.960 to call those affinities, I think. And then here we are saying the future cannot communicate with

00:57:47.760 the past. We're going to clamp them. And then when we normalize and sum, we're going to aggregate

00:57:54.000 sort of their values, depending on how interestingly they find each other. And so that's the preview

00:57:59.280 for self attention. And basically, long story short from this entire section is that you can do

00:58:05.920 weighted aggregations of your past elements by having by using matrix multiplication of a

00:58:13.280 lower triangular fashion. And then the elements here in the lower triangular part are telling you

00:58:18.800 how much of each element fuses into this position. So we're going to use this trick now to develop

00:58:24.960 the self attention block. So first, let's get some quick preliminaries out of the way.

00:58:28.960 First, the thing I'm kind of bothered by is that you see how we're passing them vocab size into

00:58:33.920 the constructor. There's no need to do that because vocab size is already defined up top as a global

00:58:38.720 variable. So there's no need to pass this stuff around. Next, what I want to do is I don't want to

00:58:44.720 actually create, I want to create like a level of indirection here, where we don't directly go to

00:58:49.040 the embedding for the logits. But instead we go through this intermediate phase, because we're

00:58:54.480 going to start making that bigger. So let me introduce a new variable and embed a short for

00:59:01.360 number of embedding dimensions. So an embed here will be say 32. That was this suggestion from

00:59:10.000 GitHub co-piled by the way. It also suggested 32, which is a good number. So this is an embedding

00:59:16.080 table and only 32 dimensional embeddings. So then here, this is not going to give us logits

00:59:22.080 directly. Instead, this is going to give us token embeddings. That's what I'm going to call it.

00:59:27.040 And then to go from the token embeddings to the logits, we're going to need a linear layer. So

00:59:31.680 self dot lmhead, let's call it short for language modeling head, is an linear from an embed up to

00:59:38.080 vocab size. And then when we swing over here, we're actually going to get the logits by exactly

00:59:43.760 what the co-pilot says. Now we have to be careful here because this C and this C are not equal.

00:59:51.440 This is an embed C and this is vocab size. So let's just say that an embed is equal to C.

00:59:57.280 And then this just creates one spurious layer of interaction through a linear layer. But this

01:00:04.400 should basically run. So we see that this runs and this currently looks kind of spurious, but

01:00:16.880 we're going to build on top of this. Now next up, so far, we've taken these in in the seas,

01:00:21.920 and we've encoded them based on the identity of the tokens inside ID X. The next thing that

01:00:28.640 people very often do is that we're not just encoding the identity of these tokens, but also

01:00:33.280 their position. So we're going to have a second position embedding table here. So self that position

01:00:39.440 embedding table is an embedding of block size by an embed. And so each position from zero to block

01:00:45.920 size minus one will also get its own embedding vector. And then here, first let me decode a b by

01:00:52.080 t from ID X dot shape. And then here we're also going to have a positive bedding, which is the

01:00:58.240 positional embedding. And these are this is tortoise arrange. So this will be basically just integers

01:01:03.520 from zero to t minus one. And all of those integers from zero to t minus one get embedded

01:01:08.720 through the table to create a T by C. And then here, this gets renamed to just say X, and X will be

01:01:16.960 the addition of the token embeddings with the positional embeddings. And here the broadcasting

01:01:22.880 note will work out. So B by T by C plus T by C, this gets right aligned a new dimension of

01:01:28.720 one gets added, and it gets broadcasted across batch. So at this point, X holds not just the token

01:01:35.360 identities, but the positions at which these tokens occur. And this is currently not that useful,

01:01:41.040 because of course, we just had a simple binary model. So it doesn't matter if you're on the fifth

01:01:44.480 position, the second position or wherever, it's all translation invariant at this stage. So this

01:01:49.920 information currently wouldn't help. But as we work on the self attention block, we'll see that

01:01:54.640 this starts to matter. Okay, so now we get the crux of self attention. So this is probably the

01:02:02.960 most important part of this video to understand. We're going to implement a small self attention

01:02:08.000 for a single individual head as they're called. So we start off with where we were. So all of this

01:02:13.600 code is familiar. So right now I'm working with an example where I change the number of channels

01:02:18.560 from two to 32. So we have a four by eight arrangement of tokens. And each token and the

01:02:25.120 information to each token is currently 32 dimensional, but we just are working with random numbers.

01:02:31.200 Now we saw here that the code as we had it before, does a simple way, simple average of all the past

01:02:40.080 tokens and the current token. So it's just the previous information and current information is

01:02:44.880 just being mixed together in an average. And that's what this code currently achieves. And it does

01:02:49.760 so by creating this lower triangular structure, which allows us to mask out this way matrix that

01:02:56.480 we create. So we mask it out, and then we normalize it. And currently, when we initialize the

01:03:03.840 affinities between all the different sort of tokens or nodes, I'm going to use those terms

01:03:09.040 interchangeably. So when we initialize the affinities between all the different tokens to be zero,

01:03:14.160 then we see that way gives us this structure where every single row has these

01:03:21.040 uniform numbers. And so that's what that's what then in this matrix multiply makes it so that

01:03:26.960 we're doing a simple average. Now, we don't actually want this to be all uniform, because

01:03:35.200 different tokens will find different other tokens more or less interesting. And we want that to be

01:03:40.560 data dependent. So for example, if I'm a vowel, then maybe I'm looking for consonants in my past.

01:03:46.240 And maybe I want to know what those constants are, and I want that information to flow to me.

01:03:50.000 And so I want to now gather information from the past, but I want to do it in a data dependent way.

01:03:56.240 And this is the problem that self attention solves. Now, the way self attention solves this

01:04:01.120 is the following. Every single node or every single token at each position will emit two vectors.

01:04:08.560 It will emit a query, and it will emit a key. Now, the query vector, roughly speaking, is what am I

01:04:16.960 looking for? And the key vector, roughly speaking, is what do I contain? And then the way we get

01:04:23.920 affinities between these tokens now in a sequence is we basically just do a dot product between the

01:04:30.640 keys and the queries. So my query dot products with all the keys of all the other tokens,

01:04:37.360 and that dot product now becomes way. And so if the key and the query are sort of aligned,

01:04:46.800 they will interact to a very high amount. And then I will get to learn more about that specific

01:04:52.480 token, as opposed to any other token in the sequence. So let's implement this now.

01:05:01.440 We're going to implement a single what's called head of self attention. So this is just one head.

01:05:09.440 There's a hyper parameter involved with these heads, which is the head size. And then here,

01:05:13.760 I'm initializing linear modules, and I'm using bias equals false. So these are just going to

01:05:18.960 apply matrix multiply with some fixed weights. And now let me produce a key and Q k and Q

01:05:27.520 by forwarding these modules on X. So the size of this will now become B by T by 16, because that is

01:05:36.480 the head size, and the same here B by T by 16. So this being that size. So you see here that when

01:05:48.480 I forward this linear on top of my X, all the tokens in all the positions in the B by T arrangement,

01:05:55.200 all of them in parallel and independently produce a key and a query. So no communication has happened

01:06:00.800 yet. But the communication comes now, all the queries will dart product with all the keys.

01:06:06.960 So basically what we want is we want way now, or the affinities between these to be query

01:06:14.480 multiplying key. But we have to be careful with we can't matrix multiply this, we actually need

01:06:19.520 to transpose k, but we have to be also careful because these are when you have the batch dimension.

01:06:26.720 So in particular, we want to transpose the last two dimensions, dimension negative one and dimension

01:06:32.640 native two. So negative two, negative one. And so this matrix multiply now will basically do the

01:06:40.720 following B by T by 16. Matrix multiplies B by 16 by T to give us B by T by T.

01:06:52.080 Right. So for every row of B, we're not going to have a T square matrix giving us the affinities.

01:07:01.840 And these are now the way. So they're not zeros. They are now coming from this dart product between

01:07:07.760 the keys and the queries. So this can now run, I can, I can run this. And the weighted aggregation

01:07:14.320 now is a function in a data band and manner between the keys and queries of these nodes. So just

01:07:20.560 inspecting what happened here, the way takes on this form. And you see that before away was just

01:07:28.960 a constant. So it was applied in the same way to all the batch elements. But now every single batch

01:07:34.000 elements will have different sort of way, because every single batch element contains different

01:07:39.360 tokens at different positions. And so this is not a data dependent. So when we look at just the

01:07:45.200 zero row, for example, in the input, these are the weights that came out. And so you can see now

01:07:51.520 that they're not just exactly uniform. And in particular, as an example here for the last row,

01:07:57.520 this was the eighth token. And the eighth token knows what content it has and it knows what position

01:08:02.880 it's in. And now the eight token based on that creates a query. Hey, I'm looking for this kind of

01:08:09.520 stuff. I'm a vowel, I'm on the eight position, I'm looking for any consonants at positions up to four.

01:08:15.120 And then all the nodes get to emit keys. And maybe one of the channels could be, I am a

01:08:22.080 consonant and I am in the position up to four. And that key would have a high number in that

01:08:27.920 specific channel. And that's how the query and the key when they dart product, they can find each

01:08:32.240 other and create a high affinity. And when they have a high affinity, like say, this token was

01:08:38.080 pretty interesting to, to this eighth token, when they have a high affinity, then through the softmax,

01:08:45.040 I will end up aggregating a lot of its information into my position. And so I'll get to learn a lot

01:08:50.560 about it. Now, just this, we're looking at way after this has already happened. Let me erase this

01:09:00.000 operation as well. So let me erase the masking and the softmax, just to show you the under the hood

01:09:04.480 internals and how that works. So without the masking and the softmax, way comes out like this,

01:09:10.480 right? This is the outputs of the dart products. And these are the raw outputs and they take on

01:09:15.360 values from negative, you know, two to positive two, etc. So that's the raw interactions and raw

01:09:22.000 affinities between all the nodes. But now if I'm a, if I'm a fifth node, I will not want to

01:09:27.760 aggregate anything from the six node, seventh node and the eighth node. So actually, we use the

01:09:32.480 upper triangular masking. So those are not allowed to communicate. And now we actually want to have

01:09:40.160 a nice distribution. So we don't want to aggregate negative point one one of this node that's crazy.

01:09:46.480 So instead we exponentiate and normalize, and now we get a nice distribution that seems to one.

01:09:50.720 And this is telling us now in the data dependent manner, how much of information to aggregate

01:09:55.600 from any of these tokens in the past. So that's way, and it's not zeros anymore. But it's

01:10:03.200 calculated in this way. Now there's one more part to a single self attention head. And that is that

01:10:10.560 when we do the aggregation, we don't actually aggregate the tokens exactly. We aggregate, we produce one

01:10:16.240 more value here, and we call that the value. So in the same way that we produce t and query, we're

01:10:23.120 also going to create a value. And then here, we don't aggregate x, we calculate a V, which is just

01:10:34.320 achieved by propagating this linear on top of x again. And then we output way multiplied by V.

01:10:42.480 So V is the elements that we aggregate, or the vector that we aggregate instead of the raw x.

01:10:49.680 And now of course, this will make it so that the output here of the single head will be 16

01:10:54.320 dimensional, because that is the head size. So you can think of x as kind of like private

01:11:00.320 information to this token, if you think about it that way. So x is kind of private to this token.

01:11:05.760 So I'm a fifth token at some, and I have some identity. And my information is kept in vector x.

01:11:12.320 And now for the purposes of the single head, here's what I'm interested in. Here's what I have.

01:11:19.440 And if you find me interesting, here's what I will communicate to you. And that's stored in V.

01:11:24.160 And so V is the thing that gets aggregated for the purposes of this single head between the

01:11:29.840 different nodes. And that's basically the self attention mechanism. This is what it does.

01:11:37.040 There are a few notes that I would like to make about attention. Number one, attention is a

01:11:43.520 communication mechanism. You can really think about it as a communication mechanism, where you

01:11:48.000 have a number of nodes in a directed graph, where basically you have edges pointed between nodes

01:11:52.960 like this. And what happens is every node has some vector of information, and it gets to aggregate

01:11:59.440 information via a weighted sum from all the nodes that point to it. And this is done in a data

01:12:05.680 dependent manner. So depending on whatever data is actually stored at each node at any point in time.

01:12:10.080 Now, our graph doesn't look like this. Our graph has a different structure. We have eight nodes,

01:12:16.880 because the block size is eight, and there's always eight tokens. And the first node is only

01:12:23.040 pointed to by itself. The second node is pointed to by the first node and itself, all the way up to

01:12:28.640 the eighth node, which is pointed to by all the previous nodes and itself. And so that's the

01:12:34.400 structure that our directed graph has or happens to have in an autoregressive sort of scenario like

01:12:39.680 language modeling. But in principle, attention can be applied to any arbitrary directed graph.

01:12:44.160 And it's just a communication mechanism between the nodes. The second note is that

01:12:48.400 note is that there's no notion of space. So attention simply acts over like a set of vectors in this

01:12:54.480 graph. And so by default, these nodes have no idea where they are positioned in the space. And that's

01:12:59.440 why we need to encode them positionally and sort of give them some information that is anchored to

01:13:04.400 a specific position so that they sort of know where they are. And this is different than, for

01:13:09.600 example, from convolution, because if you run, for example, a convolution operation over some input,

01:13:13.840 there is a very specific sort of layout of the information in space and the convolutional filters

01:13:19.760 sort of act in space. And so it's, it's not like an attention. An attention is just a set of vectors

01:13:26.480 out there in space. They communicate. And if you want them to have a notion of space, you need to

01:13:31.120 specifically add it, which is what we've done when we calculated the relative, the positional

01:13:36.800 encodings and added that information to the vectors. The next thing that I hope is very clear is that

01:13:42.320 the elements across the batch dimension, which are independent examples, never talk to each other.

01:13:46.720 They're always processed independently. And this is a bashed matrix multiply that applies

01:13:50.880 basically a matrix multiplication kind of imperil across the batch dimension. So maybe it would be

01:13:55.760 more accurate to say that in this analogy of a directed graph, we really have, because the batch

01:14:01.040 size is four, we really have four separate pools of eight nodes. And those eight nodes only talk to

01:14:06.720 each other. But in total, there's like 32 nodes that are being processed. But there's sort of four

01:14:12.080 separate pools of eight. You can look at it that way. The next note is that here in the case of

01:14:17.920 language modeling, we have this specific structure of directed graph where the future tokens will

01:14:24.320 not communicate to the past tokens. But this doesn't necessarily have to be constrained in

01:14:29.040 the general case. And in fact, in many cases, you may want to have all of the nodes talk to each

01:14:34.800 other fully. So as an example, if you're doing sentiment analysis or something like that with a

01:14:39.280 transformer, you might have a number of tokens, and you may want to have them all talk to each other

01:14:44.160 fully. Because later you are predicting, for example, the sentiment of the sentence. And so it's

01:14:49.360 okay for these nodes to talk to each other. And so in those cases, you will use an encoder block of

01:14:55.280 self attention. And all it means that it's an encoder block is that you will delete this line of code,

01:15:01.920 allowing all the nodes to completely talk to each other. What we're implementing here is sometimes

01:15:06.080 called a decoder block. And it's called a decoder, because it is sort of like decoding language. And

01:15:14.480 it's got this autoregressive format where you have to mask with the triangle and matrix, so that

01:15:19.760 notes from the future never talk to the past, because they would give away the answer. And so

01:15:25.760 basically, an encoder blocks, you would delete this, allow all the nodes to talk. In decoder blocks,

01:15:30.800 this will always be present, so that you have this triangular structure. But both are allowed

01:15:35.680 and attention doesn't care. Attention supports arbitrary connectivity between nodes. The next

01:15:40.080 thing I wanted to comment on is you keep me, you keep hearing me say attention, self attention,

01:15:44.400 etc. There's actually also something called cross attention. What is the difference? So

01:15:48.960 basically, the reason this attention is self attention, is because the keys, queries, and the

01:15:56.720 values are all coming from the same source from x. So the same source x produces key queries and

01:16:03.440 values. So these nodes are self attending. But in principle, attention is much more general than

01:16:09.200 that. So for example, an encoder decoder transformers, you can have a case where the queries are produced

01:16:15.360 from x. But the keys and the values come from a whole separate external source, and sometimes

01:16:20.560 from the encoder blocks that encode some context that we'd like to condition on. And so the keys

01:16:26.000 and the values will actually come from a whole separate source. Those are nodes on the side.

01:16:30.640 And here we're just producing queries, and we're reading off information from the side.

01:16:34.400 So cross attention is used when there's a separate source of nodes, we'd like to pull

01:16:41.120 information from into our nodes. And it's self attention if we just have nodes that we'd like

01:16:46.080 to look at each other and talk to each other. So this attention here happens to be self attention.

01:16:51.040 But in principle, attention is a lot more general. Okay, and the last note at this stage is,

01:16:58.880 if we come to the attention is only need paper here, we've already implemented attention. So

01:17:03.200 given query key and value, we've multiplied the query on a key, we've softmaxed it, and then we

01:17:09.520 are aggregating the values. There's one more thing that we're missing here, which is the dividing by

01:17:13.840 one over square root of the head size, the decay here is the head size. Why aren't they doing this

01:17:19.120 one? This is important. So they call it a scaled attention. And it's kind of like an important

01:17:25.120 normalization to basically have the problem is if you have unit Gaussian inputs, so zero mean unit

01:17:31.120 variance, k and q are unit Gaussian. And if you just do way naively, then you see that your way

01:17:36.880 actually will be the variance will be on the order of head size, which in our case is 16.

01:17:41.200 But if you multiply by one over head size square root, so this is square root, and this is one over,

01:17:47.040 then the variance of way will be one. So we'll be preserved. Now, why is this important? You'll

01:17:54.640 notice that way here will feed into softmax. And so it's really important, especially at

01:18:01.120 initialization, that way be fairly diffuse. So in our case here, we sort of locked out here and way

01:18:08.720 at a fairly diffuse numbers here. So like this, now the problem is that because of softmax, if

01:18:16.560 weight takes on very positive and very negative numbers inside it, softmax will actually converge

01:18:22.000 towards one hot vectors. And so I can illustrate that here. Say we are applying softmax to a

01:18:29.920 tensor of values that are very close to zero, then we're going to get a diffuse thing out of

01:18:33.920 softmax. But the moment I take the exact same thing and I start sharpening it, making it bigger

01:18:39.200 by multiplying these numbers by eight, for example, you'll see that the softmax will start to sharpen.

01:18:44.080 And in fact, it will sharpen towards the max. So it will sharpen towards whatever number here is

01:18:48.720 the highest. And so basically, we don't want these values to be too extreme, especially at

01:18:53.680 initialization. Otherwise, softmax will be way too peaky. And you're basically aggregating

01:18:58.960 information from like a single node. Every node just aggregates information from a single other

01:19:04.160 node. That's not what we want, especially at initialization. And so the scaling is used just

01:19:09.280 to control the variance at initialization. Okay, so having said all that, let's now take our

01:19:14.320 self attention knowledge and let's take it for a spin. So here in the code, I've created this

01:19:19.840 head module and implements a single head of self attention. So you give it a head size. And then

01:19:25.920 here it creates the key query and the value linear layers, typically people don't use biases in these.

01:19:30.960 So those are the linear projections that we're going to apply to all of our nodes.

01:19:35.280 Now here, I'm creating this trill variable. Trill is not a parameter of the module. So in sort of

01:19:41.680 pytorch naming conventions, this is called a buffer. It's not a parameter. And you have to call it,

01:19:46.800 you have to assign it to the module using a register buffer. So that creates the trill,

01:19:50.560 the trying lower triangular matrix. And over giving the input X, this should look very familiar now.

01:19:57.120 We calculate the keys, the queries, we calculate the attention scores inside way. We normalize it,

01:20:03.600 so we're using scaled attention here. Then we make sure that sure doesn't communicate with the past.

01:20:09.360 So this makes it a decoder block. And then softmax and then aggregate the value in output.

01:20:14.800 Then here in the language model, I'm creating a head in the constructor and I'm calling it self

01:20:21.120 attention head. And the head size, I'm going to keep as the same and embed just for now.

01:20:26.720 And then here, once we've encoded the information with the token embeddings and the position embeddings,

01:20:34.160 we're simply going to feed it into the self attention head. And then the output of that is going to go

01:20:38.960 into the decoder language modeling head and create the logits. So this is sort of the simplest way

01:20:46.240 to plug in a self attention component into our network right now. I had to make one more change,

01:20:52.640 which is that here in the generate, we have to make sure that our ID X that we feed into the model.

01:21:00.160 Because now we're using positional embeddings, we can never have more than block size coming in,

01:21:06.000 because if ID X is more than block size, then our position embedding table is going to run out of

01:21:10.960 scope because it only has embeddings for up to block size. And so therefore, I added some code

01:21:16.080 here to crop the context that we're going to feed into self, so that we never pass in more

01:21:23.440 block size elements. So those are the changes and let's now train the network. Okay, so I also came

01:21:28.880 up to the script here and I decreased the learning rate because the self attention can't tolerate

01:21:33.680 very, very high learning rates. And then I also increased number of iterations because the

01:21:37.840 learning rate is lower. And then I trained it and previously we were only able to get to up to 2.5,

01:21:43.040 and now we are down to 2.4. So we definitely see a little bit of improvement from 2.5 to 2.4

01:21:48.640 roughly, but the text is still not amazing. So clearly the self attention head is doing some

01:21:54.720 useful communication, but we still have a long way to go. Okay, so now we've implemented the

01:22:00.400 scale dot product attention. Now next up in the attention is all you need paper. There's something

01:22:05.280 called multi head attention. And what is multi head attention? It's just applying multiple

01:22:10.080 attentions in parallel and concatenating the results. So they have a little bit of diagram here.

01:22:16.000 I don't know if this is super clear. It's really just multiple attentions in parallel.

01:22:20.800 So let's implement that fairly straightforward. If we want a multi head attention, then we want

01:22:27.520 multiple heads of self attention running in parallel. So in PyTorch, we can do this by simply creating

01:22:33.920 multiple heads. So however many heads you want, and then what is the head size of each? And then

01:22:42.160 we run all of them in parallel into a list and simply concatenate all of the outputs. And we're

01:22:48.560 concatenating over the channel dimension. So the way this looks now is we don't have just a single

01:22:54.160 attention that has a head size of 32, because remember, an embed is 32. Instead of having one

01:23:02.320 communication channel, we now have four communication channels in parallel. And each one of these

01:23:08.720 communication channels typically will be smaller correspondingly. So because we have four communication

01:23:15.840 channels, we want eight dimensional self attention. And so from each communication channel, we're

01:23:20.880 going together eight dimensional vectors. And then we have four of them. And that

01:23:24.960 coordinates to give us 32, which is the original and embed. And so this is kind of similar to,

01:23:30.880 if you're familiar with convolutions, this is kind of like a group convolution, because basically,

01:23:35.520 instead of having one large convolution, we do convolution in groups. And that's multi headed

01:23:40.720 self attention. And so then here, we just use essay heads, self attention heads instead.

01:23:47.600 Now I actually ran it and scrolling down. I ran the same thing, and then we now get this out to

01:23:54.480 2.28 roughly. And the upper is still the generation is still not amazing. But clearly,

01:24:00.240 the validation loss is improving, because we were at 2.4 just now. And so it helps to have

01:24:05.840 multiple communication channels, because obviously these tokens have a lot to talk about. They want

01:24:10.880 to find the constants, the vowels, they want to find the vowels just from certain positions.

01:24:15.360 They want to find any kinds of different things. And so it helps to create multiple independent

01:24:20.240 channels of communication, gather lots of different types of data, and then decode the output.

01:24:25.920 Now going back to the paper for a second, of course, I didn't explain this figure in full detail,

01:24:29.760 but we are starting to see some components of what we've already implemented. We have the

01:24:33.520 positional encodings, the token encodings that add, we have the masked multi headed attention

01:24:38.560 implemented. Now, here's another multi headed attention, which is a cross attention to an

01:24:43.680 encoder, which we haven't, we're not going to implement in this case. I'm going to come back to

01:24:47.760 that later. But I want you to notice that there's a feed forward part here. And then this is grouped

01:24:53.040 into a block that gets repeated again and again. Now the feed forward part here is just a simple

01:24:58.000 multi layer projection. So the multi headed. So here position wise feed forward networks is just a

01:25:06.000 simple little MLP. So I want to start basically in a similar fashion, also adding computation into

01:25:12.000 the network. And this computation is on a per node level. So I've already implemented it, and you can

01:25:18.640 see the diff highlighted on the left here when I've added or changed things. Now before we had the

01:25:24.240 self multi headed self attention that bit the communication, but we went way too fast to calculate

01:25:29.440 the logits. So the tokens looked at each other, but didn't really have a lot of time to think on

01:25:34.560 what they found from the other tokens. And so what I've implemented here is a little feed forward

01:25:40.960 single layer. And this little layer is just a linear followed by a relevant on linearity. And

01:25:46.160 that's that's it. So it's just a little layer. And then I call it feed forward. And embed. And

01:25:54.960 then this feed forward is just called sequentially right after the self attention. So we self attend,

01:26:00.080 then we feed forward. And you'll notice that the feed forward here, when it's applying linear,

01:26:04.720 this is on a per token level. All the tokens do this independently. So the self attention is the

01:26:10.160 communication. And then once they've gathered all the data, now they need to think on that data

01:26:14.560 individually. And so that's what feed forward is doing. And that's why I've added it here. Now,

01:26:20.320 when I train this, the validation loss actually continues to go down now to 2.24, which is down

01:26:25.600 from 2.28. The outputs still look kind of terrible, but at least we've improved the situation. And so

01:26:32.240 as a preview, we're going to now start to interspers the communication with the computation. And that's

01:26:39.920 also what the transformer does, when it has blocks that communicate and then compute,

01:26:45.200 and it groups them and replicates them. Okay, so let me show you what we'd like to do.

01:26:50.560 We'd like to do something like this. We have a block. And this block is basically this part here,

01:26:55.360 except for the cross attention. Now the block basically intersperses communication and the

01:27:01.280 computation. The computation, the communication is done using multi headed self attention. And

01:27:06.480 then the computation is done using a feed forward network on all the tokens independently.

01:27:10.800 Now, what I've added here also is you'll notice this takes the number of embeddings in the

01:27:18.800 embedding dimension and number of heads that we would like, which is kind of like group sizing,

01:27:22.480 group convolution. And I'm saying that number of heads we'd like is four. And so because this is

01:27:27.600 32, we calculate that because this 32, the number of heads should be four.

01:27:33.920 The head size should be eight, so that everything sort of works out channel wise.

01:27:37.520 So this is how the transformer structures sort of the sizes typically. So the head size will become

01:27:45.120 eight. And then this is how we want to interspers them. And then here, I'm trying to create blocks,

01:27:49.920 which is just a sequential application of block block block. So that we're interspersing

01:27:54.720 communication feed forward many, many times. And then finally we decode. Now, actually try to run

01:28:00.880 this. And the problem is this doesn't actually give a very good answer. And we're very good

01:28:06.000 result. And the reason for that is we're starting to actually get like a pretty deep neural net.

01:28:10.880 And deep neural nets suffer from optimization issues. And I think that's what we're kind of

01:28:14.640 like slightly starting to run into. So we need one more idea that we can borrow from the

01:28:19.200 transfer of paper to resolve those difficulties. Now there are two optimizations that dramatically

01:28:25.280 help with the depth of these networks, and make sure that the networks remain optimizable.

01:28:30.000 Let's talk about the first one. The first one in this diagram is you see this arrow here,

01:28:34.320 and then this arrow and this arrow. Those are skip connections or sometimes called residual

01:28:40.080 connections. They come from this paper, the procedural learning from a direct mission from

01:28:45.600 about 2015 that introduced the concept. Now, these are basically what it means is you transform the

01:28:53.760 data, but then you have a skip connection with addition from the previous features. Now the way

01:28:59.600 I like to visualize it that I prefer is the following. Here the computation happens from the top to

01:29:06.400 bottom. And basically you have this residual pathway, and you are free to fork off from the

01:29:12.400 residual pathway, perform some computation, and then project back to the residual pathway via

01:29:17.120 addition. And so you go from the inputs to the targets, only the plus and plus and plus.

01:29:24.320 And the reason this is useful is because during that propagation, remember from our

01:29:29.040 micro grad video earlier, addition distributes gradients equally to both of its branches,

01:29:34.480 that that fat is the input. And so the supervision or the gradients from the loss

01:29:40.640 basically hop through every addition node all the way to the input, and then also fork off

01:29:48.160 into the residual blocks. But basically have this gradient superhighway that goes directly from

01:29:54.240 the supervision all the way to the input, unimpeded. And then these original blocks are usually

01:29:59.680 initialized in the beginning, so they contribute very, very little, if anything, to the residual

01:30:03.760 pathway. They are initialized that way. So in the beginning, they are almost kind of like not there,

01:30:09.840 but then during the optimization, they come online over time, and they start to contribute,

01:30:15.520 but at least at the initialization, you can go from directly supervision to the input gradient,

01:30:21.120 this unimpeded and just flows, and then the blocks over time kick in. And so that dramatically helps

01:30:27.280 with the optimization. So let's implement this. So coming back to our block here, basically what

01:30:31.840 we want to do is we want to do x equals x plus self attention and x equals x plus solve that

01:30:38.880 feed forward. So this is x, and then we fork off and do some communication and come back.

01:30:45.600 And we fork off and we do some computation and come back. So those are residual connections.

01:30:50.720 And then swinging back up here, we also have to introduce this projection. So an end that linear.

01:30:56.880 And this is going to be from after we concatenate this, this is the size and embed. So this is the

01:31:05.360 output of the self tension itself. But then we actually want the to apply the projection,

01:31:10.960 and that's the result. So the projection is just a linear transformation of the outcome of this

01:31:16.800 layer. So that's the projection back into the residual pathway. And then here in a feed forward,

01:31:23.120 it's going to be the same thing. I could have a self that projection here as well. But let me

01:31:28.000 just simplify it. And let me couple it inside the same sequential container. And so this is the

01:31:35.280 projection layer going back into the residual pathway. And so that's, well, that's it. So now

01:31:42.800 we can train this. So I implemented one more small change. When you look into the paper again,

01:31:48.320 you see that the dimensionality of input and output is 512 for them. And they're saying that the

01:31:53.360 inner layer here in the feed forward has dimensionality of 2048. So there's a multiplier of four. And so

01:31:59.920 the inner layer of the feed forward network should be multiplied by four in terms of channel sizes.

01:32:04.960 So I came here and I multiplied four times embed here for the feed forward. And then from four

01:32:10.320 times and embed, coming back down to an embed, when we go back to the projection. So adding a

01:32:15.760 bit of computation here and growing that layer that is in the residual block on the side of the

01:32:21.040 residual pathway. And then I trained this and we actually get down all the way to 2.08 validation

01:32:27.440 loss. And we also see that network is starting to get big enough that our train loss is getting

01:32:31.760 ahead of validation loss. So we started to see like a little bit of overfitting. And our, our,

01:32:39.920 generations here are still not amazing. But at least you see that we can see that is here this now

01:32:44.720 grief sync, like this starts to almost look like English. So yeah, we're starting to really get

01:32:50.800 there. Okay, and the second innovation that is very helpful for optimizing very deep neural

01:32:55.120 works is right here. So we have this addition now that's the residual part of this norm is referring

01:33:00.400 to something called layer norm. So layer norm is implemented in PyTorch. It's a paper that came out

01:33:05.920 a while back here. And layer norm is very, very similar to bashroom. So remember back to

01:33:13.600 our make more series part three, we implemented bashroom realization. And bashroom realization

01:33:19.760 basically just made sure that across the bash dimension, any individual neuron had unit,

01:33:27.360 Gaussian distribution. So it was zero mean and unit standard deviation, one standard deviation

01:33:34.320 upward. So what I did here is I'm copy pasting the bashroom 1D that we developed in our make more

01:33:39.680 series. And see here we can initialize, for example, this module, and we can have a batch of 32,

01:33:46.640 100 dimensional vectors feeding through the bashroom layer. So what this does is it guarantees

01:33:53.360 that when we look at just the zero column, it's a zero mean one standard deviation. So it's

01:34:00.080 normalizing every single column of this input. Now the rows are not going to be normalized by

01:34:06.720 default, because we're just normalizing columns. So let's not implement layer norm. It's very

01:34:12.320 complicated. Look, we come here, we change this from zero to one. So we don't normalize the columns,

01:34:19.440 we normalize the rows. And now we've implemented layer norm. So now the columns are not going to

01:34:27.280 be normalized. But the rows are going to be normalized for every individual example, it's 100

01:34:34.160 dimensional vector is normalized in this way. And because our computation now does not span

01:34:39.680 across examples, we can delete all of this buffers stuff, because we can always apply this

01:34:47.040 operation, and don't need to maintain any running buffers. So we don't need the buffers. We don't,

01:34:54.000 there's no distinction between training and test time. And we don't need these running buffers.

01:35:01.600 We do keep gamma and beta. We don't need the momentum. We don't care if it's training or not.

01:35:06.320 And this is now a layer norm. And it normalizes the rows instead of the columns. And this here

01:35:15.600 is identical to basically this here. So let's now implement layer norm in our transformer.

01:35:22.320 Before I incorporate the layer norm, I just wanted to note that as I said, very few details about

01:35:26.960 the transformer have changed in the last five years. But this is actually something that's likely

01:35:30.640 the parts from the original paper. You see that the ad and norm is applied after the transformation.

01:35:36.160 But in now it is a bit more basically common to apply the layer norm before the transformation.

01:35:44.320 So there's a reshuffle in the layer norms. So this is called the pre norm formulation,

01:35:48.800 and that the one that we're going to implement as well. So select deviation from the original

01:35:52.000 paper. Basically, we need to learn rooms. Layer norm one is an end dot layer norm.

01:35:58.160 And we tell it how many words the embedding dimension. And we need the second layer norm.

01:36:04.160 And then here, the layer arms are applied immediately on X. So self dot layer norm one

01:36:10.960 in applied on X and self dot layer, no two applied on X before it goes into self attention and feed

01:36:17.120 forward. And the size of the layer norm here is an embeds of 32. So when the layer norm is normalizing

01:36:25.040 our features, it is the normalization here happens. The mean and the variance are taken over 32 numbers.

01:36:34.000 So the batch and the time act as batch dimensions, both of them. So this is kind of like a per token

01:36:40.800 transformation that just normalizes the features and makes them a unit mean, unit Gaussian at

01:36:47.600 initialization. But of course, because these layer norms inside it have these gamma and beta

01:36:53.200 trainable parameters. The layer normal eventually create outputs that might not be unit Gaussian,

01:37:00.640 but the optimization will determine that. So for now, this is the this is incorporating the

01:37:06.160 layer norms. And let's train them up. Okay, so I let it run. And we see that we get down to 2.06,

01:37:11.840 which is better than the previous 2.08. So a slight improvement by adding the layer norms.

01:37:16.560 And I'd expect that they help even more if we had bigger and deeper network. One more thing I

01:37:21.520 forgot to add is that there should be a layer norm here also typically, as at the end of the

01:37:27.040 transformer and right before the final linear layer that decodes into vocabulary. So I added

01:37:33.120 that as well. So at this stage, we actually have a pretty complete transformer, according to the

01:37:38.000 original paper. And it's a decoder only transformer. I'll talk about that in a second. But at this

01:37:43.600 stage, the major pieces are in place. So we can try to scale this up and see how well we can push

01:37:48.000 this number. Now, in order to scale up the model, I had to perform some cosmetic changes here to

01:37:53.440 make it nicer. So I introduced this variable called n layer, which just specifies how many layers

01:37:58.480 of the blocks we're going to have. I create a bunch of blocks and we have a new variable number of

01:38:03.520 heads as well. I pulled out the layer norm here. And so this is identical. Now, one thing that I did

01:38:10.080 briefly change is I added dropout. So dropout is something that you can add right before the

01:38:16.720 residual connection back, right before the connection back into the residual pathway.

01:38:20.800 So we can drop out that as the last layer here. We can drop out here at the end of the multi-headed

01:38:27.760 transition as well. And we can also drop out here when we calculate the basically

01:38:34.640 affinities and after the softmax, we can drop out some of those. So we can randomly prevent

01:38:39.760 some of the nodes from communicating. And so dropout comes from this paper from 2014 or so.

01:38:46.720 And basically it takes your neural mat. And it randomly, every forward backward pass,

01:38:54.080 shuts off some subset of neurons. So randomly drops them to zero and trains without them.

01:39:01.440 And what this does effectively is because the mask of what's being dropped out has changed

01:39:06.880 every single forward backward pass, it ends up kind of training an ensemble of sub networks.

01:39:12.400 And then at test time, everything is fully enabled and kind of all of those sub networks are merged

01:39:17.200 into a single ensemble if you can, if you want to think about it that way. So I would read the paper

01:39:21.920 to get the full detail. For now, we're just going to stay on the level of this is a regularization

01:39:26.640 technique. And I added it because I'm about to scale up the model quite a bit. And I was concerned

01:39:31.200 about overfitting. So now when we scroll up to the top, we'll see that I changed a number of

01:39:37.360 hyper parameters here about our neural mat. So I made the batch size be much longer. Now 64.

01:39:42.240 I changed the block size to be 256. So previously was just eight, eight characters of context. Now it

01:39:48.560 is 256 characters of context to predict the 257th. I brought down the learning rate a little bit

01:39:56.080 because the neural mat is now much bigger. So I brought down the learning way. The embedding

01:40:00.480 dimension is now 384. And there are six heads. So 384 divide six means that every head is 64

01:40:08.000 dimensional as it as a standard. And then there are going to be six layers of that. And the dropout

01:40:14.400 will be a point to so every forward backward pass 20% of all of these intermediate calculations

01:40:21.280 are disabled and dropped to zero. And then I already trained this and I ran it. So drumroll,

01:40:27.600 how does it perform? So let me just scroll up here. We get a validation loss of 1.48, which is

01:40:36.000 actually quite a bit of an improvement on what we had before, which I think was 2.07. So we went

01:40:40.720 from 2.07 all the way down to 1.48 just by scaling up this neural mat with the code that we have.

01:40:45.520 And this of course ran for a lot longer. This may be trained for I want to say about 15 minutes

01:40:51.120 on my a 100 GPU. So that's a pretty good GPU. And if you don't have a GPU, you're not going to be

01:40:55.920 able to reproduce this on a CPU. This would be I would not run this on a CPU or MacBook or

01:41:01.840 something like that. You'll have to bring down the number of layers and the embedding dimension

01:41:06.560 and so on. But in about 15 minutes, we can get this kind of a result. And I'm printing

01:41:12.960 some of the Shakespeare here. But what I did also is I printed 10,000 characters, so a lot more.

01:41:18.400 And I wrote them to a file. And so here we see some of the outputs.

01:41:21.760 So it's a lot more recognizable as the input text file. So the input text file just for reference

01:41:29.920 looked like this. So there's always like someone speaking in this matter. And our predictions now

01:41:37.600 take on that form, except of course they're non-susical when you actually read them. So

01:41:42.960 it is every crimty be house. Oh, those preparation. We give heed.

01:41:50.320 You know, oh, oh, sent me you mighty lord.

01:41:58.720 [laughs]

01:42:00.640 Anyway, so you can read through this. It's nonsensical, of course, but this is just a

01:42:05.200 transformer trained on the character level for 1 million characters that come from Shakespeare.

01:42:10.160 So there's sort of like blabbers on in Shakespeare like manner, but it doesn't, of course, make sense

01:42:15.200 at this scale. But I think I think still a pretty good demonstration of what possible.

01:42:21.600 So now I think that kind of like concludes the programming section of this video.

01:42:28.400 We basically kind of did a pretty good job in implementing this transformer, but the picture

01:42:34.880 doesn't exactly match up to what we've done. So what's going on with all these additional parts

01:42:39.120 here? So let me finish explaining this architecture and why it looks so funky. Basically what's happening

01:42:44.800 here is what we implemented here is a decoder only transformer. So there's no component here.

01:42:51.040 This part is called the encoder and there's no cross-attention block here. Our block only has a

01:42:57.040 self-attention and the feed forward. So it is missing this third in between piece here. This piece

01:43:03.200 does cross-attention. So we don't have it and we don't have the encoder. We just have the decoder.

01:43:08.080 And the reason we have a decoder only is because we are just generating text and it's

01:43:13.360 unconditioned on anything. We're just blabbering on according to a given dataset. What makes it a

01:43:19.040 decoder is that we are using the triangular mask in our transformer. So it has this autoregressive

01:43:25.040 property where we can just go and sample from it. So the fact that it's using the triangular mask

01:43:31.440 to mask out the attention makes it a decoder and it can be used for language modeling.

01:43:35.840 Now the reason that the original paper had an encoder decoder architecture is because it is a

01:43:41.520 machine translation paper. So it is concerned with a different setting in particular. It expects some

01:43:48.960 tokens that encode, say for example, French. And then it is expected to decode the translation

01:43:54.560 in English. So typically these here are special tokens. So you are expected to read in this and

01:44:01.760 condition on it. And then you start off the generation with a special token called Start.

01:44:06.560 So this is a special new token that you introduce and always place in the beginning.

01:44:12.240 And then the network is expected to output neural networks are awesome and then a special end token

01:44:18.400 to finish the generation. So this part here will be decoded exactly as we have done it. Neural

01:44:25.680 networks are awesome. We'll be identical to what we did. But unlike what we did, they want to

01:44:31.120 condition the generation on some additional information. And in that case, this additional

01:44:36.880 information is the French sentence that they should be translating. So what they do now

01:44:42.400 is they bring the encoder. Now the encoder reads this part here. So we're all going to take the

01:44:49.040 part of French and we're going to create tokens from it exactly as we've seen in our video. And

01:44:54.720 we're going to put a transformer on it. But there's going to be no triangular mask. And so all the

01:44:59.920 tokens are allowed to talk to each other as much as they want. And they're just encoding

01:45:04.240 whatever's the content of this French sentence. Once they've encoded it, they basically come out

01:45:11.840 in the top here. And then what happens here is in our decoder, which does the language modeling,

01:45:18.000 there's an additional connection here to the outputs of the encoder. And that is brought in through

01:45:24.640 cross attention. So the queries are still generated from X. But now the keys and the values are coming

01:45:31.200 from the side. The keys and the values are coming from the top generated by the nodes that came

01:45:36.800 outside of the decoder. And those tops, the keys and the values there, the top of it, feed in on a

01:45:44.320 side into every single block of the decoder. And so that's why there's an additional cross

01:45:49.120 attention. And really what it's doing is it's conditioning the decoding, not just on the past of

01:45:54.960 this current decoding, but also on having seen the full fully encoded French prompt, sort of.

01:46:04.000 And so it's an encoded decoder model, which is why we have those two transformers, an additional

01:46:08.800 block, and so on. So we did not do this because we have no we have nothing to encode, there's no

01:46:13.600 conditioning, we just have a text file, and we just want to imitate it. And that's why we are using a

01:46:18.160 decoder only transformer exactly as done in GPT. Okay, so now I wanted to do a very brief walk

01:46:24.720 through of nano GPT, which you can find on my GitHub. And nano GPT is basically two files of interest.

01:46:31.040 There's trained up by and modeled up by trained up by is all the boilerplate code for training the

01:46:36.240 network. It is basically all the stuff that we had here is the training loop. It's just that it's

01:46:42.480 a lot more complicated because we're saving and loading checkpoints and pre trained weights. And

01:46:46.720 we are decaying the learning rate and compiling the model and using distributed training across

01:46:51.520 multiple nodes or GPUs. So the training that pogates a little bit more hairy complicated,

01:46:57.280 there's more options, etc. But the model that I should look very, very similar to what we've done

01:47:03.520 here. In fact, the model is is almost identical. So first, here we have the causal self attention

01:47:10.160 block. And all of this should look very, very recognizable to you. We're producing queries, keys,

01:47:15.040 values, we're doing dot products, we're masking, applying softmax, optionally dropping out. And

01:47:21.600 here we are pulling the way the values. What is different here is that in our code, I have

01:47:28.080 separated out the multi headed attention into just a single individual head. And then here,

01:47:34.400 I have multiple heads and I explicitly concatenate them. Whereas here, all of it is implemented in

01:47:40.240 a batch manner inside a single causal self attention. And so we don't just have a B and a T and a C

01:47:45.760 dimension. We also end up with a fourth dimension, which is the heads. And so it just gets a lot more

01:47:51.360 sort of hairy because we have four dimensional array tensors now, but it is equivalent mathematically.

01:47:58.320 So the exact same thing is happening as what we have. It's just a bit more efficient because all

01:48:02.720 the heads are now treated as a batch dimension as well. Then we have the multiple perceptron.

01:48:08.080 It's using the Galoon on linearity, which is defined here, instead of RALU. And this is done

01:48:13.760 just because opening I used it and I want to be able to load their checkpoints. The blocks of the

01:48:18.880 transformer are identical, the communicate and the compute phase as we saw. And then the GPT will

01:48:24.080 be identical. We have the position encodings, token encodings, the blocks, the layer norm at the end,

01:48:29.280 the final linear layer. And this should look all very recognizable. And there's a bit more here

01:48:36.080 because I'm loading checkpoints and stuff like that. I'm separating out the parameters into

01:48:40.080 those that should be weight-decade and those that shouldn't. But the generate function

01:48:44.880 should also be very, very similar. So a few details are different, but you should definitely be able

01:48:49.360 to look at this file and be able to understand a lot of the pieces now. So let's now bring things

01:48:54.880 back to chat GPT. What would it look like if we wanted to train chat GPT ourselves and how does

01:48:59.840 it relate to what we learned today? Well, to train in chat GPT, there are roughly two stages.

01:49:04.960 First is the pre-training stage and then the fine-tuning stage. In the pre-training stage,

01:49:10.080 we are training on the large chunk of internet and just trying to get a first decoder-only

01:49:16.000 transformer to babble text. So it's very, very similar to what we've done ourselves.

01:49:21.360 Except we've done like a tiny little baby pre-training step. And so in our case,

01:49:27.920 this is how you print a number of parameters. I printed it and it's about 10 million. So

01:49:34.000 this transformer that I created here to create a little Shakespeare transformer was about 10

01:49:40.320 million parameters. Our dataset is roughly 1 million characters, so roughly 1 million tokens.

01:49:46.800 But you have to remember that opening eyes is different vocabulary. They're not on the

01:49:50.000 character level. They use these sub-word chunks of words. And so they have a vocabulary of 50,000

01:49:56.480 roughly elements. And so their sequences are a bit more condensed. So our dataset, the Shakespeare

01:50:02.880 dataset would be probably around 300,000 tokens in the opening eye vocabulary roughly.

01:50:08.080 So we trained about 10 million parameter model and roughly 300,000 tokens. Now when you go to the

01:50:14.800 GPT-3 paper and you look at the transformers that they train, they train to a number of

01:50:22.400 transformers of different sizes. But the biggest transformer here has 175 billion parameters.

01:50:28.400 So ours is again 10 million. They used this number of layers in a transformer. This is the end in bed.

01:50:34.160 This is the number of heads. And this is the head size. And then this is the batch size. So

01:50:41.440 ours was 65. And the learning rate is similar. Now when they train this transformer, they

01:50:47.840 train on 300 billion tokens. So again, remember ours is about 300,000. So this is about a million

01:50:55.520 fold increase. And this number would not be even that large by today's standards. You'd be going up

01:50:59.920 one trillion above. So they are training a significantly larger model

01:51:05.920 on a good chunk of the internet. And that is the pre-training stage. But otherwise,

01:51:12.400 these hyperparameters should be fairly recognizable to you. And the architecture is actually like

01:51:16.640 nearly identical to what we implemented ourselves. But of course, it's a massive infrastructure

01:51:21.040 challenge to train this. You're talking about typically thousands of GPUs having to talk to

01:51:26.880 each other to train models of this size. So that's just the pre-training stage. Now, after you complete

01:51:32.480 the pre-training stage, you don't get something that responds to your questions with answers,

01:51:37.920 and it's not helpful and etc. You get a document complete. Right? So it babbles, but it doesn't

01:51:44.720 babble Shakespeare in babbles internet. It will create arbitrary news articles and documents,

01:51:49.440 and it will try to complete documents because that's what it's trained for. It's trying to

01:51:52.480 complete the sequence. So when you give it a question, it would just potentially just give you

01:51:58.000 more questions. It would follow with more questions. It will do whatever it looks like some close

01:52:02.880 document would do in the training data on the internet. And so who knows, you're getting kind of

01:52:07.680 like undefined behavior. It might basically answer with two questions with other questions. It might

01:52:13.280 ignore your question. It might just try to complete some news article. It's totally

01:52:17.520 unaligned, as we say. So the second fine-tuning stage is to actually align it to be an assistant.

01:52:23.520 And this is the second stage. And so this chat GPT blog post from OpenAI talks a little bit about

01:52:30.320 how this stage is achieved. We basically, there's roughly three steps to this stage.

01:52:37.200 So what they do here is they start to collect training data that looks specifically like

01:52:42.480 what an assistant would do. So there are documents that have the format where the question is on top

01:52:46.880 and then an answer is below. And they have a large number of these, but probably not on the order of

01:52:51.840 the internet. This is probably on the order of maybe thousands of examples. And so they then

01:52:58.800 fine-tuned the model to basically only focus on documents that look like that. And so you're

01:53:04.800 starting to slowly align it. So it's going to expect a question at the top and it's going to expect

01:53:09.040 to complete the answer. And these very, very large models are very sample efficient during their

01:53:14.880 fine-tuning. So this actually somehow works. But that's just step one. That's just fine-tuning.

01:53:20.160 So then they actually have more steps where, okay, the second step is you let the model respond

01:53:25.200 and then different raiders look at the different responses and rank them for their preferences to

01:53:30.000 which one is better than the other. They use that to train the reward model. So they can predict

01:53:35.040 basically using a different network, how much of any candidate response would be desirable.

01:53:42.720 And then once they have a reward model, they run PPO, which is a form of policy gradient

01:53:47.840 reinforcement learning optimizer to fine-tune this sampling policy so that the answers that

01:53:55.600 GPT now generates are expected to score a high reward according to the reward model.

01:54:02.640 And so basically there's a whole aligning stage here or fine-tuning stage. It's got multiple steps

01:54:08.560 in between there as well. And it takes the model from being a document completer to a question

01:54:15.200 answer. And that's like a whole separate stage. A lot of this data is not available publicly.

01:54:20.640 It is enthralled to open AI and it's much harder to replicate this stage. And so that's roughly what

01:54:27.680 would give you a chat GPT. And nano GPT focuses on the pre-training stage. Okay, and that's everything

01:54:33.120 that I wanted to cover today. So we trained to summarize a decoder-only transformer following this

01:54:40.720 famous paper attention is all you need from 2017. And so that's basically a GPT. We trained it on

01:54:47.920 a tiny Shakespeare and got sensible results. All of the training code is roughly 200 lines of code.

01:54:57.040 I will be releasing this codebase. So also it comes with all the Git log commits along the way as we

01:55:04.720 built it up. In addition to this code, I'm going to release the notebook, of course, the Google

01:55:11.120 collab. And I hope that gave you a sense for how you can train these models like say GPT-3.

01:55:18.000 That will be artistially basically identical to what we have. But they are somewhere between 10,000

01:55:23.200 and 1 million times bigger depending on how you count. And so that's all I have for now. We did not

01:55:30.720 talk about any of the fine-tuning stages that would typically go on top of this. So if you're

01:55:34.880 interested in something that's not just language modeling, but you actually want to say perform

01:55:38.880 tasks, or you want them to be aligned in a specific way, or you want to detect sentiment or anything

01:55:45.760 like that. Basically, anytime you don't want something that's just a document completer,

01:55:49.440 you have to complete further stages of fine-tuning, which we did not cover. And that could be simple

01:55:55.120 supervised fine-tuning, or it can be something more fancy like we see in Chai ChaiPT, where we

01:55:59.200 actually train a reward model and then do rounds of PPO to align it with respect to the reward model.

01:56:04.560 So there's a lot more that can be done on top of it. I think for now we're starting to get to

01:56:08.800 about two hours mark. So I'm going to kind of finish here. I hope you enjoyed the lecture.

01:56:15.920 and yeah go forth and transform. See you later.