Explainability blog with detail code
Here is detail code & notebook that describe all methods for explainability. The method include DALEX, LIME, BREAKDOWN, Shapley. Using this one can describe how different feature and feature interaction impact target variable. This example predict wine quality based on wine features.R Notebook
Goal: Explain which features make a product good quality.
Example Question : What properties of wine makes a good wine.
Wine has been used and produced for thousands of years. Different culture, different age group enjoy drinking wine. There is 400B of market cap. Companies across the world are competing to produce better quality wine to get market share.
However there is no consensus what is definition of good quality wine. Good quality is hard to define in words and explain. To explain good quality wine, we study
a wine dataset Built explanatory model
Dataset
“Wine Quality” dataset from the UC Irvine Machine Learning Data Repository
#Tutorial
#install.packages("ggridges")
#install.packages("ggthemes")
#install.packages("iml")
#install.packages("breakDown")
#install.packages("DALEX")
#install.packages("glmnet")
#install.packages("partykit")
# data wrangling
library(tidyverse)
library(readr)
# ml
library(caret)
## Loading required package: lattice
##
## Attaching package: 'caret'
## The following objects are masked from 'package:MLmetrics':
##
## MAE, RMSE
## The following object is masked from 'package:purrr':
##
## lift
# plotting
library(gridExtra)
library(grid)
library(ggridges)
library(ggthemes)
theme_set(theme_minimal())
# explaining models
# https://github.com/christophM/iml
library(iml)
# https://pbiecek.github.io/breakDown/
library(breakDown)
# https://pbiecek.github.io/DALEX/
library(DALEX)
## Welcome to DALEX (version: 2.3.0).
## Find examples and detailed introduction at: http://ema.drwhy.ai/
## Additional features will be available after installation of: ggpubr.
## Use 'install_dependencies()' to get all suggested dependencies
##
## Attaching package: 'DALEX'
## The following object is masked from 'package:dplyr':
##
## explain
library(partykit)
## Loading required package: libcoin
## Loading required package: mvtnorm
library(libcoin)
library(mvtnorm)
Overview
We first load data, clean, do data exploration. Build linear regression model. Build random forst predictor Explain
- Feature importance
- Partial dependence plots
- Individual conditional expectation plots (ICE)
- Tree surrogate
- LocalModel: Local Interpretable Model-agnostic Explanations (similar to lime)
- Shapley value for explaining single predictions
Load the data
# Load and clean data
clean_data <- function(df){
red_wine_df <- read_delim("data/winequality-red.csv", delim=";")
red_wine_df['wine_type'] <- 'red'
white_wine_df <- read_delim("data/winequality-white.csv", delim=";")
white_wine_df['wine_type'] <- 'white'
wine_df <- bind_rows(red_wine_df,white_wine_df) %>%
filter(quality >= 0 & quality <= 10) %>%
drop_na()
#white_wine_df <- read_delim("data/winequality-white.csv", delim=";")
#white_wine_df['wine_type'] <- 'white'
#wine_df <- rbind(red_wine_df,white_wine_df) %>%
return(wine_df)
}
wine_df <- clean_data(df)
## Rows: 1599 Columns: 12
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ";"
## dbl (12): fixed acidity, volatile acidity, citric acid, residual sugar, chlo...
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
## Rows: 4898 Columns: 12
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ";"
## dbl (12): fixed acidity, volatile acidity, citric acid, residual sugar, chlo...
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
wine_df = wine_df %>%
mutate(quality_cat = as.factor(ifelse(quality < 6, "qual_low", "qual_high")))
Data Exploration
“Table 1” shows all the variable of our data along with the first few rows of our data. The varaibles are defined as follows:
fixed_acidity - acids involved with wine that are fixed (don’t evaporate readily)
volatile_acidity - the amount of acetic acid in wine, which at high of levels can lead to an unpleasant, vinegar taste
citric_acid - weak organic acid that occurs naturally in citrus fruits and can add ‘freshness’ and flavor to wines
residual_sugar - refers to any natural grape sugars that are left over after fermentation stops. it’s rare to find wines with less than 1 gram/liter and wines with greater than 45 grams/liter are considered sweet
chlorides - the amount of salt in the wine
free_sulfur_dioxide -free form of SO2 exists in equilibrium between molecular SO2 (as a dissolved gas) and bisulphate ion. It exhibits both germicidal and antioxidant properties
total_sulfur_dioxide - amount of free and bound forms of S02
density - self explanatory
pH - from a winemaker’s point of view, it is a way to measure ripeness in relation to acidity
sulphates - a wine additive which can contribute to sulfur dioxide gas (S02) levels. It acts as antimicrobial and antioxidant
alcohol - the percent alcohol content of the wine
quality - output variable
Out of 6497 rows in dataset, 6251 are clean rows. We used data 1875 for exploration and 4376 for testing/model.
colnames(wine_df) = gsub(" ", "_", colnames(wine_df))
glimpse(wine_df)
## Rows: 6,497
## Columns: 14
## $ fixed_acidity <dbl> 7.4, 7.8, 7.8, 11.2, 7.4, 7.4, 7.9, 7.3, 7.8, 7.5…
## $ volatile_acidity <dbl> 0.700, 0.880, 0.760, 0.280, 0.700, 0.660, 0.600, …
## $ citric_acid <dbl> 0.00, 0.00, 0.04, 0.56, 0.00, 0.00, 0.06, 0.00, 0…
## $ residual_sugar <dbl> 1.9, 2.6, 2.3, 1.9, 1.9, 1.8, 1.6, 1.2, 2.0, 6.1,…
## $ chlorides <dbl> 0.076, 0.098, 0.092, 0.075, 0.076, 0.075, 0.069, …
## $ free_sulfur_dioxide <dbl> 11, 25, 15, 17, 11, 13, 15, 15, 9, 17, 15, 17, 16…
## $ total_sulfur_dioxide <dbl> 34, 67, 54, 60, 34, 40, 59, 21, 18, 102, 65, 102,…
## $ density <dbl> 0.9978, 0.9968, 0.9970, 0.9980, 0.9978, 0.9978, 0…
## $ pH <dbl> 3.51, 3.20, 3.26, 3.16, 3.51, 3.51, 3.30, 3.39, 3…
## $ sulphates <dbl> 0.56, 0.68, 0.65, 0.58, 0.56, 0.56, 0.46, 0.47, 0…
## $ alcohol <dbl> 9.4, 9.8, 9.8, 9.8, 9.4, 9.4, 9.4, 10.0, 9.5, 10.…
## $ quality <dbl> 5, 5, 5, 6, 5, 5, 5, 7, 7, 5, 5, 5, 5, 5, 5, 5, 7…
## $ wine_type <chr> "red", "red", "red", "red", "red", "red", "red", …
## $ quality_cat <fct> qual_low, qual_low, qual_low, qual_high, qual_low…
p1 = wine_df %>%
ggplot(aes(x = quality, fill = quality)) +
geom_bar(alpha = 0.8) +
scale_fill_tableau() +
guides(fill = FALSE)
## Warning: `guides(<scale> = FALSE)` is deprecated. Please use `guides(<scale> =
## "none")` instead.
p1
p2 = wine_df %>%
ggplot(aes(x = quality_cat, fill = quality_cat)) +
geom_bar(alpha = 0.8) +
scale_fill_tableau() +
guides(fill = FALSE)
## Warning: `guides(<scale> = FALSE)` is deprecated. Please use `guides(<scale> =
## "none")` instead.
p2
p3 = wine_df %>%
gather(x, y, fixed_acidity:alcohol) %>%
ggplot(aes(x = y, y = quality_cat, color = quality_cat, fill = quality_cat)) +
facet_wrap( ~ x, scale = "free", ncol = 4) +
scale_fill_tableau() +
scale_color_tableau() +
scale_fill_viridis_d(direction = -1, guide = "none")+
geom_density_ridges(alpha = 0.7) +
guides(fill = FALSE, color = FALSE) +
theme(plot.title = element_text(size = 24, hjust = 0.5))+
labs(title = "Relationship between Quality and and Features ", y = "Quality")
## Scale for 'fill' is already present. Adding another scale for 'fill', which
## will replace the existing scale.
## Warning: `guides(<scale> = FALSE)` is deprecated. Please use `guides(<scale> =
## "none")` instead.
p3
## Picking joint bandwidth of 0.182
## Picking joint bandwidth of 0.00375
## Picking joint bandwidth of 0.0211
## Picking joint bandwidth of 0.000499
## Picking joint bandwidth of 0.168
## Picking joint bandwidth of 3.28
## Picking joint bandwidth of 0.0278
## Picking joint bandwidth of 0.855
## Picking joint bandwidth of 0.0214
## Picking joint bandwidth of 10.2
## Picking joint bandwidth of 0.0266
#grid.arrange(p1, p2, ncol = 2, widths = c(0.3, 0.7))
wine_df2 <- wine_df[c ('fixed_acidity' ,'volatile_acidity','citric_acid',
'residual_sugar', 'chlorides','free_sulfur_dioxide',
'total_sulfur_dioxide', 'density',
'pH', 'sulphates','alcohol', 'quality_cat' )]
glimpse(wine_df2)
## Rows: 6,497
## Columns: 12
## $ fixed_acidity <dbl> 7.4, 7.8, 7.8, 11.2, 7.4, 7.4, 7.9, 7.3, 7.8, 7.5…
## $ volatile_acidity <dbl> 0.700, 0.880, 0.760, 0.280, 0.700, 0.660, 0.600, …
## $ citric_acid <dbl> 0.00, 0.00, 0.04, 0.56, 0.00, 0.00, 0.06, 0.00, 0…
## $ residual_sugar <dbl> 1.9, 2.6, 2.3, 1.9, 1.9, 1.8, 1.6, 1.2, 2.0, 6.1,…
## $ chlorides <dbl> 0.076, 0.098, 0.092, 0.075, 0.076, 0.075, 0.069, …
## $ free_sulfur_dioxide <dbl> 11, 25, 15, 17, 11, 13, 15, 15, 9, 17, 15, 17, 16…
## $ total_sulfur_dioxide <dbl> 34, 67, 54, 60, 34, 40, 59, 21, 18, 102, 65, 102,…
## $ density <dbl> 0.9978, 0.9968, 0.9970, 0.9980, 0.9978, 0.9978, 0…
## $ pH <dbl> 3.51, 3.20, 3.26, 3.16, 3.51, 3.51, 3.30, 3.39, 3…
## $ sulphates <dbl> 0.56, 0.68, 0.65, 0.58, 0.56, 0.56, 0.46, 0.47, 0…
## $ alcohol <dbl> 9.4, 9.8, 9.8, 9.8, 9.4, 9.4, 9.4, 10.0, 9.5, 10.…
## $ quality_cat <fct> qual_low, qual_low, qual_low, qual_high, qual_low…
#Remove categorical columns
wine_df_num = subset(wine_df2, select = -c(quality_cat))
histgrams <- apply(wine_df_num, 2,
function(x){
figure(title= "NULL", xlab = colnames(x),
width = 400, height = 250) %>%
ly_hist(x,breaks = 40, freq = FALSE,
color=brewer.pal(9, "GnBu")) %>%
ly_density(x)})
grid_plot(histgrams, nrow=6)
Build Model
Build Train and test set
set.seed(42)
idx = createDataPartition(wine_df2$quality_cat,
p = 0.7,
list = FALSE,
times = 1)
wine_train = wine_df2[ idx,]
wine_test = wine_df2[-idx,]
glimpse(wine_df2)
## Rows: 6,497
## Columns: 12
## $ fixed_acidity <dbl> 7.4, 7.8, 7.8, 11.2, 7.4, 7.4, 7.9, 7.3, 7.8, 7.5…
## $ volatile_acidity <dbl> 0.700, 0.880, 0.760, 0.280, 0.700, 0.660, 0.600, …
## $ citric_acid <dbl> 0.00, 0.00, 0.04, 0.56, 0.00, 0.00, 0.06, 0.00, 0…
## $ residual_sugar <dbl> 1.9, 2.6, 2.3, 1.9, 1.9, 1.8, 1.6, 1.2, 2.0, 6.1,…
## $ chlorides <dbl> 0.076, 0.098, 0.092, 0.075, 0.076, 0.075, 0.069, …
## $ free_sulfur_dioxide <dbl> 11, 25, 15, 17, 11, 13, 15, 15, 9, 17, 15, 17, 16…
## $ total_sulfur_dioxide <dbl> 34, 67, 54, 60, 34, 40, 59, 21, 18, 102, 65, 102,…
## $ density <dbl> 0.9978, 0.9968, 0.9970, 0.9980, 0.9978, 0.9978, 0…
## $ pH <dbl> 3.51, 3.20, 3.26, 3.16, 3.51, 3.51, 3.30, 3.39, 3…
## $ sulphates <dbl> 0.56, 0.68, 0.65, 0.58, 0.56, 0.56, 0.46, 0.47, 0…
## $ alcohol <dbl> 9.4, 9.8, 9.8, 9.8, 9.4, 9.4, 9.4, 10.0, 9.5, 10.…
## $ quality_cat <fct> qual_low, qual_low, qual_low, qual_high, qual_low…
options(knitr.table.format = "latex")
head(wine_df2) %>%
kbl(caption = "Summary Table of Wine Dataset") %>%
kable_classic(html_font = "Cambria", full_width = F) %>%
kable_styling(latex_options = c("striped", "scale_down"))
fixed_acidity | volatile_acidity | citric_acid | residual_sugar | chlorides | free_sulfur_dioxide | total_sulfur_dioxide | density | pH | sulphates | alcohol | quality_cat |
---|---|---|---|---|---|---|---|---|---|---|---|
7.4 | 0.70 | 0.00 | 1.9 | 0.076 | 11 | 34 | 0.9978 | 3.51 | 0.56 | 9.4 | qual_low |
7.8 | 0.88 | 0.00 | 2.6 | 0.098 | 25 | 67 | 0.9968 | 3.20 | 0.68 | 9.8 | qual_low |
7.8 | 0.76 | 0.04 | 2.3 | 0.092 | 15 | 54 | 0.9970 | 3.26 | 0.65 | 9.8 | qual_low |
11.2 | 0.28 | 0.56 | 1.9 | 0.075 | 17 | 60 | 0.9980 | 3.16 | 0.58 | 9.8 | qual_high |
7.4 | 0.70 | 0.00 | 1.9 | 0.076 | 11 | 34 | 0.9978 | 3.51 | 0.56 | 9.4 | qual_low |
7.4 | 0.66 | 0.00 | 1.8 | 0.075 | 13 | 40 | 0.9978 | 3.51 | 0.56 | 9.4 | qual_low |
#figure 2
#corr=cor(exploratory_data_wine, method = "pearson")
corr=cor(wine_df_num, method = "pearson")
ggcorrplot(corr, hc.order = TRUE,
lab = TRUE,
lab_size = 3,
method="square",
colors = c("tomato2", "white", "springgreen3"),
title="Figure 2: Correlation of Variables")
#figure 3.
exploratory_data_wine <- wine_df
attach(exploratory_data_wine)
par(mfrow=c(1,5), oma = c(1,1,1,1) + 0.1, mar = c(3,3,1,1) + 0.1)
p1 <- ggplot(aes(factor(quality), alcohol), data = exploratory_data_wine) +
geom_boxplot() +
geom_smooth(aes(quality-4,alcohol), method = 'lm',color = 'red')
p2 <- ggplot(aes(factor(quality), sulphates), data = exploratory_data_wine) +
geom_boxplot() +
geom_smooth(aes(quality-4,sulphates), method = 'lm',color = 'red')
p3 <- ggplot(aes(factor(quality), pH), data = exploratory_data_wine) +
geom_boxplot() +
geom_smooth(aes(quality-4,pH), method = 'lm',color = 'red')
p4 <- ggplot(aes(factor(quality), density), data = exploratory_data_wine) +
geom_boxplot() +
geom_smooth(aes(quality-4,density), method = 'lm',color = 'red')
p5 <- ggplot(aes(factor(quality), total_sulfur_dioxide ), data = exploratory_data_wine) +
geom_boxplot() +
geom_smooth(aes(quality-4,total_sulfur_dioxide ), method = 'lm',color = 'red')
p6 <- ggplot(aes(factor(quality), free_sulfur_dioxide ), data = exploratory_data_wine) +
geom_boxplot() +
geom_smooth(aes(quality-4,free_sulfur_dioxide ), method = 'lm',color = 'red')
p7 <- ggplot(aes(factor(quality), chlorides), data = exploratory_data_wine) +
geom_boxplot() +
geom_smooth(aes(quality-4,chlorides), method = 'lm',color = 'red')
p8 <- ggplot(aes(factor(quality), residual_sugar ), data = exploratory_data_wine) +
geom_boxplot() +
geom_smooth(aes(quality-4,residual_sugar ), method = 'lm',color = 'red')
p9 <- ggplot(aes(factor(quality), citric_acid), data = exploratory_data_wine) +
geom_boxplot() +
geom_smooth(aes(quality-4,citric_acid), method = 'lm',color = 'red')
p10 <- ggplot(aes(factor(quality), volatile_acidity), data = exploratory_data_wine) +
geom_boxplot() +
geom_smooth(aes(quality-4,volatile_acidity), method = 'lm',color = 'red')
p11 <- ggplot(aes(factor(quality), fixed_acidity), data = exploratory_data_wine) +
geom_boxplot() +
geom_smooth(aes(quality-4,fixed_acidity), method = 'lm',color = 'red')
detach(exploratory_data_wine)
grid.arrange(p1, p2,p3,p4,p5,p6,p7,p8,p9,p10,p11, nrow = 4, ncol = 3, top = "Figure 3: Box plot to show quality with each variable")
## `geom_smooth()` using formula 'y ~ x'
## `geom_smooth()` using formula 'y ~ x'
## `geom_smooth()` using formula 'y ~ x'
## `geom_smooth()` using formula 'y ~ x'
## `geom_smooth()` using formula 'y ~ x'
## `geom_smooth()` using formula 'y ~ x'
## `geom_smooth()` using formula 'y ~ x'
## `geom_smooth()` using formula 'y ~ x'
## `geom_smooth()` using formula 'y ~ x'
## `geom_smooth()` using formula 'y ~ x'
## `geom_smooth()` using formula 'y ~ x'
## Random Forest Model
fit_control = trainControl(method = "repeatedcv",
number = 5,
repeats = 3)
set.seed(42)
rf_model = caret::train(quality_cat ~ .,
data = wine_train,
method = "rf",
preProcess = c("scale", "center"),
trControl = fit_control,
verbose = FALSE)
rf_model
## Random Forest
##
## 4549 samples
## 11 predictor
## 2 classes: 'qual_high', 'qual_low'
##
## Pre-processing: scaled (11), centered (11)
## Resampling: Cross-Validated (5 fold, repeated 3 times)
## Summary of sample sizes: 3639, 3640, 3639, 3639, 3639, 3639, ...
## Resampling results across tuning parameters:
##
## mtry Accuracy Kappa
## 2 0.8125606 0.5871234
## 6 0.8071389 0.5774126
## 11 0.8073591 0.5788010
##
## Accuracy was used to select the optimal model using the largest value.
## The final value used for the model was mtry = 2.
test_predict = predict(rf_model, wine_test)
confusionMatrix(test_predict, as.factor(wine_test$quality_cat))
## Confusion Matrix and Statistics
##
## Reference
## Prediction qual_high qual_low
## qual_high 1101 187
## qual_low 132 528
##
## Accuracy : 0.8362
## 95% CI : (0.819, 0.8524)
## No Information Rate : 0.633
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.6418
##
## Mcnemar's Test P-Value : 0.002499
##
## Sensitivity : 0.8929
## Specificity : 0.7385
## Pos Pred Value : 0.8548
## Neg Pred Value : 0.8000
## Prevalence : 0.6330
## Detection Rate : 0.5652
## Detection Prevalence : 0.6612
## Balanced Accuracy : 0.8157
##
## 'Positive' Class : qual_high
##
Feature Importance
Compute the feature importance
rf_model_imp = varImp(rf_model, scale = TRUE)
p1 = rf_model_imp$importance %>%
as.data.frame() %>%
rownames_to_column() %>%
ggplot(aes(x = reorder(rowname, Overall), y = Overall)) +
geom_bar(stat = "identity", fill = "#1F77B4", alpha = 0.8) +
coord_flip()
Plot Feature Importance
As per below graph alcohol is most important feature in wine quality.
p1
#### Feature importance breakdown by Category Show overall Feature importance.
Feature importance may differ for low quality and high quality wine. Show Feature importance for both categories.
roc_imp = filterVarImp(x = wine_train[, -ncol(wine_train)], y = wine_train$quality_cat)
p2 = roc_imp %>%
as.data.frame() %>%
rownames_to_column() %>%
ggplot(aes(x = reorder(rowname, qual_high), y = qual_high)) +
geom_bar(stat = "identity", fill = "#1F77B4", alpha = 0.8) +
coord_flip()
p3 = roc_imp %>%
as.data.frame() %>%
rownames_to_column() %>%
ggplot(aes(x = reorder(rowname, qual_low), y = qual_high)) +
geom_bar(stat = "identity", fill = "#1F77B4", alpha = 0.8) +
coord_flip()
grid.arrange(p1, p2, p3, ncol = 3, widths = c(0.5, 0.5, 0.5))
Iml package - Interpretable Machine Learning
The iml package in R is used for explaining/interpreting machine learning model. It has methods for
- Feature importance
- Partial dependence plots (Feature Effect)
- Individual conditional expectation plots (ICE)
- Tree surrogate
- LocalModel: Local Interpretable Model-agnostic Explanations (similar to lime)
- Shapley value for explaining single predictions
Preparation for explainability
To explain data using Iml,
a) remove the response variable (quality_cat) b) creating a new predictor object that holds the model, the data and the class labels.
X = wine_train %>%
dplyr::select(-quality_cat) %>%
as.data.frame()
predictor = Predictor$new(rf_model, data = X, y = wine_train$quality_cat)
str(predictor)
## Classes 'Predictor', 'R6' <Predictor>
## Public:
## batch.size: 1000
## class: NULL
## clone: function (deep = FALSE)
## data: Data, R6
## initialize: function (model = NULL, data = NULL, predict.function = NULL,
## model: train, train.formula
## predict: function (newdata)
## prediction.colnames: NULL
## prediction.function: function (newdata)
## print: function ()
## task: classification
## Private:
## predictionChecked: FALSE
A. Feature Effects and partial dependence plot
Here is function to determine feature importance
feature_imp <- function(my_predictor, my_data, my_feature) {
pdp_obj = FeatureEffect$new(my_predictor, feature = my_feature)
#pdp_obj$center(min(my_data$my_feature))
glimpse(pdp_obj$results)
pdp_obj$plot()
}
feature_imp(predictor, wine_train, "alcohol")
## Rows: 42
## Columns: 4
## $ .type <chr> "ale", "ale", "ale", "ale", "ale", "ale", "ale", "ale", "ale",…
## $ .class <fct> qual_high, qual_low, qual_high, qual_low, qual_high, qual_low,…
## $ .value <dbl> -0.006762349, 0.006762349, -0.006762349, 0.006762349, -0.00676…
## $ alcohol <dbl> 8.0, 8.0, 8.9, 8.9, 9.1, 9.1, 9.2, 9.2, 9.4, 9.4, 9.5, 9.5, 9.…
feature_imp(predictor, wine_train, "alcohol")
## Rows: 42
## Columns: 4
## $ .type <chr> "ale", "ale", "ale", "ale", "ale", "ale", "ale", "ale", "ale",…
## $ .class <fct> qual_high, qual_low, qual_high, qual_low, qual_high, qual_low,…
## $ .value <dbl> -0.0043493712, 0.0043493712, -0.0043493712, 0.0043493712, -0.0…
## $ alcohol <dbl> 8.0, 8.0, 8.9, 8.9, 9.1, 9.1, 9.2, 9.2, 9.4, 9.4, 9.5, 9.5, 9.…
feature_imp(predictor, wine_train, "volatile_acidity")
## Rows: 42
## Columns: 4
## $ .type <chr> "ale", "ale", "ale", "ale", "ale", "ale", "ale", "ale…
## $ .class <fct> qual_high, qual_low, qual_high, qual_low, qual_high, …
## $ .value <dbl> 0.023142238, -0.023142238, 0.023142238, -0.023142238,…
## $ volatile_acidity <dbl> 0.08, 0.08, 0.16, 0.16, 0.18, 0.18, 0.20, 0.20, 0.21,…
feature_imp(predictor, wine_train, "density")
## Rows: 42
## Columns: 4
## $ .type <chr> "ale", "ale", "ale", "ale", "ale", "ale", "ale", "ale", "ale",…
## $ .class <fct> qual_high, qual_low, qual_high, qual_low, qual_high, qual_low,…
## $ .value <dbl> 0.029165590, -0.029165590, 0.033532402, -0.033532402, 0.015833…
## $ density <dbl> 0.98711, 0.98711, 0.98999, 0.98999, 0.99070, 0.99070, 0.99132,…
feature_imp(predictor, wine_train, "total_sulfur_dioxide")
## Rows: 42
## Columns: 4
## $ .type <chr> "ale", "ale", "ale", "ale", "ale", "ale", "ale", …
## $ .class <fct> qual_high, qual_low, qual_high, qual_low, qual_hi…
## $ .value <dbl> -0.016019079, 0.016019079, -0.016019079, 0.016019…
## $ total_sulfur_dioxide <dbl> 6, 6, 19, 19, 30, 30, 44, 44, 61, 61, 78, 78, 89,…
feature_imp(predictor, wine_train, "sulphates")
## Rows: 42
## Columns: 4
## $ .type <chr> "ale", "ale", "ale", "ale", "ale", "ale", "ale", "ale", "ale…
## $ .class <fct> qual_high, qual_low, qual_high, qual_low, qual_high, qual_lo…
## $ .value <dbl> -0.0043127417, 0.0043127417, -0.0004517378, 0.0004517378, -0…
## $ sulphates <dbl> 0.22, 0.22, 0.35, 0.35, 0.38, 0.38, 0.39, 0.39, 0.41, 0.41, …
feature_imp(predictor, wine_train, "citric_acid")
## Rows: 42
## Columns: 4
## $ .type <chr> "ale", "ale", "ale", "ale", "ale", "ale", "ale", "ale", "a…
## $ .class <fct> qual_high, qual_low, qual_high, qual_low, qual_high, qual_…
## $ .value <dbl> 0.0001099143, -0.0001099143, 0.0001099143, -0.0001099143, …
## $ citric_acid <dbl> 0.00, 0.00, 0.04, 0.04, 0.14, 0.14, 0.20, 0.20, 0.23, 0.23…
feature_imp(predictor, wine_train, "residual_sugar")
## Rows: 42
## Columns: 4
## $ .type <chr> "ale", "ale", "ale", "ale", "ale", "ale", "ale", "ale",…
## $ .class <fct> qual_high, qual_low, qual_high, qual_low, qual_high, qu…
## $ .value <dbl> -0.0104893047, 0.0104893047, 0.0004396571, -0.000439657…
## $ residual_sugar <dbl> 0.60, 0.60, 1.20, 1.20, 1.30, 1.30, 1.50, 1.50, 1.65, 1…
feature_imp(predictor, wine_train, "pH")
## Rows: 42
## Columns: 4
## $ .type <chr> "ale", "ale", "ale", "ale", "ale", "ale", "ale", "ale", "ale", …
## $ .class <fct> qual_high, qual_low, qual_high, qual_low, qual_high, qual_low, …
## $ .value <dbl> 0.001836462, -0.001836462, 0.001836462, -0.001836462, 0.0018364…
## $ pH <dbl> 2.74, 2.74, 2.97, 2.97, 3.02, 3.02, 3.06, 3.06, 3.08, 3.08, 3.1…
feature_imp(predictor, wine_train, "fixed_acidity")
## Rows: 42
## Columns: 4
## $ .type <chr> "ale", "ale", "ale", "ale", "ale", "ale", "ale", "ale", …
## $ .class <fct> qual_high, qual_low, qual_high, qual_low, qual_high, qua…
## $ .value <dbl> 0.002523003, -0.002523003, 0.002523003, -0.002523003, 0.…
## $ fixed_acidity <dbl> 3.8, 3.8, 5.7, 5.7, 6.0, 6.0, 6.1, 6.1, 6.3, 6.3, 6.4, 6…
two_feature_pdp = FeatureEffects$new(predictor, feature = c("sulphates", "pH" ))
two_feature_pdp$plot()
multiple_feature_pdp = FeatureEffects$new(predictor, feature = c("volatile_acidity",
"density", "total_sulfur_dioxide", "sulphates", "citric_acid",
"residual_sugar", "pH", "fixed_acidity" ))
multiple_feature_pdp$plot()
Feature Interaction
Show Alcohol and other Features
interact = Interaction$new(predictor, feature = "alcohol")
#plot(interact)
interact$results %>%
ggplot(aes(x = reorder(.feature, .interaction), y = .interaction, fill = .class)) +
facet_wrap(~ .class, ncol = 2) +
geom_bar(stat = "identity", alpha = 0.8) +
scale_fill_tableau() +
coord_flip() +
guides(fill = FALSE)
## Warning: `guides(<scale> = FALSE)` is deprecated. Please use `guides(<scale> =
## "none")` instead.
interact_pH = Interaction$new(predictor, feature = "pH")
#plot(interact)
interact_pH$results %>%
ggplot(aes(x = reorder(.feature, .interaction), y = .interaction, fill = .class)) +
facet_wrap(~ .class, ncol = 2) +
geom_point(stat = "identity", alpha = 0.8) +
scale_fill_tableau() +
coord_flip() +
guides(fill = "none")
tree = TreeSurrogate$new(predictor, maxdepth = 5)
tree$r.squared
## [1] 0.2850662 0.2850662
glimpse(tree)
## Classes 'TreeSurrogate', 'InterpretationMethod', 'R6' <TreeSurrogate>
## Inherits from: <InterpretationMethod>
## Public:
## clone: function (deep = FALSE)
## initialize: function (predictor, maxdepth = 2, tree.args = NULL)
## maxdepth: 5
## plot: function (...)
## predict: function (newdata, type = "prob", ...)
## predictor: Predictor, R6
## print: function ()
## r.squared: 0.285066197773788 0.285066197773788
## results: data.frame
## tree: constparty, party
## Private:
## aggregate: function ()
## compute_r2: function (predict.tree, predict.model)
## dataDesign: data.table, data.frame
## dataSample: data.table, data.frame
## feature.names: NULL
## finished: TRUE
## flush: function ()
## generatePlot: function ()
## getData: function (...)
## intervene: function ()
## match_cols: function (newdata)
## multiClass: TRUE
## object.predict.colnames: .y.hat.qual_high .y.hat.qual_low
## plotData: NULL
## predictResults: data.frame
## printParameters: function ()
## q: function (x)
## qResults: data.frame
## run: function (force = FALSE, ...)
## run.prediction: function (dataDesign)
## sampler: Data, R6
## tree.args: NULL
## tree.predict.colnames: .y.hat.tree.qual_high .y.hat.tree.qual_low
## weightSamples: function ()
plot(tree)
#glimpse(prediction)
#It will be huge list
#print(tree$results)
#tree
Local Model: (LIME Type) model explainability
LocalModel is Lime implementation in R. It enable you to bold explain local/individual prediction bold To explain a row i from text set use below method
lime_imp <- function(my_predictor, my_data, row_num) {
lime_explain <- LocalModel$new(my_predictor, x.interest = my_data[row_num, ])
return (lime_explain)
}
# remove categorical column 12
#Get lime explaination for row 10 and 20
i <- 10
j <- 20
my_lime_explaination_1 = lime_imp(predictor,wine_test[,-12],i)
## Loading required package: glmnet
## Loaded glmnet 4.1-3
my_lime_explaination_2 = lime_imp(predictor,wine_test[,-12],j)
my_lime_explaination_1$results
Above example show prediction and impact
plot(my_lime_explaination_1)
p1 = my_lime_explaination_1$results %>%
ggplot(aes(x = reorder(feature.value, -effect), y = effect, fill = .class)) +
facet_wrap(~ .class, ncol = 2) +
geom_bar(stat = "identity", alpha = 0.8) +
scale_fill_tableau() +
coord_flip() +
labs(title = paste0("Local Model (LIME) Test case #", i)) +
guides(fill = "none")
p1
p2 = my_lime_explaination_2$results %>%
ggplot(aes(x = reorder(feature.value, -effect), y = effect, fill = .class)) +
facet_wrap(~ .class, ncol = 2) +
geom_bar(stat = "identity", alpha = 0.8) +
scale_fill_tableau() +
coord_flip() +
labs(title = paste0("Local Model (LIME) Test case #", j)) +
guides(fill = "none")
p2
grid.arrange(p1, p2, ncol = 2)
Shapley
Below method show how to use Shapley value for explaining prediction. It compute feature prediction with Shapley value for cooperative game theory.
my_shapley = Shapley$new(predictor, x.interest = wine_test[,-12][i, ])
head(my_shapley$results)
my_shapley$results %>%
ggplot(aes(x = reorder(feature.value, -phi), y = phi, fill = class)) +
facet_wrap(~ class, ncol = 2) +
geom_bar(stat = "identity", alpha = 0.8) +
scale_fill_tableau() +
coord_flip() +
guides(fill = "none")
Above plot shows result from Shapley - difference of instance prediction and dataset average predition among the features.
DALEX
DALEX: Descriptive machine Learning EXplanations Dalex package contains various explianers that identify relationship between dependent and independant variables.
pred_fun = function(object, newdata){predict(object, newdata = newdata, type = "prob")[, 2]}
yTest_data = as.numeric(wine_test$quality_cat)
explainer_classif_rf = DALEX::explain(rf_model, label = "rf",
data = wine_test, y = yTest_data,
predict_function = pred_fun)
## Preparation of a new explainer is initiated
## -> model label : rf
## -> data : 1948 rows 12 cols
## -> data : tibble converted into a data.frame
## -> target variable : 1948 values
## -> predict function : pred_fun
## -> predicted values : No value for predict function target column. ( [33m default [39m )
## -> model_info : package caret , ver. 6.0.90 , task classification ( [33m default [39m )
## -> predicted values : numerical, min = 0 , mean = 0.3728943 , max = 0.982
## -> residual function : difference between y and yhat ( [33m default [39m )
## -> residuals : numerical, min = 0.09 , mean = 0.9941489 , max = 1.958
## [32m A new explainer has been created! [39m
Feature Importance
variable_importance function have feature importance.
my_classifier = model_performance(explainer_classif_rf)
my_classifier_variable_importance <- variable_importance(explainer_classif_rf, loss_function = loss_root_mean_square)
plot(my_classifier_variable_importance )
Variable Response
Use variable_response function to compute marginal response
vr_alcohol_response = model_profile(explainer_classif_rf, variable = "alcohol", type = "partial")
vr_density_response = model_profile(explainer_classif_rf, variable = "density", type = "partial")
vr_fixed_response = model_profile(explainer_classif_rf, variable = "fixed_acidity", type = "partial")
vr_volatile_response = model_profile(explainer_classif_rf, variable = "volatile_acidity", type = "partial")
vr_citric_response = model_profile(explainer_classif_rf, variable = "citric_acid", type = "partial")
vr_chlorides_response = model_profile(explainer_classif_rf, variable = "chlorides", type = "partial")
vr_free_response = model_profile(explainer_classif_rf, variable = "free_sulfur_dioxide", type = "partial")
vr_total_response = model_profile(explainer_classif_rf, variable = "total_sulfur_dioxide", type = "partial")
vr_pH = model_profile(explainer_classif_rf, variable = "pH", type = "partial")
vr_sulphates = model_profile(explainer_classif_rf, variable = "sulphates", type = "partial")
plot(vr_alcohol_response)
plot(vr_density_response)
plot(vr_fixed_response)
plot(vr_volatile_response)
plot(vr_citric_response )
plot(vr_chlorides_response)
plot(vr_free_response)
plot(vr_total_response)
plot(vr_pH)
plot(vr_sulphates )
ale_alcohol_response = model_profile(explainer_classif_rf, variable = "alcohol", type = "accumulated")
ale_density_response = model_profile(explainer_classif_rf, variable = "density", type = "accumulated")
ale_fixed_response = model_profile(explainer_classif_rf, variable = "fixed_acidity", type = "accumulated")
ale_volatile_response = model_profile(explainer_classif_rf, variable = "volatile_acidity", type = "accumulated")
ale_citric_response = model_profile(explainer_classif_rf, variable = "citric_acid", type = "accumulated")
ale_chlorides_response = model_profile(explainer_classif_rf, variable = "chlorides", type = "accumulated")
ale_free_response = model_profile(explainer_classif_rf, variable = "free_sulfur_dioxide", type = "accumulated")
ale_total_response = model_profile(explainer_classif_rf, variable = "total_sulfur_dioxide", type = "accumulated")
ale_pH = model_profile(explainer_classif_rf, variable = "pH", type = "partial")
ale_sulphates = model_profile(explainer_classif_rf, variable = "sulphates", type = "partial")
plot(ale_alcohol_response)
plot(ale_density_response)
plot(ale_fixed_response)
plot(ale_volatile_response)
plot(ale_citric_response )
plot(ale_chlorides_response)
plot(ale_free_response)
plot(ale_total_response)
plot(ale_pH)
plot(ale_sulphates )
Breakdown
Breakdown is model agnostic tool for decomposition of predictions
X2 = wine_test[,-12]
predict.function = function(model, new_observation) {
predict(model, new_observation, type="prob")[,2]
}
predict.function(rf_model, wine_test[,-12][1, ])
## [1] 0.94
br = broken(model = rf_model,
new_observation = X2[1, ],
data = X,
baseline = "Intercept",
predict.function = predict.function,
keep_distributions = TRUE)
br
## contribution
## (Intercept) 0.000
## + volatile_acidity = 0.7 0.143
## + alcohol = 9.4 0.143
## + density = 0.9978 0.044
## + chlorides = 0.076 0.056
## + fixed_acidity = 7.4 0.020
## + residual_sugar = 1.9 0.019
## + sulphates = 0.56 0.013
## + citric_acid = 0 0.004
## + pH = 3.51 0.035
## + free_sulfur_dioxide = 11 0.044
## + total_sulfur_dioxide = 34 0.052
## final_prognosis 0.573
## baseline: 0.3667738
item_to_explain<- wine_test[,-12][1, ]
my_breakdown.function = function(my_model, item_to_exlain) {
predict(my_model, item_to_explain, type="prob")[,2]
}
my_breakdown.function(rf_model, item_to_explain)
## [1] 0.94
my_breakdown = broken(model = rf_model,
new_observation = item_to_explain,
data = X,
baseline = "Intercept",
predict.function = predict.function,
keep_distributions = TRUE)
my_breakdown
## contribution
## (Intercept) 0.000
## + volatile_acidity = 0.7 0.143
## + alcohol = 9.4 0.143
## + density = 0.9978 0.044
## + chlorides = 0.076 0.056
## + fixed_acidity = 7.4 0.020
## + residual_sugar = 1.9 0.019
## + sulphates = 0.56 0.013
## + citric_acid = 0 0.004
## + pH = 3.51 0.035
## + free_sulfur_dioxide = 11 0.044
## + total_sulfur_dioxide = 34 0.052
## final_prognosis 0.573
## baseline: 0.3667738
data.frame(y = my_breakdown$contribution,
x = my_breakdown$variable) %>%
ggplot(aes(x = reorder(x, y), y = y)) +
geom_bar(stat = "identity", fill = "#1F77B4", alpha = 0.8) +
coord_flip()
plot(my_breakdown)
plot(my_breakdown, plot_distributions = TRUE) + ggtitle ("Breakdown plot")
## Warning: `fun.y` is deprecated. Use `fun` instead.