GraphGAN是一种结合了图神经网络和生成对抗网络(GAN)概念的机器学习模型。它旨在通过对抗学习框架解决图数据中的节点分类和链接预测问题,特别是在缺少标签数据的情况下。GraphGAN通过结合生成器和判别器两个主要组件来学习图中节点的有效表示。

GraphGAN的框架如下:设G=(V,E)为一个给定的图,其中V={v1,,vV}代表顶点的集合,而E={eij}i,j=1V代表边的集合。对于给定的顶点νc,定义N(νc)为直接连接到νc的顶点集,也就是顶点νc的1-hop邻居。我们将顶点νc的真实连通性分布表示为条件概率ptrue(ν|νc),它反映了νc的连通性偏好和它所连接顶点的类型。从这个视角来看,N(νc)可以被看作是从ptrue(ν|νc)抽取的一组观察样本。给定图G,GraphGAN旨在学习以下两个模型:

  • 生成器G(ν|νc;θG):尝试近似顶点νc的真实连通性分布ptrue(ν|νc),并生成(或选择)最有可能与νc相连接的顶点集。
  • 判别器D(ν,νc;θD):旨在鉴别顶点对(ν,νc)的连通性。D(ν,νc;θD)输出一个标量值,表示边(ν,νc)存在的概率。

生成器G和判别器D作为两个对手: 生成器G将试图完美地匹配ptrue(ν|νc)并生成类似于νc的真实直接邻居的相关顶点以欺骗判别器。而判别器D则相反,会尝试检测这些顶点是νc的真实邻居还是由其对手G生成的。形式上, GD在一个有以下价值函数V(G,D)的双人博弈中进行对抗:

minθGmaxθDV(G,D)=c=1V(Eνptne(ν|νc)[logD(ν,νc;θD)]+EνpG(ν|νc;θG)[log(1D(ν,νc;θD))])       (1)

上式描述的最小最大(Minimax)博弈公式是生成对抗网络的核心。在图生成对抗网络(GraphGAN)的上下文中,生成器G的目标是生成看起来像真实数据的样本。在GraphGAN的情境下,生成器试图生成看似是νc的1-hop邻居的顶点。而鉴别器D的目标是区分输入样本是来自于真实数据还是生成器生成的假数据。在GraphGAN中,它试图区分一个顶点对(ν,νc)是否是真实的1-hop邻点对。这个公式包含了两个期望(Expectations):

第一部分为真实数据的期望Evptrue(vvc)[logD(v,vc;θD)]。这个期望表示对所有真实的顶点对(ν,νc),鉴别器D输出它们是真实邻居的概率的对数。V是从顶点νc的真实邻居分布ptrue(ν|νc)中采样的。理想情况下: 如果ν确实是νc的真实邻居,我们希望鉴别器D的输出D(ν,νc;θD)接近于1 (最大的可能性)。此时logD(ν,νc;θD)会接近0(因为log(1)=0)。

第二部分是关于生成数据的期望EνpG(ν|νc;θG)[log(1D(ν,νc;θD))]。这个期望表示对所有生成器G生成的顶点对(ν,νc),鉴别器D输出它们不是真实邻居的概率的对数。ν是从生成器G的分布pG(ν|νc;θG)中采样的。理想情况下:如果ν是生成器G生成的,并不是νc的真实邻居,我们希望判别器D的输出D(ν,νc;θD)接近于0。此时, log(1D(ν,νc;θD))也会接近0(因为log(10)=0)。

可以看出这两部分期望的目标都是希望经过训练后,其值越来越小,理想情况下为0,因此可以看作GraphGAN的损失函数,GraphGAN的训练状态如图1所示。

图1 GraphGAN训练状态

阴影的圆代表从真实图结构中采样的真实与顶点νc连接的邻居,横条纹的圆代表生成器认为的应该与νc连接的邻居。在训练的初始阶段(左图),生成器G的表现并不佳,其生成的顶点与实际的邻居相比,差异较大,很容易被鉴别器D区分。随着对抗训练的进行(中图),生成器G逐渐学习到更好的生成策略,使得生成的邻居和实际的邻居更难以区分。在经过充分的训练之后,生成器G生成的邻居与实际的邻居非常接近,即右图中,阴影的圆和横条纹的圆存在大量重叠。鉴别器D几乎无法区分。在这个博弈中,鉴别器尽量提高识别真实和生成数据对的准确性,而生成器尽量生成看起来越真实的数据。通过这种对立的训练过程,两个模型都在不断改进,直到达到一个平衡点,即生成器生成的数据非常逼真,鉴别器很难区分真假数据对。

下面具体的讲解GraphGAN的优化过程,对于判别器而言,定义鉴别器D的输出为两个输入顶点(ν,νc)内积的Sigmoid函数:

D(ν,νc)=σ(dνdνc)=11+exp(dνdνc)      (2)

定义θD是判别器的可学习参数。任何具有区分能力的模型都可以在此作为D,例如一个神经网络模型,对顶点ννc的特征向量进行学习映射后得到的k维新的向量表示为dν,dνc,然后进行内积和Sigmoid处理。这个内积表示了两个顶点之间的相似性或关系强度。内积的结果通过Sigmoid函数进行处理,得到一个在0到1之间的值,可以解释为顶点对ννc是图中真实连接的概率。其中ν的采样有两种情况,换句话说,判别器随机接收两类输入数据(νptrueνG)。根据损失函数公式(1),判别器的梯度θDV(G,D)计算如下:

