Auto Byte

专注未来出行及智能汽车科技

微信扫一扫获取更多资讯

Science AI

关注人工智能与其他前沿技术、基础学科的交叉研究与融合发展

微信扫一扫获取更多资讯

Wee Tee Soh作者

图卷积实战——文本分类

本文将介绍基于基于文本的 GCN,使用 Pytorch 和基本的库。GCN 模型是目前很新颖的半监督学习方法。

总的来说,它是将整个语料库嵌入到一个以文档或者单词为节点(有的带标签,有的不带标签)的图中,各节点之间根据它们的关系存在带有不同权重的边。然后输入带有标签的文档节点,训练GCN模型,预测没有标签的文档。

本文选择圣经作为语料库,因为它是世界上最多阅读并且有着非常丰富文本结构的书籍。圣经包括66本书(创世纪、出埃及记等)以及1189个章节。本监督学习任务是训练一个语言模型能够准确分类没有标签的章节属于哪本书。(由于这个数据集的标签是全部已知,所以本文选择10%~20%的数据隐藏标签作为测试集。)

为了解决这个问题,语言模型首先要学会区分各种书籍(比如《创世纪》更多的是谈论亚当和夏娃,而《传道书》讲述了所罗门王的生平)。下面我们将展示文本 GCN 能够很好的捕捉这些信息。

语料库



为了能够让 GCN 捕捉章节的上下文,我们建立了表示章节和单词关系的图,如上图所示。节点由1189个章节和所有词汇组成,这些节点之间是带有权重的边。权重 Aij 定义如下:

上式中,PMI 是滑动窗口多对共现单词之间的逐点互信息。#W 定义10个字节的长度,#W(i) 是包含单词 i 的滑动窗口数,#W(i,j) 是包含单词 i的滑动窗口数, #W 是语料库中滑动窗口总数。TF-IDF 是一种加权技术,字词的重要性随着它在文件中出现的次数成正比增加,但同时会随着它在语料库中出现的频率成反比下降。直观地说,具有高、正的 PMI 值的词之间具有高语义相关性,相反,我们不会在具有负 PMI 的词之间建立边。总的来说,TF-IDF 加权文档和 word 之间的边,捕获文档内的上下文。PMI 加权词汇之间的边,可以跨文档捕获上下文。


相比之下,不是基于图的模型,这种跨文档的上下文信息很难作为输入特征,并且模型必须标签从头开始学习它们。由于 GCN 可以提供文档之间的关联关系,这些信息于NLP 任务明确相关,所以可以预期 GCN 的表现将会更好。

计算 TF-IDF

计算 TF-IDF 相对简单,我们知道数学公式,并且理解它的原理,只需要在1189 个文档中使用TfidfVectorizer 模块,并将结果存储在 dataframe 中。为后面创建图时文档-单词之间的权重。代码如下:

### Tfidf
vectorizer = TfidfVectorizer(input="content", max_features=None, tokenizer=dummy_fun, preprocessor=dummy_fun)
vectorizer.fit(df_data["c"])df_tfidf = vectorizer.transform(df_data["c"])
df_tfidf = df_tfidf.toarray()
vocab = vectorizer.get_feature_names()
vocab = np.array(vocab)
df_tfidf = pd.DataFrame(df_tfidf,columns=vocab)


计算词汇间 PMI

计算词汇之间的 PMI 要更复杂一些,首先我们需要在10个单词长度的滑动窗口内找到单词ij的共现,以方块矩阵的形式存储在 dataframe 中,其中行和列表示词汇表。然后使用之前的定义计算 PMI 。代码如下:

### PMI between words
window = 10 # sliding window size to calculate point-wise mutual information between words
names = vocaboccurrences = OrderedDict((name, OrderedDict((name, 0) for name in names)) for name in names)
# Find the co-occurrences:no_windows = 0; print("calculating co-occurences")for l in df_data["c"]:    
for i in range(len(l)-window):        
no_windows += 1        
d = l[i:(i+window)]; 
dum = []       
 for x in range(len(d)):            
for item in d[:x] + d[(x+1):]:                
if item not in dum:                    
occurrences[d[x]][item] += 1; dum.append(item)        
df_occurences = pd.DataFrame(occurrences, columns=occurrences.keys())
df_occurences = (df_occurences + df_occurences.transpose())/2 
## symmetrize it as window size on both sides may not be samedel occurrences
### convert to PMIp_i = df_occurences.sum(axis=0)/no_windows
p_ij = df_occurences/no_windowsdel 
df_occurencesfor col in p_ij.columns:    
p_ij[col] = p_ij[col]/p_i[col]for row in p_ij.index:   
p_ij.loc[row,:] = p_ij.loc[row,:]/p_i[row]
p_ij = p_ij + 1E-9for col in p_ij.columns:   
 p_ij[col] = p_ij[col].apply(lambda x: math.log(x))


构图

现在我们得到了所有边的权重,可以开始构图 G 了。我们使用 networkx 模块来构图。这里要提的是整个项目的繁重计算主要在于计算词汇边的权重,因为需要迭代所有可能成对的单词组合,大约有6500个单词。(我们差不多花了两天时间计算这个,代码如下)

def word_word_edges(p_ij):   
 dum = []; word_word = []; counter = 0    
