Data Science and Artificial Intelligence

K Means Clustering Case Study

Simple Case Study of Implementing K Means Clustering on the IRIS Dataset

Import tools and libraries:
from time import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn.cluster import KMeans
from sklearn.datasets import load_digits
from sklearn.decomposition import PCA
from sklearn.preprocessing import scale


Create a function to extract a cluster with k labels:
def get_cluster_metric(y_train, km_labels_):
print("Homogeneity: %0.3f" % metrics.homogeneity_score(y_train, km_labels_))
print("Completeness: %0.3f" % metrics.completeness_score(y_train, km_labels_))
print("V-measure: %0.3f" % metrics.v_measure_score(y_train, km_labels_))
Generate hypothetical data for practice:
digits = load_digits()
data = scale(
n_samples, n_features = data.shape
n_digits = len(np.unique(
labels =
sample_size = 300
print("n_digits: %d, \t n_samples %d, \t n_features %d"
% (n_digits, n_samples, n_features))


Free Step-by-step Guide To Become A Data Scientist

Subscribe and get this detailed guide absolutely FREE

n_digits: 10,            n_samples 1797,  n_features 64



Output: (1797, )


Loading the inbuilt IRIS dataset in Python:
from sklearn.datasets import load_iris


Algorithm to extract the clusters and compute sum of squared errors:
y = labels
sse = {}
accuracy = []
for k in range(1, 20):
kmeans = KMeans(n_clusters=k, max_iter=1000).fit(data)
sse[k] = kmeans.inertia_ # Inertia: Sum of distances of samples to their closest cluster center
labels_pred = kmeans.labels_
#     print(labels_pred.shape)
# check how many of the samples were correctly labeled
correct_labels = sum(labels == labels_pred)
#     print("Result: %d out of %d samples were correctly labeled. when k = %d " % (correct_labels, y.size,k))
print("correct %.02f percent classification at k = %d" % (correct_labels/float(y.size) * 100 ,k))
get_cluster_metric(y, kmeans.labels_)



#No. of clusters v/s SSE
plt.plot(list(sse.keys()), list(sse.values()))
plt.xlabel("Number of cluster")







#No. of clusters v/s accuracy
plt.plot(range(1, 20,1),accuracy)
plt.xlabel("Number of cluster")









An alumnus of the NIE-Institute Of Technology, Mysore, Prateek is an ardent Data Science enthusiast. He has been working at Acadgild as a Data Engineer for the past 3 years. He is a Subject-matter expert in the field of Big Data, Hadoop ecosystem, and Spark.

Leave a Reply

Your email address will not be published. Required fields are marked *

This site uses Akismet to reduce spam. Learn how your comment data is processed.

Related Articles