微信号:PythonPush

介绍:人生苦短,我用 Python.Python 越来越受广大程序员的喜爱.

回归树的原理及Python实现

2019-01-05 19:21 程序君

Linux编程
点击右侧关注,免费入门到精通!


作者丨李小文

https://www.zhihu.com/people/liu-tie-nan


提到回归树相信大家应该都不会觉得陌生(不陌生你点进来干嘛[捂脸]),大名鼎鼎的GBDT算法就是用回归树组合而成的。本文就回归树的基本原理进行讲解,并手把手、肩并肩地带您实现这一算法。


完整实现代码请参考本人的p...哦不是...github:


regression_tree.py


https://github.com/tushushu/imylu/blob/master/imylu/tree/regression_tree.py


regression_tree_example.py


https://github.com/tushushu/imylu/blob/master/examples/regression_tree_example.py


1. 原理篇


我们用人话而不是大段的数学公式来讲讲回归树是怎么一回事。


1.1 最简单的模型


如果预测某个连续变量的大小,最简单的模型之一就是用平均值。比如同事的平均年龄是28岁,那么新来了一批同事,在不知道这些同事的任何信息的情况下,直觉上用平均值28来预测是比较准确的,至少比0岁或者100岁要靠谱一些。我们不妨证明一下我们的直觉:


1. 定义损失函数L,其中y_hat是对y预测值,使用MSE来评估损失:



2. 对y_hat求导:



3. 令导数等于0,最小化MSE,则:



4. 所以,



结论,如果要用一个常量来预测y,用y的均值是一个最佳的选择。


1.2 加一点难度


仍然是预测同事年龄,这次我们预先知道了同事的职级,假设职级的范围是整数1-10,如何能让这个信息帮助我们更加准确的预测年龄呢?


一个思路是根据职级把同事分为两组,这两组分别应用我们之前提到的“平均值”模型。比如职级小于5的同事分到A组,大于或等于5的分到B组,A组的平均年龄是25岁,B组的平均年龄是35岁。如果新来了一个同事,职级是3,应该被分到A组,我们就预测他的年龄是25岁。


1.3 最佳分割点


还有一个问题待解决,如何取一个最佳的分割点对不同职级的同事进行分组呢? 我们尝试所有m个可能的分割点P_i,沿用之前的损失函数,对A、B两组分别计算Loss并相加得到L_i。最小的L_i所对应的P_i就是我们要找的“最佳分割点”。


1.4 运用多个变量


再复杂一些,如果我们不仅仅知道了同事的职级,还知道了同事的工资(貌似不科学),该如何预测同事的年龄呢?


我们可以分别根据职级、工资计算出职级和工资的最佳分割点P_1, P_2,对应的Loss L_1, L_2。然后比较L_1和L2,取较小者。假设L_1 < L_2,那么按照P_1把不同职级的同事分为A、B两组。在A、B组内分别计算工资所对应的分割点,再分为C、D两组。这样我们就得到了AC, AD, BC, BD四组同事以及对应的平均年龄用于预测。


1.5 答案揭晓


如何实现这种1 to 2, 2 to 4, 4 to 8的算法呢?


熟悉数据结构的同学自然会想到二叉树,这种树被称为回归树,顾名思义利用树形结构求解回归问题。


2. 实现篇


本人用全宇宙最简单的编程语言——Python实现了回归树算法,没有依赖任何第三方库,便于学习和使用。简单说明一下实现过程,更详细的注释请参考本人github上的代码。


2.1 创建Node类


初始化,存储预测值、左右结点、特征和分割点


class Node(object):
    def __init__(self, score=None):
        self.score = score
        self.left = None
        self.right = None
        self.feature = None
        self.split = None


2.2 创建回归树类


初始化,存储根节点和树的高度。


class RegressionTree(object):
    def __init__(self):
        self.root = Node()
        self.height = 1


2.3 计算分割点、MSE


根据自变量X、因变量y、X元素中被取出的行号idx,列号feature以及分割点split,计算分割后的MSE。注意这里为了减少计算量,用到了方差公式:


$D(X) = E{[X-E(X)]^2} = E(X^2)-[E(X)]^2$


