[PYTHON] Anti-real virtual sample generation: "DiCE"

Introduction

Interpretation of ML (Machine Learning) models is an important issue in the field of business. By developing an ML model with high estimation accuracy and evaluating the "correspondence between output and features", It will be possible to come up with measures that have a business effect. ex: "Project leader conditions" (features) that increase the "success probability" (output)

In recent years, various algorithms such as "ELI5", "LIME", and "SHAP" have been developed as methods for interpreting ML models. The content of these algorithms is "calculation of the contribution of features to the output", and in such "evaluation of contributions", the interpretation of the model is limited to "description of the relationship between the output and features". , There is a point that it is difficult to generate sample features for output optimization.

"DiCE" developed by Microsoft Research is a model interpretation algorithm that considers anti-real virtual, and is an algorithm that enables feature sampling to obtain the desired output. It is different from other algorithms in that it provides direct materials by sample generation.

In this blog, we will take up "DiCE", which is a model interpretation algorithm that considers anti-real virtual, and summarize the outline of algorithm understanding by subscribing to the original paper and operation check by implementation.

反実仮想モデルによる説明
Source: https://www.microsoft.com/en-us/research/project/dice/

table of contents

--Introduction ――What is the interpretation of the ML model? --Overview of DiCE --What is DiCE? --DiCE concept: "Interpretation of ML model" by "anti-real virtual sample generation" --DiCE algorithm description and implementation --Usage data --Flow to sampling --Algorithm concept --Definition of optimization function --Try using DiCE --Summary --References

What is the interpretation of the ML model?

In supervised learning, the label predicted by the learning model is returned for the given data. At this time, the following points are opaque in ML.

・ Are the prediction results obtained from the learning model created? ・ Are you able to correctly learn the causal relationship of the phenomenon?

In order to "acquire reliability of ML" and "safe practical use", it is necessary to evaluate the correspondence between features and outputs (objective variables) in order to solve the above questions. In this blog, the evaluation of this correspondence is expressed as "interpretation of ML model".

機会学習に解釈とは
Figure: What is the interpretation of the ML model?

As an aside, the number of papers with the topic "ML interpretability" has increased about four times in the last 20 years. It is possible that the ML population simply increased with the practical application of the development of ML model theory, but it seems that it is a topic that attracts a certain number of interests. Source: https://beenkim.github.io/papers/BeenK_FinaleDV_ICML2017_tutorial.pdf

Overview of "DiCE"

・ What is DiCE?

MS.PNG

A framework for enumerating anti-real virtual samples provided by Microsoft Research: Python library

pip install dice_ml

Source: https://www.microsoft.com/en-us/research/project/dice/

・ DiCE concept: "Interpretation of ML model" by "anti-real virtual sample generation"

Anti-real virtual: Assuming the opposite of the facts. How to say something like "What if ..."

反実仮想モデルによる説明

Figure: Interpretation of ML model by anti-real virtual model

Currently, algorithms such as "ELI5", "LIME", and "SHAP" are being developed to describe the "relationship between output and features". The basic concept of each algorithm is "calculation of the contribution of features to the output", and the relationship with the output can be interpreted from the "positive / negative" and "large / small" of the contribution.

However, on the other hand, it is not possible to calculate the "optimal features" for output optimization only by "description of the relationship between the output and the features" for which only the "contribution" is calculated.

For example, consider the case where a machine learning model that performs "loan loan screening" as shown in the figure performs loan judgment for any candidate. The characteristics of any candidate are characterized by variables such as "age," "education," and "past borrowing history," and the machine learning model determines the candidate's lending based on pre-learned patterns. ..

And suppose the model decides that the candidate's loan is "rejected".

In this case, the conventional interpretation algorithm can explain "why the candidate was rejected", but cannot give a concrete proposal of "how can this candidate borrow?".

The solution to this problem is the basic concept of DiCE: "Anti-real virtual sample generation".

With DiCE, it is possible to generate a sample that corresponds to an anti-real virtual from an existing ML model and present a direct improvement plan. Source: https://qiita.com/OpenJNY/items/ef885c357b4e0a1551c0

DiCE algorithm description and implementation

・ Usage data

In this blog, we will use the "Bank Marketing Data Set from UCI Machine Learning Repository" data to explain the DiCE algorithm. usi.PNG

This data is data that describes the features and loan judgments for multiple members of society, and the objective variable y is set to loan = {0: No, 1: Yes}.

Table: Bank Marketing Data Set: Show only some columns in the dataset loan_data.PNG Source: http://archive.ics.uci.edu/ml/datasets/Bank+Marketing#

