Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

小舟、杜伟编辑

速度数百倍之差,有人断言KNN面临淘汰,更快更强的ANN将取而代之

数据科学经典算法 KNN 已被嫌慢,ANN 比它快 380 倍。

模式识别领域中,K - 近邻算法(K-Nearest Neighbor, KNN)是一种用于分类和回归的非参数统计方法。K - 近邻算法非常简单而有效,它的模型表示就是整个训练数据集。就原理而言,对新数据点的预测结果是通过在整个训练集上搜索与该数据点最相似的 K 个实例(近邻)并且总结这 K 个实例的输出变量而得出的。KNN 可能需要大量的内存或空间来存储所有数据,并且使用距离或接近程度的度量方法可能会在维度非常高的情况下(有许多输入变量)崩溃,这可能会对算法在你的问题上的性能产生负面影响。这就是所谓的维数灾难

近似最近邻算法(Approximate Nearest Neighbor, ANN)则是一种通过牺牲精度来换取时间和空间的方式从大量样本中获取最近邻的方法,并以其存储空间少、查找效率高等优点引起了人们的广泛关注。

近日,一家技术公司的数据科学主管 Marie Stephen Leo 撰文对 KNN 与 ANN 进行了比较,结果表明,在搜索到最近邻的相似度为 99.3% 的情况下,ANN 比 sklearn 上的 KNN 快了 380 倍

作者表示,几乎每门数据科学课程中都会讲授 KNN 算法,但它正在走向「淘汰」!

KNN 简述

在机器学习社区中,找到给定项的「K」个相似项被称为相似性搜索或最近邻(NN)搜索。最广为人知的 NN 搜索算法是 KNN 算法。在 KNN 中,给定诸如手机电商目录之类的对象集合,则对于任何新的搜索查询,我们都可以从整个目录中找到少量(K 个)最近邻。例如,在下面示例中,如果设置 K = 3,则每个「iPhone」的 3 个最近邻是另一个「iPhone」。同样,每个「Samsung」的 3 个最近邻也都是「Samsung」。

KNN 存在的问题

尽管 KNN 擅长查找相似项,但它使用详细的成对距离计算来查找邻居。如果你的数据包含 1000 个项,如若找出新产品的 K=3 最近邻,则算法需要对数据库中所有其他产品执行 1000 次新产品距离计算。这还不算太糟糕,但是想象一下,现实世界中的客户对客户(Customer-to-Customer,C2C)市场,其中的数据库包含数百万种产品,每天可能会上传数千种新产品。将每个新产品与全部数百万种产品进行比较是不划算的,而且耗时良久,也就是说这种方法根本无法扩展。

解决方案

将最近邻算法扩展至大规模数据的方法是彻底避开暴力距离计算,使用 ANN 算法。

近似最近距离算法(ANN)

严格地讲,ANN 是一种在 NN 搜索过程中允许少量误差的算法。但在实际的 C2C 市场中,真实的邻居数量比被搜索的 K 近邻数量要多。与暴力 KNN 相比,人工神经网络可以在短时间内获得卓越的准确性。ANN 算法有以下几种:

  • Spotify 的 ANNOY

  • Google 的 ScaNN

  • Facebook 的 Faiss

  • HNSW

分层的可导航小世界(Hierarchical Navigable Small World, HNSW)

在 HNSW 中,作者描述了一种使用多层图的 ANN 算法。在插入元素阶段,通过指数衰减概率分布随机选择每个元素的最大层,逐步构建 HNSW 图。这确保 layer=0 时有很多元素能够实现精细搜索,而 layer=2 时支持粗放搜索的元素数量少了 e^-2。最近邻搜索从最上层开始进行粗略搜索,然后逐步向下处理,直至最底层。使用贪心图路径算法遍历图,并找到所需邻居数量。

HNSW 图结构。最近邻搜索从最顶层开始(粗放搜索),在最底层结束(精细搜索)。

HNSW Python 包

整个 HNSW 算法代码已经用带有 Python 绑定的 C++ 实现了,用户可以通过键入以下命令将其安装在机器上:pip install hnswlib。安装并导入软件包之后,创建 HNSW 图需要执行一些步骤,这些步骤已经被封装到了以下函数中:

import hnswlibimport numpy as npdef fit_hnsw_index(features, ef=100, M=16, save_index_file=False):    # Convenience function to create HNSW graph    # features : list of lists containing the embeddings    # ef, M: parameters to tune the HNSW algorithm        num_elements = len(features)    labels_index = np.arange(num_elements)    EMBEDDING_SIZE = len(features[0])    # Declaring index    # possible space options are l2, cosine or ip    p = hnswlib.Index(space='l2', dim=EMBEDDING_SIZE)    # Initing index - the maximum number of elements should be known    p.init_index(max_elements=num_elements, ef_construction=ef, M=M)    # Element insertion    int_labels = p.add_items(features, labels_index)    # Controlling the recall by setting ef    # ef should always be > k    p.set_ef(ef)         # If you want to save the graph to a file    if save_index_file:         p.save_index(save_index_file)        return p

创建 HNSW 索引后,查询「K」个最近邻就仅需以下这一行代码:

ann_neighbor_indices, ann_distances = p.knn_query(features, k)

KNN 和 ANN 基准实验

计划

首先下载一个 500K + 行的大型数据集。然后将使用预训练 fasttext 句子向量将文本列转换为 300d 嵌入向量。然后将在不同长度的输入数据 [1000. 10000, 100000, len(data)] 上训练 KNN 和 HNSW ANN 模型,以度量数据大小对速度的影响。最后将查询两个模型中的 K=10 和 K=100 时的最近邻,以度量「K」对速度的影响。首先导入必要的包和模型。这需要一些时间,因为需要从网络上下载 fasttext 模型。

# Imports# For input data pre-processingimport jsonimport gzipimport pandas as pdimport numpy as npimport matplotlib.pyplot as pltimport fasttext.utilfasttext.util.download_model('en', if_exists='ignore') # English pre-trained modelft = fasttext.load_model('cc.en.300.bin')# For KNN vs ANN benchmarkingfrom datetime import datetimefrom tqdm import tqdmfrom sklearn.neighbors import NearestNeighborsimport hnswlib

数据

使用亚[马逊产品数据集],其中包含「手机及配件」类别中的 527000 种产品。然后运行以下代码将其转换为数据框架。记住仅需要产品 title 列,因为将使用它来搜索相似的产品。

# Data: http://deepyeti.ucsd.edu/jianmo/amazon/data = []with gzip.open('meta_Cell_Phones_and_Accessories.json.gz') as f:    for l in f:        data.append(json.loads(l.strip()))# Pre-Processing: https://colab.research.google.com/drive/1Zv6MARGQcrBbLHyjPVVMZVnRWsRnVMpV#scrollTo=LgWrDtZ94w89# Convert list into pandas dataframedf = pd.DataFrame.from_dict(data)df.fillna('', inplace=True)# Filter unformatted rowsdf = df[~df.title.str.contains('getTime')]# Restrict to just 'Cell Phones and Accessories'df = df[df['main_cat']=='Cell Phones & Accessories']# Reset indexdf.reset_index(inplace=True, drop=True)# Only keep the title columnsdf = df[['title']]# Check the dfprint(df.shape)df.head()

如果全部都可以运行精细搜索,你将看到如下输出:

亚马逊产品数据集。

嵌入

要对文本数据进行相似性搜索,则必须首先将其转换为数字向量。一种快速便捷的方法是使用经过预训练的网络嵌入层,例如 Facebook [FastText] 提供的嵌入层。由于希望所有行都具有相同的长度向量,而与 title 中的单词数目无关,所以将在 df 中的 title 列调用 get_sentence_vector 方法。

嵌入完成后,将 emb 列作为一个 list 输入到 NN 算法中。理想情况下可以在此步骤之前进行一些文本清理预处理。同样,使用微调的嵌入模型也是一个好主意。

# Title Embedding using FastText Sentence Embeddingdf['emb'] = df['title'].apply(ft.get_sentence_vector)# Extract out the embeddings column as a list of lists for input to our NN algosX = [item.tolist() for item in df['emb'].values]

基准

有了算法的输入,下一步进行基准测试。具体而言,在搜索空间中的产品数量和正在搜索的 K 个最近邻之间进行循环测试。在每次迭代中,除了记录每种算法的耗时以外,还要检查 pct_overlap,因为一定比例的 KNN 最近邻也被挑选为 ANN 最近邻。

注意整个测试在一台全天候运行的 8 核、30GB RAM 机器上运行大约 6 天,这有些耗时。理想情况下,你可以通过多进程来加快运行速度,因为每次运行都相互独立。

