Taking the derivative of the loss function of a neural network can be quite cumbersome. Even taking the derivative of a single layer in a neural network often results in expressions cluttered with indices. In this post I’d like to show an index-free way to do it.

Consider the map {\sigma(Wx+b)} where {W\in{\mathbb R}^{m\times n}} is the weight matrix, {b\in{\mathbb R}^{m}} is the bias, {x\in{\mathbb R}^{n}} is the input, and {\sigma} is the activation function. Usually {\sigma} represents both a scalar function (i.e. mapping {{\mathbb R}\mapsto {\mathbb R}}) and the function mapping {{\mathbb R}^{m}\rightarrow{\mathbb R}^{m}} which applies {\sigma} in each coordinate. In training neural networks, we would try to optimize for best parameters {W} and {b}. So we need to take the derivative with respect to {W} and {b}. So we consider the map

\displaystyle  \begin{array}{rcl}  G(W,b) = \sigma(Wx+b). \end{array}

This map {G} is a concatenation of the map {(W,b)\mapsto Wx+b} and {\sigma} and since the former map is linear in the joint variable {(W,b)}, the derivative of {G} should be pretty simple. What makes the computation a little less straightforward is the fact the we are usually not used to view matrix-vector products {Wx} as linear maps in {W} but in {x}. So let’s rewrite the thing:

There are two particular notions which come in handy here: The Kronecker product of matrices and the vectorization of matrices. Vectorization takes some {W\in{\mathbb R}^{m\times n}} given columnwise {W = [w_{1}\ \cdots\ w_{n}]} and maps it by

\displaystyle  \begin{array}{rcl}  \mathrm{Vec}:{\mathbb R}^{m\times n}\rightarrow{\mathbb R}^{mn},\quad \mathrm{Vec}(W) = \begin{bmatrix} w_{1}\\\vdots\\w_{n} \end{bmatrix}. \end{array}

The Kronecker product of matrices {A\in{\mathbb R}^{m\times n}} and {B\in{\mathbb R}^{k\times l}} is a matrix in {{\mathbb R}^{mk\times nl}}

\displaystyle  \begin{array}{rcl}  A\otimes B = \begin{bmatrix} a_{11}B & \cdots &a_{1n}B\\ \vdots & & \vdots\\ a_{m1}B & \cdots & a_{mn}B \end{bmatrix}. \end{array}

We will build on the following marvelous identity: For matrices {A}, {B}, {C} of compatible size we have that

\displaystyle  \begin{array}{rcl}  \mathrm{Vec}(ABC) = (C^{T}\otimes A)\mathrm{Vec}(B). \end{array}

Why is this helpful? It allows us to rewrite

\displaystyle  \begin{array}{rcl}  Wx & = & \mathrm{Vec}(Wx)\\ & = & \mathrm{Vec}(I_{m}Wx)\\ & = & \underbrace{(x^{T}\otimes I_{m})}_{\in{\mathbb R}^{m\times mn}}\underbrace{\mathrm{Vec}(W)}_{\in{\mathbb R}^{mn}}. \end{array}

So we can also rewrite

\displaystyle  \begin{array}{rcl}  Wx +b & = & \mathrm{Vec}(Wx+b )\\ & = & \mathrm{Vec}(I_{m}Wx + b)\\ & = & \underbrace{ \begin{bmatrix} x^{T}\otimes I_{m} & I_{m} \end{bmatrix} }_{\in{\mathbb R}^{m\times (mn+m)}}\underbrace{ \begin{bmatrix} \mathrm{Vec}(W)\\b \end{bmatrix} }_{\in{\mathbb R}^{mn+m}}\\ &=& ( \underbrace{\begin{bmatrix} x^{T} & 1 \end{bmatrix}}_{\in{\mathbb R}^{1\times(n+1)}}\otimes I_{m}) \begin{bmatrix} \mathrm{Vec}(W)\\b \end{bmatrix}. \end{array}

So our map {G(W,b) = \sigma(Wx+b)} mapping {{\mathbb R}^{m\times n}\times {\mathbb R}^{m}\rightarrow{\mathbb R}^{m}} can be rewritten as

\displaystyle  \begin{array}{rcl}  \bar G( \begin{pmatrix} \mathrm{Vec}(W)\\b \end{pmatrix}) = \sigma( ( \begin{bmatrix} x^{T} & 1 \end{bmatrix}\otimes I_{M}) \begin{bmatrix} \mathrm{Vec}(W)\\b \end{bmatrix}) \end{array}

mapping {{\mathbb R}^{mn+m}\rightarrow{\mathbb R}^{m}}. Since {\bar G} is just a concatenation of {\sigma} applied coordinate wise and a linear map, now given as a matrix, the derivative of {\bar G} (i.e. the Jacobian, a matrix in {{\mathbb R}^{m\times (mn+m)}}) is calculated simply as