cols = list(p_ij.columns); cols = [str(w) for w in cols]    
for w1 in cols:        
for w2 in cols:            
if (counter % 300000) == 0:               
 print("Current Count: %d; %s %s" % (counter, w1, w2))            
if (w1 != w2) and ((w1,w2) not in dum) and (p_ij.loc[w1,w2] > 0):               
 word_word.append((w1,w2,{"weight":p_ij.loc[w1,w2]})); dum.append((w2,w1))           
 counter += 1   
 return word_word   
 ### Build graphG = nx.Graph()G.add_nodes_from(df_tfidf.index) 
## document nodesG.add_nodes_from(vocab) 
## word nodes### build edges between document-word
 pairsdocument_word = [(doc,w,{"weight":df_tfidf.loc[doc,w]}) for doc in df_tfidf.index for w in df_tfidf.columns]
G.add_edges_from(document_word)
### build edges between word-word 
pairsword_word = word_word_edges(p_ij)G.add_edges_from(word_word)

图卷积神经网络

图节点没有明确的物理空间信息,不像像素点,能够明确在中心点的左边还是右边。所以,要想进行卷积,要找到最能捕获图结构的每个节点的特征表示。作者通过将每个节点的卷积核权重 和特征空间 X 投影到图的傅立叶空间中来解决这个问题,使得卷积变成具有特征的节点的逐点乘法。具体的这里不做赘述,可以参照 kipf 他们的文章。

我们将在这里使用两层 GCN ,两层 GCN 之后的复杂张量由下式给出:

这里:

这里的 A 是图 G 的邻接矩阵,D 是图 G 的度矩阵。W0 W1 分别是 GCN 第一层和第二层可学习的卷积核权重,也是需要被训练学习的。X 是输入特征矩阵,是与节点数相同的维度的对角方形矩阵,这意味着输入是图中每个节点的 one-hot 编码。最后将输出馈送到具有softmax 函数的层,用于书籍分类。

双层 GCN 的Pytorch 代码如下:

class gcn(nn.Module):   
 def __init__(self, X_size, A_hat, bias=True): # X_size = num features      
  super(gcn, self).__init__()      
  self.A_hat = torch.tensor(A_hat, requires_grad=False).float()     
   self.weight = nn.parameter.Parameter(torch.FloatTensor(X_size, 330))      
  var = 2./(self.weight.size(1)+self.weight.size(0))     
   self.weight.data.normal_(0,var)   
     self.weight2 = nn.parameter.Parameter(torch.FloatTensor(330, 130))  
      var2 = 2./(self.weight2.size(1)+self.weight2.size(0))     
   self.weight2.data.normal_(0,var2)     
   if bias:         
   self.bias = nn.parameter.Parameter(torch.FloatTensor(330))      
      self.bias.data.normal_(0,var)    
        self.bias2 = nn.parameter.Parameter(torch.FloatTensor(130))     
       self.bias2.data.normal_(0,var2)    
    else:        
    self.register_parameter("bias", None)   
     self.fc1 = nn.Linear(130,66)        
    def forward(self, X): ### 2-layer GCN architecture   
     X = torch.mm(X, self.weight)   
     if self.bias is not None:   
         X = (X + self.bias)    
    X = F.relu(torch.mm(self.A_hat, X))   
     X = torch.mm(X, self.weight2)      
  if self.bias2 is not None:       
     X = (X + self.bias2)    
    X = F.relu(torch.mm(self.A_hat, X))    
    return self.fc1(X)

训练

总共的 1189 个章节,我们掩盖了111个标签(10%左右)。由于1189 个章节的标签分布很不均匀(如下图所示),所以一些分布低的类别我们不会隐藏,以确保 GCN 可以学习到 66 个类别。

结果

从上面的 loss vs epoch 图中可以看到,训练是很顺利的,损失大约在 2000 epoch的时候达到饱和。

随着训练进行,训练集准确度和测试集准确度同时增加,到 2000 epoch的时候,测试集准确度趋于饱和在 50% 左右。考虑到我们的类别有 66 个,假设模型以纯粹的机会预测,基准精度为 1.5%,相比之下 50% 已经很不错了。意思是,GCN 模型可以正确预测章节属于哪本书的时候有 50% ,即使它之前并没有见过这些章节。

错误分类的章节

GCN 可以捕捉文档内或者文档间的上下文信息,但是分类错误的章节呢?这意味着 GCN 模型失败吗?

例子:

书籍《马太福音》章节 27被模型预测为《路加》

查阅该章节的具体内容可以发现,该章节的内容和路加的部分内容很相似,都是在讲耶稣被处死的事情。

书籍《以赛亚》章节12 被模型预测为《诗篇》

同样的,《以赛亚》的 12 章节主要描述的是一些对上帝的赞美和歌颂,而这正是《诗篇》的全文主旨。

总结

GCN 在文本分类上是一个强大的模型,因为它能够捕捉文档内部和文档之间的信息。

项目地址

圣经数据集地址:

https://github.com/scrollmapper/bible_databases

图神经网络的参照:

https://arxiv.org/abs/1809.05679

项目源代码:

https://github.com/plkmo/Bible_Text_GCN

极验
极验

极验是全球顶尖的交互安全技术服务商,于2012年在武汉成立。全球首创 “行为式验证技术” ,利用生物特征与人工智能技术解决交互安全问题,为企业抵御恶意攻击防止资产损失提供一站式解决方案。

工程文本分类
9
暂无评论
暂无评论~