・ Flow until sampling

The problem setting and sampling policy in this explanation is shown below.

(Problem setting)

    1. Some candidates cannot take out a loan (loan = 0: No)
  1. An ML model has been developed that learns loans and features.
    1. Candidates sample the conditions for loan = 1: Yes (anti-real virtual) from the model.

(Sampling policy) Prepare the storage vector c for the anti-real virtual sample and calculate the output f (c) from the ML model. → Define a loss function that becomes smaller when the label becomes the desired class (loan = 1) → Extract sample c when taking an extreme value. = Anti-real virtual sample

問題設定.png Figure: Problem setting and sampling flow

・ Algorithm concept

DiCE has some ingenuity implemented in minimizing the above optimization function. I will explain these seven concepts here. In addition, the formulas that appear below are quoted from the original paper. Source: https://www.microsoft.com/en-us/research/publication/explaining-machine-learning-classifiers-through-diverse-counterfactual-examples/

    1. "Feasibility" Even if a vector (loan = 0: No) that is too far away from the fact vector is sampled as an anti-real virtual vector (c), it is unrealistic and cannot be realized.

Therefore, in DiCE, the distance between the vector (loan = 0: No) and the anti-real virtual vector (c) is added to the evaluation function, and optimization is applied so that the distance does not become too far. 望まないデータ群.png

  1. ** "Diversity" ** When sampling multiple anti-real virtual vectors (c), it is better to have various selection patterns, and I do not want similar vectors to be included.

Therefore, we define the distance between multiple anti-real virtual vectors (c_i) and optimize so that the value is as far as possible.

反実仮想モデルによる説明
    1. ** "Adopt hinge loss for loss function" ** The hinge function that is sometimes used for SVM etc. is used as loss. honge_yloss.PNG
  1. ** "Distinguishing between continuous variables and categorical variables" ** DiCE calculates the distance between multidimensional data and defines "feasibility" and "diversity" (1.2.). At that time, the distance is calculated by distinguishing between the continuous variable that has a continuous distribution and the categorical variable that has been made into a dummy variable due to the difference in the distribution method (probably ... it may be wrong).

  2. ** "Weight adjustment of features by calculating distance considering variance (continuous variable)" ** When considering the distance between multidimensional vectors, it is not possible to measure the appropriate distance between data groups using only the mean value (or median value). This is because if the variance is different for each dimensional variable, the distance of the data group will change due to the spread of the data. In such cases, it is necessary to consider the distance considering the variance such as the Maharabis distance. DiCE uses "variation by median: MAD", which is more robust than the mean value, to convert the distance related to continuous variables to the variance-considered type. dist_cont.PNG

MAD:median absolute deviation

If weighting is not performed for variables with a large MAD, “value fluctuations in a wider range” will occur, resulting in unrealistic feature quantities being sampled.

  1. ** "Selection of features to change" ** As a real problem of anti-real virtual sampling, there are features that cannot be changed.
  1. ** "Trade-off between feasibility and diversity" **

The optimization function defined by DiCE has hyperparameters that set the weights of "feasibility" and "diversity". By adjusting this value, each ratio can be changed. (Although it is not described in this report, it seems that there is an evaluation index for useful sampling, so I think that parameter tuning can be done by using that index.)

dist_cat.PNG

・ Definition of optimization function

After considering the above concept, define the function to be optimized.

After considering the sampling logic of the anti-real virtual vector (c_i) and each concept, the optimization function can be described as follows. opt_function.PNG λ1 and λ2 are hyperparameters

・ Try using DiCE

Now, let's actually generate an anti-real virtual sample using DiCE.

This implementation refers to the Microsoft document (GitHub). Source: https://www.microsoft.com/en-us/research/project/dice/

#Library import
import pandas as pd
import numpy as np
import dice_ml
import tensorflow as tf
from tensorflow import keras
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
#Read data & make objective variable dummy
data = pd.read_csv('bank-full.csv'
                   ,sep=';'
                   ,usecols=['age', 'job', 'marital', 'education', 'default', 'balance', 'housing','loan'])
data['loan'] = pd.get_dummies(data.loan,drop_first=True)
data.head()

loan_data.PNG

#Define data structures for DiCE
d = dice_ml.Data(dataframe=data, continuous_features=['age', 'balance'], outcome_name='loan')
#Learning ML model
train, _ = d.split_data(d.normalize_data(d.one_hot_encoded_data))
X_train = train.loc[:, train.columns != 'loan']
y_train = train.loc[:, train.columns == 'loan']

