00:00:00.000 Hi everyone.

00:00:01.200 Today we are continuing our implementation of Makemore.

00:00:04.200 Now in the last lecture, we implemented the Multilayer

00:00:06.080 Perceptron along the lines of Benjio et al 2003

00:00:08.880 for character level language modeling.

00:00:10.720 So we followed this paper, took in a few characters

00:00:13.160 in the past, and used an MLP to predict the next character

00:00:15.800 in a sequence.

00:00:17.320 So what we'd like to do now is we'd

00:00:18.520 like to move on to more complex and larger neural networks,

00:00:21.480 like recurrent neural networks, and their variations

00:00:23.640 like the GRU, LSTM, and so on.

00:00:26.440 Now before we do that though, we have to stick around the level

00:00:29.000 of multilayer perception for a bit longer.

00:00:31.600 And I'd like to do this because I would like us

00:00:33.280 to have a very good intuitive understanding

00:00:35.400 of the activations in the neural net during training,

00:00:38.360 and especially the gradients that are flowing backwards,

00:00:40.880 and how they behave, and what they look like.

00:00:43.360 And this is going to be very important

00:00:44.720 to understand the history of the development

00:00:46.480 of these architectures.

00:00:48.040 Because we'll see that recurrent neural networks,

00:00:49.800 while they are very expressive, in that they

00:00:52.400 are a universal approximator, and can in principle implement

00:00:55.840 all the algorithms, we'll see that they are not

00:00:58.760 very easily optimizable with the first order of gradient-based

00:01:01.680 techniques that we have available to us,

00:01:03.080 and that we use all the time.

00:01:04.680 And the key to understanding why they are not

00:01:06.920 optimizable easily is to understand the activations

00:01:10.520 and the gradients and how they behave during training.

00:01:12.680 And we'll see that a lot of the variants,

00:01:14.120 since recurrent neural networks, have

00:01:16.400 tried to improve that situation.

00:01:19.120 And so that's the path that we have to take,

00:01:21.560 and let's get started.

00:01:22.880 So the starting code for this lecture

00:01:24.280 is largely the code from before, but I've cleaned it up

00:01:27.160 a little bit.

00:01:28.320 So you'll see that we are importing all the torch and map

00:01:32.280 plotlet utilities.

00:01:33.560 We're reading in the words just like before.

00:01:35.600 These are eight example words.

00:01:37.240 There's a total of 32,000 of them.

00:01:39.320 Here's a vocabulary of all the lowercase letters

00:01:41.680 and the special dot token.

00:01:44.360 Here we are reading the data set and processing it

00:01:47.760 and creating three splits, the train, dev, and the test

00:01:52.440 split.

00:01:53.680 Now in MLP, this is the identical same MLP,

00:01:56.400 except you see that I removed a bunch of magic numbers

00:01:58.840 that we had here.

00:01:59.760 And instead we have the dimensionality

00:02:01.600 of the embedding space of the characters

00:02:03.520 and the number of hidden units in the hidden layer.

00:02:06.120 And so I've pulled them outside here

00:02:07.880 so that we don't have to go and change all these magic numbers

00:02:10.400 all the time.

00:02:11.680 With the same neural net with 11,000 parameters

00:02:14.200 that we optimize now over 200,000 steps

00:02:16.600 with batch size of 32.

00:02:18.320 And you'll see that I refactored the code here

00:02:21.320 a little bit, but there are no functional changes.

00:02:23.640 I just created a few extra variables, a few more comments,

00:02:27.240 and I removed all the magic numbers

00:02:29.280 and otherwise is the exact same thing.

00:02:32.040 Then when we optimize, we saw that our loss looked

00:02:34.600 something like this.

00:02:36.000 We saw that the train and val loss were about 2.16

00:02:40.120 and so on.

00:02:41.720 Here I refactored the code a little bit

00:02:44.280 for the evaluation of arbitrary splits.

00:02:47.080 So you pass in the string of which split you'd

00:02:48.920 like to evaluate.

00:02:50.120 And then here, depending on train, val or test,

00:02:53.360 I index in and I get the correct split.

00:02:55.560 And then this is the forward pass of the network

00:02:57.280 and evaluation of the loss and printing it.

00:03:00.040 So just making it nicer.

00:03:02.720 One thing that you'll notice here is

00:03:05.360 I'm using a decorator torch.no grad,

00:03:07.640 which you can also look up and read documentation of.

00:03:11.360 Basically what this decorator does on top of a function

00:03:14.400 is that whatever happens in this function

00:03:17.640 is seen by torch to never require an ingredient.

00:03:22.000 So it will not do any of the bookkeeping

00:03:24.240 that it does to keep track of all the gradients

00:03:26.760 in anticipation of an eventual backward pass.

00:03:29.640 It's almost as if all the tensors that get created here

00:03:31.960 have a requires grad of false.

00:03:34.520 And so it just makes everything much more efficient

00:03:36.440 because you're telling torch that I will not call

00:03:38.560 dot backward on any of this computation

00:03:40.680 and you don't need to maintain the graph under the hood.

00:03:43.720 So that's what this does.

00:03:45.600 And you can also use a context manager

00:03:48.520 with torch dot no grad and you can look those up.

00:03:52.040 Then here we have the sampling from a model.

00:03:54.960 Just as before, just a poor passive in neural net,

00:03:58.280 getting the distribution, sampling from it,

00:04:00.640 adjusting the context window and repeating

00:04:02.800 until we get the special and token.

00:04:04.840 And we see that we are starting to get

00:04:06.840 much nicer looking words, simple from the model.

00:04:09.800 It's still not amazing and they're still not fully name like

00:04:13.320 but it's much better than when we had to work

00:04:14.920 with the bigerm model.

00:04:16.000 So that's our starting point.

00:04:19.160 Now the first thing I would like to scrutinize

00:04:20.680 is the initialization.

00:04:22.640 I can tell that our network is very improperly configured

00:04:26.160 at initialization and there's multiple things wrong with it

00:04:29.080 but let's just start with the first one.

00:04:31.200 Look here on the zero federation, the very first iteration.

00:04:34.880 We are recording a loss of 27 and this rapidly comes down

00:04:38.280 to roughly one or two or so.

00:04:40.320 So I can tell that the initialization is almost up

00:04:42.200 because this is way too high.

00:04:44.400 In training of neural nets, it is almost always the case

00:04:46.920 that you will have a rough idea for what loss to expect

00:04:49.400 at initialization and that just depends on the loss function

00:04:52.840 and the problem setup.

00:04:54.760 In this case, I do not expect 27.

00:04:57.120 I expect a much lower number and we can calculate it together.

00:05:00.600 Basically at initialization, what we'd like is that

00:05:04.920 there's 27 characters that could come next

00:05:07.200 for any one training example.

00:05:09.040 At initialization, we have no reason to believe any characters

00:05:11.720 to be much more likely than others.

00:05:13.680 And so we'd expect that the probability distribution

00:05:15.800 that comes out initially is a uniform distribution

00:05:19.160 assigning about equal probability to all the 27 characters.

00:05:22.520 So basically what we'd like is the probability

00:05:25.720 for any character would be roughly one over 27.

00:05:30.160 That is the probability we should record

00:05:33.880 and then the loss is the negative log probability.

00:05:36.640 So let's wrap this in a tensor

00:05:38.240 and then then we can take the log of it.

00:05:42.080 And then the negative log probability

00:05:44.040 is the loss we would expect, which is 3.29,

00:05:47.600 much, much lower than 27.

00:05:49.920 And so what's happening right now is that at initialization,

00:05:52.880 the neural net is creating probability distributions

00:05:55.080 that are all messed up.

00:05:56.280 Some characters are very confident

00:05:58.120 and some characters are very not confident.

00:06:00.640 And then basically what's happening

00:06:01.880 is that the network is very confidently wrong

00:06:05.240 and that's what makes it record very high loss.

00:06:10.240 So here's a smaller four dimensional example of the issue.

00:06:13.400 Let's say we only have four characters

00:06:15.960 and then we have logits that come out of the neural net

00:06:18.560 and they are very, very close to zero.

00:06:20.920 Then when we take the softmax of all zeros,

00:06:23.800 we get probabilities that are a diffuse distribution.

00:06:27.400 So sums to one and is exactly uniform.

00:06:31.120 And then in this case, if the label is say two,

00:06:33.760 it doesn't actually matter if the label is two or three

00:06:37.160 or one or zero because it's a uniform distribution,

00:06:39.920 we're recording the exact same loss,

00:06:41.320 in this case 1.38.

00:06:43.160 So this is the loss we would expect for a four dimensional example.

00:06:46.120 And now you can see of course that as we start

00:06:48.080 to manipulate these logits,

00:06:50.560 we're going to be changing the loss here.

00:06:52.480 So it could be that we lock out and by chance,

00:06:55.800 this could be a very high number like five or something like that.

00:06:59.320 Then in that case, we'll record a very low loss

00:07:01.080 because we're signing the correct probability

00:07:02.840 at initialization by chance to the correct label.

00:07:06.720 Much more likely it is that some other dimension

00:07:10.400 will have a high logit.

00:07:14.040 And then what will happen is we start to record

00:07:15.520 much higher loss.

00:07:17.160 And what can happen is basically the logits come out

00:07:20.320 like something like this,

00:07:22.200 and they take on extreme values

00:07:24.400 and we record really high loss.

00:07:26.560 For example, if we have torched out a random of four,

00:07:31.680 so these are uniformly distributed numbers for them.

00:07:36.680 And here we can also print the logits,

00:07:41.680 probabilities that come out of it and loss.

00:07:45.680 And so because these logits are near zero,

00:07:49.680 for the most part, the loss that comes out is okay.

00:07:52.680 But suppose this is like times 10 now.

00:07:55.680 You see how because these are more extreme values,

00:08:00.680 it's very unlikely that you're going to be guessing

00:08:03.680 the correct bucket.

00:08:05.680 And then you're confidently wrong

00:08:07.560 and recording very high loss.

00:08:09.720 If your logits are coming up even more extreme,

00:08:12.720 you might get extremely, you know,

00:08:14.680 same losses like infinity even at initialization.

00:08:17.600 So basically this is not good

00:08:21.680 and we want the logits to be roughly zero

00:08:23.960 when the network is initialized.

00:08:27.880 In fact, the logits don't have to be just zero,

00:08:30.040 they just have to be equal.

00:08:31.240 So for example, if all the logits are one,

00:08:34.160 then because of the normalization inside the softmax,

00:08:36.760 this will actually come out okay.

00:08:38.560 But by symmetry, we don't want it to be any arbitrary,

00:08:40.640 positive or negative number.

00:08:42.120 We just want it to be all zeros

00:08:43.600 and record the loss that we expect at initialization.

00:08:46.400 So let's now come quickly see

00:08:47.520 where things go wrong in our example.

00:08:49.800 Here we have the initialization.

00:08:51.640 Let me reinitialize the neural out.

00:08:53.520 And here let me break after the very first iteration.

00:08:56.040 So we only see the initial loss, which is 27.

00:09:00.160 So that's way too high.

00:09:01.400 And intuitively, now we can expect the variables involved.

00:09:04.280 And we see that the logits here,

00:09:05.920 if we just print some of these,

00:09:07.880 if we just print the first row,

00:09:10.960 we see that the logits take on quite extreme values.

00:09:13.840 And that's what's creating the fake confidence

00:09:16.440 in incorrect answers and makes the loss

00:09:19.160 get very, very high.

00:09:22.120 So these logits should be much, much closer to zero.

00:09:25.360 So now let's think through how we can achieve.

00:09:27.560 Logits coming out of this neural net

00:09:29.960 to be more closer to zero.

00:09:32.560 You see here that logits are calculated

00:09:34.160 as the hidden states multiplied by W2 plus B2.

00:09:37.680 So first of all, currently we're initializing B2

00:09:40.480 as random values of the right size.

00:09:44.280 But because we want roughly zero,

00:09:46.680 we don't actually want to be adding a bias of random numbers.

00:09:49.280 So in fact, I'm going to add a times zero here

00:09:51.920 to make sure that B2 is just basically zero

00:09:55.680 at initialization.

00:09:57.480 And second, this is H multiplied by W2.

00:10:00.360 So if we want logits to be very, very small,

00:10:03.000 then we would be multiplying W2 and making that smaller.

00:10:06.000 So for example, if we scale down W2 by 0.1,

00:10:09.960 all the elements, then if I do again,

00:10:13.120 just a very first situation,

00:10:14.440 you see that we are getting much closer to what we expect.

00:10:17.360 So roughly what we want is about 3.29, this is 4.2.

00:10:22.360 I can make this maybe even smaller, 3.32.

00:10:26.400 Okay, so we're getting closer and closer.

00:10:28.640 Now, you're probably wondering,

00:10:30.880 can we just set this to zero?

00:10:32.680 Then we get, of course, exactly what we're looking for

00:10:35.440 at initialization.

00:10:38.120 And the reason I don't usually do this

00:10:40.240 is because I'm very nervous,

00:10:42.360 and I'll show you in a second why you don't want to be

00:10:44.880 setting W's or weights of a neural net exactly to zero.

00:10:49.120 You usually want it to be small numbers

00:10:51.120 instead of exactly zero.

00:10:52.480 For this output layer in this specific case,

00:10:55.600 I think it would be fine,

00:10:57.320 but I'll show you in a second where things go wrong

00:10:58.960 very quickly if you do that.

00:11:00.720 So let's just go with 0.01.

00:11:02.960 In that case, our loss is close enough,

00:11:05.200 but has some entropy.

00:11:06.600 It's not exactly zero.

00:11:08.360 It's got some little entropy,

00:11:10.320 and that's used for symmetry breaking,

00:11:11.800 as we'll see in a second.

00:11:13.640 Logits are now coming out much closer to zero,

00:11:16.200 and everything is well and good.

00:11:18.160 So if I just erase these,

00:11:21.160 and I now take away the break statement,

00:11:24.960 we can run the optimization with this new initialization,

00:11:28.360 and let's just see what losses we record.

00:11:32.560 Okay, so I'll let it run,

00:11:33.960 and you see that we started off good,

00:11:35.720 and then we came down a bit.

00:11:37.160 The plot of the loss now doesn't have

00:11:40.440 this hockey shape appearance,

00:11:42.240 because basically what's happening in the hockey stick,

00:11:45.720 the very first few iterations of the loss,

00:11:48.040 what's happening during the optimization,

00:11:50.000 is the optimization is just squashing down the logits,

00:11:52.920 and then it's rearranging the logits.

00:11:55.040 So basically we took away this easy part

00:11:57.480 of the loss function,

00:11:58.680 where just the weights were just being shrunk down.

00:12:01.840 And so therefore, we don't get these easy gains

00:12:04.880 in the beginning,

00:12:05.720 and we're just getting some of the hard gains

00:12:07.360 of training the actual neural nut,

00:12:08.880 and so there's no hockey stick appearance.

00:12:11.440 So good things are happening in that both, number one,

00:12:14.560 loss of initialization is what we expect,

00:12:17.000 and the loss doesn't look like a hockey stick,

00:12:20.640 and this is true for any neural nut you might train,

00:12:23.720 and something to look out for.

00:12:25.600 And second, the loss that came out

00:12:27.600 is actually quite a bit improved.

00:12:29.560 Unfortunately, I erased what we had here before.

00:12:31.920 I believe this was 2.12,

00:12:34.960 and this was 2.16.

00:12:37.280 So we get a slightly improved result,

00:12:40.120 and the reason for that is,

00:12:41.560 because we're spending more cycles, more time,

00:12:44.360 optimizing the neural nut actually,

00:12:46.480 instead of just spending the first

00:12:49.000 several thousand iterations, probably just squashing down the weights,

00:12:53.200 because they are so way too high

00:12:54.800 in the beginning of the initialization.

00:12:56.840 So something to look out for, and that's number one.

00:13:00.040 Now let's look at the second problem.

00:13:01.760 Let me reinitialize our neural nut,

00:13:03.480 and let me reintroduce the break statement.

00:13:06.000 So we have a reasonable initial loss.

00:13:08.560 So even though everything is looking good

00:13:09.920 on the level of the loss,

00:13:10.960 and we get something that we expect,

00:13:12.640 there's still a deeper problem

00:13:14.160 lurking inside this neural nut, and it's initialization.

00:13:17.440 So the logits are now okay.

00:13:19.880 The problem now is with the values of h,

00:13:23.000 the activations of the hidden states.

00:13:25.360 Now if we just visualize this vector, sorry, this tensor h,

00:13:29.080 it's kind of hard to see,

00:13:29.920 but the problem here, roughly speaking,

00:13:31.720 is you see how many of the elements are one or negative one.

00:13:36.000 Now recall that torch.10h,

00:13:38.040 the 10h function, is a squashing function.

00:13:40.520 It takes arbitrary numbers,

00:13:41.760 and it squashes them into a range of negative one and one,

00:13:44.360 and it does so smoothly.

00:13:46.160 So let's look at the histogram of h

00:13:47.920 to get a better idea of the distribution

00:13:49.960 of the values inside this tensor.

00:13:52.360 We can do this first.

00:13:53.600 Well, we can see that h is 32 examples

00:13:58.040 and 200 activations in each example.

00:14:00.840 We can view it as negative one

00:14:02.520 to stretch it out into one large vector,

00:14:05.520 and we can then call two list to convert this

00:14:09.560 into one large Python list of floats.

00:14:13.680 And then we can pass this into plt.hist for histogram,

00:14:17.640 and we say we want 50 bins,

00:14:20.040 and it's semicolon to suppress a bunch of output we don't want.

