Exploring Gradient Descent

Week 8 •

Recently, I've started trying to wrap my head around a new-ish rendering technique called "Gaussian Splatting". There's quite a bit of math involved, and I've been struggling to understand it in the context of rendering, so I decided to go back to the basics and make sure my understanding of the underlying math is solid, which is what we'll be talking about this week.

The Basics of Gaussian Splatting

Gaussian splatting is a rendering technique that's used to create realistic-looking 3D scenes from a bunch of points. Normally, when we render a bunch of points like this, we just draw them as dots on the screen (or, more likely, as a bunch of tiny squares or triangles). This works ok, but it mostly just looks like a bunch of dots, and you have to squint to see the 3D shape that the points are supposed to represent. Gaussian splatting adds a bunch of metadata to each point, allowing them to be stretched and squished in a way that, when rendered, gives a much more accurate representation of the 3D shape.

The math behind this is a little over my head for now, but the basic idea is that you start with a bunch of points and render them to the screen as spheres. Then, you compare the result to a ground truth image, and use gradient descent to adjust the parameters of the spheres to make the rendered image look more like the ground truth.

You repeat this process several times and from a bunch of different angles, and eventually the parameters of the spheres converge to a set of values which give a better overall shape to the scene. If you blend the these spheres together, you end up with a very realistic 3D scene, allowing you to fly around in real time and get a pretty accurate view from any angle, which we call "novel view synthesis".

Here's a few images to give you an idea of what I'm talking about, taken from Hugging Face's blog post on the topic:

ImageA raw point cloud

ImageUnblended gaussians

ImageGaussians blended together

So how do you decide how the parameters change on each step? That's where gradient descent comes in.

Gradient Descent

Gradient Descent is a method for finding a local minimum of a function. For a simple function, we might be able to graph it and just look at the graph to find the minimum. For more complex functions or functions which have many inputs, this isn't possible, because we don't have a good way to visualize functions with more than 2 or 3 inputs.

All is not lost, however! Even if we can't visualize the function and see where the minimum is, we can still find the minimum by applying gradient descent. The idea goes like this: we start at some random point, and take small steps "downhill" from that point. If we take enough steps, we end up finding a minimum point. The Wikipedia article gives a good analogy:

Imagine you're standing on a mountain, and you want to get to the bottom. There's really thick fog all around you, so you can only see the ground directly beneath your feet. You look at the ground and decide which direction goes down the steepest. Then you take a small step in that direction. You repeat this process, and eventually you'll find yourself at the bottom, or at least in a valley.

Derivatives