model = keras.Sequential()
model.add(keras.layers.Dense(20, input_shape=(X_train.shape[1],), kernel_regularizer=keras.regularizers.l1(0.001), activation=tf.nn.relu))
model.add(keras.layers.Dense(1, activation=tf.nn.sigmoid))
model.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(0.01), metrics=['accuracy'])
model.fit(X_train, y_train, validation_split=0.20, epochs=100, verbose=0, class_weight={0:1,1:2})
#Provide ML model to DiCE object
backend = 'TF'+tf.__version__[0] # TF1
m = dice_ml.Model(model=model, backend=backend)
#Anti-real virtual sampling model
exp = dice_ml.Dice(d, m)
#Query instance setting: Setting the reference value for anti-real virtual (I understand that)
query_instance = {'age':20,
                  'job':'blue-collar',
                  'marital':'single',
                  'education':'secondary',
                  'default':'no',
                  'balance': 129,
                  'housing':'yes'}
#Generating anti-real virtual sampling
dice_exp = exp.generate_counterfactuals(query_instance
                                        ,total_CFs=4
                                        ,desired_class="opposite")
#total_CFs: number of vectors to generate
#desired_class: Sample class you want to generate: opposite in case of anti-real virtual
#Visualization of sample results
dice_exp.visualize_as_dataframe()

loan_result.PNG

Looking at the sample results, we can see that four samples with outcome = 1 are generated for outcome = 0. The loan column is the input value for the sigmoid function, and it can be seen that outcome = 0 is about 0.3 and outcome = 1 is 0.5 or more.

Looking at the features, we can see that age, job, marital, etc. have changed. Certainly, you can see that the features different from the set query_instance are generated.

Also, when "changing the weight for the feature amount" or "specifying the feature amount to be changed", set as follows.

feature_weights = {'age': 1}#age weight → 1
features_to_vary = ['age','job']#Change only age and job
dice_exp = exp.generate_counterfactuals(query_instance
                                        ,total_CFs=4
                                        ,desired_class="opposite"
                                        ,feature_weights=feature_weights
                                        ,features_to_vary=features_to_vary)

Summary

In this blog, we took up "DiCE", which is a model interpretation algorithm considering anti-real virtual, and summarized the outline of algorithm understanding by subscribing to the original paper and operation check by implementation. DiCE is different from other algorithms in terms of "direct material provision by sample generation", and I think it will contribute to the interpretation of multifaceted models. The content of the original paper is not limited to the description, but there are some points that I have not followed yet, and I think there are mistakes. I would appreciate it if you could point out.

that's all.

References

・ Welcome to ELI5 ’s documentation! (ELI5) https://eli5.readthedocs.io/en/latest/ ・ "Why Should I Trust You?": Explaining the Predictions of Any Classifier (LIME) https://arxiv.org/abs/1602.04938#:~:text=version%2C%20v3)%5D-,%22Why%20Should%20I%20Trust%20You%3F%22%3A%20Explaining,the%20Predictions%20of%20Any%20Classifier&text=In%20this%20work%2C%20we%20propose,model%20locally%20around%20the%20prediction.

・ Try to interpret the prediction result of machine learning with LIME https://qiita.com/fufufukakaka/items/d0081cd38251d22ffebf ・ Explainable AI: ELI5, LIME and SHAP (kaggle kernel) https://www.kaggle.com/kritidoneria/explainable-ai-eli5-lime-and-shap ・ DiCE: Diverse Counterfactual Explanations for Machine Learning Classifiers (DiCE) https://www.microsoft.com/en-us/research/project/dice/ https://arxiv.org/pdf/1905.07697.pdf(転記してある数式はすべてここから) ・ DiCE: Interpretation / explanation method of machine learning model by anti-real virtual sample https://qiita.com/OpenJNY/items/ef885c357b4e0a1551c0 ・ Introductory Statistical Causal Reasoning: Judea Pearl (Author), Madelyn Glymour (Author), Nicholas P. Jewell (Author), Hiroshi Ochiumi (Translation) https://www.amazon.co.jp/%E5%85%A5%E9%96%80-%E7%B5%B1%E8%A8%88%E7%9A%84%E5%9B%A0%E6%9E%9C%E6%8E%A8%E8%AB%96-Judea-Pearl/dp/4254122411 ・ Overview of CounterFactual Machine Learning (Counterfactual Machine Learning) https://usaito.github.io/files/190729_sonyRD.pdf ・ Interpretable Machine Learning: The fuss, the concrete and the questions https://beenkim.github.io/papers/BeenK_FinaleDV_ICML2017_tutorial.pdf

Recommended Posts

Anti-real virtual sample generation: "DiCE"