00:14:23.000 So we see this histogram,

00:14:25.440 and we see that most of the values by far

00:14:27.480 take on value of negative one and one.

00:14:30.040 So this 10h is very, very active.

00:14:33.160 And we can also look at basically why that is,

00:14:37.840 we can look at the pre-activations

00:14:39.200 that feed into the 10h.

00:14:42.760 And we can see that the distribution

00:14:44.120 of the pre-activations is very, very broad.

00:14:47.360 These take numbers between negative 15 and 15,

00:14:50.040 and that's why in a torture 10h,

00:14:51.920 everything is being squashed and capped

00:14:53.760 to be in the range of negative one and one,

00:14:55.680 and lots of numbers here take on very extreme values.

00:14:59.080 Now, if you are new to neural networks,

00:15:01.040 you might not actually see this as an issue,

00:15:03.360 but if you're well versed in the dark arts

00:15:05.320 of backpropagation and have an intuitive sense

00:15:07.840 of how these gradients flow through a neural net,

00:15:10.240 you are looking at your distribution

00:15:11.680 of 10h activations here and you are sweating.

00:15:14.880 So let me show you why.

00:15:16.320 We have to keep in mind that during backpropagation,

00:15:18.280 just like we saw in micro grad,

00:15:19.880 we are doing backward pass starting at the loss

00:15:22.080 and flowing through the network backwards.

00:15:24.680 In particular, we're going to back propagate

00:15:26.200 through this Torx.10h.

00:15:28.640 And this layer here is made up of 200 neurons

00:15:31.680 for each one of these examples,

00:15:33.640 and it implements an element twice 10h.

00:15:36.560 So let's look at what happens in 10h in the backward pass.

00:15:39.760 We can actually go back to our previous micro grad code

00:15:42.920 in the very first lecture and see how we implement a 10h.

00:15:46.000 We saw that the inputs here was x,

00:15:49.200 and then we calculate t, which is the 10h of x.

00:15:52.320 So that's t, and t is between negative one and one.

00:15:54.840 It's the output of the 10h.

00:15:56.360 And then in the backward pass,

00:15:57.440 how do we back propagate through a 10h?

00:16:00.040 We take out that grad and then we multiply it,

00:16:03.920 this is the chain rule with the local gradient,

00:16:06.160 which took the form of one minus t squared.

00:16:09.000 So what happens if the outputs of your 10h

00:16:11.360 are very close to negative one or one?

00:16:14.040 If you plug in t equals one here,

00:16:16.000 you're going to get zero, multiplying out that grad.

00:16:19.680 No matter what out that grad is,

00:16:21.160 we are killing the gradient,

00:16:22.880 and we're stopping effectively the backward propagation

00:16:25.560 through this 10h unit.

00:16:27.400 Similarly, when t is negative one,

00:16:29.160 this will again become zero,

00:16:30.520 and out that grad just stops.

00:16:32.920 And intuitively, this makes sense because

00:16:35.280 this is a 10h neuron,

00:16:37.080 and what's happening is,

00:16:39.000 if its output is very close to one,

00:16:41.280 then we are in the tail of this 10h.

00:16:43.920 And so changing basically the input

00:16:48.000 is not going to impact the output of the 10h too much,

00:16:52.120 because it's in a flat region of the 10h.

00:16:55.680 And so therefore, there's no impact on the loss.

00:16:58.520 And so indeed, the weights and the biases

00:17:02.480 along with this 10h neuron do not impact the loss,

00:17:05.520 because the output of this 10h unit

00:17:07.120 is in a flat region of the 10h,

00:17:08.720 and there's no influence.

00:17:09.640 We can be changing them whatever we want,

00:17:12.280 however we want, and the loss is not impacted.

00:17:14.600 That's another way to justify that indeed,

00:17:17.200 the gradient would be basically zero, it vanishes.

00:17:19.960 Indeed, when t equals zero,

00:17:24.400 we get one times out that grad.

00:17:27.320 So when the 10h takes on exactly value of zero,

00:17:31.200 then out that grad is just passed through.

00:17:34.920 So basically what this is doing, right,

00:17:36.400 is if t is equal to zero,

00:17:38.240 then the 10h unit is sort of inactive,

00:17:42.360 and gradient just passes through.

00:17:44.800 But the more you are in the flat tails,

00:17:47.240 the more the gradient is squashed.

00:17:49.480 So in fact, you'll see that the gradient flowing

00:17:52.040 through 10h can only ever decrease.

00:17:54.560 In the amount that it decreases

00:17:56.560 is proportional through a square here,

00:18:00.440 depending on how far you are in the flat tails

00:18:03.200 of this 10h.

00:18:05.040 And so that's kind of what's happening here.

00:18:07.040 And through the concern here is that if all of these outputs h

00:18:12.040 are in the flat regions of negative one and one,

00:18:15.000 then the gradients that are flowing through the network

00:18:17.320 will just get destroyed at this layer.

00:18:19.800 Now there is some redeeming quality here,

00:18:24.080 and that we can actually get a sense

00:18:25.400 of the problem here as follows.

00:18:27.760 I wrote some code here.

00:18:29.240 And basically what we want to do here

00:18:30.720 is we want to take a look at h,

00:18:32.920 take the absolute value and see how often it is

00:18:36.440 in the flat region.

00:18:38.520 So say greater than 0.99.

00:18:41.400 And what you get is the following.

00:18:44.360 And this is a Boolean tensor.

00:18:45.760 So in the Boolean tensor, you get a white,

00:18:49.240 if this is true and a black, if this is false.

00:18:52.400 And so basically what we have here is the 32 examples

00:18:55.000 and the 200 hidden neurons.

00:18:57.240 And we see that a lot of this is white.

00:19:00.320 And what that's telling us is that all these

00:19:02.720 ten-h neurons were very, very active

00:19:06.120 and they're in a flat tail.

00:19:08.880 And so in all these cases,

00:19:11.960 the backward gradient would get destroyed.

00:19:15.000 Now we would be in a lot of trouble

00:19:18.520 if for any one of these 200 neurons,

00:19:22.520 if it was the case that the entire column is white.

00:19:26.200 Because in that case we have what's called the dead neuron.

00:19:28.680 And this could be a ten-h neuron

00:19:30.040 where the initialization of the weights and the biases

00:19:32.240 could be such that no single example

00:19:34.400 ever activates this ten-h in the sort of active part

00:19:38.880 of the ten-h.

00:19:39.960 If all the examples land in the tail,

00:19:43.000 then this neuron will never learn.

00:19:44.760 It is a dead neuron.

00:19:46.560 And so just scrutinizing this and looking for columns

00:19:50.360 of completely white, we see that this is not the case.

00:19:54.160 So I don't see a single neuron that is all of white.

00:19:59.400 And so therefore it is the case that for every one of these

00:20:02.200 ten-h neurons, we do have some examples

00:20:05.400 that activate them in the active part of the ten-h.

00:20:08.960 And so some gradients will flow through

00:20:10.560 and this neuron will learn

00:20:12.280 and neuron will change and it will move

00:20:14.320 and it will do something.

00:20:16.360 But you can sometimes get yourself in cases

00:20:18.440 where you have dead neurons.

00:20:20.280 And the way this manifests is that

00:20:22.600 for ten-h neurons this would be

00:20:24.320 when no matter what inputs you plug in from your data set,

00:20:27.200 this ten-h neuron always fires completely one

00:20:29.920 or completely negative one.

00:20:31.320 And then it will just not learn

00:20:33.400 because all the gradients will be just zero that.

00:20:36.680 This is true not just for ten-h,

00:20:37.800 but for a lot of other nonlinearities

00:20:39.640 that people use in neural networks.

00:20:41.160 So we certainly use ten-h a lot,

00:20:43.200 but sigmoid will have the exact same issue

00:20:45.080 because it is a squashing neuron.

00:20:47.480 And so the same will be true for sigmoid,

00:20:50.000 but basically the same will actually apply to sigmoid.

00:20:57.080 The same will also apply to a relu.

00:20:59.120 So relu has a completely flat region here below zero.

00:21:03.440 So if you have a relu neuron,

00:21:04.960 then it is a pass-through if it is positive

00:21:08.640 and if the pre-activation is negative,

00:21:11.040 it will just shut it off.

00:21:12.600 Since the region here is completely flat,

00:21:15.040 then during back propagation,

00:21:17.200 this would be exactly zeroing out the gradient.

00:21:19.880 Like all of the gradients would be set exactly to zero

00:21:22.920 instead of just like a very, very small number

00:21:24.600 depending on how positive or negative t is.

00:21:28.400 And so you can get, for example, a dead relu neuron

00:21:31.480 and a dead relu neuron would basically look like.

00:21:35.280 Basically what it is is if a neuron

00:21:37.520 with a relu nonlinearity never activates,

00:21:41.120 so for any examples that you plug in in the dataset,

00:21:44.000 it never turns on, it's always in this flat region,

00:21:47.400 then this relu neuron is a dead neuron.

00:21:49.480 It's weights and bias will never learn.

00:21:52.080 They will never get a gradient

00:21:53.240 because the neuron never activated.

00:21:55.720 And this can sometimes happen at initialization

00:21:57.960 because the weights and the biases just make it

00:21:59.600 so that by chance some neurons are just forever dead.

00:22:02.840 But it can also happen during optimization.

00:22:04.880 If you have like a too high learning weight, for example,

00:22:07.560 sometimes you have these neurons

00:22:08.840 that get too much of a gradient

00:22:10.360 and they get knocked out off the data manifold.

00:22:13.680 And what happens is that from then on,

00:22:15.600 no example ever activates this neuron.

00:22:17.840 So this neuron remains dead forever.

00:22:19.480 So it's kind of like a permanent brain damage

00:22:21.080 in a mind of a network.

00:22:23.840 And so sometimes what can happen is

00:22:25.400 if your learning rate is very high, for example,

00:22:27.360 and you have a neural net with relu neurons,

00:22:29.680 you train the neural net and you get some last loss.

00:22:32.640 But then actually what you do is

00:22:34.440 you go through the entire training set

00:22:36.320 and you forward your examples

00:22:39.400 and you can find neurons that never activate

00:22:42.080 the are dead neurons in your network.

00:22:44.040 And so those neurons will never turn on.

00:22:46.400 And usually what happens is that during training,

00:22:48.400 these relu neurons are changing, moving, et cetera.

00:22:50.680 And then because of a high gradient somewhere by chance,

00:22:53.080 they get knocked off and then nothing ever activates them.

00:22:56.400 And from then on, they are just dead.

00:22:59.000 So that's kind of like a permanent brain damage

00:23:00.600 that can happen to some of these neurons.

00:23:03.160 These other nonlinearities like leaky relu

00:23:05.440 will not suffer from this issue as much

00:23:07.360 because you can see that it doesn't have flat tails.

00:23:10.560 You'll almost always get gradients.

00:23:12.920 And elu is also fairly frequently used.

00:23:16.480 It also might suffer from this issue

00:23:17.880 because it has flat parts.

00:23:19.840 So that's just something to be aware of

00:23:22.600 and something to be concerned about.

00:23:24.160 And in this case, we have way too many activations H

00:23:28.720 that take on extreme values.

00:23:30.520 And because there's no column of white,

00:23:33.000 I think we will be okay.

00:23:34.360 And indeed the network optimizes

00:23:35.720 and gives us a pretty decent loss.

00:23:37.640 But it's just not optimal.

00:23:38.920 And this is not something you want,

00:23:40.480 especially during initialization.

00:23:42.280 And so basically what's happening is that

00:23:45.240 this H pre-activation that's flowing to 10H,

00:23:48.640 it's too extreme.

00:23:50.000 It's too large.

00:23:51.040 It's creating a very,

00:23:52.760 it's creating a distribution that is too saturated

00:23:55.840 in both sides of the 10H.

00:23:57.280 And it's not something you want

00:23:58.400 because it means that there's less training

00:24:01.280 for these neurons because they update less frequently.

00:24:05.760 So how do we fix this?

00:24:07.280 Well, H pre-activation is MCAT,

00:24:11.240 which comes from C.

00:24:12.720 So these are uniform Gaussian.

00:24:15.000 But then it's multiplied by W1 plus B1.

00:24:17.560 And H pre-act is too far off from zero.

00:24:20.240 And that's causing the issue.

00:24:21.560 So we want this pre-activation to be closer to zero,

00:24:24.760 very similar to what we had with logits.

00:24:27.360 So here we want actually something very, very similar.

00:24:30.600 Now it's okay to set the biases to very small number.

00:24:35.080 We can either multiply by 001

00:24:36.840 to get like a little bit of entropy.

00:24:38.440 I sometimes like to do that

00:24:41.680 just so that there's like a little bit of variation

00:24:45.120 in diversity in the original initialization

00:24:48.120 of these 10H neurons.

00:24:49.440 And I find in practice that that can help

00:24:51.600 optimization a little bit.

00:24:53.720 And then the weights we can also just like squash.

00:24:56.240 So let's multiply everything by 0.1.

00:24:59.240 Let's rerun the first batch.

00:25:01.560 And now let's look at this.

00:25:03.160 And well, first let's look at here.

00:25:05.440 You see now because we multiply it doubly by 0.1,

00:25:09.560 we have a much better histogram.

00:25:11.160 And that's because the pre-activations

00:25:12.560 are now between negative 1.5 and 1.5.

00:25:14.920 And this we expect much, much less white.

00:25:18.120 Okay, there's no white.

00:25:19.880 So basically that's because there are no neurons

00:25:23.760 that saturated above 0.99 in either direction.

00:25:27.960 So it's actually a pretty decent place to be.

00:25:30.200 Maybe we can go up a little bit.

00:25:34.560 Sorry, am I changing W1 here?

00:25:39.280 So maybe we can go to 0.2.

00:25:40.560 Okay, so maybe something like this

00:25:44.280 is a nice distribution.

00:25:46.440 So maybe this is what our initialization should be.

00:25:49.160 So let me now erase these.

00:25:52.480 And let me, starting with initialization,

00:25:56.640 let me run the full optimization without the break.

00:26:00.400 And let's see what we get.

00:26:03.120 Okay, so the optimization finished.

00:26:04.800 And I rerun the loss.

00:26:06.160 And this is the result that we get.

00:26:08.040 And then just as a reminder,

00:26:09.160 I put down all the losses

00:26:10.240 that we saw previously in this lecture.

00:26:12.400 So we see that we actually do get an improvement here.

00:26:15.160 And just as a reminder,

00:26:16.400 we started off with a validation loss of 2.17 when we started.

00:26:20.080 By fixing the softmax being confidently wrong,

00:26:22.440 we came down to 2.13.

00:26:24.080 And by fixing the 10th layer being way too saturated,

00:26:26.640 we came down to 2.10.

00:26:28.800 And the reason this is happening, of course,

00:26:30.320 is because our initialization is better.

00:26:31.920 And so we're spending more time being productive training

00:26:34.640 instead of not very productive training

00:26:38.000 because our gradients are set to zero.

00:26:40.120 And we have to learn very simple things

00:26:42.480 like the overconfidence of the softmax in the beginning.

00:26:45.440 And we're spending cycles

00:26:46.400 just like squashing down the weight matrix.

00:26:49.000 So this is illustrating basically initialization

00:26:53.160 and its impacts on performance

00:26:55.720 just by being aware of the internals of these neural nets

00:26:58.560 and their activations and their gradients.

00:27:00.520 Now, we're working with a very small network.

00:27:02.920 This is just one layer, multi-layer perception.

00:27:05.520 So because the network is so shallow,

00:27:07.760 the optimization problem is actually quite easy.

00:27:10.040 And very forgiving.

00:27:11.520 So even though our initialization was terrible,

00:27:13.520 the network still learned eventually.

00:27:15.440 It just got a bit worse result.

00:27:17.400 This is not the case in general though.

00:27:19.440 Once we actually start working with much deeper networks

00:27:22.800 that have say 50 layers,

00:27:24.720 things can get much more complicated.

00:27:27.120 And these problems stack up.

00:27:30.400 And so you can actually get into a place

00:27:33.040 where the network is basically not training at all

00:27:35.120 if your initialization is bad enough.

00:27:37.320 And the deeper your network is and the more complex it is,

00:27:39.960 the less forgiving it is to some of these errors.

00:27:43.080 And so something to be definitely aware of

00:27:46.480 and something to scrutinize, something to plot

00:27:49.520 and something to be careful with.

00:27:51.080 And yeah.

00:27:53.640 Okay, so that's great that that worked for us.

00:27:55.720 But what we have here now is all these magic numbers

00:27:58.280 like point two, like where do I come up with this?

00:28:00.600 And how am I supposed to set these

00:28:02.040 if I have a large neural net with lots and lots of layers?

00:28:05.280 And so obviously no one does this by hand.

00:28:07.560 There's actually some relatively principled ways

00:28:09.560 of setting these scales that I would like to introduce to you now.

00:28:14.040 So let me paste some code here that I prepared

00:28:16.480 just to motivate the discussion of this.

00:28:18.480 So what I'm doing here is we have some random input here,

00:28:22.760 X that is drawn from a Gaussian.

00:28:25.320 And there's 1000 examples that are 10 dimensional.

00:28:28.680 And then we have a weight in layer here

00:28:30.680 that is also initialized using Gaussian,

00:28:33.080 just like we did here.

00:28:34.760 And these neurons in the head and layer look at 10 inputs

00:28:38.520 and there are 200 neurons in this head layer.

00:28:41.640 And then we have here, just like here,

00:28:44.280 in this case, the multiplication, X multiplied by W

00:28:47.160 to get the pre-activations of these neurons.

