EGNN
EGNN(Equivariant Graph Neural Network)是一种经典的具有保持等变性(Equivariance)的图神经网络。这意味着如果输入数据通过某种变换(如旋转、翻转等),网络的输出也会以相同的方式变换,从而保证输出与输入在几何上是一致的。这一特性使得EGNN在处理几何数据和物理系统模拟中表现出优异的能力。
1.EGNN等变性的引入方式
在EGNN这篇工作中,首先对GNNs的框架进行了公式化的定义。具体来说其框架分成三部分:边特征的定义(Edge)、边特征的聚合(Agg)和节点特征的更新(Node)。
对于朴素的GNN模型而言,边特征(Edge)的公式可以定义为\(m_{ij} = \phi_{e}\left( h_{i}^{l}, h_{j}^{l}, a_{ij} \right)\),网络会为每对相连的节点( \(\mathrm{i}\) 和 \(\mathrm{j}\) )计算一个消息向量\(m_{ij}\) 。这个计算是通过一个特定的函数(一般是神经网络层的映射)\(\phi_{e}\)完成的,该函数取决于两个节点的特征向量\(\boldsymbol{h}_i^l\) 和\(\boldsymbol{h}_j^l\) ,以及它们之间的边特征\(a_{ij}\) 。这里的\(\mathrm{l}\) 指的是当前层,此函数是设计来捕捉节点间的直接关系和交互的,可以理解为一个编码节点间关系的方式。
边特征的聚合(Agg)公式可以定义为\(m_i=\sum_{j\in\mathcal{N}(i)}m_{ij}\)。在聚合步骤中,对于每个节点\(\mathrm{i}\) ,它从其所有邻居节点\(\mathrm{j}\)(表示为\(\mathcal{N}(i)\))收集消息向量\(m_{ij}\),并将这些消息向量求和,得到\(m_{i}\)。这一步的目的是将从各个邻居那里收集到的信息汇总到单个节点上,从而为下一步的节点状态更新提供必要的信息输入。
节点特征的更新(Node)公式可以定义为\(h_i^{l+1}=\phi_h\left(h_i^l,m_i\right)\)。在节点更新步骤中,节点\(\mathrm{i}\)的新的隐藏状态\(\boldsymbol{h}_i^{l+1}\)是通过另一个函数\(\phi_{h}\)(一般也是神经网络)来计算的,该函数取决于节点\(\mathrm{i}\)当前的隐藏状态\(\boldsymbol{h}_i^l\) 和聚合后的消息\(\mathrm{i}\)。这一步骤是模型对每个节点进行状态更新,它使得节点能够根据自己和周围邻居的信息进行自我更新。
在上述定义的图神经网络(GNN)的标准框架中,可以发现其设计并没有显式地保证模型输出对节点输入的几何变换具有不变性或等变性。具体来说,在Edge步骤中,\(\phi_{e}\)通常是个普通的神经网络层,它处理输入特征但不会直接考虑这些特征如何在空间变换(如旋转或平移)下变化。也就是说,如果图的整体结构经历了某种几何变换,\(m_{ij}\) 的计算方式并没有直接机制来保证相同的变换反映在输出消息中。同理,在节点特征更新阶段也是类似的情况。最后,分析其聚合步骤,一般是简单的累加所有邻居的信息。虽然这个步骤在数学上是置换不变的(即节点的顺序不会影响求和结果),但这并不意味着它对几何变换是不变的或等变的。聚合操作忽略了节点间的空间关系或其在图中的绝对位置,因此对于图的空间变换并不敏感。
等变性和不变性通常需要通过额外的结构设计来实现,即算法的归纳偏置。以之前讲解过的SchNet模型举例,当使用上述EGNN对GNNs算法框架的定义来描述SchNet时,与传统GNN最大的不同体现在边特征(Edge)的定义步骤中,公式可表示成:
\(\boldsymbol{m}_{ij}=\phi_{\mathrm{cf}}\left(\left\|\boldsymbol{r}_{ij}^l\right\|\right)\phi_{\mathrm{s}}\left(\boldsymbol{h}_j^l\right)\) (1)
可以看出,此时边特征的定义不仅仅是简单地依赖于节点的特征向量,而是通过两个分开的函数\(\phi_{\mathrm{cf}}\)和\(\phi_{\mathrm{s}}\)来定义。具体来说: \(\phi_{\mathrm{cf}}\left( \left\lVert r_{ij}^l \right\rVert \right)\)是一个连续滤波卷积 (Continuous Filter Convolution),它仅基于节点\(\mathrm{i}\) 和\(\mathrm{j}\)之间的距离\(\left|r_{ij}^l\right|\)来计算得到卷积操作的权重,即卷积核的数值。这些权重是专门设计来捕捉原子间距离对相互作用强度的影响,这一点在处理分子结构时极其重要,它允许模型对节点之间的物理距离进行敏感响应。\(\phi_\mathrm{s}\left(\boldsymbol{h}_j^l\right)\)则处理来自节点\(\mathrm{j}\)(目标节点\(\mathrm{i}\)的邻居节点)的特征向量,通过神经网络层映射得到的新特征与距离权重 \(\phi_{\mathrm{cf}}\left( \left\lVert r_{ij}^l \right\rVert \right)\)相乘得到更新后的边特征\(m_{ij}\) 。
通过将原子间的距离作为边的一个关键特征,SchNet模型能够有效应对化学物质分析中的一个主要挑战:准确模拟原子间的相互作用力。这些相互作用力通常强烈依赖于原子间的相对位置。更重要的是,由于这些距离信息本质上是标量值,它们在平移、旋转或反射等几何变换下保持不变。因此,为SchNet模型提供了旋转不变性的归纳偏置,使其在处理几何变换时具有稳定的表现。
类似的,EGNN也可以使用上述GNNs框架进行定义,如表1所示。
根据表1,具体解释一下EGNN的公式表示。首先,公式\(\boldsymbol{m}_{ij} = \phi_e\left( \boldsymbol{h}_i^{l}, \boldsymbol{h}_j^{l}, \left\lVert \boldsymbol{r}_{ij}^{l} \right\rVert^2, a_{ij} \right)\)描述了如何计算节点\(\mathrm{i}\)和节点\(\mathrm{j}\)之间的边的特征\(m_{ij}\)。其中, \(\phi_{e}\)是一个可学习的网络层,\(\boldsymbol{h}_i^l\)和\(\boldsymbol{h}_j^l\)分别是节点\(\mathrm{i}\)和节点\(\mathrm{j}\)在第\(\mathrm{l}\) 层的特征表示, \(\left\lVert \boldsymbol{r}_{ij}^l \right\rVert^2\)是节点\(\mathrm{i}\)和节点\(\mathrm{j}\)之间的距离的平方(通常是欧氏距离),\(a_{ij}\) 是边的特征或类型。这些信息会拼接起来一起作为网络层\(\phi_{e}\)的输入。相较于标准的GNN,EGNN多考虑了距离信息 \(\left\lVert \boldsymbol{r}_{ij}^l \right\rVert^2\);相较于SchNet,其距离信息的考虑非常简单与直接,并没有经过径向函数的处理。值得注意的是,EGNN在边特征定义的阶段比标准GNN和SchNet都多了一种新的输入信息:向量信息\(\hat{m}_{ij}\)。具体来说,就是简单的将节点\(\mathrm{i}\)到节点\(\mathrm{j}\)的向量\(r_{ij}^{l}\)直接与经过神经网络层\(\phi_{x}\)处理后的边特征\(m_{ij}\)进行相乘操作,这种将向量信息与边特征结合的方法能够捕捉节点之间的定向关系,增强了对图中边的方向性和空间结构的理解,从而在各种需要精确空间理解的应用中,提供了比传统GNN和SchNet更为强大的性能。
信息聚合阶段比较简单,公式\(m_i=\sum_{j\in N(i)}m_{ij}\)描述了如何将节点的所有邻居节点发来的消息聚合在一起,得到\(m_{i}\)。类似地,公式\(\hat{\boldsymbol{m}}_i=C\sum_{j\neq i}\hat{\boldsymbol{m}}_{ij}\)是将所有方向化的消息聚合在一起,得到修正的聚合消息\(\hat{m}_i\),其中\(\mathrm{C}\)是一个归一化常数,用来调整聚合的范围。
最后,在节点更新阶段。公式\(h_i^{l+1}=\phi_h\left(h_i^l,m_i\right)\)是节点特征的更新规则,其中\(\phi_{h}\)是一个可学习的神经网络层,用当前层的节点特征\(h_i^{l}\)和聚合后的边特征\(m_{i}\)来更新下一层的节点特征\(h_i^{l+1}\)。而公式\(x_i^{l+1}=x_i^l+\hat{m}_i\)则是EGNN中独特的步骤,也是整个工作的精髓。因为它考虑了节点的空间位置,使得网络能够保持等变性。具体来说,\(x_{i}^{l}\)表示在第\(\mathrm{l}\)层时,节点\(\mathrm{i}\)的位置坐标。\(\hat{m}_i\)是从与节点\(\mathrm{i}\)相连的所有邻居节点\(\mathrm{j}\)方向化的聚合信息,表示节点之间的相对位置信息。 \(x_i^{l+1}\)表示在第\(\text{l+1}\)层时,节点\(\mathrm{j}\)的新位置坐标,更新节点\(\mathrm{j}\)坐标的意义在于,在某些情况下(比如分子动力学模拟),节点的物理位置或坐标是随着时间或网络层次的推进而不断变化的。这个公式通过添加修正消息(可以看作是空间变化的量)到当前的位置上,从而更新每个节点的位置。这样,网络不仅能够更新节点的特征表示(如电荷、能量状态等),也能够更新节点的空间位置,从而使得模型能够考虑到结构的动态变化。EGNN通过这样的位置更新机制,保证了当整个图结构在空间中进行变换(如旋转、反射)时,节点的新位置( \(x_i^{l+1}\))将以一致的方式反映这种变换,这也是EGNN实现等变性的重要方式。
2.EGNN等变性的证明
等变性是指,当输入数据经过一个变换(比如旋转、缩放、平移)时,模型的输出也会以相同的变换改变。在图神经网络中,特别是在处理空间数据时,等变性是一个关键属性,因为它允许网络正确地处理空间关系,而不是仅仅依赖于节点的特征,用公式2表示等变性即为:
\(Qx^{l+1}+g,h^{l+1}=\mathrm{EGCL}\left(Qx^l+g,h^l\right)\) (2)
其中,Q通常是一个旋转矩阵变换矩阵, g是一个全局的平移向量,它们作用在位置坐标\(x^{l}\)上可以模拟输入数据的旋转,缩放和平移变换。\(h^{l}\)是节点的特征,当节点位置发生变换时,\(h^{l}\)也需要以一种保持等变性的方式进行更新。 EGCL表示EGNN的网络层,它接收变换后的位置和当前的节点特征,并输出更新后的位置和特征。如果EGCL被设计为等变的,那么意味着如果输入发生了旋转,EGCL的输出也会反映出相应的旋转。
下面,开始证明上面的EGNN式2具备等变性,EGNN的公式如下:
\(m_{ij}=\phi_e\left(h_i^l,h_j^l,\left\|x_i^l-x_j^l\right\|^2,a_{ij}\right)\) (3)
\(\boldsymbol{x}_i^{l+1}=\boldsymbol{x}_i^l+C\sum_{j\neq i}\left(\boldsymbol{x}_i^l-\boldsymbol{x}_j^l\right)\phi_x\left(\boldsymbol{m}_{ij}\right)\) (4)
\(m_i=\sum_{j\neq i}m_{ij}\) (5)
\(h_i^{l+1}=\phi_h\left(h_i^l,m_i\right)\) (6)
对于式3,不管是距离信息\(\left|x_i^l-x_j^l\right|^2\),还是节点或边信息\(h_i^l,h_j^l,a_{ij}\)都不会因为空间的改变而改变。换句话说:距离、节点或边信息不具备方向性,不会因为输入数据发生旋转、缩放或平移而改变自身的值。因此式3在对节点信息进行迭代更新时,实际上是不变性质的,而不是等变性质。EGNN的等变性质主要体现在式4对位置信息的更新上,假设\(x^{i}\)发生旋转和平移的空间变换\(Qx_i^{l+1}+g\),根据式4可以得到:
\(\boldsymbol{Q}\boldsymbol{x}_i^{l+1}+\boldsymbol{g}=\boldsymbol{Q}\boldsymbol{x}_i^l+\boldsymbol{g}+C\sum_{j\neq i}\left(\boldsymbol{Q}\boldsymbol{x}_i^l+\boldsymbol{g}-\left[\boldsymbol{Q}\boldsymbol{x}_j^l+\boldsymbol{g}\right]\right)\phi_x\left(\boldsymbol{m}_{i,j}\right)\) (7)
对式4进行推导,可得到:
\(\begin{aligned}
& \boldsymbol{Q}\boldsymbol{x}_i^l+\boldsymbol{g}+C\sum_{j\neq i}\left(\boldsymbol{Q}\boldsymbol{x}_i^l+\boldsymbol{g}-\boldsymbol{Q}\boldsymbol{x}_j^l-\boldsymbol{g}\right)\phi_x\left(\boldsymbol{m}_{i,j}\right) \\
& =\boldsymbol{Q}\boldsymbol{x}_i^l+\boldsymbol{g}+\boldsymbol{Q}C\sum_{j\neq i}\left(\boldsymbol{x}_i^l-\boldsymbol{x}_j^l\right)\phi_x\left(\boldsymbol{m}_{i,j}\right) \\
& =\boldsymbol{Q}\left(\boldsymbol{x}_i^l+C\sum_{j\neq i}\left(\boldsymbol{x}_i^l-\boldsymbol{x}_j^l\right)\phi_x\left(\boldsymbol{m}_{i,j}\right)\right)+\boldsymbol{g} \\
& =\boldsymbol{Q}\boldsymbol{x}_i^{l+1}+\boldsymbol{g}
\end{aligned}\) (8)
从式8的推导可以发现,位置信息的更新是可以随着输入数据的旋转、缩放和平移的变化而变化的,即可以实现等变性,至此证明完毕。值得注意的是:由于等变的性质只存在于位置信息\(x^{i}\)中,因此如果整个网络在做下游任务时只用了EGNN更新后的节点特征,而没有利用位置特征\(x^{i}\),那么整个网络还是不变性的网络。简单说,EGNN实现的等变其实并不全面。
作者
arwin.yu.98@gmail.com