The information technology industry is in the middle of a powerful trend towards machine learning and artificial intelligence. These are difficult skills to master but if you embrace them and just do it, you’ll be making a very significant step towards advancing your career. As with any learning curve, it’s useful to start simple. The K-Means clustering algorithm is pretty intuitive and easy to understand, so in this post I’m going to describe what K-Means does and show you how to experiment with it using Spark and Python, and visualize its results in a Jupyter notebook.
What is K-Means?
k-means clustering aims to group a set of objects in such a way that objects in the same group (or cluster) are more similar to each other than to those in other groups (clusters). It operates on a table of values where every cell is a number. K-Means only supports numeric columns. In Spark those tables are usually expressed as a dataframe. A dataframe with two columns can be easily visualized on a graph where the x-axis is the first column and the y-axis is the second column. For example, here’s a 2 dimensional graph for a dataframe with two columns.
If you were to manually group the data in the above graph, how would you do it? You might draw two circles, like this:
And in this case that is pretty close to what you get through k-means. The following figure shows how the data is segmented by running k-means on our two dimensional dataset.
Charting feature columns like that can help you make intuitive sense of how k-means is segmenting your data.
Visualizing K-Means Clusters in 3D
The above plots were created by clustering two feature columns. There could have been other columns in our data set, but we just used two columns. If we want to use an additional column as a clustering feature we would want to visualize the cluster over three dimensions. Here’s an example that shows how to visualize cluster shapes with a 3D scatter/mesh plot in a Jupyter notebook using Python 3:
# Initialize plotting library and functions for 3D scatter plots from sklearn.datasets import make_blobs from sklearn.datasets import make_gaussian_quantiles from sklearn.datasets import make_classification, make_regression from sklearn.externals import six import pandas as pd import numpy as np import argparse import json import re import os import sys import plotly import plotly.graph_objs as go plotly.offline.init_notebook_mode() def rename_columns(df, prefix='x'): """ Rename the columns of a dataframe to have X in front of them :param df: data frame we're operating on :param prefix: the prefix string """ df = df.copy() df.columns = [prefix + str(i) for i in df.columns] return df