00:28:50.840 And basically the analysis here looks at,

00:28:53.160 okay, suppose these are uniform Gaussian

00:28:55.240 and these weights are uniform Gaussian.

00:28:57.160 If I do X times W and we forget for now the bias

00:29:00.760 and the nonlinearity, then what is the mean

00:29:04.440 and the standard deviation of these Gaussians?

00:29:06.960 So in the beginning here, the input is just

00:29:09.600 a normal Gaussian distribution, mean zero

00:29:11.840 and the standard deviation is one.

00:29:13.600 And the standard deviation again is just the measure

00:29:15.440 of a spread of the Gaussian.

00:29:17.200 But then once we multiply here and we look at the

00:29:21.200 histogram of Y, we see that the mean, of course,

00:29:24.760 stays the same, it's about zero,

00:29:27.040 because this is a symmetric operation.

00:29:28.880 But we see here that the standard deviation

00:29:30.440 has expanded to three.

00:29:32.560 So the input standard deviation was one,

00:29:34.280 but now it would grown to three.

00:29:36.400 And so what you're seeing in the histogram

00:29:37.640 is that this Gaussian is expanding.

00:29:39.800 And so we're expanding this Gaussian from the input.

00:29:45.720 And we don't want that, we want most of the neural nets

00:29:47.800 to have relatively similar activations.

00:29:50.520 So unit Gaussian roughly throughout the neural net.

00:29:53.960 As the question is, how do we scale these W's

00:29:56.240 to preserve the distribution to remain a Gaussian?

00:30:03.720 And so intuitively, if I multiply here

00:30:06.720 these elements of W by a large number, let's say by five,

00:30:10.440 then this Gaussian grows and grows in standard deviation.

00:30:15.920 So now we're at 15.

00:30:17.360 So basically these numbers here in the output Y

00:30:20.160 take on more and more extreme values.

00:30:22.560 But if we scale it down, like say point two,

00:30:25.200 then conversely, this Gaussian is getting smaller

00:30:28.800 and smaller and it's shrinking.

00:30:31.080 And you can see that the standard deviation is 0.6.

00:30:33.920 And so the question is what do I multiply by here

00:30:36.640 to exactly preserve the standard deviation to be one?

00:30:40.960 And it turns out that the correct answer mathematically

00:30:42.840 when you work out through the variance

00:30:44.840 of this multiplication here is that you are supposed

00:30:48.560 to divide by the square root of the fan in.

00:30:52.840 The fan in is basically the number of input

00:30:56.520 clovents here, 10.

00:30:58.040 So we are supposed to divide by 10 square root.

00:31:00.800 And this is one way to do the square root.

00:31:02.480 You raise it to a power of 0.5.

00:31:04.440 That's the same as doing a square root.

00:31:07.160 So when you divide by the square root of 10,

00:31:10.520 then we see that the output Gaussian,

00:31:14.080 it has exactly standard deviation of Y.

00:31:17.440 Now unsurprisingly, a number of papers have looked

00:31:19.760 into how but to best initialize neural networks.

00:31:23.360 And in the case of multiple perceptions,

00:31:25.240 we can have fairly deep networks

00:31:26.880 that have these nonlinearities in between.

00:31:29.120 And we want to make sure that the activations are well behaved

00:31:31.840 and they don't expand to infinity or shrink all the way to 0.

00:31:35.240 And the question is how do we initialize the weights

00:31:37.000 so that these activations take on reasonable values

00:31:39.280 throughout the network.

00:31:40.920 Now, one paper that has studied this in quite a bit of detail

00:31:43.720 that is often referenced is this paper by Kamin-Het Hall

00:31:47.000 called Delving Deep Interactive Fires.

00:31:49.360 Now in this case, they actually study convolutional neural networks

00:31:52.360 and they study especially the relu nonlinearity

00:31:56.040 and the p relu nonlinearity

00:31:57.880 instead of a 10-H nonlinearity.

00:31:59.840 But the analysis is very similar

00:32:01.600 and basically what happens here is for them,

00:32:06.000 the relu nonlinearity that they care about quite a bit here

00:32:09.280 is a squashing function where all the negative numbers

00:32:13.280 are simply clamped to 0.

00:32:15.840 So the positive numbers are a path through

00:32:17.760 but everything negative is just set to 0.

00:32:20.520 And because you are basically throwing away

00:32:23.000 half of the distribution,

00:32:24.640 they find in their analysis of the forward activations

00:32:27.440 in the neural net that you have to compensate for that

00:32:29.800 with a gain.

00:32:30.800 And so here, they find that basically when they initialize

00:32:36.680 their weights, they have to do it with a zero-ming Gaussian

00:32:39.200 whose standard deviation is square root of two over the fanon.

00:32:43.440 What we have here is we are initializing the Gaussian

00:32:46.160 with the square root of fanon.

00:32:49.000 This NL here is the fanon.

00:32:50.520 So what we have is square root of one over the fanon

00:32:55.400 because we have the division here.

00:32:58.080 Now, they have to add this factor of two

00:33:00.040 because of the relu which basically discards

00:33:02.680 half of the distribution and clamps it at zero.

00:33:05.480 And so that's where you get an initial factor.

00:33:07.960 Now, in addition to that, this paper also studies

00:33:10.680 not just the sort of behavior of the activations

00:33:13.520 in the forward pass of the neural net,

00:33:15.240 but it also studies the back propagation.

00:33:17.720 And we have to make sure that the gradients also

00:33:19.720 are well-behaved.

00:33:20.920 And so, because ultimately they end up updating our parameters.

00:33:25.800 And what they find here through a lot of the analysis

00:33:28.280 that I invite you to read through,

00:33:29.520 but it's not exactly approachable.

00:33:31.880 What they find is basically,

00:33:33.880 if you properly initialize the forward pass,

00:33:36.120 the backward pass is also approximately initialized

00:33:39.280 up to a constant factor that has to do with the size

00:33:42.640 of the number of hidden neurons in an early and late layer.

00:33:47.640 And, but basically they find empirically

00:33:51.080 that this is not a choice that matters too much.

00:33:54.000 Now, this timing initialization is also implemented

00:33:57.080 in PyTorch.

00:33:58.080 So if you go to torch.nn.init documentation,

00:34:00.600 you'll find timing normal.

00:34:02.520 And in my opinion, this is probably the most common way

00:34:05.040 of initializing neural networks now.

00:34:07.480 And it takes a few keyword arguments here.

00:34:09.800 So number one, it wants to know the mode.

00:34:12.920 Would you like to normalize the activations

00:34:14.800 or would you like to normalize the gradients

00:34:16.560 to be always Gaussian with zero mean

00:34:19.640 and a unit or one standard deviation?

00:34:22.640 And because they find the paper

00:34:23.800 that this doesn't matter too much,

00:34:25.120 most of the people just leave it as the default,

00:34:26.880 which is pen in.

00:34:28.120 And then second, pass in the nonlinearity that you are using

00:34:31.440 because depending on the nonlinearity,

00:34:33.520 we need to calculate a slightly different gain.

00:34:35.920 And so if your nonlinearity is just linear,

00:34:39.280 so there's no nonlinearity,

00:34:40.840 then the gain here will be one

00:34:42.480 and we have the exact same kind of formula

00:34:44.800 that we've got up here.

00:34:46.360 But if the nonlinearity is something else,

00:34:47.920 we're going to get a slightly different gain.

00:34:49.840 And so if we come up here to the top,

00:34:52.160 we see that for example in the case of Ralu,

00:34:54.320 this gain is a square root of two.

00:34:56.360 And the reason it's a square root because in this paper,

00:34:59.280 you see how the two is inside of the square root.

00:35:06.120 So the gain is a square root of two.

00:35:08.240 In a case of linear or identity,

00:35:11.920 we just get a gain of one.

00:35:13.840 In a case of 10H, which is what we're using here,

00:35:16.200 the advised gain is a five over three.

00:35:18.920 And intuitively, why do we need a gain

00:35:21.040 on top of the initialization?

00:35:22.680 It's because 10H, just like Ralu,

00:35:24.720 is a contractive transformation.

00:35:27.480 So what that means is you're taking the output distribution

00:35:29.920 from this matrix multiplication,

00:35:31.520 and then you are squashing it in some way.

00:35:33.680 Now Ralu squashes it by taking everything below zero

00:35:36.200 and clamping it to zero.

00:35:37.800 10H also squashes it because it's a contractive operation.

00:35:40.600 It will take the tails and it will squeeze them in.

00:35:44.240 And so in order to fight the squeezing in,

00:35:46.680 we need to boost the weights a little bit

00:35:49.040 so that we renormalize everything back to standard,

00:35:51.560 unit standard deviation.

00:35:53.440 So that's why there's a little bit of a gain that comes out.

00:35:56.600 Now I'm skipping through this section a little bit quickly

00:35:58.800 and I'm doing that actually intentionally.

00:36:00.960 And the reason for that is because

00:36:03.200 about seven years ago when this paper was written,

00:36:06.080 you have to actually be extremely careful

00:36:07.800 with the activations and ingredients

00:36:09.520 and their ranges and their histograms.

00:36:11.800 And you have to be very careful

00:36:12.760 with the precise setting of gains

00:36:14.280 and the scrutinizing of the null and your T's used and so on.

00:36:17.040 And everything was very finicky and very fragile

00:36:19.560 and very properly arranged for the neural network train,

00:36:22.560 especially if your neural network was very deep.

00:36:24.880 But there are a number of modern innovations

00:36:26.320 that have made everything significantly more stable

00:36:28.320 and more well behaved.

00:36:29.640 And it's become less important to initialize

00:36:31.480 these networks exactly right.

00:36:33.960 And some of those modern innovations for example

00:36:35.760 are residual connections which we will cover in the future.

00:36:39.120 The use of a number of normalization layers,

00:36:42.840 like for example, batch normalization,

00:36:44.920 layer normalization, group normalization,

00:36:47.080 we're going to go into a lot of these as well.

00:36:48.920 And number three, much better optimizers.

00:36:50.880 Not just a cast ingredient in the scent,

00:36:52.440 the simple optimizer we're basically using here,

00:36:55.160 but slightly more complex optimizers

00:36:57.160 like RMS prop and especially Adam.

00:36:59.640 And so all of these modern innovations

00:37:01.080 make it less important for you to precisely calibrate

00:37:04.040 the initialization of the neural net.

00:37:06.080 All that being said in practice, what should we do?

00:37:09.680 In practice when I initialize these neural nets,

00:37:11.720 I basically just normalize my weights

00:37:13.840 by the square root of the fan in.

00:37:15.880 So basically, roughly what we did here is what I do.

00:37:20.800 Now if we want to be exactly accurate here,

00:37:22.960 we can go back in it of kind of normal,

00:37:27.720 this is how good implemented.

00:37:29.680 We want to set the standard deviation

00:37:31.120 to be gain over the square root of fan in.

00:37:34.040 So to set the standard deviation of our weights,

00:37:39.000 we will proceed as follows.

00:37:41.320 Basically when we have a torch dot random

00:37:43.720 and let's say I just create a thousand numbers,

00:37:46.000 we can look at the standard deviation of this.

00:37:47.400 And of course that's one, that's the amount of spread.

00:37:50.080 Let's make this a bit bigger so it's closer to one.

00:37:52.440 So that's the spread of the Gaussian of zero mean

00:37:55.960 and unit standard deviation.

00:37:58.120 Now basically when you take these

00:37:59.720 and you multiply by say point two,

00:38:02.440 that basically scales down the Gaussian

00:38:04.480 and that makes its standard deviation point two.

00:38:07.080 So basically the number that you multiply by here

00:38:09.040 ends up being the standard deviation of this Gaussian.

00:38:12.200 So here this is a standard deviation point two Gaussian here

00:38:17.000 when we sample our w one.

00:38:19.280 But we want to set the standard deviation

00:38:20.960 to gain over square root of fan mode, which is fan in.

00:38:25.160 So in other words, we want to multiply by gain

00:38:29.720 which for 10 h is five over three,

00:38:33.080 five over three is the gain.

00:38:35.800 And then times,

00:38:38.960 (keyboard clicking)

00:38:41.960 I guess I'll divide

00:38:45.760 square root of the fan in.

00:38:51.320 And in this example here, the fan in was 10.

00:38:53.840 And I just noticed actually here,

00:38:55.880 the fan in for w one is actually an embed times block size,

00:38:59.640 which as you all recall is actually 30.

00:39:01.840 And that's because each character is 10 dimensional,

00:39:03.960 but then we have three of them and we concatenate them.

00:39:06.000 So actually the fan in here was 30.

00:39:08.000 And I should have used 30 here probably.

00:39:10.200 But basically we want 30 square root.

00:39:13.320 So this is the number.

00:39:14.560 This is what our standard deviation we want to be.

00:39:17.120 And this number turns out to be point three.

00:39:19.640 Whereas here just by fiddling with it

00:39:21.440 and looking at the distribution

00:39:22.520 and making sure it looks okay, we came up with point two.

00:39:25.960 And so instead what we want to do here

00:39:27.920 is we want to make the standard deviation be

00:39:33.320 five over three, which is our gain, divide.

00:39:35.640 This amount times point two, square root.

00:39:41.200 And these brackets here are not that necessary,

00:39:44.240 but I'll just put them here for clarity.

00:39:46.120 This is basically what we want.

00:39:47.520 This is the chiming in it in our case

00:39:50.160 for a 10-H nonlinearity.

00:39:52.200 And this is how we would initialize the neural mat.

00:39:54.720 And so we're multiplying by point three

00:39:58.080 instead of multiplying by point two.

00:40:01.000 And so we can initialize this way

00:40:05.160 and then we can train the neural mat and see what we got.

00:40:08.080 Okay, so I trained the neural mat

00:40:09.240 and we end up in roughly the same spot.

00:40:12.240 So looking at the valuation loss,

00:40:13.560 we now get 2.10.

00:40:15.240 And previously we also had 2.10.

00:40:17.520 There's a little bit of a difference,

00:40:18.720 but that's just the randomness, the process I suspect.

00:40:21.520 But the big deal of course is we get to the same spot,

00:40:24.400 but we did not have to introduce any magic numbers

00:40:29.000 that we got from just looking at histograms

00:40:31.200 and guessing checking.

00:40:32.480 We have something that is semi-principled

00:40:34.040 and will scale us to much bigger networks

00:40:37.160 and something that we can sort of use as a guide.

00:40:40.160 So I mentioned that the precise setting

00:40:41.640 of these initializations is not as important today

00:40:44.480 due to some modern innovations.

00:40:46.040 And I think now is a pretty good time

00:40:47.240 to introduce one of those modern innovations

00:40:49.300 and that is batch normalization.

00:40:51.280 So batch normalization came out in 2015

00:40:54.400 from a team at Google.

00:40:55.880 And it was an extremely impactful paper

00:40:57.840 because it made it possible to train very deep neural nets

00:41:01.520 quite reliably and basically just worked.

00:41:05.160 So here's what batch normalization does

00:41:06.480 and what's implemented.

00:41:07.480 Basically we have these hidden states, HP, right?

00:41:13.720 And we were talking about how we don't want

00:41:15.520 these pre-activation states to be way too small

00:41:20.440 because then the 10H is not doing anything.

00:41:23.920 But we don't want them to be too large

00:41:25.200 because then the 10H is saturated.

00:41:27.520 In fact, we want them to be roughly, roughly Gaussian.

00:41:30.480 So zero mean and a unit or a one standard deviation,

00:41:34.080 at least initialization.

00:41:36.040 So the insight from the batch normalization paper is,

00:41:39.440 okay, you have these hidden states

00:41:41.160 and you'd like them to be roughly Gaussian.

00:41:43.600 Then why not take the hidden states

00:41:45.560 and just normalize them to be Gaussian?

00:41:47.840 And it sounds kind of crazy, but you can just do that

00:41:51.240 because standardizing hidden states

00:41:55.280 so that their unit Gaussian

00:41:56.800 is a perfectly differentiable operation as we'll see.

00:41:59.680 And so that was kind of like the big insight in this paper.

00:42:02.160 And when I first read it, my mind was blown

00:42:04.520 because you can just normalize these hidden states

00:42:06.280 and if you'd like unit Gaussian states in your network,

00:42:09.840 at least initialization, you can just normalize them

00:42:12.720 to be unit Gaussian.

00:42:14.240 So let's see how that works.

00:42:16.520 So we're going to scroll to our pre-activations here

00:42:18.800 just before they enter into the 10H.

00:42:21.480 Now the idea again is remember,

00:42:22.720 we're trying to make these roughly Gaussian

00:42:24.960 and that's because if these are way too small numbers,

00:42:27.320 then the 10H here is kind of inactive.

00:42:30.400 But if these are very large numbers,

00:42:32.920 then the 10H is way to saturate it and gradiate it in the flow.

00:42:36.520 So we'd like this to be roughly Gaussian.

00:42:39.160 So the insight in pasteurimization again

00:42:41.520 is that we can just standardize these activations.

00:42:44.280 So they are exactly Gaussian.

00:42:46.880 So here H pre-act has a shape of 32 by 200,

00:42:52.040 32 examples by 200 neurons in the end layer.

00:42:56.040 So basically what we can do is we can take H pre-act

00:42:58.280 and we can just calculate the mean

00:43:00.200 and the mean we want to calculate across the zero dimension

00:43:05.400 and we want to also keep them as true

00:43:08.040 so that we can easily broadcast this.

00:43:11.640 So the shape of this is one by 200.

00:43:14.800 In other words, we are doing the mean

00:43:16.720 over all the elements in the batch.

