如何让有监督学习变得有解释性

机器学习目前面临的一个大问题是模型缺少解释性,当别人问你,为何我要相信你的模型的时候,大多模型无法给我一个可靠的解释,仅仅提升准确率,并不能解决问题,这限制了在诸如金融,医疗等模型必须解释性,应用机器学习。今天介绍的模型LIME来自17年的NIPS会议,对应的python/R包能够对目前包括图像,自然语言,以及对数据表的传统分类及回归模型给出模型为何为何做出这样分类的解释。这篇小文先讲原理,再展示效果,最后介绍代码。

/pic/1_AZONQu3UI7wmiavWNia6CicoZF3tCQ.png

LIME的全称是“ Local Interpretable Model-Agnostic Explanations.”,局部性假设一个数据点被分类模型做了标记,LIME只会针对这个数据点,利用该点本身的特征对其进行解释,而Model-Agnostic指的是给出的解释不会与你用的是什么样的模型有关,不会涉及模型如何做出分类,这使得LIME具有通用性。

LIME的原理很简单,下如代表的一个二分类器,模型拟合的函数很复杂,而要待分类的数据点X是图中的粗线条红色十字,LIME在X周围随机生成一些数据点,让分类器进行分类,之后按照随机生成的位点和原位点X的距离,来最优化出图中的虚线,这里虚线的斜率代表了该模型针对数据X分类时,各个特征的重要性,以及当数据点发生变化时分类结果会怎么变化,从而解释了模型为何对X进行如下的分类。

/pic/2_DXHd2KCGOictxY139IQmtqAd3I4b2Q.png

接下来看分类的代码

1
<pre style="box-sizing: border-box;font family: SFMono-Regular, Consolas, "Liberation Mono", Menlo, Courier, monospace;font-size: 13.6px;overflow-wrap: normal;background-color: rgb(246, 248, 250);border-radius: 3px;line-height: 1.45;overflow: auto;padding: 16px;word-break: normal;color: rgb(36, 41, 46);text-align: start;">library(<span class="pl-smi" style="box-sizing: border-box;">caret</span>)<br></br>library(<span class="pl-smi" style="box-sizing: border-box;">lime</span>)<br></br><br></br><span class="pl-c" style="box-sizing: border-box;color: rgb(106, 115, 125);"><span class="pl-c" style="box-sizing: border-box;">#</span> Split up the data set</span><br></br><span class="pl-smi" style="box-sizing: border-box;">iris_test</span> <span class="pl-k" style="box-sizing: border-box;color: rgb(215, 58, 73);"><-</span> <span class="pl-smi" style="box-sizing: border-box;">iris</span>[<span class="pl-c1" style="box-sizing: border-box;color: rgb(0, 92, 197);">1</span><span class="pl-k" style="box-sizing: border-box;color: rgb(215, 58, 73);">:</span><span class="pl-c1" style="box-sizing: border-box;color: rgb(0, 92, 197);">5</span>, <span class="pl-c1" style="box-sizing: border-box;color: rgb(0, 92, 197);">1</span><span class="pl-k" style="box-sizing: border-box;color: rgb(215, 58, 73);">:</span><span class="pl-c1" style="box-sizing: border-box;color: rgb(0, 92, 197);">4</span>]<br></br><span class="pl-smi" style="box-sizing: border-box;">iris_train</span> <span class="pl-k" style="box-sizing: border-box;color: rgb(215, 58, 73);"><-</span> <span class="pl-smi" style="box-sizing: border-box;">iris</span>[<span class="pl-k" style="box-sizing: border-box;color: rgb(215, 58, 73);">-</span>(<span class="pl-c1" style="box-sizing: border-box;color: rgb(0, 92, 197);">1</span><span class="pl-k" style="box-sizing: border-box;color: rgb(215, 58, 73);">:</span><span class="pl-c1" style="box-sizing: border-box;color: rgb(0, 92, 197);">5</span>), <span class="pl-c1" style="box-sizing: border-box;color: rgb(0, 92, 197);">1</span><span class="pl-k" style="box-sizing: border-box;color: rgb(215, 58, 73);">:</span><span class="pl-c1" style="box-sizing: border-box;color: rgb(0, 92, 197);">4</span>]<br></br><span class="pl-smi" style="box-sizing: border-box;">iris_lab</span> <span class="pl-k" style="box-sizing: border-box;color: rgb(215, 58, 73);"><-</span> <span class="pl-smi" style="box-sizing: border-box;">iris</span>[[<span class="pl-c1" style="box-sizing: border-box;color: rgb(0, 92, 197);">5</span>]][<span class="pl-k" style="box-sizing: border-box;color: rgb(215, 58, 73);">-</span>(<span class="pl-c1" style="box-sizing: border-box;color: rgb(0, 92, 197);">1</span><span class="pl-k" style="box-sizing: border-box;color: rgb(215, 58, 73);">:</span><span class="pl-c1" style="box-sizing: border-box;color: rgb(0, 92, 197);">5</span>)]<br></br><br></br><span class="pl-c" style="box-sizing: border-box;color: rgb(106, 115, 125);"><span class="pl-c" style="box-sizing: border-box;">#</span> Create Random Forest model on iris data</span><br></br><span class="pl-smi" style="box-sizing: border-box;">model</span> <span class="pl-k" style="box-sizing: border-box;color: rgb(215, 58, 73);"><-</span> train(<span class="pl-smi" style="box-sizing: border-box;">iris_train</span>, <span class="pl-smi" style="box-sizing: border-box;">iris_lab</span>, <span class="pl-v" style="box-sizing: border-box;color: rgb(227, 98, 9);">method</span> <span class="pl-k" style="box-sizing: border-box;color: rgb(215, 58, 73);">=</span> <span class="pl-s" style="box-sizing: border-box;color: rgb(3, 47, 98);"><span class="pl-pds" style="box-sizing: border-box;">'</span>rf<span class="pl-pds" style="box-sizing: border-box;">'</span></span>)<br></br><br></br><span class="pl-c" style="box-sizing: border-box;color: rgb(106, 115, 125);"><span class="pl-c" style="box-sizing: border-box;">#</span> Create an explainer object</span><br></br><span class="pl-smi" style="box-sizing: border-box;">explainer</span> <span class="pl-k" style="box-sizing: border-box;color: rgb(215, 58, 73);"><-</span> lime(<span class="pl-smi" style="box-sizing: border-box;">iris_train</span>, <span class="pl-smi" style="box-sizing: border-box;">model</span>)<br></br><br></br><span class="pl-c" style="box-sizing: border-box;color: rgb(106, 115, 125);"><span class="pl-c" style="box-sizing: border-box;">#</span> Explain new observation</span><br></br><p><span class="pl-smi" style="box-sizing: border-box;">explanation</span> <span class="pl-k" style="box-sizing: border-box;color: rgb(215, 58, 73);"><-</span> explain(<span class="pl-smi" style="box-sizing: border-box;">iris_test</span>, <span class="pl-smi" style="box-sizing: border-box;">explainer</span>, <span class="pl-v" style="box-sizing: border-box;color: rgb(227, 98, 9);">n_labels</span> <span class="pl-k" style="box-sizing: border-box;color: rgb(215, 58, 73);">=</span> <span class="pl-c1" style="box-sizing: border-box;color: rgb(0, 92, 197);">1</span>, <span class="pl-v" style="box-sizing: border-box;color: rgb(227, 98, 9);">n_features</span> <span class="pl-k" style="box-sizing: border-box;color: rgb(215, 58, 73);">=</span> <span class="pl-c1" style="box-sizing: border-box;color: rgb(0, 92, 197);">2</span>)</p><p><span style="font-size: 13.6px;"><br></br></span></p><p><span style="font-size: 13.6px;">explanation</span></p><p><span class="pl-c" style="color: rgb(106, 115, 125);font-size: 13.6px;box-sizing: border-box;"><br></br></span></p><p><span class="pl-c" style="color: rgb(106, 115, 125);font-size: 13.6px;box-sizing: border-box;">#</span><span style="color: rgb(106, 115, 125);font-size: 13.6px;"> And can be visualised directly</span></p><p><span style="font-size: 13.6px;">plot_features(</span><span class="pl-smi" style="font-size: 13.6px;box-sizing: border-box;">explanation</span><span style="font-size: 13.6px;">)</span></p><p><br></br></p>

