# Import necessary libraries
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve, auc, confusion_matrix
import matplotlib.pyplot as plt
from ipywidgets import interact, FloatSlider
import seaborn as sns
Introduction
In this week’s discussion section, we will create some plots to better undertsand how much class imbalances can effect our classification model. Rather than creating a widget that updates the parameters of the model ( like we have done in the past couple weeks), this week we will create a widget that updates our data - specifically updating the class imbalance within our data. To do so, we will use synthesized data that is made with the intention of better understanding how relationships within data for logistic regression work. It is important to note that your results with real data may look very different - unlike this notebook, the real world data you will be working with was not made to better understand logistic regression.
Data
While our data is synthetic, we will still have it hold an environmnetal value. Our data is going to represent the prescence/absence of the invasive European green crab that is often found in California coastal waters. These crabs prefer warmer water temperatures between 64° F and 79° F and salinity levels between 26 and 39 ppt. The features for our data will be water temperature and salinity, and our target variable will be the presence (1) or absence (0) of green crabs at our different sampling sites. Import the libraries below and copy the function to generate our data below to get started.
Time for some FUN(ctions)!
We will create six different functions for the different parts of our interactive output: one to generate the data, a second to create a barplot to represent the class imbalance, a third to create a confusion matrix, another to create an ROC curve, a function to wrap everything together, and a final function to add interactivity. Let’s get to it!
Function 1
Create a function that generates the species data. The parameters should be the sample size and the ratio of present green crabs.
def generate_species_data(n_samples=1000, presence_ratio=0.3):
# Calculate number of samples for each class
= int(n_samples * presence_ratio)
n_present = n_samples - n_present
n_absent
# Generate features for presence sites
# Green crabs prefer warmer temps (between 64 and 79 degrees Fahrenheit) and salinity between 26 and 39 ppt
= np.random.normal(loc=71, scale= 4, size=n_present)
temp_present = np.random.normal(loc=32, scale=3, size=n_present)
salinity_present = np.column_stack([temp_present, salinity_present])
X_present = np.ones(n_present)
y_present
# Generate features for absence sites
# Sites with warmer temps or lower salinity
= np.random.normal(loc=26, scale=3, size=n_absent)
temp_absent = np.random.normal(loc=28, scale=2, size=n_absent)
salinity_absent = np.column_stack([temp_absent, salinity_absent])
X_absent = np.zeros(n_absent)
y_absent
# Combine and shuffle the data
= np.concatenate([X_present, X_absent])
X = np.concatenate([y_present, y_absent])
y
# Shuffle the data
= np.random.permutation(n_samples)
shuffle_idx = X[shuffle_idx]
X = y[shuffle_idx]
y
return X, y
Function 2
Create a function that creats a bar plot of species presense distribution based on the ratio selected by the user.
def plot_class_distribution(y):
= (8,4))
plt.figure(figsize
# Count the values in each category
= pd.Series(y).value_counts()
class_counts
# Create the barplot of Absent and Present species
= ['Absent', 'Present'], y = class_counts, color = '#005477')
sns.barplot(x 'Distribution of Species Presence/Absence')
plt.title('Number of Sampling sites')
plt.ylabel(
# Add percent over each bar
= len(y)
total for i,count in enumerate(class_counts):
= count/total * 100
percentage f'{percentage:.1f}%', ha = 'center', va = 'bottom')
plt.text(i, count, plt.show()
Function 3
Create a function that plots a confusion matrix of the predicted y values and true y values.
def plot_confusion_matrix(y_true, y_pred):
# Create confusion matrix
= confusion_matrix(y_true, y_pred)
cm
# Create confusion matrix plot
= (8,6))
plt.figure(figsize = 'd', cmap = 'GnBu',annot = True,
sns.heatmap(cm, fmt = ['Absent', 'Present'],
xticklabels = ['Absent', 'Present'])
yticklabels 'Confusion Matrix')
plt.title('True Label')
plt.ylabel('Predicted Label')
plt.xlabel(
plt.show()
# Calculate and display metrics
= cm[1,1]
TP = cm[0,0]
TN = cm[0,1]
FP = cm[1,0]
FN
print(f"True positives (correctly predicted presence): {TP}")
# Calculate accuracy + various metric
= (TP + TN) / (TP + TN + FP + FN)
accuracy = TP/ (TP + FN)
sensitivity = TN / (TN + FP )
specificity
print(f"\nModel Performance Metrics:")
print(f"Accuracy: {accuracy:.3f}")
print(f"Sensitivity ( True positive rate): {sensitivity:.3f}")
print(f"Specificity ( True negative rate:) {specificity:.3f}")
Function 4
Create a function that plots an ROC curve using the predicted y class probabilities and true y values.
def plot_roc_curve(y_test, y_pred_prob):
= roc_curve(y_test, y_pred_prob)
fpr, tpr, _
= auc(fpr, tpr)
roc_auc
= (8,6))
plt.figure(figsize = 'darkorange', lw =2, label = f'ROC Curve (AUC = {roc_auc:.2f})')
plt.plot(fpr, tpr, color 0,1], [0,1], color = 'navy', lw = 2, linestyle = '--',
plt.plot([= 'Random Classifier (AUC = 0.5)')
label 'False positive rate')
plt.xlabel('True positive rate')
plt.ylabel('ROC curve: Species Presence Prediction')
plt.title(
plt.legend()True)
plt.grid( plt.show()
Function 5
Create function that runs a logistic regression and outputs the three plots you created above.
def interactive_logistic_regression(presence_ratio = 0.3):
# Generate data based on class imbalance from user
= generate_species_data(presence_ratio = presence_ratio)
X,y
# Plot class distribution
print("\nClass Distribution")
plot_class_distribution(y)
# Split data
= train_test_split(X, y , test_size = 0.3, random_state = 42)
X_train, X_test, y_train, y_test
# Train model
= LogisticRegression()
model
model.fit(X_train, y_train)
# Make predict
= model.predict(X_test)
y_pred = model.predict_proba(X_test)[:,1]
y_pred_prob
# Plot confusion matrix
print("\nConfusion matrix:")
plot_confusion_matrix(y_test, y_pred)
# Plot ROC curve
print("\nROC Curve:")
plot_roc_curve(y_test, y_pred_prob)
Function 6
Create a function that adds interactivity to function 5.
# Create interactive widget
def generate_log_regression():
interact(interactive_logistic_regression, = FloatSlider(min = .1, max = .9, step= .1, value = 0.3,
presence_ratio = "% Present"))
description generate_log_regression()