00:43:20.880 And similarly, we can calculate the standard deviation

00:43:24.000 of these activations.

00:43:25.280 And that will also be one by 200.

00:43:29.480 Now in this paper, they have the sort of prescription here.

00:43:34.480 And see here we are calculating the mean,

00:43:36.880 which is just taking the average value

00:43:41.120 of any neurons activation.

00:43:43.840 And then the standard deviation is basically kind of like

00:43:46.440 the measure of the spread that we've been using,

00:43:50.200 which is the distance of every one of these values

00:43:53.840 away from the mean and that squared and averaged.

00:43:57.760 That's the variance.

00:44:01.600 And then if you want to take the standard deviation,

00:44:03.240 you will square root the variance

00:44:05.360 to get the standard deviation.

00:44:06.760 So these are the two that we're calculating.

00:44:10.080 And now we're going to normalize or standardize

00:44:12.640 these Xs by subtracting the mean

00:44:14.320 and dividing by the standard deviation.

00:44:17.840 So basically we're taking H preact

00:44:20.680 and we subtract the mean.

00:44:23.920 And then we divide by the standard deviation.

00:44:32.640 This is exactly what these two STD and mean are calculating.

00:44:38.440 Oops, sorry.

00:44:41.000 This is the mean and this is the variance.

00:44:43.080 You see how the sigma is the standard deviation usually.

00:44:45.480 So this is sigma squared,

00:44:46.800 which is the variance is the square of the standard deviation.

00:44:49.440 So this is how you standardize these values.

00:44:53.160 And what this will do is that every single neuron now

00:44:55.880 and its firing rate will be exactly unit Gaussian

00:44:58.880 on these 32 examples at least of this batch.

00:45:01.760 That's why it's called batch normalization.

00:45:03.360 We are normalizing these batches.

00:45:05.640 And then we could in principle train this.

00:45:09.560 Notice that calculating the mean and their standard deviation,

00:45:12.320 these are just mathematical formulas.

00:45:13.720 They're perfectly differentiable.

00:45:15.240 All this is perfectly differentiable

00:45:16.720 and we can just train this.

00:45:18.800 The problem is you actually won't achieve

00:45:20.800 a very good result with this.

00:45:23.080 And the reason for that is

00:45:24.680 we want these to be roughly Gaussian,

00:45:27.480 but only at initialization.

00:45:29.720 But we don't want these to be forced to be Gaussian always.

00:45:34.120 We'd like to allow the neural net to move this around

00:45:37.520 to potentially make it more diffuse,

00:45:39.280 to make it more sharp,

00:45:40.560 to make some 10-inch neurons maybe more trigger happy

00:45:44.000 or less trigger happy.

00:45:45.440 So we'd like this distribution to move around

00:45:47.560 and we'd like the back propagation to tell us

00:45:49.640 how the distribution should move around.

00:45:52.360 And so in addition to this idea of standardizing

00:45:55.800 the activations at any point in the network,

00:45:59.320 we have to also introduce this additional component

00:46:01.600 in the paper here described as scale and shift.

00:46:05.880 And so basically what we're doing is

00:46:06.960 we're taking these normalized inputs

00:46:09.240 and we are additionally scaling them by some gain

00:46:12.240 and offsetting them by some bias

00:46:14.320 to get our final output from this layer.

00:46:17.800 And so what that amounts to is the following.

00:46:20.440 We are going to allow a batch-formalization gain

00:46:23.880 to be initialized at just once

00:46:27.520 and the once will be in the shape of one by n hidden.

00:46:30.760 And then we also will have a bn bias,

00:46:35.280 which will be torched at zeros.

00:46:37.760 And it will also be of the shape n by one by n hidden.

00:46:42.280 And then here the bn gain will multiply this

00:46:46.360 and the bn bias will offset it here.

00:46:49.760 So because this is initialized to one and this to zero,

00:46:53.920 at initialization, each neuron's firing values

00:46:58.840 in this batch will be exactly unit Gaussian

00:47:01.960 and will have nice numbers,

00:47:03.560 no matter what the distribution of the H-preact is coming in,

00:47:07.160 coming out it will be unit Gaussian for each neuron.

00:47:09.760 And that's roughly what we want, at least at initialization.

00:47:12.560 And then during optimization,

00:47:15.440 we'll be able to back propagate to bn gain

00:47:17.440 and bn bias and change them.

00:47:19.480 So the network is given the full ability

00:47:21.200 to do with this whatever it wants internally.

00:47:24.840 Here we just have to make sure that we include these

00:47:30.720 in the parameters of the neural nut

00:47:32.160 because they will be trained with back propagation.

00:47:34.760 So let's initialize this.

00:47:37.520 And then we should be able to train.

00:47:39.480 And then we're going to also copy this line,

00:47:49.720 which is the best normalization layer

00:47:51.920 here on a single line of code.

00:47:53.680 And we're going to swing down here

00:47:54.880 and we're also going to do the exact same thing

00:47:57.440 at test time here.

00:47:58.400 So similar to train time, we're going to normalize

00:48:05.000 and then scale and that's going to give us

00:48:07.560 our train and validation loss.

00:48:09.320 And we'll see in a second that we're actually going

00:48:12.160 to change this a little bit,

00:48:13.040 but for now I'm going to keep it this way.

00:48:15.600 So I'm just going to wait for this to converge.

00:48:17.640 Okay, so I allowed the neural nut to converge here.

00:48:20.040 And when we scroll down, we see that our validation loss here

00:48:22.600 is 2.10 roughly, which I wrote down here.

00:48:26.320 And we see that this is actually kind of comparable

00:48:28.160 to some of the results that we've achieved previously.

00:48:31.200 Now, I'm not actually expecting an improvement in this case.

00:48:34.800 And that's because we are dealing with a very simple neural nut

00:48:37.000 that has just a single hidden layer.

00:48:39.360 So in fact, in this very simple case of just one hidden layer,

00:48:43.040 we were able to actually calculate what the scale of W

00:48:45.480 should be to make these pre-activations already

00:48:48.680 have a roughly Gaussian shape.

00:48:50.240 So the bastardization is not doing much here.

00:48:53.240 But you might imagine that once you have a much deeper

00:48:55.520 neural nut that has lots of different types of operations,

00:48:59.000 and there's also, for example, residual connections

00:49:01.000 which we'll cover and so on,

00:49:02.800 it will become basically very, very difficult

00:49:05.480 to tune the scales of your weight matrices

00:49:08.800 such that all the activations throughout the neural nut

00:49:11.040 are roughly Gaussian.

00:49:12.840 And so that's going to become very quickly intractable.

00:49:15.920 But compared to that, it's going to be much, much easier

00:49:18.720 to sprinkle batch normalization layers

00:49:20.600 throughout the neural nut.

00:49:22.120 So in particular, it's common to look at every single

00:49:25.480 linear layer like this one.

00:49:26.960 This is a linear layer multiplied by weight matrices

00:49:29.040 and adding the bias.

00:49:30.840 Or, for example, convolutions,

00:49:32.280 which we'll cover later and also perform basically

00:49:34.960 a multiplication with weight matrix,

00:49:37.440 but in a more spatially structured format.

00:49:39.880 It's customary to take these linear layer

00:49:42.400 or convolutional layer and append a batch normalization layer

00:49:46.360 right after it to control the scale of these activations

00:49:50.000 at every point in the neural nut.

00:49:51.800 So we'd be adding these batch room layers

00:49:53.440 throughout the neural nut, and then this controls

00:49:55.760 the scale of these activations throughout the neural nut.

00:49:58.560 It doesn't require us to do perfect mathematics

00:50:01.720 and care about the activation distributions

00:50:04.080 for all these different types of neural

00:50:06.520 lego building blocks that you might want to introduce

00:50:08.120 into your neural nut.

00:50:09.320 And it significantly stabilizes the training.

00:50:12.200 And that's why these layers are quite popular.

00:50:14.840 Now, the stability offered by batch normalization

00:50:16.840 actually comes at a terrible cost.

00:50:18.880 And that cost is that if you think about what's happening here,

00:50:22.320 something terribly strange and unnatural is happening.

00:50:26.440 It used to be that we have a single example

00:50:28.880 feeding into a neural nut,

00:50:30.320 and then we calculate this activations and its logits.

00:50:34.200 And this is a deterministic sort of process.

00:50:37.360 So you arrive at some logits for this example.

00:50:40.040 And then because of efficiency of training,

00:50:42.320 we suddenly started to use batches of examples.

00:50:44.800 But those batches of examples were processed independently,

00:50:47.520 and it was just an efficiency thing.

00:50:49.840 But now suddenly in batch normalization,

00:50:51.520 because of the normalization through the batch,

00:50:53.600 we are coupling these examples mathematically

00:50:56.520 and in the forward pass and the backward pass

00:50:58.080 of the neural nut.

00:50:59.440 So now the hidden state activations, HP Act,

00:51:02.720 and your logits for any one input example

00:51:05.520 are not just a function of that example and its input,

00:51:08.320 but they're also a function of all the other examples

00:51:10.840 that happen to come for a ride in that batch.

00:51:14.400 And these examples are sampled randomly.

00:51:16.400 And so what's happening is, for example,

00:51:17.760 when you look at HP Act that's going to feed into H,

00:51:20.600 the hidden state activations, for example,

00:51:22.920 for any one of these input examples,

00:51:25.320 is going to actually change slightly,

00:51:27.680 depending on what other examples there are

00:51:29.120 in the batch.

00:51:30.320 And depending on what other examples happen to come for a ride,

00:51:34.080 H is going to change suddenly and is going to like jitter,

00:51:37.400 if you imagine sampling different examples,

00:51:39.520 because the statistics of the mean and standard deviation

00:51:42.120 are going to be impacted.

00:51:44.000 And so you'll get a jitter for H,

00:51:45.680 and you'll get a jitter for logits.

00:51:48.560 And you think that this would be a bug or something undesirable,

00:51:52.440 but in a very strange way,

00:51:54.360 this actually turns out to be good in neural network training.

00:51:57.960 And as a side effect.

00:51:59.640 And the reason for that is that you can think of this

00:52:01.720 as kind of like a regularizer.

00:52:03.920 Because what's happening is you have your input and you get your H.

00:52:06.400 And then depending on the other examples,

00:52:08.320 this is jittering a bit.

00:52:09.880 And so what that does is that it's effectively padding out

00:52:12.720 any one of these input examples.

00:52:14.280 And it's introducing a little bit of entropy.

00:52:16.400 And because of the padding out,

00:52:18.680 it's actually kind of like a form of a data augmentation,

00:52:21.320 which we'll cover in the future.

00:52:22.960 And it's kind of like augmenting the input a little bit

00:52:25.760 and it's jittering it.

00:52:26.840 And that makes it harder for the neural nets to overfit

00:52:29.360 these concrete specific examples.

00:52:32.000 So by introducing all this noise,

00:52:33.800 it actually like pads out the examples

00:52:35.800 and it regularizes the neural net.

00:52:37.760 And that's one of the reasons why,

00:52:40.080 deceivingly as a second order effect,

00:52:42.160 this is actually a regularizer.

00:52:43.720 And that has made it harder for us

00:52:45.720 to remove the use of batch normalization.

00:52:48.760 Because basically no one likes this property

00:52:50.360 that the examples in the batch are coupled mathematically

00:52:54.200 and in the forward pass.

00:52:55.720 And at least all kinds of like strange results

00:52:58.760 will go into some of that in a second as well.

00:53:01.800 And it leads to a lot of bugs and so on.

00:53:04.880 And so no one likes this property.

00:53:07.040 And so people have tried to deprecate the use

00:53:10.240 of batch normalization and move to other normalization

00:53:12.400 techniques that do not couple the examples of a batch.

00:53:14.840 Examples are layer normalization,

00:53:16.880 instance normalization, group normalization, and so on.

00:53:20.040 And we'll cover some of these later.

00:53:22.320 But basically long story short,

00:53:25.480 batch normalization was the first kind of normalization layer

00:53:28.160 to be introduced.

00:53:29.160 It worked extremely well.

00:53:30.920 It happened to have this regularizing effect.

00:53:33.480 It stabilized training.

00:53:35.920 And people have been trying to remove it

00:53:38.200 and move to some of the other normalization techniques.

00:53:40.920 But it's been hard because it just works quite well.

00:53:44.320 And some of the reason that it works quite well

00:53:46.240 is again because of this regularizing effect

00:53:48.160 and because it is quite effective

00:53:50.600 at controlling the activations and their distributions.

00:53:54.520 So that's kind of like the brief story of batch normalization.

00:53:57.480 And I'd like to show you one of the other weird sort

00:54:00.920 of outcomes of this coupling.

00:54:03.720 So here's one of the strange outcomes

00:54:05.040 that I only glossed over previously.

00:54:07.680 When I was evaluating the loss on the validation set.

00:54:10.840 Basically, once we've trained a neural net,

00:54:13.280 we'd like to deploy it in some kind of a setting.

00:54:15.920 And we'd like to be able to feed in a single

00:54:17.560 individual example and get a prediction out

00:54:19.800 from our neural net.

00:54:21.440 But how do we do that when our neural net now

00:54:23.440 in a forward pass estimates the statistics

00:54:25.880 of the mean and standard deviation of a batch,

00:54:27.960 the neural net expects batches as an input now.

00:54:30.520 So how do we feed in a single example

00:54:32.120 and get sensible results out?

00:54:34.400 And so the proposal in the batch normalization paper

00:54:37.400 is the following.

00:54:38.960 What we would like to do here is we would like to basically

00:54:41.800 have a step after training that calculates and sets

00:54:47.280 the batch term mean and standard deviation

00:54:49.560 a single time over the training set.

00:54:52.280 And so I wrote this code here in interest of time.

00:54:55.360 And we're going to call what's called calibrate

00:54:57.240 the batch room statistics.

00:54:59.160 And basically what we do is torch.not.torch.not.not grad

00:55:02.560 telling PyTorch that none of this,

00:55:04.640 we will call the doc backward on

00:55:06.600 and it's going to be a bit more efficient.

00:55:08.920 We're going to take the training set,

00:55:10.720 get the pre-activations for every single training example.

00:55:13.640 And then one single time estimate the mean and standard

00:55:15.760 deviation over the entire training set.

00:55:18.240 And then we're going to get B and mean

00:55:19.560 and B and standard deviation.

00:55:21.000 And now these are fixed numbers

00:55:22.920 estimating of the entire training set.

00:55:25.280 And here instead of estimating it dynamically,

00:55:28.800 we are going to instead here use B and mean.

00:55:33.440 And here we're just going to use B and standard deviation.

00:55:37.160 And so at test time, we are going to fix these,

00:55:40.720 clamp them and use them during inference.

00:55:43.200 And now you see that we get basically identical result.

00:55:48.200 But the benefit that we've gained

00:55:50.760 is that we can now also forward a single example

00:55:53.280 because the mean and standard deviation

00:55:54.600 are now fixed sort of tensors.

00:55:57.440 That said, nobody actually wants to estimate

00:55:59.440 this mean and standard deviation as a second stage

00:56:02.480 after neural network training

00:56:04.280 because everyone is lazy.

00:56:05.680 And so this batch normalization paper

00:56:07.800 actually introduced one more idea,

00:56:09.600 which is that we can estimate the mean and standard deviation

00:56:12.680 in a running manner, running manner

00:56:15.120 during training of the neural network.

00:56:17.040 And then we can simply just have a single stage of training

00:56:20.120 and on the side of that training,

00:56:21.720 we are estimating the running mean and standard deviation.

00:56:24.560 So let's see what that would look like.

00:56:26.720 Let me basically take the mean here

00:56:28.760 that we are estimating on the batch.

00:56:30.120 And let me call this B and mean on the I iteration.

00:56:33.320 And then here this is B and STD.

00:56:38.040 B and STD at I.

00:56:43.080 Okay.

00:56:47.120 And the mean comes here and the STD comes here.

00:56:52.120 So far I've done nothing.

00:56:54.200 I've just moved around and I created these extra variables

00:56:56.880 for the mean and standard deviation.

00:56:58.480 And I've put them here.

00:56:59.760 So far nothing has changed.

00:57:01.720 But what we're going to do now is we're going to keep

00:57:03.520 a running mean of both of these values during training.

00:57:06.760 So let me swing up here and let me create a B and mean

00:57:10.000 underscore running.

00:57:11.560 And I'm going to initialize it at zeros.

00:57:16.040 And then B and STD running,

00:57:18.720 which are initialized at once.

00:57:21.080 Because in the beginning,

00:57:25.400 because of the way we initialized W1 and B1,

00:57:29.200 H preact will be roughly unit Gaussian.

00:57:31.160 So the mean will be roughly zero

00:57:32.560 and the standard deviation roughly one.

00:57:34.520 So I'm going to initialize these that way.

00:57:37.160 But then here I'm going to update these.

00:57:39.440 And in PyTorch,

00:57:40.400 these mean and standard deviation that are running,

00:57:45.400 they're not actually part of the gradient based optimization.

00:57:47.760 We're never going to derive gradients with respect to them.

00:57:50.240 They're updated on the side of training.

00:57:53.640 And so what we're going to do here is we're going to say

00:57:55.960 with Torched up no grad,

00:57:58.080 telling PyTorch that the update here is not supposed

00:58:01.840 to be building out a graph

00:58:03.040 because there will be no doubt backward.

00:58:05.400 But this running mean is basically going to be 0.99

00:58:09.600 times the current value.

00:58:13.600 Plus 0.001 times the, this value, this new mean.

00:58:18.600 And in the same way, BNSDD running will be

00:58:24.480 mostly what it used to be.

