Why is linear algebra essential in machine learning?

Tivadar Danka small portrait Tivadar Danka
Effect of a linear transformation on a grid

Understanding math will make you a better engineer.

So, I am writing the best and most comprehensive book about it.

For every topic in computer science, there is an XKCD comic that summarizes it perfectly. My all-time favorite one is the following.

2021-08-linear-algebra-xkcd-machine_learning.png

All jokes aside, linear algebra plays a crucial part in machine learning. From classical algorithms to state-of-the-art, it is everywhere. This post is about why.

Data = vectors

As you probably know, data is represented by vectors.

Data points are just tuples of measurements. In their raw form, they are hardly useful for us. They are just blips in space.

data as a set of points

Without operations and transformations, it is difficult to predict class labels or do anything else.

Vector spaces provide a mathematical structure where operations naturally arise. Instead of a blip, just imagine an arrow pointing to the data point from a fixed origin.

data as vectors

On vectors, we can easily define operations using our geometric intuition. Addition is translation, while scalar multiplication is scaling.

operations on vectors

Why do we even need to add data points together?

To transform raw data into a form that can be used for predictive purposes. Raw data can have a really complicated structure, and we aim to simplify it as much as possible. For instance, raw data is often standardized by subtracting the mean of features and scaling with their variance. This way, each feature is of the same magnitude, making sure that none of them are dominated by the ones on the largest scale.

dataset standardization

Aside from the operations, vector spaces give rise to linear transformations. They are essentially distortions of the vectors space, yielding a new set of features for our dataset. We are going to take a detailed look at them below.

Machine learning algorithms are functions

In essence, a machine learning model works by doing the following two things.

  1. Find an alternative representation of the data.
  2. Make decisions based on this representation.

Linear algebra plays a role in describing and manipulating those representations, may it be the raw data or a high-level feature set.

Regardless of the features, data points are given by vectors. Finding more descriptive representations is the same as finding functions mapping between vector spaces. The simplest ones are the linear transformations given by matrices.

f(x)=Ax,xRn,ARn.f(x) = Ax, \quad x \in \mathbb{R}^n, \quad A \in \mathbb{R}^n.

Why do we love linear transformations? First, they are easy to work with and fast to compute. Moreover, combined with simple nonlinear functions, they can create expressive models.

Linear transformations = transformations of data

How does a linear transformation transform the data? To see this, the only thing we need to notice is that the images of the basis vectors completely determine a given linear transformation.

Linearity means that the order of addition, scalar multiplication, and function application can be changed. So, the image of every vector is a linear combination of the images of the basis vectors.

To be mathematically precise, this is what happens:

f(x)=f(i=1nxiei)=i=1nxif(ei).\begin{align*} f(x) &= f\bigg( \sum_{i=1}^{n} x_i e_i \bigg) \\ &= \sum_{i=1}^{n} x_i f(e_i). \end{align*}

We can visualize this for linear transformations on the two-dimensional plane.

effect of a linear transformation

As you can see, the images of the basis vectors form a parallelogram. (Whose sides can fall onto a single line.) From yet another perspective, this is the same as distorting the grid determined by the basis vectors.

linear transformation as grid distortion

Finding more descriptive representations

How can a linear transformation help to find better representations of the data?

Think about PCA, which finds features with no redundancy. This is done by a simple linear transformation. (If you are not familiar with how PCA works, check out my recent article about it!)

So, linear transformations give rise to new features. How descriptive can these be?

For instance, in classification tasks, we want each high-level feature to represent the probability of belonging to a given class. Are linear transformations enough to express this?

Almost.

Any true underlying relationship between data and class label can be approximated by composing linear transformations with certain nonlinear functions (such as the Sigmoid or ReLU).

This is formally expressed by the Universal Approximation Theorem.

This is why machine learning is just a pile of linear algebra, stirred until it looks right. (Not just accordingly to XKCD.) In summary, linear transformations are

  • simple to work with,
  • fast to compute,
  • and can be used to build powerful models.

Having a deep understanding of math will make you a better engineer.

I want to help you with this, so I am writing a comprehensive book that takes you from high school math to the advanced stuff.
Join me on this journey and let's do this together!