In this tutorial, we'll dive into the fascinating world of the Iris dataset and explore various aspects of it using the powerful Yellowbrick library. Yellowbrick is a Python library that provides a suite of visual diagnostic tools for machine learning, making it easier to interpret and understand our models.
Yellowbrick is an open source, pure Python project that extends the scikit-learn API with visual analysis and diagnostic tools. The Yellowbrick API also wraps matplotlib to create publication-ready figures and interactive data explorations while still allowing developers fine-grain control of figures. For users, Yellowbrick can help evaluate the performance, stability, and predictive value of machine learning models and assist in diagnosing problems throughout the machine learning workflow.
First, let's load the Iris dataset using scikit-learn's load_iris
function. We'll also define the feature names and class names for better readability.
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)
feature_names = ['sepal length', 'sepal width', 'petal length', 'petal width']
class_names = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']
Next, we'll split the data into training and testing sets using scikit-learn's train_test_split
function. This will allow us to train our models on a portion of the data and evaluate their performance on unseen data.
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
To get a better understanding of the relationships between the different features in the Iris dataset, we can use Yellowbrick's RadViz
visualizer. This visualizer creates a scatter plot matrix that shows the distribution of each feature and the correlations between them.
from yellowbrick.features import RadViz
import matplotlib.pyplot as plt
viz = RadViz(features=feature_names, classes=class_names)
viz.fit(X_train, y_train)
viz.transform(X_train)
fig1 = plt.figure(figsize=(10, 8))
viz.poof()
RadViz scatter plot
The RadViz
visualizer gives us a comprehensive view of the Iris dataset. Each data point is represented by a colored dot, with the color indicating its class. We can see how the different features are distributed and how they relate to each other.
Now, let's train a K-Nearest Neighbors (KNN) classifier and visualize the decision boundaries it learns. We'll use Yellowbrick's DecisionViz
visualizer for this purpose.
from sklearn.neighbors import KNeighborsClassifier
from yellowbrick.contrib.classifier import DecisionViz
model = KNeighborsClassifier(n_neighbors=5)
model.fit(X_train, y_train)
viz = DecisionViz(
model,
features=feature_names[:2], # Use the first two features for 2D visualization
classes=class_names
)
viz.fit(X_train[:, :2], y_train)
fig2 = plt.figure(figsize=(10, 8))
viz.draw(X_train[:, :2], y_train)
viz.poof()
Classification Boundaries with DecisionViz
The DecisionViz
visualizer shows us the decision boundaries learned by the KNN classifier. The background is colored based on the predicted class for each region, and the data points are plotted as colored dots. This visualization helps us understand how the classifier is making its predictions.