00:58:26.000 But it will receive a small update in the direction

00:58:31.040 of what the current standard deviation is.

00:58:33.320 And as you're seeing here,

00:58:36.120 this update is outside and on the side

00:58:38.640 of the gradient based optimization.

00:58:41.280 And it's simply being updated not using gradient descent.

00:58:43.840 It's just being updated using a janky like smooth,

00:58:47.480 sort of running mean manner.

00:58:52.280 And so while the network is training

00:58:55.400 and these pre-activations are sort of changing

00:58:57.720 and shifting around during back propagation,

00:59:00.640 we are keeping track of the typical mean

00:59:02.440 and standard deviation and estimating them once.

00:59:05.400 And when I run this,

00:59:09.360 now I'm keeping track of this in a running manner.

00:59:12.000 And what we're hoping for, of course,

00:59:13.320 is that the mean, mean_running and mean_std

00:59:17.560 are going to be very similar to the ones

00:59:20.040 that we've calculated here before.

00:59:22.280 And that way we don't need a second stage

00:59:24.640 because we've sort of combined the two stages

00:59:26.840 and we've put them on the side of each other

00:59:28.640 if you wanna look at it that way.

00:59:30.600 And this is how this is also implemented

00:59:32.320 in the batch normalization layer in PyTorch.

00:59:35.000 So during training, the exact same thing will happen.

00:59:38.960 And then later when you're using inference,

00:59:41.120 it will use the estimated running mean

00:59:43.520 of both the mean and standard deviation

00:59:45.800 of those hidden states.

00:59:46.880 So let's wait for the optimization to converge.

00:59:50.160 And hopefully the running mean and standard deviation

00:59:52.320 are roughly equal to these two.

00:59:53.880 And then we can simply use it here

00:59:55.840 and we don't need this stage

00:59:57.400 of explicit calibration at the end.

00:59:59.200 Okay, so the optimization finished.

01:00:01.320 I'll rerun the explicit estimation

01:00:03.840 and then the B and mean from the explicit estimation is here.

01:00:07.760 And B and mean from the running estimation

01:00:11.040 during the optimization,

01:00:13.800 you can see it's very, very similar.

01:00:16.240 It's not identical, but it's pretty close.

01:00:18.640 And in the same way B and STD is this

01:00:22.600 and B and STD running is this.

01:00:25.480 As you can see that once again,

01:00:27.800 they are fairly similar values, not identical,

01:00:30.280 but pretty close.

01:00:31.880 And so then here instead of B and mean,

01:00:33.720 we can use the B and mean running.

01:00:35.960 Instead of B and STD, we can use B and STD running.

01:00:38.560 And hopefully the validation loss

01:00:42.040 will not be impacted too much.

01:00:43.520 Okay, so it's basically identical.

01:00:46.680 And this way we've eliminated the need

01:00:49.280 for this explicit stage of calibration

01:00:51.600 because we are doing it in line over here.

01:00:54.200 Okay, so we're almost done with vascularization.

01:00:56.080 There are only two more notes that I'd like to make.

01:00:58.440 Number one, I've skipped a discussion

01:00:59.960 over what is this + ? doing here.

01:01:02.200 This epsilon is usually like some small fixed number.

01:01:04.920 For example, one in negative five by default.

01:01:07.240 And what it's doing is that it's basically

01:01:08.720 preventing a division by zero.

01:01:10.680 In the case that the variance over your batch

01:01:14.400 is exactly zero.

01:01:15.880 In that case, here we normally have a division by zero,

01:01:19.040 but because of the + ? this is going to become

01:01:21.520 a small number in the denominator instead,

01:01:23.560 and things will be more well behaved.

01:01:25.600 So feel free to also add a + ? here of a very small number.

01:01:29.120 It doesn't actually substantially change the result.

01:01:31.200 I'm going to skip it in our case

01:01:32.400 just because this is unlikely to happen

01:01:34.160 in our very simple example here.

01:01:36.320 And the second thing I want you to notice

01:01:38.160 is that we're being wasteful here,

01:01:39.560 and it's very subtle,

01:01:41.280 but right here where we are adding the bias into H preact,

01:01:45.480 these biases now are actually useless

01:01:48.080 because we're adding them to the H preact,

01:01:50.520 but then we are calculating the mean

01:01:52.920 for every one of these neurons and subtracting it.

01:01:56.000 So whatever bias you add here

01:01:58.040 is going to get subtracted right here.

01:02:00.800 And so these biases are not doing anything.

01:02:02.880 In fact, they're being subtracted out

01:02:04.600 and they don't impact the rest of the calculation.

01:02:07.280 So if you look at b1.grad, it's actually going to be zero

01:02:10.320 because it's being subtracted out

01:02:11.680 and doesn't actually have any effect.

01:02:13.600 And so whenever you're using rationalization layers,

01:02:16.120 then if you have any weight layers

01:02:17.680 before like a linear or a conv or something like that,

01:02:20.600 you're better off coming here and just like not using bias.

01:02:24.280 So you don't want to use bias,

01:02:26.280 and then here you don't want to add it

01:02:29.000 because that's spurious.

01:02:30.640 Instead, we have this rationalization bias here

01:02:33.720 and that rationalization bias is now in charge of

01:02:36.640 the biasing of this distribution

01:02:38.920 instead of this b1 that we had here originally.

01:02:42.200 And so basically, the rationalization layer has its own bias

01:02:45.920 and there's no need to have a bias in the layer before it

01:02:49.360 because that bias is going to be subtracted out anyway.

01:02:52.000 So that's the other small detail to be careful with.

01:02:53.880 Sometimes it's not going to do anything catastrophic.

01:02:56.720 This b1 will just be useless.

01:02:58.560 It will never get any gradient.

01:03:00.440 It will not learn.

01:03:01.280 It will stay constant and it's just wasteful,

01:03:03.120 but it doesn't actually really impact anything otherwise.

01:03:07.200 Okay, so I rearranged the code a little bit with comments

01:03:09.920 and I just wanted to give a very quick summary

01:03:11.680 of the rationalization layer.

01:03:13.760 We are using rationalization to control the statistics

01:03:16.960 of activations in the neural net.

01:03:19.720 It is common to sprinkle rationalization layer

01:03:22.160 across the neural net and usually we will place it

01:03:24.720 after layers that have multiplications,

01:03:27.760 like for example, a linear layer or a convolutional layer

01:03:30.720 which we may cover in the future.

01:03:32.760 Now, the rationalization internally has parameters

01:03:37.680 for the gain and the bias and these are trained

01:03:40.080 using back propagation.

01:03:41.800 It also has two buffers.

01:03:44.480 The buffers are the mean and the standard deviation,

01:03:47.120 the running mean and the running mean of the standard deviation.

01:03:51.000 And these are not trained using back propagation.

01:03:53.000 These are trained using this janky update

01:03:55.920 of kind of like a running mean update.

01:03:57.840 So these are sort of the parameters

01:04:02.880 and the buffers of batch room layer.

01:04:05.240 And then really what it's doing is it's calculating

01:04:07.400 the mean and the standard deviation of the activations

01:04:10.520 that are feeding into the batch room layer over that batch.

01:04:14.000 Then it's centering that batch to be unit Gaussian

01:04:18.520 and then it's offsetting and scaling it

01:04:20.520 by the learned bias and gain.

01:04:24.120 And then on top of that, it's keeping track

01:04:26.040 of the mean and standard deviation of the inputs.

01:04:28.920 And it's maintaining this running mean and standard deviation.

01:04:32.760 And this will later be used at inference

01:04:35.000 so that we don't have to re-estimate

01:04:36.760 the mean and standard deviation all the time.

01:04:38.960 And in addition, that allows us

01:04:40.520 to basically forward individual examples at test time.

01:04:44.280 So that's the batch normalization layer.

01:04:45.800 It's a fairly complicated layer,

01:04:48.440 but this is what it's doing internally.

01:04:50.440 Now, I wanted to show you a little bit of a real example.

01:04:53.200 So you can search re-estimate,

01:04:55.240 which is a residual neural network.

01:04:57.720 And these are context of neural arcs

01:04:59.840 used for image classification.

01:05:02.120 And of course, we haven't come in re-estenets in detail.

01:05:04.680 So I'm not going to explain all the pieces of it.

01:05:07.800 But for now, just note that the image feeds

01:05:10.360 into a re-estimate on the top here.

01:05:12.200 And there's many, many layers with repeating structure

01:05:15.120 all the way to predictions of what's inside that image.

01:05:18.320 This repeating structure is made up of these blocks.

01:05:20.880 And these blocks are just sequentially stacked up

01:05:23.120 in this deep neural network.

01:05:25.640 Now, the code for this, the block basically that's used

01:05:29.680 and repeated sequentially in series

01:05:32.400 is called this bottleneck block.

01:05:36.160 And there's a lot here.

01:05:37.440 This is all PyTorch.

01:05:38.760 And of course, we haven't covered all of it,

01:05:40.240 but I want to point out some small pieces of it.

01:05:43.160 Here in the init is where we initialize the neural net.

01:05:45.640 So this code of block here is basically

01:05:47.720 the kind of stuff we're doing here.

01:05:48.960 We're initializing all the layers.

01:05:51.000 And in the forward, we are specifying

01:05:53.000 how the neural net acts once you actually have the input.

01:05:55.720 So this code here is along the lines of what we're doing here.

01:05:59.120 And now these blocks are replicated and stacked up

01:06:04.720 serially, and that's what a residual network would be.

01:06:08.920 And so notice what's happening here.

01:06:10.920 Com one, these are convolutional layers.

01:06:14.880 And these convolutional layers, basically,

01:06:16.720 they're the same thing as a linear layer,

01:06:19.520 except convolutional layers don't apply--

01:06:22.880 convolutional layers are used for images.

01:06:24.800 And so they have spatial structure.

01:06:26.560 And basically, this linear multiplication and bias

01:06:28.920 offset are done on patches instead of the full input.

01:06:34.760 So because these images have spatial structure,

01:06:37.920 convolutions just basically do wx plus b,

01:06:40.840 but they do it on overlapping patches of the input.

01:06:43.960 But otherwise, it's wx plus b.

01:06:46.720 Then we have the normal layer, which by default here

01:06:49.040 is initialized to be a batch norm in 2D,

01:06:51.320 so 2D-dimensional batch normalization layer.

01:06:54.240 And then we have a nonlinearity like relu.

01:06:56.680 So instead of here they use relu,

01:06:59.600 we are using 10H in this case.

01:07:02.560 But both are just nonlinearities,

01:07:04.520 and you can just use them relatively interchangeably

01:07:07.360 for very deep networks relu's typically, empirically,

01:07:09.960 work a bit better.

01:07:11.720 So see the motif that's being repeated here.

01:07:14.120 We have convolution, batch normalization, relu--

01:07:16.560 convolution, batch normalization, et cetera.

01:07:19.360 And then here, this is residual connection

01:07:21.080 that we haven't covered yet.

01:07:23.000 But basically, that's the exact same pattern we have here.

01:07:25.360 We have a weight layer, like a convolution

01:07:28.640 or like a linear layer, batch normalization,

01:07:32.440 and then 10H, which is nonlinearity.

01:07:35.560 But basically, a weight layer, a normalization layer,

01:07:38.320 and nonlinearity.

01:07:39.560 And that's the motif that you would be stacking up

01:07:41.520 when you create these deep neural networks.

01:07:43.520 Exactly as it's done here.

01:07:45.560 And one more thing I'd like you to notice

01:07:46.960 is that here when they are initializing the conv layers,

01:07:50.240 like conv one by one, the depth for that is right here.

01:07:54.480 And so it's initializing an nn.conf2d,

01:07:56.760 which is a convolutional layer in PyTorch.

01:07:58.920 And there's much of keyword arguments here

01:08:00.400 that I'm not going to explain yet.

01:08:02.240 But you see how there's bias equals false?

01:08:04.760 The bias equals false is exactly for the same reason

01:08:07.120 as bias is not used in our case.

01:08:10.120 You see how I erase the use of bias.

01:08:12.160 And the use of bias is spurious,

01:08:13.680 because after this weight layer,

01:08:15.160 there's a batch normalization.

01:08:16.760 And the batch normalization subtracts that bias

01:08:19.200 and then has its own bias.

01:08:20.640 So there's no need to introduce these spurious parameters.

01:08:23.160 It wouldn't hurt performance, it's just useless.

01:08:25.800 And so because they have this motif of conv,

01:08:28.680 batch, and relu, they don't need a bias here,

01:08:31.040 because there's a bias inside here.

01:08:33.400 So by the way, this example here is very easy to find,

01:08:36.920 just do a resnet PyTorch.

01:08:38.320 And it's this example here.

01:08:41.760 So this is kind of like the stock implementation

01:08:43.600 of a residual neural network in PyTorch.

01:08:46.320 And you can find that here.

01:08:48.200 But of course, I haven't covered many of these parts yet.

01:08:50.720 And I would also like to briefly descend

01:08:52.400 into the definitions of these PyTorch layers

01:08:55.160 and the parameters that they take.

01:08:57.000 Now, instead of a convolutional layer,

01:08:58.360 we're going to look at a linear layer,

01:09:00.360 because that's the one that we're using here.

01:09:02.800 This is a linear layer,

01:09:03.880 and I haven't covered convolutions yet.

01:09:06.120 But as I mentioned, convolutions

01:09:07.360 are basically linear layers except on patches.

01:09:11.240 So a linear layer performs a WX+B,

01:09:14.480 except here they're calling the W-A transpose.

01:09:17.000 So the clock is WX+B very much like we did here.

01:09:21.520 To initialize this layer, you need to know the fan in,

01:09:24.160 the fan out, so that they can initialize this W.

01:09:29.160 This is the fan in and the fan out.

01:09:32.000 So they know how big the weight matrix should be.

01:09:35.040 You need to also pass in whether or not you want a bias.

01:09:39.000 And if you set it to false,

01:09:40.280 no bias will be inside this layer.

01:09:43.480 And you may want to do that exactly like in our case,

01:09:47.120 if your layer is followed by a normalization layer,

01:09:49.560 such as BatchNorm.

01:09:50.640 So this allows you to basically disable a bias.

01:09:54.600 Now, in terms of the initialization,

01:09:55.720 if we swing down here,

01:09:57.160 this is reporting the variables used

01:09:58.760 inside this linear layer.

01:10:01.000 And our linear layer here has two parameters,

01:10:04.240 the weight and the bias.

01:10:05.880 In the same way, they have a weight and a bias.

01:10:08.600 And they're talking about how they initialize it by default.

01:10:11.720 So by default, pytorch will initialize your weights

01:10:14.280 by taking the fan in and then doing one over fan in square root.

01:10:19.280 And then instead of a normal distribution,

01:10:23.600 they are using a uniform distribution.

01:10:25.720 So it's very much the same thing,

01:10:27.920 but they are using a one instead of five over three.

01:10:30.480 So there's no gain being calculated here.

01:10:32.560 The gain is just one,

01:10:33.680 but otherwise it's exactly one over the square root of fan in,

01:10:37.800 exactly as we have here.

01:10:39.040 So one over the square root of K is the scale of the weights,

01:10:45.080 but when they are drawing the numbers,

01:10:46.600 they're not using a Gaussian by default,

01:10:48.760 they're using a uniform distribution by default.

01:10:51.400 And so they draw uniformly from negative square root of K

01:10:54.320 to square root of K.

01:10:56.080 But it's the exact same thing and the same motivation

01:10:58.640 from, with respect to what we've seen in this lecture.

01:11:03.040 And the reason they're doing this is

01:11:04.640 if you have a roughly Gaussian input,

01:11:06.640 this will ensure that out of this layer,

01:11:09.400 you will have a roughly Gaussian output.

01:11:11.760 And you basically achieve that by scaling the weights

01:11:15.280 by one over the square root of fan in.

01:11:17.760 So that's what this is doing.

01:11:19.920 And then the second thing is the pasteurmalization layer.

01:11:23.200 So let's look at what that looks like in PyTorch.

01:11:26.160 So here we have a one dimensional pasteurmalization layer,

01:11:28.560 exactly as we are using here.

01:11:30.840 And there are a number of keyword arguments

01:11:32.120 going into it as well.

01:11:33.520 So we need to know the number of features

01:11:35.720 up for us that is 200.

01:11:37.400 And that is needed so that we can initialize

01:11:39.400 these parameters here.

01:11:40.880 The gain, the bias, and the buffers

01:11:43.520 for the running mean and standard deviation.

01:11:45.920 Then they need to know the value of epsilon here.

01:11:49.960 And by default this is one negative five.

01:11:51.760 You don't typically change this too much.

01:11:54.000 Then they need to know the momentum.

01:11:55.960 And the momentum here, as they explain,

01:11:58.200 is basically used for these running mean

01:12:01.080 and running standard deviation.

01:12:02.800 So by default, the momentum here is 0.1.

01:12:05.120 The momentum we are using here, in this example, is 0.001.

01:12:08.560 And basically you may want to change this sometimes.

01:12:13.720 And roughly speaking, if you have a very large batch size,

01:12:17.360 then typically what you'll see is that

01:12:19.160 when you estimate the mean and standard deviation,

01:12:21.640 for every single batch size, if it's large enough,

01:12:23.760 you're going to get roughly the same result.

01:12:26.120 And so therefore you can use slightly higher momentum,

01:12:29.520 like 0.1.

01:12:31.120 But for a batch size as small as 32,

01:12:34.720 the mean and standard deviation here

01:12:36.080 might take on slightly different numbers,

01:12:37.840 because there's only 32 examples we are using

