Multinomial Logistic Regression: Python Example

In this example, we will

  • Fit a multinomial logistic regression model to predict which digit (0 to 9) an image represents.

  • Adjust hyperparameters to optimize the performance of our model on an Approximate Leave-one-out Cross-validation of the training data set.

  • Use Approximate Leave-one-out Cross-validation to identify outliers in the training data set.

Note

We assume the reader already has a basic understanding of logistic regression and approximate leave-one-out cross-validation. If not, see Logistic Regression: The Definitive Guide.

Taking a Look at the Digits Data Set

Let’s begin with a quick analysis of the data set.

The digits data set comes packaged with sklearn. To load it, we can run

from sklearn.datasets import load_digits

X, y = load_digits(return_X_y = True)

The data set contains 1797 entries. Each entry consists of the values of a 8x8 pixel hand-written image of a digit along with its label identifying which number it’s suppose to represent (0 through 9).

Let’s count the number of instances of each digit and compute the distribution.

from collections import defaultdict

counts = defaultdict(int)
for yi in y:
    counts[yi] += 1
print('digit','\t', 'percentage')
for i in range(10):
    percentage = counts[i] / len(y) * 100.0
    print(i,'\t', '%.2f' % percentage)

This gives the table

Digit

Percentage

0

9.91

1

10.13

2

9.85

3

10.18

4

10.07

5

10.13

6

10.07

7

9.96

8

9.68

9

10.02

This shows us that the data set is close to being equally distributed.

Next, let’s use matplotlib to take a look at some example entries. We’ll randomly select 25 of the images and display the digits in a grid.

import numpy as np
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(6, 6))
columns = 5
rows = 5
np.random.seed(0)
indexes = np.random.randint(len(y), size=columns*rows)
for i, index in enumerate(indexes):
    ax = fig.add_subplot(rows, columns, i+1)
    ax.matshow(X[index].reshape(8, 8))
    ax.set_title(str(y[index]))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
fig.tight_layout()

This displays

../_images/digits.svg

We can see how pixelated and difficult to distinguish some of the digits are.

Now let’s turn to building a model.

Fitting the Model

We start by normalizing the feature matrix.

from sklearn.preprocessing import StandardScaler

X_p = StandardScaler().fit_transform(X)

We’ll then use bbai to fit a logistic regression model with hyperparameters adjusted so as to perform best on an approximate leave-one-out cross-validation of the training data set.

Note

To install bbai, see Getting Started.

from bbai.glm import LogisticRegression

model = LogisticRegression()
model.fit(X_p, y)

Note

There’s no need to specify a regularization parameter with bbai.glm.LogisticRegression. It uses second order optimization to find the value of C with the best approximate leave-one-out cross-validation.

Running

print('C_opt =', model.C_)
print('ALO_opt=', model.aloocv_)

Prints

C_opt = 1.0605577384507674
ALO_opt = 0.09934658483240931

We can plot out the approximate leave-one-out cross-validation to see where the optimal value lies along the curve.

def compute_alo(C):
    model_p = LogisticRegression(C=C)
    model_p.fit(X_p, y)
    return model_p.aloocv_
Cs = np.arange(0.5, 3, 0.1)
alos = [compute_alo(C) for C in Cs]
plt.plot(Cs, alos)
plt.title('Approximate Leave-one-out Cross-validation')
plt.ylabel('Negative Log-likelihood')
plt.xlabel('C')
plt.show()

Displays

../_images/alo-c.svg

The logistic regression model will be specified by an intercept vector of 10 value, one value for each class; and a weight matrix of 10x64 of values. Each row of the matrix contains values for the 8x8 pixels in the feature image.

Let’s visualize the model weights by plotting out the weights in a heatmap.

import matplotlib.pyplot as plt
import seaborn as sns

fig = plt.figure(figsize=(20, 15))
columns = 4
rows = 3
for i in range(10):
    ax = fig.add_subplot(rows, columns, i+1)
    sns.heatmap(model.coef_[i].reshape(8, 8), annot=True, fmt='.2f', cbar=False, ax=ax)
    ax.set_title(str(i))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
fig.tight_layout()

Displays

../_images/digit-weights.svg

Looking at Predictions

To see how our model performs, let’s look at some in-sample predictions. We’ll select a collection random images and display the predicted probabilities for each digit.

columns = 2
rows = 5
np.random.seed(1)
indexes = np.random.randint(len(y), size=rows)
preds = model.predict_proba(np.array([X_p[index] for index in indexes]))
gridspec = {
    'height_ratios' : [6,  1],
}
fig, (ax_imgs, ax_preds) = \
    plt.subplots(columns, rows, gridspec_kw=gridspec, figsize=(30, 7))
for i, index in enumerate(indexes):
    ax_img = ax_imgs[i]
    ax_pred = ax_preds[i]
    ax_img.matshow(X[index].reshape(8, 8))
    sns.heatmap([preds[i]], annot=True, fmt='.2f',
                cbar=False, ax=ax_pred, square=True)
    ax_img.set_title(str(y[index]))
    ax_img.get_xaxis().set_visible(False)
    ax_img.get_yaxis().set_visible(False)
    ax_pred.set_title('Predicted Probabilities')
    ax_pred.get_yaxis().set_visible(False)
    ax_pred.set_xticks([j + 0.5  for j in range(10)])
    ax_pred.set_xticklabels(list(range(10)))

Displays

../_images/pred.svg

Those predictions look good; but of course, we judge a model by how well it predicts out-of-sample data, not in-sample data. To get an idea of how the model performs on out-of-sample data we plot out the approximate leave-one-out errors (measured as negative log-likelihood), sorted by largest to smallest.

alos = list(enumerate(model.aloocvs_))
alos = sorted(alos, key=lambda element: -element[1])
plt.scatter(list(range(len(alos))), [alo for _, alo in alos])
plt.title('Approximate Leave-one-out Errors')
plt.ylabel('Negative Log-likelihood')

Displays

../_images/alo-errors.svg

Let’s take a look at the 10 images with the largest leave-one-out errors. These may be outliers.

fig = plt.figure(figsize=(6, 3))
columns = 5
rows = 2
np.random.seed(0)
indexes = [index for index, _ in alos[:10]]
for i, index in enumerate(indexes):
    ax = fig.add_subplot(rows, columns, i+1)
    ax.matshow(X[index].reshape(8, 8))
    ax.set_title(str(y[index]))
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
fig.tight_layout()

Displays

../_images/outliers.svg

For the full source code, see github.com/rnburn/bbai/example/01-digits.py.

Stay up to date