The information technology industry is in the middle of a powerful trend towards machine learning and artificial intelligence. These are difficult skills to master but if you embrace them and just do it, you’ll be making a very significant step towards advancing your career. As with any learning curve, it’s useful to start simple. The K-Means clustering algorithm is pretty intuitive and easy to understand, so in this post I’m going to describe what K-Means does and show you how to experiment with it using Spark and Python, and visualize its results in a Jupyter notebook.

## What is K-Means?

k-means clustering aims to group a set of objects in such a way that objects in the same group (or cluster) are more similar to each other than to those in other groups (clusters). It operates on a table of values where every cell is a number. K-Means only supports numeric columns. In Spark those tables are usually expressed as a dataframe. A dataframe with two columns can be easily visualized on a graph where the x-axis is the first column and the y-axis is the second column. For example, here’s a 2 dimensional graph for a dataframe with two columns.

If you were to manually group the data in the above graph, how would you do it? You might draw two circles, like this:

And in this case that is pretty close to what you get through k-means. The following figure shows how the data is segmented by running k-means on our two dimensional dataset.

Charting feature columns like that can help you make intuitive sense of how k-means is segmenting your data.

## Visualizing K-Means Clusters in 3D

The above plots were created by clustering two feature columns. There could have been other columns in our data set, but we just used two columns. If we want to use an additional column as a clustering feature we would want to visualize the cluster over three dimensions. Here’s an example that shows how to visualize cluster shapes with a 3D scatter/mesh plot in a Jupyter notebook using Python 3:

3 dimensional k-means
In [1]:
# Initialize plotting library and functions for 3D scatter plots
from sklearn.datasets import make_blobs
from sklearn.datasets import make_gaussian_quantiles
from sklearn.datasets import make_classification, make_regression
from sklearn.externals import six
import pandas as pd
import numpy as np
import argparse
import json
import re
import os
import sys
import plotly
import plotly.graph_objs as go
plotly.offline.init_notebook_mode()

def rename_columns(df, prefix='x'):
"""
Rename the columns of a dataframe to have X in front of them

:param df: data frame we're operating on
:param prefix: the prefix string
"""
df = df.copy()
df.columns = [prefix + str(i) for i in df.columns]
return df