{ "cells": [ { "cell_type": "markdown", "id": "0a019af8-b0cf-48a5-8d4a-49cd9589ebd8", "metadata": {}, "source": [ "# $k$-means clustering" ] }, { "cell_type": "code", "execution_count": null, "id": "7d41e8ab", "metadata": {}, "outputs": [], "source": [ "# initialize packages\n", "%matplotlib ipympl\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import ipywidgets as widgets\n", "\n", "# load data\n", "from sklearn.datasets import load_iris\n", "iris = load_iris(as_frame=True)\n", "#np.random.seed(42)\n", "\n", "# initialize data\n", "X_iris = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n", "\n", "# set global variables\n", "k = 3 # number of clusters\n", "n = len(X_iris) # number of data points\n", "centroids = None # location of centroids\n", "labels = None # assigned cluster of each point\n", "iteration = 0\n", "\n", "# define functions\n", "\n", "# redraw plot\n", "def update_plot():\n", " if labels is None:\n", " points.set_array(None)\n", " points.set_color('steelblue')\n", " else:\n", " points.set_array(labels)\n", " if centroids is not None:\n", " centroid_plot.set_offsets(centroids)\n", " else:\n", " centroid_plot.set_offsets(np.empty((0,2)))\n", " fig1.canvas.draw_idle()\n", "\n", "# interation steps\n", "def assign_clusters():\n", " global labels\n", " distances = np.sqrt(((X_iris - centroids[:, np.newaxis])**2).sum(axis=2))\n", " labels = np.argmin(distances, axis=0)\n", "\n", "def update_centroids():\n", " global centroids\n", " centroids = np.array([X_iris[labels == i].mean(axis=0) for i in range(k)])\n", "\n", "# functions called by buttons\n", "def initialize(b):\n", " global centroids, labels, iteration\n", " iteration = 0\n", " labels = None\n", " idx = np.random.choice(n, k, replace=False)\n", " centroids = X_iris[idx]\n", " update_plot()\n", "\n", "def step_assign(b):\n", " if centroids is None:\n", " return\n", " assign_clusters()\n", " update_plot()\n", "\n", "def step_update(b):\n", " global iteration\n", " if labels is None:\n", " return\n", " update_centroids()\n", " iteration += 1\n", " update_plot()" ] }, { "cell_type": "markdown", "id": "15d5e8ac-2b81-42b4-9f86-d66dcf995e91", "metadata": {}, "source": [ "# Activity 6: iterate the $k$-means algorithm" ] }, { "cell_type": "code", "execution_count": null, "id": "c2bc3c58", "metadata": {}, "outputs": [], "source": [ "# set up plot figure\n", "plt.ioff()\n", "fig1, ax1 = plt.subplots(figsize=(6,6))\n", "plt.ion()\n", "\n", "fig1.canvas.header_visible = False\n", "fig1.canvas.footer_visible = False\n", "fig1.canvas.toolbar_visible = False\n", "\n", "x1_min, x1_max = 0, 8\n", "x2_min, x2_max = 0, 3\n", "\n", "ax1.set_xlabel(\"petal length (cm)\"); ax1.set_ylabel(\"petal width (cm)\")\n", "ax1.set_xlim(x1_min, x1_max)\n", "ax1.set_ylim(x2_min, x2_max)\n", "\n", "# plot data\n", "points = ax1.scatter(X_iris[:,0], X_iris[:,1], color='steelblue')\n", "centroid_plot = ax1.scatter([], [], marker='X', color='red', s=200)\n", "\n", "# Create buttons\n", "init_btn = widgets.Button(description='Initialize centroids')\n", "assign_btn = widgets.Button(description='Assign points')\n", "update_btn = widgets.Button(description='Update centroids')\n", "\n", "init_btn.on_click(initialize)\n", "assign_btn.on_click(step_assign)\n", "update_btn.on_click(step_update)\n", "\n", "button_row = widgets.HBox([init_btn, assign_btn, update_btn])\n", "ui = widgets.VBox([fig1.canvas, button_row])\n", "display(ui)" ] }, { "cell_type": "code", "execution_count": null, "id": "c2c00e03-1f5e-43ea-8721-673ca1b102ce", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.12" } }, "nbformat": 4, "nbformat_minor": 5 }