Welcome back! In my previous posts on machine learning, I had talked about univariate linear regression, in which a continuous “target” variable was predicted on the basis of a single input feature. More generally, you would want to predict the target variable on the basis of multiple features. This is called multivariate linear regression.
The idea
In this post, I had used an example of predicting housing prices from the size of the house in square feet. But in practice, you would probably consider more features of the house, as shown below:
Here, the variables x1 to x4 represent 4 different features of the house. As before, the variable y denotes the target variable, in this case the price. Also, each row in the table (excluding the last column) represents a training example. To be clear with the notations I will use:
- n = number of input features (in this case 4)
- m = number of training examples (in this case, the number of rows in the table)
- x(i) = features of i’th training example, in the form of a column vector (in this case, the blue box contains the elements of x(2))
- x(i)j = value of feature j in the i’th training example
Hypothesis function
Recall that the hypothesis function for univariate linear regression was:
\begin{aligned} h_\theta(x) = \theta_0 ~+~ \theta_1x \end{aligned}
But now that we have multiple features, there must be a corresponding coefficient in the hypothesis function for each of them. And as usual, the hypothesis function will be a linear function of the input features. So for multivariate linear regression, the hypothesis function will be:
\begin{aligned} h_\theta(x) = \theta_0 ~+~ \theta_1x_1 ~+~ \theta_2x_2 ~+~ ... ~+~ \theta_nx_n \end{aligned}
As usual, all the θ’s are the parameters of the model. Let’s simplify this notation a bit. For convenience, let us introduce a variable x0, which is always 1 for all the training examples. And we will also introduce two column vectors:
\begin{aligned} x = \begin{bmatrix} x_0 \\ x_1 \\ x_2 \\ :\\:\\ x_n \end{bmatrix} \in ~R^{n+1},~~ \theta = \begin{bmatrix} \theta_0 \\ \theta_1 \\ \theta_2 \\ :\\:\\ \theta_n \end{bmatrix} \in ~R^{n+1} \end{aligned}
Then the hypothesis function can be written as:
\begin{aligned} h_\theta(x) &= \theta_0x_0 ~+~ \theta_1x_1 ~+~ \theta_2x_2 ~+~ ... ~+~ \theta_nx_n \\ &= \theta^Tx \end{aligned}
where the superscript T denotes the transpose of a vector/matrix.
After this, similar to the one variable case, the cost function (a function the algorithm tries to minimize, as before) can now be written as:
\begin{aligned} J( \theta) &= \frac 1 {2m} \sum_{i=1}^m (h_\theta(x^{(i)})~-~y^{(i)})^2 \\ where ~h_\theta(x) &= \theta^Tx \end{aligned}
except that now, the cost function J is a function of the parameter vector θ (instead of θ0 and θ1).
Gradient descent
Just like univariate linear regression, multivariate linear regression can use gradient descent to minimize the cost function. In this case, since there are more than two parameters to find, the gradient descent algorithm reads:
\begin{aligned} re&peat~until~convergence~\{ \\ &\theta_j := \theta_j~-~\alpha \frac \partial {\partial\theta_j} J(\theta) \\ &(j = 0, 1, 2, ..., n) \\ &(simultaneously~update~all~\theta_j) \\ \}~~&~ \end{aligned}
where the symbols have their usual meanings. (If you haven’t read about gradient descent yet, read this post.)
For the given cost function, the partial derivatives can be shown to be:
\begin{aligned} \frac \partial {\partial\theta_j}J(\theta) &= \frac 1 m \sum_{i=1}^m(h_\theta(x^{(i)})~-~y^{(i)}).x_j^{(i)} \\ (j &= 0, 1, 2, ..., n) \end{aligned}
(If you want, you can derive this yourself if you know calculus.) Keep in mind that ‘i’ here is the training example number, ‘j’ is the feature number, and x(i)0 = 1 for all ‘i’. So now the gradient descent algorithm reads:
\begin{aligned} re&peat~until~convergence~\{ \\ &\theta_j := \theta_j~-~ \frac \alpha m \sum_{i=1}^m(h_\theta(x^{(i)})~-~y^{(i)}).x_j^{(i)} \\ &(j = 0, 1, 2, ..., n) \\ &(simultaneously~update~all~\theta_j) \\ \}~~&~ \end{aligned}
And so, using this algorithm, you can solve a multivariate linear regression problem. That’s a wrap to the topic of linear regression – please do stay tuned for my future posts!