def _get_split_mse(self, X, y, idx, feature, split):
    split_sum = [0, 0]
    split_cnt = [0, 0]
    split_sqr_sum = [0, 0]

    for i in idx:
        xi, yi = X[i][feature], y[i]
        if xi split:
            split_cnt[0] += 1
            split_sum[0] += yi
            split_sqr_sum[0] += yi ** 2
        else:
            split_cnt[1] += 1
            split_sum[1] += yi
            split_sqr_sum[1] += yi ** 2

    split_avg = [split_sum[0] / split_cnt[0], split_sum[1] / split_cnt[1]]
    split_mse = [split_sqr_sum[0] - split_sum[0] * split_avg[0],
                    split_sqr_sum[1- split_sum[1] * split_avg[1]]
    return sum(split_mse), splitsplit_avg


2.4 计算最佳分割点


def _choose_split_point(self, X, y, idx, feature):
    unique = set([X[i][feature] for i in idx])
    if len(unique) == 1:
        return None

    unique.remove(min(unique))
    mse, split, split_avg = min(
        (self._get_split_mse(X, y, idx, feature, split)
            for split in unique), key=lambda x: x[0])
    return mse, feature, split, split_avg


2.5 选择最佳特征


遍历所有特征,计算最佳分割点对应的MSE,找出MSE最小的特征、对应的分割点,左右子节点对应的均值和行号。如果所有的特征都没有不重复元素则返回None


def _choose_feature(self, X, y, idx):
    m = len(X[0])
    split_rets = [x for x in map(lambda x: self._choose_split_point(
        X, y, idx, x), range(m)) if x is not None]

    if split_rets == []:
        return None
    _, feature, split, split_avg = min(
        split_rets, key=lambda x: x[0])

    idx_split = [[], []]
    while idx:
        i = idx.pop()
        xi = X[i][feature]
        if xi < split:
            idx_split[0].append(i)
        else:
            idx_split[1].append(i)
    return feature, split, split_avg, idx_split


2.6 规则转文字


将规则用文字表达出来,方便我们查看规则。


def _expr2literal(self, expr):
    feature, op, split = expr
    op = ">=" if op == 1 else "<"
    return "Feature%d %s %.4f" % (feature, op, split)


2.7 获取规则


将回归树的所有规则都用文字表达出来,方便我们了解树的全貌。这里用到了队列+广度优先搜索。有兴趣也可以试试递归或者深度优先搜索。


def _get_rules(self):
    que = [[self.root, []]]
    self.rules = []

    while que:
        nd, exprs = que.pop(0)
        if not(nd.left or nd.right):
            literals = list(map(self._expr2literal, exprs))
            self.rules.append([literals, nd.score])

        if nd.left:
            rule_left = copy(exprs)
            rule_left.append([nd.feature, -1, nd.split])
            que.append([nd.left, rule_left])

        if nd.right:
            rule_right = copy(exprs)
            rule_right.append([nd.feature, 1, nd.split])
            que.append([nd.right, rule_right])


2.8 训练模型


仍然使用队列+广度优先搜索,训练模型的过程中需要注意: 1. 控制树的最大深度max_depth; 2. 控制分裂时最少的样本量min_samples_split; 3. 叶子结点至少有两个不重复的y值; 4. 至少有一个特征是没有重复值的。


def fit(self, X, y, max_depth=5, min_samples_split=2):
    self.root.score = sum(y) / len(y)
    idxs = list(range(len(y)))
    que = [(self.depth + 1, self.root, idxs)]

    while que:
        depth, nd, idxs = que.pop(0)

        if depth > max_depth:
            depth -= 1
            break

        if len(idxs) < min_samples_split or \
                len(set(map(lambda i: y[i], idxs))) == 1:
            continue

        split_ret = self._choose_feature(X, y, idxs)
        if split_ret is None:
            continue

        _, feature, split, split_avg = split_ret

        nd.feature = feature
        nd.split = split
        nd.left = Node(split_avg[0])
        nd.right = Node(split_avg[1])

        idxs_split = list_split(X, idxs, feature, split)
        que.append((depth+1, nd.left, idxs_split[0]))
        que.append((depth+1, nd.right, idxs_split[1]))

    self.depth = depth
    self._get_rules()


2.9 打印规则


def print_rules(self):
    for i, rule in enumerate(self.rules):
        literals, score = rule
        print("Rule %d: " % i, ' | '.join(
            literals) + ' => split_hat %.4f' % score)


2.10 预测一个样本


def _predict(self, row):
    nd = self.root
    while nd.left and nd.right:
        if row[nd.feature] < nd.split:
            nd = nd.left
        else:
            nd = nd.right
    return nd.score


2.11 预测多个样本


def predict(self, X):
    return [self._predict(Xi) for Xi in X]


3 效果评估


3.1 main函数


使用著名的波士顿房价数据集,按照7:3的比例拆分为训练集和测试集,训练模型,并统计准确度。


@run_time
def main():
    print("Tesing the accuracy of RegressionTree...")
    X, y = load_boston_house_prices()
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, random_state=10)

    reg = RegressionTree()
    reg.fit(X=X_train, y=y_train, max_depth=5)

    reg.print_rules()
    get_r2(reg, X_test, y_test)


3.2 效果展示


最终生成了15条规则,拟合优度0.801,运行时间1.74秒,效果还算不错~




3.3 工具函数


本人自定义了一些工具函数,可以在github上查看


https://github.com/tushushu/imylu/tree/master/imylu/utils


1. run_time - 测试函数运行时间


2. load_boston_house_prices - 加载波士顿房价数据


3. train_test_split - 拆分训练集、测试集


4. get_r2 - 计算拟合优度


总结


回归树的原理:


损失最小化,平均值大法。 最佳行与列,效果顶呱呱。


回归树的实现:


一顿操作猛如虎,加减乘除二叉树。


 推荐↓↓↓ 

👉16个技术公众号】都在这里!

涵盖:程序员大咖、源码共读、程序员共读、数据结构与算法、黑客技术和网络安全、大数据科技、编程前端、Java、Python、Web编程开发、Android、iOS开发、Linux、数据库研发、幽默程序员等。

万水千山总是情,点个 “ 好看” 行不行
 
Python开发 更多文章 教程 | 十分钟学会函数式 Python 2019年程序员岗位招聘信息分析 浅入深谈:一道Python面试题,让我明白了殊途同归,却开始怀疑自己 程序员必知的 Python 陷阱与缺陷列表 Python3在磁盘上的B+树:Bplustree
猜您喜欢 软件测试中排错的基本方法 专访洪小文:人和当前AI的智能有什么区别?以简制繁 v.s. 以繁制繁 深入探讨 JavaScript 中的错误对象和堆栈追踪 基于AngularJS的个推前端云组件探秘 【达内幽默】让上班族瞬间怒掀桌的20件事,你遇到哪几件?