01:12:39.840 to estimate the mean and standard deviation.

01:12:41.880 So the value is changing around a lot.

01:12:44.200 And if your momentum is 0.1,

01:12:46.160 that might not be good enough for this value to settle

01:12:49.680 and converge to the actual mean and standard deviation

01:12:53.240 over the entire training set.

01:12:55.160 And so basically if your batch size is very small,

01:12:57.600 momentum of 0.1 is potentially dangerous,

01:12:59.800 and it might make it so that the running mean and standard deviation

01:13:02.920 is thrashing too much during training,

01:13:05.200 and it's not actually converging properly.

01:13:07.440 Alpha equals true determines whether this

01:13:11.940 bachelorsalization layer has these learnable

01:13:14.040 alpha parameters, the gain and the bias.

01:13:18.480 And this is almost always kept to true.

01:13:20.720 I'm not actually sure why you would want to change this

01:13:23.200 to false.

01:13:24.120 Then track running stats is determining whether or not

01:13:29.360 bachelorsalization layer of PyTorch will be doing this.

01:13:32.800 And one reason you may want to skip the running stats

01:13:37.600 is because you may want to, for example,

01:13:39.200 estimate them at the end as a stage two like this.

01:13:43.320 And in that case, you don't want the batch

01:13:44.520 normalization layer to be doing all this extra compute

01:13:46.560 that you're not going to use.

01:13:48.880 And finally, we need to know which device we're going to run

01:13:52.080 this batch normalization on, a CPU or a GPU,

01:13:55.000 and what the data type should be.

01:13:56.880 Half precision, single precision,

01:13:58.200 double precision, and so on.

01:13:59.600 So that's the batch normalization layer.

01:14:02.560 Otherwise, they link to the paper.

01:14:03.920 It's the same formula we've implemented,

01:14:06.160 and everything is the same exactly as we've done here.

01:14:09.600 Okay, so that's everything that I wanted to cover

01:14:12.440 for this lecture.

01:14:13.840 Really what I wanted to talk about is the importance

01:14:15.840 of understanding the activations and the gradients

01:14:18.160 and their statistics in neural networks.

01:14:20.480 And this becomes increasingly important,

01:14:22.000 especially as you make your neural networks

01:14:23.560 bigger, larger, and deeper.

01:14:24.960 We looked at the distributions basically

01:14:27.400 at the output layer, and we saw that if you have

01:14:30.160 two confident mispredictions, because the activations

01:14:32.800 are too messed up at the last layer,

01:14:35.000 you can end up with these hockey stick losses.

01:14:37.520 And if you fix this, you get a better loss

01:14:39.600 at the end of training, because your training

01:14:41.520 is not doing wasteful work.

01:14:44.000 Then we also saw that we need to control the activations.

01:14:46.000 We don't want them to squash to zero

01:14:49.160 or explode to infinity.

01:14:50.720 And because that you can run into a lot of trouble

01:14:52.920 with all of these nonlinearities in these neural nets.

01:14:55.840 And basically, you want everything to be fairly homogeneous

01:14:58.160 throughout the neural net.

01:14:59.000 You want roughly Gaussian activations

01:15:00.520 throughout the neural net.

01:15:02.560 Let me talk about, okay, if we want roughly Gaussian

01:15:05.720 activations, how do we scale these weight matrices

01:15:08.720 and biases during initialization of the neural net

01:15:11.320 so that we don't get, so everything is as controlled

01:15:14.640 as possible?

01:15:15.480 So that gave us a large boost in improvement.

01:15:20.000 And then I talked about how that strategy is not actually

01:15:23.680 possible for much, much deeper neural nets.

01:15:27.400 Because when you have much deeper neural nets

01:15:30.240 with lots of different types of layers,

01:15:32.480 it becomes really, really hard to precisely set the weights

01:15:35.800 and the biases in such a way that the activations

01:15:38.280 are roughly uniform throughout the neural net.

01:15:41.360 So then I introduced the notion of the normalization layer.

01:15:44.560 Now there are many normalization layers

01:15:45.960 that people use in practice.

01:15:47.960 Besh normalization, layer normalization,

01:15:50.320 consistent normalization, group normalization.

01:15:52.720 We haven't covered most of them,

01:15:54.080 but I've introduced the first one.

01:15:55.600 And also the one that I believe came out first

01:15:58.200 and that's called batch normalization.

01:16:00.760 And we saw how batch normalization works.

01:16:03.000 This is a layer that you can sprinkle throughout

01:16:05.000 your deep neural net.

01:16:06.320 And the basic idea is if you want roughly Gaussian

01:16:09.360 activations, well then take your activations

01:16:11.880 and take the mean and the standard deviation

01:16:14.640 and center your data.

01:16:16.640 And you can do that because the centering operation

01:16:19.640 is differentiable.

01:16:21.440 But on top of that, we actually had to add

01:16:23.520 a lot of bells and whistles.

01:16:25.160 And that gave you a sense of the complexities

01:16:27.040 of the batch normalization layer.

01:16:28.640 Because now we're centering the data, that's great.

01:16:31.000 But suddenly we need the gain and the bias.

01:16:33.440 And now those are trainable.

01:16:35.280 And then because we are coupling all the training examples,

01:16:38.640 now suddenly the question is, how do you do the inference?

01:16:41.160 Or to do the inference, we need to now estimate

01:16:44.480 these mean and standard deviation

01:16:47.320 once or the entire training set

01:16:49.880 and then use those at inference.

01:16:51.880 But then no one likes to do stage two.

01:16:53.760 So instead we fold everything into the batch

01:16:56.200 normalization layer during training

01:16:57.800 and try to estimate these in the running manner

01:17:00.280 so that everything is a bit simpler.

01:17:02.760 And that gives us the batch normalization layer.

01:17:06.320 And as I mentioned, no one likes this layer.

01:17:09.440 It causes a huge amount of bugs.

01:17:12.560 And intuitively it's because it is coupling examples

01:17:17.040 in the forward pass of the neural net.

01:17:18.800 And I've shocked myself in the foot with this layer

01:17:23.280 over and over again in my life.

01:17:25.160 And I don't want you to suffer the same.

01:17:28.360 So basically try to avoid it as much as possible.

01:17:32.000 Some of the other alternatives to these layers

01:17:33.760 are for example group normalization or layer normalization.

01:17:36.640 And those have become more common in more recent deep learning.

01:17:40.960 But we haven't covered those yet.

01:17:43.240 But definitely batch normalization

01:17:44.440 was very influential at the time when it came out

01:17:46.840 in roughly 2015.

01:17:48.720 Because it was kind of the first time

01:17:50.400 that you could train reliably much deeper neural nets.

01:17:55.400 And fundamentally the reason for that

01:17:56.680 is because this layer was very effective

01:17:59.080 at controlling the statistics of the activations

01:18:01.560 in a neural net.

01:18:03.200 So that's the story so far.

01:18:05.360 And that's all I wanted to cover.

01:18:07.840 And in the future lecture, so hopefully we

01:18:09.440 can start going into recurrent neural nets.

01:18:11.600 And recurrent neural nets as we'll see

01:18:14.240 are just very, very deep networks.

01:18:16.440 Because you unroll the loop.

01:18:18.720 And when you actually optimize these neural nets.

01:18:21.480 And that's where a lot of this analysis

01:18:25.240 around the activation statistics and all these

01:18:28.080 normalization layers will become very, very important

01:18:30.800 for good performance.

01:18:32.760 So we'll see that next time.

01:18:34.160 Bye.

01:18:35.320 OK, so I lied.

01:18:36.480 I would like us to do one more summary here as a bonus.

01:18:39.200 And I think it's useful as to have one more summary

01:18:41.880 of everything I've presented in this lecture.

01:18:43.920 But also I would like us to start by tortuifying our code

01:18:46.760 a little bit.

01:18:47.240 So it looks much more like what you would encounter in PyTorch.

01:18:50.360 So you'll see that I will structure our code

01:18:52.120 into these modules, like a linear module and a batch

01:18:57.520 room module.

01:18:58.680 And I'm putting the code inside these modules

01:19:01.400 so that we can construct neural networks very much like we

01:19:03.560 would construct the in PyTorch.

01:19:04.760 And I will go through this in detail.

01:19:06.720 So we'll create our neural net.

01:19:08.800 Then we will do the optimization loop as we did before.

01:19:12.240 And then the one more thing that I want to do here

01:19:14.240 is I want to look at the activation statistics

01:19:16.160 both in the forward pass and in the backward pass.

01:19:19.320 And then here we have the evaluation and sampling

01:19:21.480 just like before.

01:19:22.960 So let me rewind all the way up here and go a little bit slower.

01:19:26.840 So here I am creating a linear layer.

01:19:29.280 You'll notice that Torch.NN has lots of different types

01:19:31.960 of layers.

01:19:32.720 And one of those layers is the linear layer.

01:19:35.120 Torch.NN.linear takes a number of input features, output

01:19:37.640 features, whether or not we should have bias.

01:19:39.920 And then the device that we want to place this layer on

01:19:42.520 and the data type.

01:19:43.920 So I will omit these two, but otherwise we have the exact

01:19:47.280 same thing.

01:19:48.360 We have the fan in, which is the number of inputs.

01:19:50.720 Fan out, the number of outputs.

01:19:53.440 And whether or not we want to use a bias.

01:19:55.320 And internally, inside this layer, there's a weight and a

01:19:57.880 bias if you'd like it.

01:19:59.880 It is typical to initialize the weight using, say, random

01:20:04.400 numbers drawn from Gaussian.

01:20:05.960 And then here's the timing initialization that we've

01:20:08.920 discussed already in this lecture.

01:20:10.600 And that's a good default.

01:20:12.000 And also the default that I believe PyTorch uses.

01:20:14.720 And by default, the bias is usually initialized to zeros.

01:20:18.360 Now, when you call this module, this will basically

01:20:21.360 calculate W times X plus B if you have Nb.

01:20:24.880 And then when you also call that parameters on this module,

01:20:27.440 it will return the tensors that are the parameters of this

01:20:31.080 layer.

01:20:32.200 Now, next we have the bachelors-renolization layer.

01:20:34.480 So I've written that here.

01:20:37.000 And this is very similar to PyTorch and then the bachelors

01:20:41.440 1D layer as shown here.

01:20:44.400 So I'm kind of taking these three parameters here, the

01:20:48.120 dimensionality, the epsilon that we'll use in the division,

01:20:51.480 and the momentum that we will use in keeping track of these

01:20:54.320 running stats, the running mean and the running variance.

01:20:58.160 Now, PyTorch actually takes quite a few more things, but

01:21:00.640 I'm assuming some of their settings.

01:21:02.280 So for us, affine will be true.

01:21:03.920 That means that we will be using a gamma beta after

01:21:06.400 denormalization.

01:21:08.040 The track running stats will be true.

01:21:09.560 So we will be keeping track of the running mean and the

01:21:11.640 running variance in the bachelors.

01:21:14.640 Our device by default is the CPU.

01:21:17.080 And the data type by default is float, float 32.

01:21:22.200 So those are the defaults.

01:21:23.440 Otherwise, we are taking all the same parameters in this

01:21:26.360 bachelors layer.

01:21:27.440 So first, I'm just saving them.

01:21:29.960 Now, here's something new.

01:21:31.080 There's a dot training, which by default is true.

01:21:33.560 And PyTorch and N modules also have this attribute,

01:21:36.080 that training.

01:21:37.120 And that's because many modules and batch.

01:21:39.480 So is included in that have a different behavior, whether you

01:21:43.560 are training your own or whether you are running it in an

01:21:46.360 evaluation mode and calculating your evaluation laws or using

01:21:49.960 it for inference on some test examples.

01:21:52.800 And batch.

01:21:53.560 So is an example of this, because when we are training, we

01:21:56.160 are going to be using the mean and the variance

01:21:57.800 estimated from the current batch.

01:21:59.680 But during inference, we are using the running mean and

01:22:02.360 running variance.

01:22:03.960 And so also, if we are training, we are updating mean

01:22:07.040 and variance.

01:22:07.760 But if we are testing, then these are not being

01:22:09.640 updated.

01:22:09.960 They're kept fixed.

01:22:11.680 And so this flag is necessary.

01:22:13.360 And by default true, just like in PyTorch.

01:22:16.320 Now, the parameters of Bachelors 1D are the gamma and the

01:22:18.800 beta here.

01:22:21.680 And then the running mean and running variance are called

01:22:24.520 buffers in PyTorch nomenclature.

01:22:27.480 And these buffers are trained using exponential moving

01:22:31.440 average here explicitly.

01:22:33.360 And they are not part of the back propagation and stochastic

01:22:36.000 gradient descent.

01:22:37.000 So they are not sort of parameters of this layer.

01:22:39.800 And that's why when we have parameters here, we only

01:22:43.280 return gamma and beta.

01:22:44.600 We do not return the mean and the variance.

01:22:46.600 This is trained internally here, every forward pass, using

01:22:51.840 exponential moving average.

01:22:54.560 So that's the initialization.

01:22:56.880 Now, in a forward pass, if we are training, then we use the

01:23:00.360 mean and the variance estimated by the batch.

01:23:03.320 We'll upload the paper here.

01:23:05.840 We calculate the mean and the variance.

01:23:08.800 Now, up above, I was estimating the standard deviation and

01:23:12.280 keeping track of the standard deviation here in the running

01:23:15.680 standard deviation instead of running variance.

01:23:18.040 But let's follow the paper exactly.

01:23:20.120 Here, they calculate the variance, which is the standard

01:23:22.800 deviation squared.

01:23:23.800 And that's what's kept track of in the running variance

01:23:26.640 instead of a running standard deviation.

01:23:29.720 But those two would be very, very similar, I believe.

01:23:33.840 If we are not training, then we use running mean and

01:23:35.680 variance, we normalize.

01:23:39.040 And then here, I am calculating the output of this layer.

01:23:42.040 And I'm also assigning it to an attribute called dot out.

01:23:45.480 Now, dot out is something that I'm using in our modules

01:23:48.800 here.

01:23:49.560 This is not what you would find in PyTorch.

01:23:51.400 We are slightly deviating from it.

01:23:53.240 I'm creating a dot out because I would like to very easily

01:23:57.400 maintain all those variables so that we can create statistics

01:24:00.000 of them and plot them.

01:24:01.360 But PyTorch and modules will not have a dot out attribute.

01:24:05.440 And finally, here, we are updating the buffers using,

01:24:07.880 again, as I mentioned, exponential moving average,

01:24:11.240 given the provided momentum.

01:24:13.040 And importantly, you'll notice that I'm using the Torstap

01:24:15.480 No-GRAT context manager.

01:24:17.360 And I'm doing this because if we don't use this,

01:24:19.920 then PyTorch will start building out an entire

01:24:22.120 computational graph out of these tensors because it is

01:24:25.120 expecting that we will eventually call dot backward.

01:24:28.040 But we are never going to be calling dot backward on

01:24:29.960 anything that includes running mean and running variance.

01:24:32.600 So that's why we need to use this context manager so that

01:24:35.400 we are not maintaining them using all this additional memory.

01:24:40.360 So this will make it more efficient.

01:24:41.880 And it's just telling PyTorch that while we know

01:24:43.360 backward, we just have a bunch of tensors.

01:24:45.320 We want to update them.

01:24:46.320 That's it.

01:24:47.760 And then we return.

01:24:50.200 OK, now scrolling down, we have the 10H layer.

01:24:52.760 This is very, very similar to Torch dot 10H.

01:24:55.880 And it doesn't do too much.

01:24:57.720 It just calculates 10H, as you might expect.

01:25:00.400 So that's Torch dot 10H.

01:25:02.480 And there's no parameters in this layer.

01:25:05.200 But because these are layers, it now becomes very easy to

01:25:08.800 sort of stack them up into basically just a list.

01:25:13.280 And we can do all the initialization that we're used to.

01:25:16.280 So we have the initial sort of embedding matrix.

01:25:19.480 We have our layers, and we can call them sequentially.

01:25:22.240 And then again, with Torch dot no grad, there's some

01:25:24.840 initialization here.

01:25:26.160 So we want to make the output softmax a bit less

01:25:28.520 confident like we saw.

01:25:30.360 And in addition to that, because we are using a six layer

01:25:33.320 multi layer perceptron here, so you see how I'm stacking

01:25:36.200 linear 10H, linear 10H, et cetera, I'm going to be using

01:25:40.160 the game here.

01:25:41.320 And I'm going to play with this in a second, so you'll see

01:25:43.400 how when we change this, what happens to this statistics.

01:25:47.280 Finally, the parameters are basically the embedding matrix

01:25:49.720 and all the parameters in all the layers.

01:25:52.440 And notice here, I'm using a double list comprehension, if

01:25:55.400 you want to call it that.

01:25:56.120 But for every layer in layers and for every parameter in each

01:25:59.880 of those layers, we are just stacking up all those

01:26:02.400 piece, all those parameters.

01:26:05.040 Now in total, we have 46,000 parameters.

01:26:09.400 And I'm telling PyTorch that all of them require gradient.

01:26:12.320 Then here, we have everything here we are actually mostly

01:26:19.560 used to.

01:26:20.720 We are sampling batch.

01:26:22.080 We are doing a forward pass.

01:26:23.520 The forward pass now is just the linear application of all

01:26:25.760 the layers in order, followed by the cross entropy.

01:26:29.400 And then in the backward pass, you'll notice that for

01:26:31.240 every single layer, I now iterate over all the outputs.

01:26:34.160 And I'm telling PyTorch to retain the gradient of them.

01:26:37.440 And then here, we are already used to all the gradients