# Number of products for benchmark loopn_products = [1000, 10000, 100000, len(X)]# Number of neighbors for benchmark loopn_neighbors = [10, 100]# Dictionary to save metric results for each iterationmetrics = {'products':[], 'k':[], 'knn_time':[], 'ann_time':[], 'pct_overlap':[]}for products in tqdm(n_products):    # "products" number of products included in the search space    features = X[:products]        for k in tqdm(n_neighbors):           # "K" Nearest Neighbor search        # KNN         knn_start = datetime.now()        nbrs = NearestNeighbors(n_neighbors=k, metric='euclidean').fit(features)        knn_distances, knn_neighbor_indices = nbrs.kneighbors(X)        knn_end = datetime.now()        metrics['knn_time'].append((knn_end - knn_start).total_seconds())                # HNSW ANN        ann_start = datetime.now()        p = fit_hnsw_index(features, ef=k*10)        ann_neighbor_indices, ann_distances = p.knn_query(features, k)        ann_end = datetime.now()        metrics['ann_time'].append((ann_end - ann_start).total_seconds())                # Average Percent Overlap in Nearest Neighbors across all "products"        metrics['pct_overlap'].append(np.mean([len(np.intersect1d(knn_neighbor_indices[i], ann_neighbor_indices[i]))/k for i in range(len(features))]))                metrics['products'].append(products)        metrics['k'].append(k)        metrics_df = pd.DataFrame(metrics)metrics_df.to_csv('metrics_df.csv', index=False)metrics_df

运行结束时输出如下所示。从表中已经能够看出,HNSW ANN 完全超越了 KNN。

以表格形式呈现的结果。

结果

以图表的形式查看基准测试的结果,以真正了解二者之间的差异,其中使用标准的 matplotlib 代码来绘制这些图表。这种差距是惊人的。根据查询 K=10 和 K=100 最近邻所需的时间,HNSW ANN 将 KNN 彻底淘汰。当搜索空间包含约 50 万个产品时,在 ANN 上搜索 100 个最近邻的速度是 KNN 的 380 倍,同时两者搜索到最近邻的相似度为 99.3%。

在搜索空间包含 500K 个元素,搜索空间中每个元素找到 K=100 最近邻时,HNSW ANN 的速度比 Sklearn 的 KNN 快 380 倍。

在搜索空间包含 500K 个元素,搜索空间中每个元素找到 K=100 最近邻时,HNSW ANN 和 KNN 搜索到最近邻的相似度为 99.3%。

基于以上结果,作者认为可以大胆地说:「KNN 已死」。

本篇文章的代码作者已在 GitHub 上给出:https://github.com/stephenleo/adventures-with-ann/blob/main/knn_is_dead.ipynb

原文链接:https://medium.com/towards-artificial-intelligence/knn-k-nearest-neighbors-is-dead-fc16507eb3e

入门KNN算法ANN算法速度
1
相关数据
维数灾难技术

维数灾难(英语:curse of dimensionality,又名维度的诅咒)是一个最早由理查德·贝尔曼(Richard E. Bellman)在考虑优化问题时首次提出来的术语,用来描述当(数学)空间维度增加时,分析和组织高维空间(通常有成百上千维),因体积指数增加而遇到各种问题场景。这样的难题在低维空间中不会遇到,如物理空间通常只用三维来建模。

最近邻搜索技术

最邻近搜索(Nearest Neighbor Search, NNS)又称为“最近点搜索”(Closest point search),是一个在尺度空间中寻找最近点的优化问题。问题描述如下:在尺度空间M中给定一个点集S和一个目标点q ∈ M,在S中找到距离q最近的点。很多情况下,M为多维的欧几里得空间,距离由欧几里得距离或曼哈顿距离决定。

参数技术

在数学和统计学裡,参数(英语:parameter)是使用通用变量来建立函数和变量之间关系(当这种关系很难用方程来阐述时)的一个数量。

概率分布技术

概率分布(probability distribution)或简称分布,是概率论的一个概念。广义地,它指称随机变量的概率性质--当我们说概率空间中的两个随机变量具有同样的分布(或同分布)时,我们是无法用概率来区别它们的。

模式识别技术

模式识别(英语:Pattern recognition),就是通过计算机用数学技术方法来研究模式的自动处理和判读。 我们把环境与客体统称为“模式”。 随着计算机技术的发展,人类有可能研究复杂的信息处理过程。 信息处理过程的一个重要形式是生命体对环境及客体的识别。其概念与数据挖掘、机器学习类似。

查询技术

一般来说,查询是询问的一种形式。它在不同的学科里涵义有所不同。在信息检索领域,查询指的是数据库和信息系统对信息检索的精确要求

暂无评论
暂无评论~