When we want to find the "slope" of a function (which we'll call the "gradient"), we can use the concept of a derivative. The derivative of a function tells us how the function's output changes relative to changes in its input. For example, if a function's derivative is positive at a certain point, that means that if we increase the input, the output will also increase. If the derivative is negative, then increasing the input will decrease the output.

Since derivatives are a fundamental concept in calculus, I won't go into too much detail here, but it's important to understand this concept in order to understand gradient descent.

Partial Derivatives

If a function has more than one input, then we need to think about the derivative in terms of each input separately. That is, we need to isolate a single input variable and ask the question:

"How does the output of the function change as this one input changes, if all other inputs stay the same?".

To do this, we take the derivative "with respect to" a single input. This is called a "partial derivative", and is written as df/dx or df/dy, where f is the function and x and y are the inputs.

Let's look at the following function:

z = f(x, y) = x * y

Here, z is the result of the function f(x, y). If we take the partial derivative of f with respect to x, we're asking the question: "How does z change as x changes, if y stays the same?" The answer is that a small change in x will cause z to change by y times that small change in x. So the partial derivative of f with respect to x is y. Similarly, the partial derivative of f with respect to y is x, because a small change in y will cause z to change by x times that small change in y.

So the derivative of f with respect to x is:

df/dx = y

and the derivative of f with respect to y is:

df/dy = x

Let's put some numbers in there to make it more concrete:

x = -2
y = 3
f(2, 3) = -2 * 3 = -6

if we increase x by 1 and keep y the same, then:

x = -1
y = 3
f(3, 3) = -1 * 3 = -3

In other words, for every 1 unit increase in x, the output of f increases by 3, which is the value of y.

Similarly, if we increase y by 1 and keep x the same, then:

x = -2
y = 4
f(2, 4) = -2 * 4 = -8

In other words, for every 1 unit increase in y, the output of f decreases by 2, which is the value of x.

Putting it in Context

Phew, ok, math is over now. We have all the pieces we need to understand gradient descent, so I'll just be taking conceptually from here on out.

For gradient descent, we need to create a function that represents how close we are to our "goal". This function is called the "loss function", and it's usually defined as the difference between the output of our system and the ground truth. We want to minimize this difference by making the output of our function as close to the ground truth as possible, so we want to minimize the loss function. For example, in the case of gaussian splatting, the loss function might be the difference between the color of our rendered pixels and the pixels in our ground truth image.

To do that, we'll start at some random point (i.e. we give all inputs random values), then we'll take the partial derivative of the loss function with respect to each input. This will give us the slope of the loss function at that point for each input. We can then slightly shift each input in the opposite direction of the slope, which will have the effect of moving us down the slope of the loss function. We repeat this process many times, and eventually we'll end up at a point where the loss function is at a local minimum.

And that's it! That's the basic idea behind gradient descent. We...descend the gradient of the loss function. For our gaussian splatting example, we'd be adjusting the parameters of the spheres to make the rendered image look more and more like the ground truth image. We do this from a bunch of different angles, and eventually the parameters converge to get the results I showed at the start.

Wrapping Up

I know this wasn't the most exciting topic, but I wanted to make sure I had a solid understanding of the math behind gradient descent before I move on to more complex topics, like implementing gaussian splatting.

If you're interested in learning more about this topic, here are a few resources that I found helpful:

I was specifically looking at gradient descent in the context of machine learning and neural networks, since that's where it's most often used. I went through the process of building a couple of simple neural networks from scratch, which was a great way to solidify my understanding of the topic. I highly recommend those last two videos if you're interested in the nitty-gritty of how neural networks "learn".

Hopefully I'll have something prettier and less math-y to talk about next week, but I'm glad I took the time to make sure I understand this topic well.

See you next week!

Bonus: The Chain Rule

Because the chain rule isn't strictly necessary for understanding gradient descent, I moved this section to the end, but it's a pretty neat concept, so I wanted to include it.

Ok, so as explained above, we know how to take the derivative of a multi-variable function with respect to one of its inputs, but what if the function is made up of other functions? For example, let's look at the following system of equations:

f(x, y) = h(g(x, y))
h(z) = sin(z)
g(x, y) = x * y

Here, we have a function f(x, y) that's defined by two other functions h(z) and g(x, y). Now imagine we want to find out how much the output of f changes as we change x. This is where the chain rule comes in handy.

The chain rule tells us that:

The derivative of f with respect to x is the derivative of h with respect to g times the derivative of g with respect to x.

Symbolically, this is written as:

df/dx = dh/dg * dg/dx

In other words, we can break down the derivative of f into the derivatives of h and g and multiply them together.

Let's looks at our example again:

f(x, y) = h(g(x, y))
h(z) = sin(z)
g(x, y) = x * y

We can take the derivative of h with respect to g:

dh/dg = cos(g)

And the derivative of g(x, y) with respect to x:

dg/dx = y

So the derivative of f(x, y) with respect to x is:

df/dx
  = dh/dg        * dg/dx
  = cos(g(x, y)) * y
  = cos(x * y)   * y

What this tells us is that if we increase x by a small amount, the output of f will change by cos(x * y) * y.

This is a very powerful concept. It's used all over the place in machine learning and other fields, where we have a set of inputs and then a set of functions that transform those inputs into a final output. The chain rule allows us to walk through our system of functions backwards, finding the derivative of the final output with respect to each input by chaining the derivatives of each function from the overall system together. That means that we only ever need to think about the derivative of a single function at a time. We can even automate the process of finding the derivative of a complex system of functions, as long as we can find the derivative of each individual function in that system.

This is exactly what the Micrograd library I mentioned in the "wrapping up" section does, as well as how PyTorch's more robust automatic differentiation system, Autograd, works.