θDV(G,D)={θDlogD(ν,νc), if νptrueθD(1logD(ν,νc)), if νG.        (3)

在训练过程中,GraphGAN交替地更新生成器G和判别器D的参数,以便最小化生成器的损失函数和最大化判别器的损失函数,这个过程通常被称为对抗训练。式(3)便是固定生成器G更新判别器D时,判别器的梯度计算方法。

类似的,生成器G具体的模型类型也没有限制,GCN、GAT甚至简单的神经网络等等都可以,只要能具备预测某顶点与其他顶点间连接概率的能力即可。当固定判别器D,更新生成器G的参数时,式(1)中的Eνptrue(νc)[logD(ν,νc;θD)]部分可以看作常数,只需要关注另一部分,即EνG(νc;θG)[log(1D(ν,νc;θD))])。此时,生成器的梯度θGV(G,D)为:

θGV(G,D)=θGc=1VEνG(νc)[log(1D(ν,νc))]=c=1Vi=1NθGG(νi|νc)log(1D(νi,νc))=c=1Vi=1NG(νi|νc)θGlogG(νi|νc)log(1D(νi,νc))=c=1VEνG(|νc)[θGlogG(ν|νc)log(1D(ν,νc))]            (4)

在式(4)的推导中,c=1VEνG(νc)[log(1D(ν,νc))]是对所有顶点νc的期望求和,其中每个期望是从生成器G生成的顶点ν对应的log(1D(ν,νc))的平均值。这表明我们在考虑所有生成器产生的假顶点与真实顶点νc之间关系的平均对数概率。期望的实现可以转化成对每个生成样本的求和,即i=1NθGG(νi|νc)log(1D(νi,νc)),这里对每个生成样本Vi求和,计算生成概率相对于θG的梯度与log(1D(νi,νc))的乘积。又因为概率的梯度可以写作概率乘以似然比例的梯度,因此可以使用G(νi|νc)θGlogG(νi|νc)log(1D(νi,νc))代替θGG(νi|νc)log(1D(νi,νc))。最后,我们用期望值来代替求和,因为期望值是随机变量的平均值,这里的随机变量是ν的生成概率,得到推导后生成器的梯度计算公式:EνG(νc)[θGlogG(ν|νc)log(1D(ν,νc))]。整体上,这个梯度表达式说明:为了更新生成器G的参数,我们需要考虑由G生成的每个顶点v对损失函数的贡献,并根据该贡献来调整θG。梯度的方向指示了如何更新参数θG以减少生成的顶点v被判别器D识别出的概率,从而使生成器产生更真实的样本。

值得注意的是,在给定顶点νc时,生成器G计算其他顶点与νc的连通性概率G(ν|νc)需要使用Softmax函数计算出νc外的所以节点。公式如下所示:

G(v|vc)=exp(gvgvc)vvcexp(gvgvc)        (5)

当图数据体量庞大时,上述公式的计算复杂度是非常高的。因此,在GraphGAN框架中,提出了一种称为Graph Softmax的新方法,它的核心思想有三点:

(1)规范化(Normalized):生成器应该产生一个有效的概率分布,这意味着对于一个给定的顶点νc,所有可能与之相连的顶点v的生成概率之和应该等于1,这也是Softmax函数的基本思想。

(2)图结构意识(Graph-structure-aware):生成器在确定顶点之间的连通性概率时,应该利用图的结构信息。具体来说,如果两个顶点之间的最短路径距离增加,它们被认为是相连的概率应该下降。

(3)计算效率(Computationally Efficient):与Softmax不同,Graph Softmax在计算时只考虑图中少数的特定顶点,从而提高计算效率。

为了实现Graph Softmax,首先需要执行一个从顶点νc出发的广度优先搜索(Breadth-First Search,BFS),来构建一棵以νc为根的BFS树Tc。使用这棵树,我们可以计算每个邻居顶点νi相对于νc的相关性概率pc(νi),这是一个标准的Softmax函数,它在νc的直接邻居集Nc(ν)上运算。 接着,为了计算任意顶点ννc之间的连通性概率G(ν|νc;θG),我们使用Tc中从νcν的唯一路径。路径上每对相邻顶点的相关性概率的乘积定义了ννc之间的连通性概率。通过这种方式,图Softmax只关注νc通过BFS可达的部分图,而不是整个图,这极大地提高了计算效率。同时,它也自然地考虑到了图的结构信息,因为它直接在BFS树上进行操作,这棵树反映了图中顶点之间的真实距离关系。因此这种方法是规范化的,具有图结构意识,且计算效率高,具体示例如图2所示。

图2 BFS树

如图2所示,从顶点νc开始,我们有原始的图G和一个以νc为根节点的 BFS 树。第一步是从νc的直接邻居中选择下一个顶点。在这个例子中,顶点νr1被选中,与它相关联的概率是0.7。接下来,从νr1的邻居中选择下一个顶点。顶点νr2被选中,与它相关联的概率是 0.3。然后,从νr2的邻居中选择下一个顶点。可能的选择只有两种,与它们相关联的概率是 0.6和0.4,假设此时以0.6的概率选择了νr1,那么νr2的路径采样过程完成,否则路径采样继续延续。最后一步是更新采样路径上所有顶点的概率。计算路径νcνr1νr2上所有顶点的相关性概率的乘积得到0.7×0.3×0.6=0.126 ,这个数字代表从νcνr2的连通性概率。

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注

作者

arwin.yu.98@gmail.com

相关文章

zh-CN Chinese (Simplified)