Today we will retrace lesson 13-14's notebook that "builds up" pytorch abstractions from scratch. As a first step we'll rederive everything in hardcore numpy (maybe hardcore should be reserved for C). Then we'll start building the abstractions.
First up we load mnist
data:
from pathlib import Path
from fastdownload import FastDownload # nice helper for caching downloads
import gzip, pickle
MNIST_URL='https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/data/mnist.pkl.gz?raw=true'
path_gz = FastDownload().download(MNIST_URL)
with gzip.open(path_gz, 'rb') as f:
((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
Forward pass
We'll build a simple MLP with one hidden layer:
dim_in = x_train.shape[1]
dim_h = 50
dim_out = max(y_train) + 1
W1 = np.random.randn(dim_h, dim_in)
b1 = np.random.randn(dim_h)
W2 = np.random.randn(dim_out, dim_h)
b2 = np.random.randn(dim_out)
# Linear layer op
def lin(x, W, b):
return x @ W.T + b
x = x_train[:50]
h = lin(x, W1, b1).clip(min=0)
out = lin(h, W2, b2)
y_pred = out.argmax(axis=1)
Weight initialization is something easy to forget about, as it is done by pytorch behind the scenes anytime you initialize a layer with parameters. Here we just use normal distribution. The net isn't deep so we won't have any trouble. out
contains the unnormalized logits of the predictions, and y_pred
the predicted labels.
To calculate the loss we use the cross-entropy, or negative log-likelihood when using labels:
loss = -(out - logsumexp(out, 1))[range(bs), y].mean()
Backward pass
After the forward pass we have to backpropagate the gradient. Let's start with dout
. To simplify our life we ignore the batches. Something about calculating gradients in neural networks can be really confusing. I think the way to do it is first to mark the right intermediate outputs as variables, and when calculating gradient of loss with respect to a var, use chain rule separately for each function/variable that depends on that var and add it all up.
Loss function
The loss is for a sample \((x,y)\) is
where \(p\) is the predicted probability for label \(y\). Then,
Some care is needed as out
holds the unnormalized logits. From the code we see that
and so
combined with the chain rule for vector functions
which we can write in vector notation as
or just notice that the left side is just the logits so it's really \(\log p - \mathbb 1_y\)
Linear layer
The linear layer formula is \(out=in\cdot W^T+b\) . Continuing with the chain rule we have
The k'th element of \(out\) is the scalar product of \(in\) with the k'th column of \(W^T\) or the k'th row of \(W\), ie
the derivative of one output element with respect to a weight matrix element is then
and the derivative of the loss with respect to the matrix element is
Which we can write as the outer product of the two vectors
the bias is simpler, we just have
so it's just dout
. Finally, for \(in\),
ReLU
ReLU is quite easy, and since it's an element wise op we can skip all the sums. Let's say \(\tilde{h}\) is the hidden activation before ReLU, ie \(h_{i}=\max(\tilde h_i, 0)\). What's the derivative?
the derivative doesn't exist at 0, but we don't care about that.
Batches
One last detail regarding batches. The loss for a batch is the mean of the separate losses, so to backpropagate over the batch we just take the mean of the gradients for the different samples. I have a feeling that this is wasteful though, and maybe we can do better?
In code
The whole process looks like
def backward(W1, b1, W2, b2, h, logits, y):
dout = np.exp(logits)
dout[range(logits.shape[0]), y] -= 1
dh, dW2, db2 = dlin(dout, h, W2, b2)
dh_hat = dh.copy()
dh_hat[h == 0] = 0
_, dW1, db1 = dlin(dh_hat, x, W1, b1)
return dW1.mean(0), db1.mean(0), dW2.mean(0), db2.mean(0)
Optimizing
We can also write an optimization loop
epochs = 10
bs = 50
lr = 1e-1
for e in range(epochs):
for i in range(0, len(x_train), bs):
s = slice(i, min(len(x_train), i+bs))
x = x_train[s]
y = y_train[s]
h, logits, loss = forward(x, y, W1, b1, W2, b2)
dW1, db1, dW2, db2 = backward(W1, b1, W2, b2, h, logits, y)
W1, b1, W2, b2 = map(lambda x: x[0]-lr*x[1], [(W1, dW1), (b1, db1), (W2, dW2), (b2, db2)])
print(f"loss={loss}")
print(f"acc={acc(logits, y)}")
which gets us such and such results
loss=0.6875153176459544
acc=0.76
loss=0.5454049286257023
acc=0.76
loss=0.4964384126281533
acc=0.78
loss=0.45487948890913216
acc=0.82
Tomorrow we'll work on the PyTorch abstractions.