微信号:lls_tech

介绍:Help everyone become a global citizen!

流利说基于 TensorFlow 的自适应系统实践

2017-03-25 09:51 Xulin


今年2月15日,谷歌举办了首届 TensorFlow Dev Summit,并且发布了 TensorFlow 1.0 正式版。 3月18号,上海的谷歌开发者社区(GDG)组织了针对峰会的专场回顾活动。本文是我在活动上分享的一些回顾,主要介绍了在流利说,我们是如何使用 TensorFlow 来构建学生模型并应用在自适应系统里面的。

一、应用背景

自适应学习是什么

自适应学习是现在教育科技领域谈得比较多的一个概念,它的核心问题可以用一句话概括,即通过个性化的学习路径规划,提高学生的学习效率。为什么需要自适应学习?在传统的教学过程中,每个学生的学习路径是一致的,由于学生个人基础和学习能力的差异性,这种千人一面的做法对大部分学生来说其实比较低效。由此我们可以很自然地想到:如果我们能够根据学生的能力,去匹配合适的教学内容,就应该可以提高他们的学习效率。而这正是自适应系统希望达成的目标。

学生模型

那么自适应学习是如何达成这个目标的呢?这包含了两个核心问题,首先是学生的能力评估,正确的评估学生的能力是后续一切工作的基础,这是学生模型关心的问题。 其次是在评估好学生能力后,如何推送合适的内容,这是教学模型所关心的问题。本篇文章我们来讲讲如何利用 TensorFlow 来构建学生模型。

为了选择一个合适的学生模型,首先需要了解学生学习的过程。一个典型的学习过程是一个时间序列,用户在这个时间序列的各个时刻进行了一些学习行为,从而提高了自身的能力。我们可以假设学生的能力是可以通过学生在各个时刻回答问题的对错来反映的。要注意的是,由于学生学习时间的跨度可能很大,不能认为学生的水平保持不变,所以直接使用一些评测的方法(做了学生能力不变的假设)是不合适的。

Deep Knowledge Tracing

为了对学习序列建模,并评估学生各个时刻的能力,我们采用了 Deep Knowledge Tracing(DKT)模型,这个模型是由 Stanford 大学的 Piech Chris 等人在 NIPS 2015发表的,其本质是一个 Seq2Seq 的 RNN 模型,我们来看下模型的结构图:

上图是 DKT 模型按照时间展开的示意图,其输入序列x1, x2, x3 ...对应了t1, t2, t3 ... 时刻学生答题信息的编码,隐层状态对应了各个时刻学生的知识点掌握情况,模型的输出序列对应了各时刻学生回答题库中的所有习题回答正确的概率。

现在以上图为例来看看模型的各层结构。简单起见,假设题库总共有4道习题,那么首先可以确定的输出层节点数量为4,对应了各题回答正确的概率。接着,如果我们对输出采用 one-hot 编码,输入层的节点数就是题目数量 * 答题结果 = 4 * 2 = 8个。首先将输入层全连接到 RNN 的隐层,接着建立隐层到输出层的全连接,最后使用 Sigmoid 函数作为激活函数,一个基础的 DKT 模型就构建完毕了。接着为了训练模型,定义如下的损失函数:

其中 y 是 t 时刻的模型预测输出,q_{t+1} 是 t+1时刻用户回答的题目 ID(one-hot向量),a_{t+1}是 t+1 时刻的用户答题的对错, l是 binary cross entropy 损失函数。下面我们用几十行 TensorFlow 代码来实现一下这个模型。

二、模型构建

首先初始化模型参数,并且用tf.placeholder来接收模型的输入:

接着构建RNN层:

这里我们用 tf.dynamic_rnn 构建了一个多层循环神经网络,cell 参数用来指定了隐层神经元的结构,sequence_len 参数表示一个 batch 中各个序列的有效长度。state_series 表示隐层的输出,是一个三阶的 Tensor,self.current_state 表示 batch 各个序列的最后一个 step 的隐状态。

输出层:

输出层我们构建了两个变量作为隐层到输出层的连接的参数,并用 tf.sigmoid 作为激活函数。到这里我们已经可以得到模型的预测输出 self.pred_all,这也是一个三阶的张量, shape (batch_size, self.max_steps, num_skills)

为了训练模型,还需要计算模型损失函数和梯度,我们结合预测和标签信息来获得损失函数:

获得梯度并更新参数:

需要注意的是,在用 tf.gradients 得到梯度后,我们使用了 tf.clip_by_global_norm 方法,这主要是为了防止梯度爆炸的现象。最后应用了一次梯度下降得到的 self.train_op 就是计算图的训练结点。得到训练结点后,我们的计算图 (Graph) 就已经构造完毕,接着只需要创建一个 tf.Session 对象,并调用其run()方法来运行计算图就可以进行模型训练和测试了。由于训练和测试的接收的feed_dict类似,我们定义 step 方法来用作训练和测试,如下:

定义 assign_lr 方法来设置学习率:

至此,TensorFlowDKT 类就构造完毕了,我们可以这样使用它:

Demo 的完整代码,见https://github.com/lingochamp/tensorflow-dkt

三、工程实践

截至2016年12月底,流利说的懂你英语课程已经积累了数亿量级用户答题数据,在处理这些数据优化模型指标的过程中,我们也积累了一些实践经验。

Trancated BPTT

我们收集到的学习数据里,最长的序列长度超过五万。出于计算效率的考虑,包括 TensorFlow 在内的多数深度框架在进行 BPTT 的时候都会将序列按照时间维度展开,这在序列长度达到五万的情况下是不现实的(显存会爆)。所以我们需要将长的序列切断分为多个序列,然后保存前一部分序列训练的隐状态作为接下来一部分序列的初始状态输入,这样来进行长序列的训练。

多 GPU 加速

当数据到达数亿的量级以后,进行一次训练已经需要比较多的时间了,这个时候我们可以通过多 GPU 并行来加速训练。这里我们使用 Multi Tower 结构,这是一种数据并行的多 GPU 方案,我们来看下它的示意图:

可以看到在 Multi Tower 结构里,每个 GPU 持有一个模型实例,这些实例之间共享参数的。训练开始后,每次我们将多个 batch 数据分别喂给各个模型实例,在各 GPU 设备分别求得梯度信息。 接着,我们将收集到的梯度返回到 CPU,取平均以后,用来更新模型的参数。 由于模型的参数是共享的,这也就意味着所有模型实例的参数都得到了更新。接着我们来看下,以 TensorFlowDKT 类为例,我们如何用 Multi Tower 结构来构造训练结点:

其中 average_gradients 方法的代码可以参考https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10multigpu_train.py 。接着我们构造一个方法来返回 feed_dict

最后由于 dynamic_rnn 的一些 Operation 尚不支持 GPU,在训练开始前,我们需要配置一下 Session 避免出错:

完成上面的步骤,我们就可以用多 GPU 来加速模型训练了。

模型导出

在流利说,学生模型训练是用 Python API完成的,而学生模型预测服务则是用 C++ 实现的。关于如何从 C++ 如何从 Protobuf 文件中加载 Graph 可以参考https://www.tensorflow.org/tutorials/image_recognition。这里有一个模型导出的问题,即 Python 中的 tf.train.write_graph 方法只能够保存模型的图结构,而不能保存变量的值到 Protobuf 文件中。这个问题可以通过将 Variables 转换为tf.constant来解决, tensorflow.python.tools.freeze_graph提供了这样的方法。

结语

TensorFlow 是一个十分简单易用的机器学习框架,也是目前最流行的深度学习框架,它可以让机器学习研究者更少的关注底层的问题,而专注于问题的解决和算法的优化上。流利说算法团队从16年初开始就将 TensorFlow 应用到内部的机器学习项目里面,积累了很多相关的使用经验,从而帮助我们的用更智能算法来服务用户。

References

  1. Piech, Chris, et al. "Deep knowledge tracing." Advances in Neural Information Processing Systems. 2015.

  2. https://www.tensorflow.org/

  3. https://github.com/tensorflow/tensorflow



 
流利说技术团队 更多文章 从零开始手撸 HashMap 英语流利说前端工程化实践 英语流利说后端基础组件演进 流利说工程师带你了解AWS re:Invent 2016 如何建立数学模型估算日活用户数?
猜您喜欢 【报名微课堂】为什么使用React作为云平台的前端框架 【python】八大排序算法的 Python 实现 理解MySQL——架构与概念 创业板、市盈率、Python!|【量化小讲堂】计算创业板平均市盈率 预告 | SQL 优化之 transformation