Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

NIPS2016最佳论文的TensorFlow实现

当地时间 12 月 5 日,机器学习和计算神经科学的国际顶级会议第 30 届神经信息处理系统大会(NIPS 2016)在西班牙巴塞罗那开幕。本届最佳论文奖(Best Paper Award)获奖论文是《Value Iteration Networks》。在获悉了这一消息之后,机器之心就对该论文的作者之一吴翼进行了一次简单的采访,详情参阅独家 | 机器之心对话 NIPS 2016 最佳论文作者:如何打造新型强化学习观?。昨天,GitHub 用户 Abhishek Kumar 发布了一个 VIN 的 TensorFlow 实现。

这个 repository 包含了价值迭代网络(Value Iteration Networks)的一个 TensorFlow 实现。这个代码基于原作者的 Theano 实现。

VIN 的 TensorFlow 实现:https://github.com/TheAbhiKumar/tensorflow-value-iteration-networks

原作者的 Theano 实现:https://github.com/avivt/VIN

训练

从原作者的 repo 下载 16x16 和 28x28 GridWorld 数据集。因为方便程度和大小上的考量,本 repo 包含的是 8×8 的 GridWorld 数据集。

python3 train.py

如果你想监控训练过程,可以将 config.log 改变为 True,并且加载 tensorboard --logdir /tmp/vintf/。日志目录默认在 /tmp/vintf/,但也可以在 config.logdir 中修改。该代码目前默认运行在 8x8 GridWorld 模型上。

 8x8 GridWorld 模型可以在 30 次迭代以下收敛到大约 98.5% 的准确度。论文中提到该准确度应该达到 99.6% 左右,我也用 Theano 代码重现了这一结果。当使用与 Theano 实现在 16×16 和 28×28 的相同代码上的相同参数时,该 TensorFlow 模型并不如 NaN 结果那么完美。

软件要求

  • Python >= 3.5

  • TensorFlow >= 0.12

  • SciPy >= 0.18.1(用于加载数据)

数据集

这里使用的 GridWorld 数据集来自于原作者的 repo。其中也包含用于生成该数据集的 Matlab 脚本。处理该数据集的代码来自原 repo,但进行了一点点地修改。

该模型原本在三个其它域上进行了测试,该作者的原代码:https://github.com/avivt/VIN/issues/4

  • Mars Rover Navigation

  • Continuous control

  • WebNav

论文:Value Iteration Networks

584a47fd77c0d.png

摘要

在本研究中,我们介绍了价值迭代网络(value iteration network, VIN):一个完全可微分的神经网络,其中嵌入了「规划模块」。VIN 可以经过学习获得规划(planning)的能力,适用于预测涉及基于规划的推理结果,例如用于规划强化学习的策略。这种新方法的关键在于价值迭代算法的新型可微近似,它可以被表征为一个卷积神经网络,并以端到端的方式训练使用标准反向传播。我们在离散和连续的路径规划域和一个基于自然语言的搜索任务上评估了 VIN 产生的策略。实验证明,通过学习明确的规划计算,VIN 策略可以更好地泛化到未见过的新域。

vin.png

results.png

入门TensorFlow获奖论文NIPS 2016实现工程吴翼
暂无评论
暂无评论~