这里的数据集是鸢尾花数据集,用随机森林做了分类,测试时选择了5个数据,需要解释的是最后三句,explain函数将分类器,待解释的特征作为参数,而最后对解释器进行可视化,得出下面的图:

/pic/3_6FFhibqEI4SMoF9jdxiadUCkBKIiaw.png

针对每个测试案例,给出了一个柱状图,图中给出了原特征中的四个的俩个,柱状图中横轴的那句话代表该特征需满足的条件,柱状图的长度代表该特征在满足该条件对模型做出该分类的重要性,柱状图的上方给出该数据点的分类结果。可以看出,对于分类为setosa,最重要的是花瓣的宽度小于0.4,由于4个数据点中都指出该条件最重要。

对于回归模型,也能给出类似的结果,下图是波士顿房价数据集,用sklearn的中的随机森林,来做回归后,对模型解释性给出的可视化,这里每个柱状图上的条件按是否对房价提升和降低进行了打分,这可以看成对模型预测结果的解释。

/pic/4_KImjicxx3yJLcVL2U9b5F2GdPhYlVg.jpg

而对图像数据,LIME可以可以标出图像中那个部分导致了模型导致了分类结果,例如下如被分类为草莓,LIME标出的分类原因正好对应是图片中俩个草莓,

/pic/5_G2nH7ZPJFz8WOX4lpm0f4ACHiaib6Q.png

而下图被谷歌的inception从高到低按可能的概率给出不同的分类,LIME可以解释不同的标签为何将图片看成是电吉他,木吉他或者拉布拉多狗。

/pic/6_bic8KWOvX7Cobf8gMbibKuz9fHsNeA.png

这里展示的是文本主题分类中,对新闻主题分类进行解释的可视化,这里柱状图的横轴是待分类文本的关键词,这段文本的主题被分类为无神论,而不是基督教,究竟是哪些关键词导致这样的分类了,这里不同的关键词,他们支持的分类结果及对应权重。

/pic/7_LF59V6PPasOP7qDXuc6tpUHADl6ejw.png

LIME还针对文本分类,做了可视化的交互工具Skiny,来互动的展示文本中对分类影响大的关键词。

/pic/

总结一下,之前提升模型解释性的尝试,只能依靠树模型给出的特征重要性,或者看去掉了那些特征,模型的准确性变差较多,或者看那些特征的相对梯度较大,LIME提供的通用框架,对促成可解释的机器学习,有所帮助,找到了关键特征,对建模时的缺失值补全,数据清洗,以及数据增强,都会有所助益。

更多阅读

图像分类中的隐式标签正则化

GAN的五个神奇应用场景

Louis.W2019-06-10 07:26:46

Shap, partial dependent plot, permutation importance.

徐瑞龙2019-06-10 11:14:15

意思是用局部的线性模型提供解释性?

作者

类似的概念,对回归分类都挺有效的赞 1