Data Science and Artificial Intelligence

Multivariate Multilabel Classification with Logistic Regression

This entry is part 1 of 17 in the series Machine Learning Algorithms

Introduction to Multi-class Classification

The goal of this blog post is to show you how logistic regression can be applied to do multi-class classification. We will mainly focus on learning to build a logistic regression model for doing a multi-class classification.

Logistic regression is one of the most fundamental and widely used Machine Learning Algorithms. Logistic regression is usually among the first few topics which people pick while learning predictive modeling. Logistic regression is not a regression algorithm but a probabilistic classification model.

Free Step-by-step Guide To Become A Data Scientist

Subscribe and get this detailed guide absolutely FREE

Classification in Machine Learning is a technique of learning, where an instance is mapped to one of many labels. The machine learns patterns from data in such a way that the learned representation successfully maps the original dimension to the suggested label/class without any intervention from a human expert.

Logistic regression has a sigmoidal curve.

Following is the graph for the sigmoidal function:









The equation for the sigmoid function is:

It ensures that the generated number is always between 0 and 1 since the numerator is always smaller than the denominator by 1. See below:

The idea in logistic regression is to cast the problem in the form of a generalized linear regression model.

where ŷ =predicted value, x= independent variables and the β are coefficients to be learned.

This can be compactly expressed in vector form:

Thus, the logistic link function can be used to cast logistic regression into the Generalized Linear Model.

In its vanilla form logistic regression is used to do binary classification. Multiclass classification with logistic regression can be done either through the one-vs-rest scheme in which for each class a binary classification problem of data belonging or not to that class is done, or changing the loss function to cross- entropy loss.

In the Python Logistic Regression class, multi-class classification can be enabled/disabled by passing values to the argument called ‘‘multi_class’ in the constructor of the algorithm.  In the multiclass case, the training algorithm uses the one-vs-rest (OvR) scheme if the ‘multi_class’ option is set to ‘ovr’ and uses the cross-entropy loss if the ‘multi_class’ option is set to ‘multinomial’. (Currently, the ‘multinomial’ option is supported only by the ‘lbfgs’, ‘sag’ and ‘newton-cg’ solvers.) By default, multi_class is set to ’ovr’.

Problem Statement

Classify a handwritten image of a digit into a label from 0-9. Use multiclass logistic regression for this task.

About the Dataset

The MNIST database of handwritten digits, available from this page, has a training set of 60,000 examples, and a test set of 10,000 examples. It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting.

Four files are available on this site:

  • train-images-idx3-ubyte.gz: training set images (9912422 bytes)
  • train-labels-idx1-ubyte.gz: training set labels (28881 bytes)
  • t10k-images-idx3-ubyte.gz: test set images (1648877 bytes)
  • t10k-labels-idx1-ubyte.gz: test set labels (4542 bytes)
Parameters Number
Classes 10
Samples per class ~7000 samples per class
Samples total 70000
Dimensionality 784
Features integers values from 0 to 255

The MNIST database of handwritten digits is available on the following website: MNIST Dataset

Import libraries:

from sklearn.datasets import fetch_mldata
from sklearn.preprocessing import StandardScaler
from sklearn import metrics
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import pandas as pd
import numpy as np

Load data :

# You can add the parameter data_home to wherever to where you want to download your data
mnist = fetch_mldata('MNIST original')

Check data after loading:


(70000, 784)
['label', 'data']
[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]

Split data into train/test:

# test_size: what proportion of original data is used for test set
train_img, test_img, train_lbl, test_lbl = train_test_split(,, test_size=1/7.0, random_state=122)

Standardize the data:

scaler = StandardScaler()
# Fit on training set only.
# Apply transform to both the training set and the test set.
train_img = scaler.transform(train_img)
test_img = scaler.transform(test_img)

Fit the model:

model = LogisticRegression(solver = 'lbfgs'), train_lbl)

Validate the fitting:

# use the model to make predictions with the test data
y_pred = model.predict(test_img)
# how did our model perform?
count_misclassified = (test_lbl != y_pred).sum()
print('Misclassified samples: {}'.format(count_misclassified))
accuracy = metrics.accuracy_score(test_lbl, y_pred)
print('Accuracy: {:.2f}'.format(accuracy))

Misclassified samples: 829
Accuracy: 0.92

Way Ahead

The model has a 92% accuracy score. Since this is a very simplisticdataset with distinctly separable classes. But there you have it. That’s how to implement multi-class classification with logistic regression using scikit-learn. Load your favorite data set and give it a try! From here on, all you need is practice.

Series NavigationUsing Decision Trees for Regression Problems >>

Abhay Kumar

Abhay Kumar, lead Data Scientist – Computer Vision in a startup, is an experienced data scientist specializing in Deep Learning in Computer vision and has worked with a variety of programming languages like Python, Java, Pig, Hive, R, Shell, Javascript and with frameworks like Tensorflow, MXNet, Hadoop, Spark, MapReduce, Numpy, Scikit-learn, and pandas.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.

Related Articles