01:26:40.680 sent to none.

01:26:41.720 Do the backward to fill in the gradients.

01:26:43.920 Do an update using stochastic gradient sent.

01:26:46.360 And then track some statistics.

01:26:48.760 And then I am going to break after a single iteration.

01:26:52.040 Now here in this cell, in this diagram, I'm visualizing the

01:26:54.960 histogram, the histograms of the forward pass activations.

01:26:58.680 And I'm specifically doing it at the 10-inch layers.

01:27:01.800 So iterating over all the layers, except for the very last

01:27:05.320 one, which is basically just the softmax layer.

01:27:10.160 If it is a 10-inch layer-- and I'm using a 10-inch layer

01:27:12.760 just because they have a finite output, negative 1 to 1.

01:27:15.400 And so it's very easy to visualize here.

01:27:17.360 So you see negative 1 to 1.

01:27:18.800 And it's a finite range.

01:27:19.760 And it's easy to work with.

01:27:21.720 I take the out tensor from that layer into T. And then I'm

01:27:25.840 calculating the mean, the standard deviation, and the

01:27:28.360 percent saturation of T. And we are defined to percent

01:27:31.760 saturation as that T dot absolute value is greater than 0.97.

01:27:35.440 So that means we are here at the tails of the 10-inch.

01:27:38.600 And remember that when we are in the tails of the 10-inch,

01:27:40.920 that will actually stop gradients.

01:27:42.760 So we don't want this to be too high.

01:27:45.560 Now here I'm calling Torx dot histogram.

01:27:48.920 And then I am plotting this histogram.

01:27:51.200 So basically what this is doing is that every different

01:27:53.320 type of layer-- and they all have a different color--

01:27:55.320 we are looking at how many values in these tensors take on

01:28:00.480 any of the values below on this axis here.

01:28:04.120 So the first layer is fairly saturated here at 20%.

01:28:07.960 So you can see that it's got tails here.

01:28:10.440 But then everything sort of stabilizes.

01:28:12.440 And if we had more layers here, it would actually just

01:28:14.480 stabilize it around the standard deviation of about 0.65.

01:28:17.960 And the saturation would be roughly 5%.

01:28:20.600 And the reason that this stabilizes and gives us a nice

01:28:23.160 distribution here is because gain is set to 5/3.

01:28:27.680 Now here, this gain, you see that by default we initialize

01:28:33.000 with one over square root of fan in.

01:28:35.320 But then here during initialization, I come in and I

01:28:37.560 iterate over all the layers.

01:28:38.720 And if it's a linear layer, I boost that by the gain.

01:28:42.360 Now we saw that one--

01:28:44.560 so basically if we just do not use a gain, then what happens?

01:28:48.760 If I redraw this, you will see that the standard deviation

01:28:53.000 is shrinking and the saturation is coming to 0.

01:28:57.000 And basically what's happening is the first layer is pretty

01:29:00.000 decent.

01:29:00.840 But then further layers are just kind of like shrinking

01:29:03.600 down to 0.

01:29:04.920 And it's happening slowly, but it's shrinking to 0.

01:29:07.600 And the reason for that is when you just have a sandwich

01:29:11.160 of linear layers alone, then initializing our weights

01:29:16.760 in this manner, we saw previously would have conserved

01:29:20.200 the standard deviation of 1.

01:29:22.080 But because we have this interspersed 10-inch layers

01:29:24.920 in there, these 10-inch layers are squashing functions.

01:29:29.520 And so they take your distribution

01:29:31.240 and they slightly squash it.

01:29:32.880 And so some gain is necessary to keep expanding it,

01:29:37.160 to fight the squashing.

01:29:39.920 So it just turns out that 5/3 is a good value.

01:29:43.480 So if we have something too small, like 1, we saw that things

01:29:46.520 will come towards 0.

01:29:49.000 But if it's something too high, let's do 2.

01:29:52.440 Then here we see that--

01:29:53.560 well, let me do something a bit more extreme,

01:29:58.680 because-- so it's a bit more visible.

01:30:00.440 Let's try 3.

01:30:02.240 OK, so we see here that the saturation is

01:30:04.080 going to be way too large.

01:30:07.000 So 3 would create way too saturated activations.

01:30:10.880 So 5/3 is a good setting for a sandwich of linear layers

01:30:16.200 with 10-inch activations.

01:30:17.840 And it roughly stabilizes the standard deviation

01:30:20.440 at a reasonable point.

01:30:22.600 Now, honestly, I have no idea where 5/3 came from in PyTorch

01:30:27.240 when we were looking at the cutting initialization.

01:30:30.040 I see empirically that it stabilizes this sandwich

01:30:32.840 of linear and 10-inch, and that the saturation is in a good range.

01:30:36.720 But I don't actually know if this came out of some math formula.

01:30:39.560 I tried searching briefly for where this comes from,

01:30:42.760 but I wasn't able to find anything.

01:30:44.520 But certainly we see that empirically,

01:30:46.080 these are very nice ranges.

01:30:47.440 Our saturation is roughly 5%, which is a pretty good number.

01:30:50.920 And this is a good setting of the gain in this context.

01:30:55.160 Similarly, we can do the exact same thing with the gradients.

01:30:58.280 So here is a very same loop if it's a 10-inch,

01:31:01.480 but instead of taking the layer dot out, I'm taking the grad.

01:31:04.440 And then I'm also showing the mean and standard deviation

01:31:07.160 and I'm plotting the histogram of these values.

01:31:10.040 And so you'll see that the gradient distribution

01:31:11.760 is fairly reasonable.

01:31:13.480 And in particular, what we're looking for

01:31:14.920 is that all the different layers in this sandwich

01:31:17.720 has roughly the same gradient.

01:31:19.560 Things are not shrinking or exploding.

01:31:21.920 So we can, for example, come here

01:31:23.960 and we can take a look at what happens

01:31:25.400 if this gain was way too small.

01:31:27.440 So this was 0.5.

01:31:28.760 Then you see the, first of all,

01:31:32.560 the activations are shrinking to zero,

01:31:34.320 but also the gradients are doing something weird.

01:31:36.400 The gradients started out here

01:31:38.080 and then now they're like expanding out.

01:31:40.160 And similarly, if we, for example,

01:31:43.480 have a too high of a gain, so like three,

01:31:45.480 then we see that also the gradients have,

01:31:48.320 there's some asymmetry going on

01:31:49.560 where as you go into deeper and deeper layers,

01:31:52.120 the activations are also changing.

01:31:54.160 And so that's not what we want.

01:31:55.480 And in this case, we saw that without the use of batch term,

01:31:58.400 as we are going through right now,

01:32:00.360 we have to very carefully set those gains

01:32:03.240 to get nice activations in both the forward pass

01:32:06.440 and the backward pass.

01:32:07.600 Now, before we move on to batch normalization,

01:32:10.160 I would also like to take a look at what happens

01:32:12.040 when we have no 10H units here.

01:32:13.960 So erasing all the 10H nonlinearities,

01:32:16.800 but keeping the gain at five over three,

01:32:19.400 we now have just a giant linear sandwich.

01:32:22.120 So let's see what happens to the activations.

01:32:24.360 As we saw before, the correct gain here is one.

01:32:27.480 That is the standard deviation preserving gain.

01:32:29.680 So 1.667 is too high.

01:32:33.680 And so what's gonna happen now is the following.

01:32:36.240 I have to change this to be linear.

01:32:39.000 So we are, because there's no more 10H players.

01:32:43.040 And let me change this to linear as well.

01:32:45.080 So what we're seeing is the activations started out

01:32:50.320 on the blue and have by layer four become very diffuse.

01:32:55.120 So what's happening to the activations is this.

01:32:57.800 And with the gradients on the top layer,

01:33:00.640 the activation, the gradient statistics are the purple,

01:33:04.400 and then they diminish as you go down deeper than the layers.

01:33:07.720 And so basically you have an asymmetry like in the neural net.

01:33:10.880 And you might imagine that if you have very deep neural networks,

01:33:13.040 say like 50 layers or something like that,

01:33:15.160 this just, this is not a good place to be.

01:33:18.880 So that's why before that normalization,

01:33:21.320 this was incredibly tricky to set.

01:33:24.240 In particular, if this is too large of a gain,

01:33:26.640 this happens.

01:33:27.480 And if it's too little of a gain, then this happens.

01:33:31.480 So the opposite of that basically happens.

01:33:33.440 Here we have a shrinking and a diffusion,

01:33:38.440 depending on which direction you look at it from.

01:33:42.400 And so certainly this is not what you want.

01:33:44.200 And in this case, the correct setting of the gain

01:33:46.040 is exactly one, just like we're doing at initialization.

01:33:50.240 And then we see that the statistics for the forward

01:33:54.120 and the backward pass are well-paid.

01:33:56.160 And so the reason I want to show you this is,

01:34:00.000 but basically like getting your own list of train

01:34:02.640 before these normalization layers.

01:34:04.280 And before the use of advanced optimizers like Adam,

01:34:07.000 which we still have to cover,

01:34:08.320 and residual connections and so on,

01:34:10.560 training neural nets basically look like this.

01:34:13.400 It's like a total balancing act.

01:34:15.000 You have to make sure that everything is precisely

01:34:17.480 orchestrated and you have to care about the activations

01:34:19.640 and the gradients and their statistics.

01:34:21.320 And then maybe you can train something,

01:34:23.520 but it was basically impossible to train very deep networks.

01:34:25.760 And this is fundamentally the reason for that.

01:34:28.040 You'd have to be very, very careful with your initialization.

01:34:32.240 The other point here is,

01:34:34.480 you might be asking yourself, by the way,

01:34:35.600 I'm not sure if I covered this,

01:34:37.120 why do we need these 10H layers at all?

01:34:40.760 Why do we include them and then have to worry about the gain?

01:34:43.560 And the reason for that, of course,

01:34:45.040 is that if you just have a stack of linear layers,

01:34:47.760 then certainly we're getting very easily nice

01:34:50.600 activations and so on,

01:34:52.320 but this is just a massive linear sandwich

01:34:54.560 and it turns out that it collapses

01:34:55.920 to a single linear layer in terms of its representation power.

01:34:59.680 So if you were to plot the output as a function

01:35:02.200 and the input, you're just getting a linear function.

01:35:04.360 No matter how many linear layers you stack up,

01:35:06.520 you still just end up with a linear transformation.

01:35:09.040 All the WX plus B's just collapse into a large WX plus B

01:35:13.960 with slightly different W's as likely different B.

01:35:16.080 But interestingly, even though the forward pass collapses

01:35:19.840 to just a linear layer,

01:35:21.120 because of back propagation and the dynamics

01:35:24.200 of the backward pass,

01:35:26.040 the optimization, actually, is not identical.

01:35:28.620 You actually end up with all kinds of interesting

01:35:32.080 dynamics in the backward pass

01:35:34.760 because of the way the chain rule is calculating it.

01:35:37.880 And so optimizing a linear layer by itself

01:35:40.920 and optimizing a sandwich of 10 linear layers.

01:35:43.880 In both cases, those are just a linear transformation

01:35:45.800 in the forward pass,

01:35:46.800 but the training dynamics would be different.

01:35:48.560 And there's entire papers that analyze,

01:35:50.680 in fact, like infinitely layered linear layers and so on.

01:35:54.440 And so there's a lot of things too

01:35:56.240 that you can play with there.

01:35:58.680 But basically the attention on linearities

01:36:00.320 allow us to turn this sandwich

01:36:05.320 from just a linear function

01:36:09.200 into a neural network that can, in principle,

01:36:13.120 approximate any arbitrary function.

01:36:15.640 Okay, so now I've reset the code

01:36:17.000 to use the linear 10-H sandwich like before.

01:36:20.600 And I've reset everything,

01:36:21.960 so the gains five over three.

01:36:24.040 We can run a single step of optimization,

01:36:26.400 and we can look at the activation statistics

01:36:28.280 of the forward pass and the backward pass.

01:36:30.560 But I've added one more plot here

01:36:31.960 that I think is really important to look at

01:36:33.760 when you're training your neural nets and to consider.

01:36:36.240 And ultimately what we're doing is

01:36:37.880 we're updating the parameters of the neural net.

01:36:40.160 So we care about the parameters and their values

01:36:42.960 and their gradients.

01:36:44.480 So here what I'm doing is I'm actually iterating

01:36:46.320 over all the parameters available,

01:36:48.040 and then I'm only restricting it

01:36:50.680 to the two-dimensional parameters,

01:36:52.360 which are basically the weights of these linear layers.

01:36:54.720 And I'm skipping the biases,

01:36:56.320 and I'm skipping the gammas and the betas in the best room,

01:37:00.320 just for simplicity.

01:37:02.520 But you can also take a look at those as well.

01:37:04.160 But what's happening with the weights

01:37:05.640 is instructive by itself.

01:37:08.000 So here we have all the different weights, their shapes.

01:37:12.840 So this is the embedding layer, the first linear layer,

01:37:15.640 all the way to the very last linear layer.

01:37:17.600 And then we have the mean,

01:37:18.600 the standard deviation of all these parameters,

01:37:21.080 the histogram.

01:37:22.920 And you can see that it actually doesn't look that amazing.

01:37:24.920 So there's some trouble in paradise,

01:37:26.640 even though these gradients looked okay,

01:37:28.800 there's something weird going on here.

01:37:30.360 I'll get to that in a second.

01:37:32.080 And the last thing here is the gradient to data ratio.

01:37:35.880 So sometimes I like to visualize this as well,

01:37:37.640 because what this gives you a sense of is,

01:37:40.440 what is the scale of the gradient

01:37:42.440 compared to the scale of the actual values?

01:37:45.640 And this is important because we're going to end up

01:37:47.680 taking a step update

01:37:49.160 that is the learning rate times the gradient onto the data.

01:37:54.080 And so the gradient has too large of a magnitude.

01:37:56.680 If the numbers in there are too large,

01:37:58.280 compared to the numbers in data, then you'd be in trouble.

01:38:01.720 But in this case, the gradient to data is our low numbers.

01:38:05.280 So the values inside grad are 1,000 times smaller

01:38:09.240 than the values inside data in these weights, most of them.

01:38:13.840 Now, notably that is not true about the last layer.

01:38:17.200 And so the last layer actually here,

01:38:18.480 the output layer is a bit of a troublemaker

01:38:20.560 in the way that this is currently arranged,

01:38:22.520 because you can see that the last layer here in pink

01:38:27.520 takes on values that are much larger

01:38:30.680 than some of the values inside the neural net.

01:38:35.840 So the standard deviations are roughly one

01:38:37.520 in negative three throughout, except for the last layer,

01:38:41.520 which actually has roughly one in negative two

01:38:43.960 standard deviation of gradients.

01:38:45.760 And so the gradients on the last layer are currently

01:38:48.400 about 100 times greater, sorry, 10 times greater,

01:38:52.280 than all the other weights inside the neural net.

01:38:55.880 And so that's problematic because in the simplest

01:38:58.560 stochastic gradient in the sense setup,

01:39:00.200 you would be training this last layer about 10 times faster

01:39:03.720 than you would be training the other layers

01:39:05.480 at initialization.

01:39:07.120 Now this actually kind of fixes itself a little bit

01:39:10.000 if you train for a bit longer.

01:39:11.080 So for example, if I agree with then 1,000,

01:39:14.040 only then do a break.

01:39:16.200 Let me re-initialize and then let me do it 1,000 steps.

01:39:20.000 And after 1,000 steps, we can look at the forward pass.

01:39:24.160 Okay, so you see how the neurons are a bit,

01:39:26.200 are saturating a bit.

01:39:27.760 And we can also look at the backward pass,

01:39:30.040 but otherwise they look good.

01:39:31.240 They're about equal and there's no shrinking to zero

01:39:34.040 or exploding to infinities.

01:39:36.160 And you can see that here in the weights,

01:39:38.680 things are also stabilizing a little bit.

01:39:40.240 So their tails of the last pink layer

01:39:42.760 are actually coming in during the optimization.

01:39:46.280 But certainly this is like a little bit troubling,

01:39:48.760 especially if you are using a very simple update rule

01:39:51.040 like stochastic gradient descent

01:39:52.560 instead of a modern optimizer like Adam.

01:39:55.240 Now I'd like to show you one more plot

01:39:56.640 that I usually look at when I train neural networks.

01:39:58.960 And basically the gradient to data ratio

01:40:01.760 is not actually that informative,

01:40:03.280 because what matters at the end

01:40:04.360 is not the gradient to data ratio,

01:40:06.200 but the update to the data ratio.

01:40:08.400 Because that is the amount by which we will actually

01:40:10.280 change the data in these tensors.

01:40:12.920 So coming up here, what I'd like to do

01:40:14.920 is I'd like to introduce a new update to data ratio.

01:40:19.760 It's going to be less than we're going to build it out

01:40:21.280 every single iteration.

01:40:23.120 And here I'd like to keep track of basically

01:40:25.840 the ratio every single iteration.

01:40:28.920 So without any gradients, I'm comparing the update,

01:40:35.000 which is learning rate times the gradient,

01:40:37.840 that is the update that we're going to apply

01:40:40.320 to every parameter,

01:40:42.120 sociometric neural world parameters.

01:40:44.520 And then I'm taking the basically standard deviation

01:40:46.280 of the update we're going to apply

01:40:48.200 and divided by the actual content,

01:40:52.520 the data of that parameter and its standard deviation.

01:40:56.080 So this is the ratio of basically how great

01:40:58.720 are the updates to the values in these tensors.

01:41:02.360 Then we're going to take a log of it,

01:41:03.520 and actually I'd like to take a log 10,

