Articles - Classification Methods Essentials

Multinomial Logistic Regression Essentials in R

  |   35361  |  Post a comment  |  Classification Methods Essentials

The multinomial logistic regression is an extension of the logistic regression (Chapter @ref(logistic-regression)) for multiclass classification tasks. It is used when the outcome involves more than two classes.

In this chapter, we’ll show you how to compute multinomial logistic regression in R.

Contents:


Loading required R packages

  • tidyverse for easy data manipulation
  • caret for easy predictive modeling
  • nnet for computing multinomial logistic regression
library(tidyverse)
library(caret)
library(nnet)

Preparing the data

We’ll use the iris data set, introduced in Chapter @ref(classification-in-r), for predicting iris species based on the predictor variables Sepal.Length, Sepal.Width, Petal.Length, Petal.Width.

We start by randomly splitting the data into training set (80% for building a predictive model) and test set (20% for evaluating the model). Make sure to set seed for reproducibility.

# Load the data
data("iris")
# Inspect the data
sample_n(iris, 3)
# Split the data into training and test set
set.seed(123)
training.samples <- iris$Species %>% 
  createDataPartition(p = 0.8, list = FALSE)
train.data  <- iris[training.samples, ]
test.data <- iris[-training.samples, ]

Computing multinomial logistic regression

# Fit the model
model <- nnet::multinom(Species ~., data = train.data)
# Summarize the model
summary(model)
# Make predictions
predicted.classes <- model %>% predict(test.data)
head(predicted.classes)
# Model accuracy
mean(predicted.classes == test.data$Species)

Model accuracy:

mean(predicted.classes == test.data$Species)
## [1] 0.967

Our model is very good in predicting the different categories with an accuracy of 97%.

Discussion

This chapter describes how to compute multinomial logistic regression in R. This method is used for multiclass problems. In practice, it is not used very often. Discriminant analysis (Chapter @ref(discriminant-analysis)) is more popular for multiple-class classification.