Multinomial Logistic Regression Essentials in R
The multinomial logistic regression is an extension of the logistic regression (Chapter @ref(logistic-regression)) for multiclass classification tasks. It is used when the outcome involves more than two classes.
In this chapter, we’ll show you how to compute multinomial logistic regression in R.
Contents:
Loading required R packages
tidyverse
for easy data manipulationcaret
for easy predictive modelingnnet
for computing multinomial logistic regression
library(tidyverse)
library(caret)
library(nnet)
Preparing the data
We’ll use the iris
data set, introduced in Chapter @ref(classification-in-r), for predicting iris species based on the predictor variables Sepal.Length, Sepal.Width, Petal.Length, Petal.Width.
We start by randomly splitting the data into training set (80% for building a predictive model) and test set (20% for evaluating the model). Make sure to set seed for reproducibility.
# Load the data
data("iris")
# Inspect the data
sample_n(iris, 3)
# Split the data into training and test set
set.seed(123)
training.samples <- iris$Species %>%
createDataPartition(p = 0.8, list = FALSE)
train.data <- iris[training.samples, ]
test.data <- iris[-training.samples, ]
Computing multinomial logistic regression
# Fit the model
model <- nnet::multinom(Species ~., data = train.data)
# Summarize the model
summary(model)
# Make predictions
predicted.classes <- model %>% predict(test.data)
head(predicted.classes)
# Model accuracy
mean(predicted.classes == test.data$Species)
Model accuracy:
mean(predicted.classes == test.data$Species)
## [1] 0.967
Our model is very good in predicting the different categories with an accuracy of 97%.
Discussion
This chapter describes how to compute multinomial logistic regression in R. This method is used for multiclass problems. In practice, it is not used very often. Discriminant analysis (Chapter @ref(discriminant-analysis)) is more popular for multiple-class classification.