\displaystyle  \begin{array}{rcl}  D\bar G( \begin{pmatrix} \mathrm{Vec}(W)\\b \end{pmatrix}) & = & D\sigma(Wx+b)( \begin{bmatrix} x^{T} & 1 \end{bmatrix}\otimes I_{M})\\ &=& \underbrace{\mathrm{diag}(\sigma'(Wx+b))}_{\in{\mathbb R}^{m\times m}}\underbrace{( \begin{bmatrix} x^{T} & 1 \end{bmatrix}\otimes I_{M})}_{\in{\mathbb R}^{m\times(mn+m)}}\in{\mathbb R}^{m\times(mn+m)}. \end{array}

While this representation of the derivative of a single layer of a neural network with respect to its parameters is not particularly simple, it is still index free and moreover, straightforward to implement in languages which provide functions for the Kronecker product and vectorization. If you do this, make sure to take advantage of sparse matrices for the identity matrix and the diagonal matrix as otherwise the memory of your computer will be flooded with zeros.

Now let’s add a scalar function {L} (e.g. to produce a scalar loss that we can minimize), i.e. we consider the map

\displaystyle  \begin{array}{rcl}  F( \begin{pmatrix} \mathrm{Vec}(W)\\b \end{pmatrix}) = L(G(Wx+b)) = L(\bar G( \begin{pmatrix} \mathrm{Vec}(W)\\b \end{pmatrix}). \end{array}

The derivative is obtained by just another application of the chain rule:

\displaystyle  \begin{array}{rcl}  DF( \begin{pmatrix} \mathrm{Vec}(W)\\b \end{pmatrix}) = DL(G(Wx+b))D\bar G( \begin{pmatrix} \mathrm{Vec}(W)\\b \end{pmatrix}). \end{array}

If we want to take gradients, we just transpose the expression and get

\displaystyle  \begin{array}{rcl}  \nabla F( \begin{pmatrix} \mathrm{Vec}(W)\\b \end{pmatrix}) &=& D\bar G( \begin{pmatrix} \mathrm{Vec}(W)\\b \end{pmatrix})^{T} DL(G(Wx+b))^{T}\\ &=& ([x^{T}\ 1]\otimes I_{m})^{T}\mathrm{diag}(\sigma'(Wx+b))\nabla L(G(Wx+b))\\ &=& \underbrace{( \begin{bmatrix} x\\ 1 \end{bmatrix} \otimes I_{m})}_{\in{\mathbb R}^{(mn+m)\times m}}\underbrace{\mathrm{diag}(\sigma'(Wx+b))}_{\in{\mathbb R}^{m\times m}}\underbrace{\nabla L(G(Wx+b))}_{\in{\mathbb R}^{m}}. \end{array}

Note that the right hand side is indeed vector in {{\mathbb R}^{mn+m}} and hence, can be reshaped to a tupel {(W,b)} of an {m\times n} matrix and an {m} vector.

A final remark: the Kronecker product is related to tensor products. If {A} and {B} represent linear maps {X_{1}\rightarrow Y_{1}} and {X_{2}\rightarrow Y_{2}}, respectively, then {A\otimes B} represents the tensor product of the maps, {X_{1}\otimes X_{2}\rightarrow Y_{1}\otimes Y_{2}}. This relation to tensor products and tensors explains where the tensor in TensorFlow comes from.

Advertisements

I can’t claim that I am an expert in machine learning. I’d rather say that I am merely a tourist in this area. Anyway, here is a small piece of thought on how (supervised) machine learning imitates human learning.

What are some features of human learning? Ideally, humans aim to understand a subject. To achieve this, they study examples, try to make their own deductions, do experiments, make predictions and test them. The overall goal is to get to the heart of things.

What are features of so called supervised machine learning: The methods get training data, i.e. pairs in input and output that match. The foremost goal of the method is to perform good on test data, i.e. to produce correct output to an input the method hasn’t seen before. In practice, one sets up a fairly general model (such as a neural network or a kernelized support vector machine) and often does as little modeling of the task at hand as possible.

This does not sound as though supervised machine learning and human learning are the same or even related. Their goals and methods are genuinely different.

But let us look at how human learn for a very specific task: Preparing for an exam. Recently I had to prepare several larger exams in mathematics for engineers, each with hundred plus students and got to think how they approach the task of learning. When the exam comes closer, the interactions with the students get more frequent. I had a large “ask anything” meeting, I had students coming to office hours, and I had “digital office hours” where the students could ask question via a messenger in a chat room. So I had quite some interactions and could get a little insight into their way of learning, into their problems and progress.

Here are some observations of how the students tried to learn: The question I got were mostly about the exercises we had handed out (in other words, the students asked for information on the “training data”). They were studying heavy on these exercises, barely using the textbook or their lecture notes to look up theorems or definitions (in other words, some were working “model free” or with a “general purpose model” which says something like “do computations following general rules”). They work with the underlying assumption that the exam is made up of questions similar to the exercises (and actually, this is a reasonable assumption – I announced this in class) (in other words, the test data comes from the same distribution as the training data).

Viewed like this, learning of humans (for an exam, that is) and machine learning sound much more similar. And the similarities do not end here. Also some known problems with machine learning methods can be observed with the students: Students get stuck in local minima (they reach a point where further improvement in impossible by revising the seen data – even though they could, in principle, learn more from the given exercises, they keep going the known paths, practicing the known computations, not learning new techniques). Students overfit to the training data (on the test data, aka the exam, they face new problems and oftentimes apply the learned methods to tasks where they don’t work, getting wrong results which would be true if the problem would be a little different). The trained students are vulnerable to adversarial attacks (for every problem I posed as exercises I could make a slight change that would confuse most students). Also, similar to recent observations in machine learning, overparametrization helps to avoid overfitting and overparametrization helps to avoid spurious local valleys, i.e. when the students have more techniques at hand, which is related to a more flexible machine learning method, they do better on unseen data and do not get stuck at bad local minima where no improvement is possible.

Granted, some observation are a kind of a stretch, but still, in conclusion, I’d advocate to replace the term “machine learning” with machine cramming (the German version would be maschinelles Büffeln or maschinelles Pauken).