01:41:05.280 just so it's a nicer visualization.

01:41:09.880 So we're going to be basically looking

01:41:11.360 at the exponents of this division here,

01:41:16.360 and then that item to pop out the float.

01:41:19.320 And we're going to be keeping track of this

01:41:20.680 for all the parameters and adding it to this UD tensor.

01:41:24.200 So now let me re-inertilize and run a thousand iterations.

01:41:27.640 We can look at the activations,

01:41:30.640 the gradients and the parameter gradients

01:41:32.880 as we did before.

01:41:34.280 But now I have one more plot here to introduce.

01:41:36.560 Now what's happening here is we're every interval

01:41:39.360 of parameters and I'm constraining it again,

01:41:41.840 like I did here, to just two weights.

01:41:43.760 So the number of dimensions in these sensors is two.

01:41:47.800 And then I'm basically plotting all of these update ratios

01:41:52.200 over time.

01:41:53.040 So when I plot this, I plot those ratios

01:41:57.600 and you can see that they evolve over time

01:41:59.440 during initialization that take on certain values.

01:42:02.000 And then these updates are like

01:42:03.120 start stabilizing usually during training.

01:42:05.920 Then the other thing that I'm plotting here

01:42:07.320 is I'm plotting here like an approximate value

01:42:09.360 that is a rough guide for what it roughly should be.

01:42:12.920 And it should be like roughly one in negative three.

01:42:15.480 And so that means that basically there's some values

01:42:18.280 in this tensor and they take on certain values

01:42:21.880 and the updates to them at every single iteration

01:42:24.280 are no more than roughly 1,000th of the actual magnitude

01:42:29.040 in those tensors.

01:42:30.080 If this was much larger, like for example,

01:42:33.400 if the log of this was like say, negative one,

01:42:37.720 this is actually updating those values quite a lot.

01:42:40.080 They're undergoing a lot of change.

01:42:42.240 But the reason that the final layer here is an outlier

01:42:46.600 is because this layer was artificially shrunk down

01:42:50.160 to keep the softmax incompetent.

01:42:53.400 So here, you see how we multiply the weight by point one

01:42:58.320 in the initialization to make the last layer prediction

01:43:02.520 less confident.

01:43:04.080 That made, that artificially made the values

01:43:07.480 inside that tensor way too low.

01:43:09.280 And that's why we're getting temporarily a very high ratio,

01:43:12.120 but you see that that stabilizes over time

01:43:14.160 once that weight starts to learn.

01:43:17.960 But basically I like to look at the evolution

01:43:19.680 of this update ratio for all my parameters usually.

01:43:23.160 And I like to make sure that it's not too much

01:43:26.280 above one in negative three roughly.

01:43:29.640 So around negative three on this log plot,

01:43:33.040 if it's below negative three,

01:43:34.160 usually that means that parameters are not training fast enough.

01:43:37.360 So if our learning rate was very low,

01:43:39.040 let's do that experiment.

01:43:40.360 Let's initialize and then let's actually do a learning rate

01:43:44.600 of say one in negative three here.

01:43:47.560 So 0.001.

01:43:49.640 If you're learning rate is way too low,

01:43:51.320 this plot will typically reveal it.

01:43:56.400 So you see how all of these updates

01:43:58.880 are way too small.

01:44:00.320 So the size of the update is basically 10,000 times

01:44:05.120 in magnitude to the size of the numbers

01:44:09.120 in that tensor in the first place.

01:44:10.680 So this is a symptom of training way too slow.

01:44:13.120 So this is another way to sometimes set the learning rate

01:44:16.880 and to get a sense of what that learning rate should be.

01:44:19.200 And ultimately this is something that you would keep track of.

01:44:22.080 If anything, the learning rate here

01:44:26.920 is a little bit on the higher side

01:44:29.520 because you see that we're above the black line of negative three.

01:44:33.960 We're somewhere around negative 2.5.

01:44:36.000 It's like, okay, but everything is like somewhat stabilizing.

01:44:39.720 And so this looks like a pretty decent setting

01:44:41.400 of learning rates and so on.

01:44:44.080 But this is something to look at.

01:44:45.480 And when things are mis-calibrated,

01:44:46.720 you will see very quickly.

01:44:48.440 So for example, everything looks pretty well behaved, right?

01:44:52.240 But just as a comparison,

01:44:53.520 when things are not properly calibrated,

01:44:55.000 what does that look like?

01:44:56.440 Let me come up here and let's say that, for example,

01:45:00.480 what do we do?

01:45:01.760 Let's say that we forgot to apply this fan in normalization.

01:45:05.800 So the weights inside the linear layers

01:45:07.200 are just a sample from a Gaussian in all those stages.

01:45:10.720 What happens to our, how do we notice that something's off?

01:45:14.440 Well, the activation plot will tell you,

01:45:16.200 whoa, your neurons are way too saturated.

01:45:18.920 The gradients are gonna be all messed up.

01:45:21.480 The histogram for these weights are gonna be all messed up as well.

01:45:25.280 And there's a lot of asymmetry.

01:45:27.120 And then if we look here,

01:45:28.200 I suspect it's all gonna be also pretty messed up.

01:45:30.680 So you see there's a lot of discrepancy in how fast

01:45:34.920 these layers are learning.

01:45:36.440 And some of them are learning way too fast.

01:45:38.560 So negative one, negative 1.5,

01:45:41.400 those are very large numbers in terms of this ratio.

01:45:44.240 Again, you should be somewhere around negative three

01:45:45.920 and not much more about that.

01:45:47.640 So this is how mis-calibrations of your neural nets

01:45:51.560 are going to manifest.

01:45:52.920 And these kinds of plots here

01:45:54.320 are a good way of bringing those mis-calibrations

01:45:59.320 to your attention and so you can address them.

01:46:04.400 Okay, so far we've seen that

01:46:05.680 when we have this linear 10-H sandwich,

01:46:08.200 we can actually precisely calibrate the gains

01:46:10.200 and make the activations, the gradients and the parameters

01:46:13.360 and the updates all look pretty decent.

01:46:15.880 But it definitely feels a little bit like balancing

01:46:19.080 of a pencil on your finger.

01:46:21.240 And that's because this gain has to be very precisely

01:46:24.300 calibrated.

01:46:25.920 So now let's introduce Bachelormalization layers

01:46:27.620 into the mix.

01:46:30.080 Let's see how that pops, fix the problem.

01:46:32.920 So here, I'm going to take the Bachelorm 1D class

01:46:38.400 and I'm going to start placing it inside.

01:46:41.120 And as I mentioned before,

01:46:42.880 the standard typical place you would place it

01:46:45.080 is between the linear layer.

01:46:47.040 So right after it, but before the non-linearity.

01:46:49.240 But people have definitely played with that.

01:46:51.240 And in fact, you can get very similar results,

01:46:54.160 even if you place it after the non-linearity.

01:46:56.340 And the other thing that I wanted to mention

01:46:59.260 is it's totally fine to also place it at the end

01:47:01.740 after the last linear layer and before the loss function.

01:47:04.780 So this is potentially fine as well.

01:47:07.420 And in this case, this would be output,

01:47:11.260 would be vocab size.

01:47:12.580 Now because the last layer is Bachelorm,

01:47:17.020 we would not be changing to weight

01:47:18.700 to make the softmax less confident.

01:47:20.620 We'd be changing the gamma.

01:47:23.060 Because gamma, remember, in the Bachelorm

01:47:25.540 is the variable that multiplicatively interacts

01:47:28.180 with the output of that normalization.

01:47:30.140 So we can initialize this sandwich now.

01:47:35.820 We can train and we can see that the activations

01:47:39.460 are going to of course look very good.

01:47:41.700 And they are going to necessarily look good

01:47:44.020 because now before every single 10H layer,

01:47:46.820 there is a normalization in the Bachelorm.

01:47:49.260 So this is unsurprisingly, all looks pretty good.

01:47:53.020 It's going to be standard deviation of roughly 0.65,

01:47:55.740 2% and roughly equal standard deviation

01:47:58.140 throughout the entire layers.

01:47:59.740 So everything looks very homogeneous.

01:48:02.740 The gradients look good, the weights look good

01:48:06.860 and their distributions.

01:48:09.220 And then the updates also look pretty reasonable.

01:48:14.180 We're going above negative three a little bit,

01:48:16.300 but not by too much.

01:48:17.900 So all the parameters are training

01:48:19.660 roughly the same rate here.

01:48:24.660 But now what we've gained is we are going to be slightly less

01:48:31.260 brittle with respect to the gain of these.

01:48:34.100 So for example, I can make the gain be say 0.2 here,

01:48:39.100 which is much slower than what we had with the 10H.

01:48:42.860 But as we'll see, the activations will actually

01:48:44.740 be exactly unaffected.

01:48:46.740 And that's because, again, this explicit normalization,

01:48:49.700 the gradients are going to look OK.

01:48:51.460 The weight gradients are going to look OK.

01:48:53.900 But actually, the updates will change.

01:48:56.980 And so even though the forward and backward pass

01:48:59.980 to a very large extent look OK, because of the backward

01:49:02.500 pass of the batch norm and how the scale of the incoming

01:49:05.380 activations interacts in the batch norm and its backward pass,

01:49:10.340 this is actually changing the scale

01:49:14.100 of the updates on these parameters.

01:49:16.340 So the gradients of these weights are affected.

01:49:19.620 So we still don't get it completely

01:49:21.500 free pass to pass in arbitrary weights here.

01:49:24.980 But everything else is significantly more robust

01:49:28.700 in terms of the forward, backward, and the weight gradients.

01:49:32.900 It's just that you may have to retune your learning rate

01:49:35.500 if you are changing sufficiently the scale of the activations

01:49:39.500 that are coming into the batch norms.

01:49:41.540 So here, for example, we changed the gains

01:49:45.180 of these linear layers to be greater.

01:49:47.180 And we're seeing that the updates are coming out lower

01:49:49.220 as a result.

01:49:51.700 And then finally, we can also, if we are using batch

01:49:53.980 norms, we don't actually need to necessarily--

01:49:56.700 let me reset this to 1, so there's no gain.

01:49:59.180 We don't necessarily even have to normalize back fan-in sometimes.

01:50:03.540 So if I take out the fan-in-- so these are just now

01:50:06.460 random Gaussian-- we'll see that because of batch

01:50:09.260 norm, this will actually be relatively well behaved.

01:50:11.780 So this is a look, of course, in the forward pass look good.

01:50:17.580 The gradients look good.

01:50:19.860 The weight updates look OK, a little bit of fat tails

01:50:24.780 in some of the layers.

01:50:26.660 And this looks OK as well.

01:50:29.300 But as you can see, we're significantly below negative 3,

01:50:33.540 so we'd have to bump up the learning rate of this batch norm

01:50:36.540 so that we are training more properly.

01:50:39.060 And in particular, looking at this roughly

01:50:41.100 looks like we have to 10x the learning rate to get

01:50:43.500 to about 1 in negative 3.

01:50:46.700 So we come here, and we would change this

01:50:48.460 to be update of 1.0.

01:50:51.380 And if I reinitialize--

01:50:52.940 then we'll see that everything still, of course, looks good.

01:51:02.380 And now we are roughly here.

01:51:04.180 And we expect this to be an OK training run.

01:51:07.140 So long story short, we are significantly more robust

01:51:09.500 to the gain of these linear layers,

01:51:11.700 whether or not we have to apply the fan in.

01:51:13.980 And then we can change the gain.

01:51:16.300 But we actually do have to worry a little bit about the update

01:51:19.780 scales and making sure that the learning rate is properly

01:51:22.860 calibrated here.

01:51:24.060 But the activations of the forward backward pass

01:51:26.700 and the updates are all looking significantly more

01:51:29.540 well behaved, except for the global scale that is potentially

01:51:32.620 being adjusted here.

01:51:34.660 OK, so now let me summarize.

01:51:36.500 There are three things I was hoping

01:51:37.940 to achieve with this section.

01:51:39.420 Number one, I wanted to introduce you

01:51:41.100 to Bachelors normalization, which

01:51:42.420 is one of the first modern innovations

01:51:44.300 that we're looking into that helped stabilize very deep

01:51:47.620 neural networks and their training.

01:51:49.660 And I hope you understand how the Bachelors normalization works

01:51:52.580 and how it would be used in a neural network.

01:51:56.100 Number two, I was hoping to pytorchify some of our code

01:51:59.300 and wrap it up into these modules,

01:52:01.820 so like linear, Bachelors 1D, 10H, et cetera.

01:52:04.700 These are layers or modules.

01:52:06.780 And they can be stacked up into neural nets,

01:52:09.060 like LEGO building blocks.

01:52:10.940 And these layers actually exist in pytorch.

01:52:14.980 And if you import torch and then then you can actually--

01:52:17.940 the way I've constructed it, you can simply just

01:52:20.100 use pytorch by prepending an end up

01:52:22.940 to all these different layers.

01:52:25.180 And actually, everything will just

01:52:27.100 work because the API that I've developed here

01:52:29.820 is identical to the API that pytorch uses.

01:52:32.740 And the implementation also is basically, as far as I'm aware,

01:52:36.140 identical to the one in pytorch.

01:52:38.380 And number three, I try to introduce you

01:52:39.940 to the diagnostic tools that you would use

01:52:42.460 to understand whether your neural network is in a good state

01:52:45.260 dynamically.

01:52:46.300 So we are looking at the statistics and histograms

01:52:48.900 and activation of the forward pass activation,

01:52:52.220 the backward pass gradients.

01:52:54.060 And then also, we're looking at the weights

01:52:55.740 that are going to be updated as part of stochasticarity

01:52:58.180 in the sun.

01:52:58.980 And we're looking at their means, standard deviations,

01:53:01.100 and also the ratio of gradients to data,

01:53:04.500 or even better, the updates to data.

01:53:07.780 And we saw that, typically, we don't actually

01:53:09.860 look at it as a single snapshot frozen in time

01:53:12.060 at some particular iteration.

01:53:13.740 Typically, people look at this as over time,

01:53:16.460 just like I've done here.

01:53:17.780 And they look at these update-to-data ratios,

01:53:19.660 and they make sure everything looks OK.

01:53:21.580 And in particular, I said that, what do you negative three,

01:53:25.340 or basically negative three on the log scale,

01:53:27.540 is a good, rough heuristic for what you want this ratio to be.

01:53:31.820 And if it's way too high, then probably the learning rate

01:53:34.380 or the updates are a little too big.

01:53:36.540 And if it's way too small, then the learning rate

01:53:37.940 is probably too small.

01:53:39.540 So that's just some of the things

01:53:40.820 that you may want to play with when you try to get your neural

01:53:43.580 work to work very well.

01:53:46.540 Now, there's a number of things I did not try to achieve.

01:53:49.260 I did not try to beat our previous performance, as an example,

01:53:51.980 by introducing the Bachelor of Lawyer.

01:53:54.060 Actually, I did try, and I found that I used the learning

01:53:57.620 rate finding mechanism that I've described before.

01:53:59.940 I tried to train the Bachelor of Lawyer, a Bachelor of Neural

01:54:02.460 Nut.

01:54:03.260 And I actually ended up with results that are very, very

01:54:05.980 similar to what we've obtained before.

01:54:08.260 And that's because our performance now is not bottlenecked

01:54:11.460 by the optimization, which is what Bachelorm is helping with.

01:54:15.140 The performance of the stage is bottlenecked by--

01:54:17.420 what I suspect is the context length of our context.

01:54:21.860 So currently, we are taking three characters

01:54:23.620 to predict the fourth one.

01:54:24.740 And I think we need to go beyond that.

01:54:26.180 And we need to look at more powerful architectures

01:54:28.380 like recurrent neural networks and transformers

01:54:30.900 in order to further push the log probabilities that we're

01:54:34.020 achieving on this data set.

01:54:36.500 And I also did not try to have a full explanation of all

01:54:40.940 of these activations, the gradients, and the backward pass

01:54:43.540 and the statistics of all these gradients.

01:54:45.540 And so you may have found some of the parts here unintuitive.

01:54:47.940 And maybe you were slightly confused about, OK,

01:54:49.940 if I change the gain here, how come

01:54:52.580 that we need a different learning rate?

01:54:54.140 And I didn't go into the full detail

01:54:55.540 because you'd have to actually look at the backward pass

01:54:57.420 of all these different layers and get

01:54:59.020 an intuitive understanding of how that works.

01:55:01.260 And I did not go into that in this lecture.

01:55:04.020 The purpose really was just to introduce you

01:55:05.740 to the diagnostic tools and what they look like.

01:55:08.260 But there's still a lot of work remaining

01:55:09.860 on the intuitive level to understand the initialization,

01:55:12.700 the backward pass and all of that interacts.

01:55:15.780 But you shouldn't feel too bad because, honestly, we

01:55:18.340 are getting to the cutting edge of where the field is.

01:55:22.900 We certainly haven't, I would say, solved initialization.

01:55:25.580 And we haven't solved back propagation.

01:55:28.180 And these are still very much an active area of research.

01:55:30.740 People are still trying to figure out

01:55:31.940 where's the best way to initialize these networks, what

01:55:33.980 is the best update rule to use, and so on.

01:55:37.380 So none of this is really solved.

01:55:38.860 And we don't really have all the answers to all these cases.

01:55:44.020 But at least we're making progress.

01:55:46.220 And at least we have some tools to tell us

01:55:48.260 whether or not things are on the right track for now.

01:55:51.700 So I think we've made a positive progress in this lecture.

01:55:54.980 And I hope you enjoyed that.

01:55:56.020 And I will see you next time.