Multinomial Logistic Regression: Python Example¶
In this example, we will
Fit a multinomial logistic regression model to predict which digit (0 to 9) an image represents.
Adjust hyperparameters to optimize the performance of our model on an Approximate Leave-one-out Cross-validation of the training data set.
Use Approximate Leave-one-out Cross-validation to identify outliers in the training data set.
Note
We assume the reader already has a basic understanding of logistic regression and approximate leave-one-out cross-validation. If not, see Logistic Regression: The Definitive Guide.
Taking a Look at the Digits Data Set¶
Let’s begin with a quick analysis of the data set.
The digits data set comes packaged with sklearn. To load it, we can run
from sklearn.datasets import load_digits
X, y = load_digits(return_X_y = True)
The data set contains 1797 entries. Each entry consists of the values of a 8x8 pixel hand-written image of a digit along with its label identifying which number it’s suppose to represent (0 through 9).
Let’s count the number of instances of each digit and compute the distribution.
from collections import defaultdict
counts = defaultdict(int)
for yi in y:
counts[yi] += 1
print('digit','\t', 'percentage')
for i in range(10):
percentage = counts[i] / len(y) * 100.0
print(i,'\t', '%.2f' % percentage)
This gives the table
Digit |
Percentage |
---|---|
0 |
9.91 |
1 |
10.13 |
2 |
9.85 |
3 |
10.18 |
4 |
10.07 |
5 |
10.13 |
6 |
10.07 |
7 |
9.96 |
8 |
9.68 |
9 |
10.02 |
This shows us that the data set is close to being equally distributed.
Next, let’s use matplotlib to take a look at some example entries. We’ll randomly select 25 of the images and display the digits in a grid.
import numpy as np
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(6, 6))
columns = 5
rows = 5
np.random.seed(0)
indexes = np.random.randint(len(y), size=columns*rows)
for i, index in enumerate(indexes):
ax = fig.add_subplot(rows, columns, i+1)
ax.matshow(X[index].reshape(8, 8))
ax.set_title(str(y[index]))
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
fig.tight_layout()
This displays
We can see how pixelated and difficult to distinguish some of the digits are.
Now let’s turn to building a model.
Fitting the Model¶
We start by normalizing the feature matrix.
from sklearn.preprocessing import StandardScaler
X_p = StandardScaler().fit_transform(X)
We’ll then use bbai to fit a logistic regression model with hyperparameters adjusted so as to perform best on an approximate leave-one-out cross-validation of the training data set.
Note
To install bbai, see Getting Started.
from bbai.glm import LogisticRegression
model = LogisticRegression()
model.fit(X_p, y)
Note
There’s no need to specify a regularization parameter with bbai.glm.LogisticRegression. It uses second order optimization to find the value of C with the best approximate leave-one-out cross-validation.
Running
print('C_opt =', model.C_)
print('ALO_opt=', model.aloocv_)
Prints
C_opt = 1.0605577384507674
ALO_opt = 0.09934658483240931
We can plot out the approximate leave-one-out cross-validation to see where the optimal value lies along the curve.
def compute_alo(C):
model_p = LogisticRegression(C=C)
model_p.fit(X_p, y)
return model_p.aloocv_
Cs = np.arange(0.5, 3, 0.1)
alos = [compute_alo(C) for C in Cs]
plt.plot(Cs, alos)
plt.title('Approximate Leave-one-out Cross-validation')
plt.ylabel('Negative Log-likelihood')
plt.xlabel('C')
plt.show()
Displays
The logistic regression model will be specified by an intercept vector of 10 value, one value for each class; and a weight matrix of 10x64 of values. Each row of the matrix contains values for the 8x8 pixels in the feature image.
Let’s visualize the model weights by plotting out the weights in a heatmap.
import matplotlib.pyplot as plt
import seaborn as sns
fig = plt.figure(figsize=(20, 15))
columns = 4
rows = 3
for i in range(10):
ax = fig.add_subplot(rows, columns, i+1)
sns.heatmap(model.coef_[i].reshape(8, 8), annot=True, fmt='.2f', cbar=False, ax=ax)
ax.set_title(str(i))
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
fig.tight_layout()
Displays
Looking at Predictions¶
To see how our model performs, let’s look at some in-sample predictions. We’ll select a collection random images and display the predicted probabilities for each digit.
columns = 2
rows = 5
np.random.seed(1)
indexes = np.random.randint(len(y), size=rows)
preds = model.predict_proba(np.array([X_p[index] for index in indexes]))
gridspec = {
'height_ratios' : [6, 1],
}
fig, (ax_imgs, ax_preds) = \
plt.subplots(columns, rows, gridspec_kw=gridspec, figsize=(30, 7))
for i, index in enumerate(indexes):
ax_img = ax_imgs[i]
ax_pred = ax_preds[i]
ax_img.matshow(X[index].reshape(8, 8))
sns.heatmap([preds[i]], annot=True, fmt='.2f',
cbar=False, ax=ax_pred, square=True)
ax_img.set_title(str(y[index]))
ax_img.get_xaxis().set_visible(False)
ax_img.get_yaxis().set_visible(False)
ax_pred.set_title('Predicted Probabilities')
ax_pred.get_yaxis().set_visible(False)
ax_pred.set_xticks([j + 0.5 for j in range(10)])
ax_pred.set_xticklabels(list(range(10)))
Displays
Those predictions look good; but of course, we judge a model by how well it predicts out-of-sample data, not in-sample data. To get an idea of how the model performs on out-of-sample data we plot out the approximate leave-one-out errors (measured as negative log-likelihood), sorted by largest to smallest.
alos = list(enumerate(model.aloocvs_))
alos = sorted(alos, key=lambda element: -element[1])
plt.scatter(list(range(len(alos))), [alo for _, alo in alos])
plt.title('Approximate Leave-one-out Errors')
plt.ylabel('Negative Log-likelihood')
Displays
Let’s take a look at the 10 images with the largest leave-one-out errors. These may be outliers.
fig = plt.figure(figsize=(6, 3))
columns = 5
rows = 2
np.random.seed(0)
indexes = [index for index, _ in alos[:10]]
for i, index in enumerate(indexes):
ax = fig.add_subplot(rows, columns, i+1)
ax.matshow(X[index].reshape(8, 8))
ax.set_title(str(y[index]))
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
fig.tight_layout()
Displays
For the full source code, see github.com/rnburn/bbai/example/01-digits.py.