SchNet

SchNet是一种用于分子和材料的量子化学计算的深度学习模型,最初由德国的研究人员在2017年开发。它是一种基于图的神经网络,专门设计来处理三维分子结构中的原子间相互作用。SchNet的独特之处在于它使用了连续卷积操作,可以捕捉到原子之间复杂的距离依赖关系,从而能够预测分子的能量、力以及其他重要的化学性质。这种模型通过学习原子环境的表示来实现高精度的预测,而不依赖于任何人工定义的特征或先验的物理知识。SchNet已经在多种化学和材料科学任务中展现出了卓越的性能,成为了分子模拟领域的一种重要工具。
1.1 SchNet模型结构
SchNet的架构如图1所示,这是一个针对分子和材料建模而设计的深度学习网络。图中分为三个主要部分:架构总览、交互模块和滤波器生成网络。

1.架构总览
如图1左图所示,输入表示为\((Z_1,\cdots,Z_n)\)和\((r_1,\cdots,r_n)\),其中\(\mathbf{z}_{i}\)是原子类型(例如氢、氧、碳等)被编码成的一个one-hot编码的向量。这是基本的化学信息,对于识别不同原子的性质非常重要。\(r_{i}\)是每个原子在分子中的三维坐标,这些位置信息对于模型来说至关重要,因为它们直接决定了原子之间的距离和相互作用,这些都是决定分子性质的关键因素。
输入数据首先经过嵌入层(embedding),这是一个常见的深度学习技术,一般使用神经网络层将每个原子的one-hot节点特征向量转换成一个固定大小的连续特征向量。在这个例子中,每个原子经过神经网络层映射成的新特征向量的维度是64,即每个原子被映射为一个包含64个元素的向量。这个向量编码了原子的某些属性,可以在网络训练过程中进行学习和优化。接下来是多个交互层,每个层都旨在捕捉和更新原子之间的相互作用信息,得到更有效的原子特征向量表示。之后,一个原子级别的操作(atom-wise)对每个原子的特征进行转换。原子级别操作将对每个原子的特征向量独立地进行学习变换。这意味着网络在进行这些操作时,不会考虑原子之间的相互作用或原子的邻居,而是将注意力集中在单个原子的特征向量上。在实践中,这通常通过应用一个全连接层来实现。
Shifted softplus是一种激活函数,用于引入非线性,该激活函数定义为:\(\text{Shifted softplus}(x)=\ln(0.5e^x+0.5)\)
其函数图像如图2所示。

它是一个平滑的非线性函数,可以将非线性引入神经网络。Shifted Softplus类似于激活函数ReLU,但对于负输入值具有非零梯度。即使输入具有负值,也可以帮助网络继续学习,且在x=0的位置过度的更加平滑。值得注意的是,Shifted Softplus 在处理几何数据时往往比ReLU更受欢迎,主要原因有下面几个:
(1)平滑的梯度:Shifted Softplus 函数在整个定义域中都有一个平滑的导数。这意味着梯度不会突然变为零(如ReLU在负输入时的表现),从而避免了神经网络在训练过程中的梯度消失问题。在几何结构的数据中,原子或节点的特征可能涉及复杂的相互作用和细微的变化,平滑的梯度允许模型更细致地调整其参数以捕获这些特征。
(2)非饱和性:Shifted Softplus是一个非饱和激活函数,这意味着它的输出不会在较大的正输入值时饱和到一个固定值(与Sigmoid或Tanh不同)。非饱和性有助于减轻训练过程中的梯度饱和问题,提高网络在深层结构中的信息流动。
(3)处理负输入:与ReLU相比,Shifted Softplus 即使在负输入下也能提供小的但非零的梯度。这种特性使得网络能够从所有数据中学习,即使是那些在几何空间中可能表示不同类型相互作用(如排斥力)的负输入值。
(4)连续性:几何深度学习领域通常需要处理连续数据,如原子位置或结构动力学。Shifted Softplus 的平滑性和连续性对于这些应用来说是有利的,因为它允许模型生成连续的输出和梯度,这对于学习数据中的复杂模式至关重要。
在模型架构总览的最后,使用神经网络层(atom-wise level)将特征进行缩减,然后通过求和池化(sum pooling)计算出一个标量,来代表整个分子的能量估计\(\hat{E}\)。
2.交互模块
交互模块如图1中间部分所示,其目的是更新原子的特征,其组成结构通常使用基于连续滤波卷积的网络,这允许模型学习原子之间在连续空间中的复杂相互作用。交互模块的输入是每个原子的特征表示,具体来说是通过嵌入层得到初始的特征向量,也可以是自定义的原子特征表示,例如其类型、电荷、电子排布等编码。交互模块的核心处理操作是连续滤波卷积(cfconv),其将根据原子之间的相对位置动态生成卷积核的大小和权重,这样的设计非常适合处理分子动力学和化学结构数据。
连续性卷积计算之后,再进行原子级操作(atom-wise)对每个原子的特征向量进行独立的非线性变换,具体实现上将应用一个或多个全连接层,并通过激活函数增加模型的非线性。这使得模型能够更新每个原子的状态,反映其在具体问题中的变化。
值得注意的是,在交互模块的最后,还存在残差连接结构。通过元素级别的加法,来自不同层的特征可以被整合,这种方式保留了原始输入特征的信息。残差连接还可以帮助防止深层网络的训练问题,如梯度消失或爆炸。
3.连续滤波卷积网络
在深度学习领域,传统的卷积层是为处理规则化的数据而设计的。这些数据类型通常包括图像像素、视频帧或数字音频数据,它们在本质上是结构化的,因为它们的数据点被整齐地排列在一个固定的网格上。例如,在图像处理中,每个像素都有固定的位置,卷积核可以在这些位置上滑动,以检测边缘、纹理等特征。在处理像视频帧这样的连续数据流时,卷积层可以捕获时间上的变化并学习动态模式。
然而,当数据本身分布在非规则的空间中时,传统的卷积网络就遇到了难题。例如,原子、分子等可能不会均匀地分布在一个固定网格上,而是在连续空间中以各种不规则的间隔排列,如图3所示。

图3中展示了离散卷积(传统的2D卷积)和连续卷积如何针对原子位置变化引起的能量变换预测问题的结果对比。图3中分为左右两部分,每部分包含了上方的原子模型示意图和下方的能量-原子位置关系图。图3(a)是“离散卷积核”的示例,上方的两个原子模型示意图显示了原子位置的微小变化,但下方的能量图显示了随原子位置变化的能量预测是不连续的,这种不连续性可能会导致模拟的不准确性,因为现实中原子位置的改变通常对能量的影响是平滑连续的。图3 (b)是“连续卷积核”的示例,上方的原子模型示意图与左侧相似,显示了原子位置的变化。但是,与离散卷积核不同的是,下方的能量图显示了一个平滑的曲线,表明能量预测随原子位置的微小变化而平滑变化。这样的连续变化更能准确地模拟真实世界的物理现象,因为它可以精确地反映出原子位置变化对能量的实际影响。
同样的情况也出现在处理天文观测数据、气候模型以及金融市场的波动等数据时。这些数据类型要求卷积层能够理解和处理在非规则空间中分布的信号。传统方法中的一种解决方案是通过重采样将输入数据重新映射到规则的网格上,然后应用插值方案。但这可能要求有大量的网格点,并且可能在本质上不够精确。
为了解决这个问题,研究者们提出了将连续卷积核应用于卷积层的概念,这是传统的卷积思想被引入到几何深度学习领域来处理几何数据的重要一步。与传统的离散卷积核不同,连续卷积核能够处理输入数据点在空间中任意排列的情况,可以直接在原始的、不规则的空间分布上操作。这是通过使用一种连续卷积核生成函数(Filter-generating Function)实现的,该函数能够根据数据点在空间中的具体位置动态地产生相应的卷积核值。
通过这种方式,连续滤波卷积不仅提高了处理非规则空间数据的能力,而且也为图结构数据、3D 形状以及其他超出欧几里得空间的数据结构提供了强大的新工具。这种方法对于诸如新材料设计、复杂系统建模和精准药物开发等领域具有革命性的意义,因为它们常常需要处理在不规则空间中分布的复杂数据。下面,对比离散卷积(2D卷积),给出连续卷积的数学表达。
在二维空间中,对于一个图像I和一个3×3的卷积核K,在位置(x,y)处的离散卷积C可以定义为:
\(C(x,y)=\sum_{i=-1}^1\sum_{j=-1}^1I(x+i,y+j)\cdot K(i,j)\)
其中,每个I(x+i,y+j)是图像上的一个像素,K(i,j)是卷积核上对应的权重。卷积核覆盖的像素是局部邻域,而卷积操作本身是一个局部的线性操作,因为它仅考虑图像中(x,y)点周围的一个小区域(即局部邻域)。
连续滤波卷积是一种更为泛化的操作,它能够处理非规则分布的数据点,其操作可以用以下公式表示:
\(x_i'=\sum_{j\in N(i)}x_j\odot h_\Theta(\exp(-\gamma(e_{j,i}-\mu)))\)
上式的各个符号和操作解释如下:
(1)\(x_i^{\prime}\):表示更新后的特征向量,是节点i在应用卷积操作后的特征。
(2)\(\sum_{j\in N(i)}\):表示对节点i的邻居节点j进行求和操作,N(i)表示节点i的邻居节点集合。
(3)\(x_{j}\):表示邻居节点j的特征向量。
(4)\(\odot\):表示Hadamard乘积(元素乘积),即对两个向量对应元素相乘。
(5)\(h_{\Theta}\):表示连续卷积核生成函数,在具体实现时,可以选择一些固定的数学函数,也可以选择一些带可学习参数可学习的模型。
(6)\(\exp(-\gamma(e_{j,i}-\mu))\):表示一个高斯径向基函数,其中\(\text{r}\)是一个超参数,控制高斯函数的宽度;\(e_{j,i}\)表示原子i和原子j之间的距离;\(\mu \)是一个均值,表示高斯基函数的中心位置。
值得注意的是,在计算目标节点与其邻居相对距离\((e_{j,i}-\mu)\)的这个操作中,还可以通过增加一个截断变量C来限制连续卷积核只计算目标节点与它一定半径内的邻居之间的距离,而不是全部邻居的距离,如图4所示。

在图4中,目标节点为\(x_{i}\),其他节点全部为目标节点\(x_{i}\)的邻居节点,但是由于截断距离的限制,有两个截断距离外的节点并没有参加相对距离的计算。
最后,针对连续滤波卷积的关键:连续卷积核生成函数进行展开探讨。在具体实现时,可以选择一些常见的数学函数或结构,例如:
(1)余弦函数:
\(C=0.5\times(\cos(\textbf{edge_weight}\times\pi/\text{cutoff })+1.0)\)
其中,edge_weight是边的特征向量,一般是目标节点与其邻居的距离信息经过高斯基函数映射以后得到的特征向量,cutoff是截断变量。为了方便理解,令edge_weight的值域是[0,5],截断变量分别为1、3、5。对余弦结果C的可视化如图5所示。

可以发现,在截断变量的范围内,C的变化值域始终为[0,1]。具体来说,当目标节点的邻居离目标节点无限近时,C接近1;在目标节点的邻居离目标节点较远,接近截断变量限制的边界时,C接近0。实际上,计算结果C可以看作一种权重,通过与邻居特征相乘的方式根据相对距离来控制邻居节点对于当前目标节点的影响。
(2)多项式或径向基函数(RBF):
\(W(r_i-r_j)=\sum_k\alpha_k\phi_k(\parallel r_i-r_j\parallel)\)
其中,\(\phi_{k}\)可以是高斯核、多项式核或其他任何合适的径向基函数,\(\alpha_{k}\)是需要学习的参数。
(3)神经网络:
由于W是从\(R^{D}\)到\(R^{F}\)的映射,可以使用一个小型的前馈神经网络来实现:
\(W(r_i-r_j)=\mathrm{NN}(\parallel r_i-r_j\parallel)\)
其中NN 表示神经网络,它接受位置差的某种表示(例如,向量的模长或者经过某种处理的向量)作为输入,并输出卷积核的权重。
(4)注意力机制:
类似于Transformer模型中的注意力机制,W可以设计为考虑输入特征之间的关系:
\(W(r_i,r_j,x_i,x_j)=\mathrm{softmax}\Bigg(\frac{Q(r_i,x_i)K(r_j,x_j)^T}{\sqrt{d_k}}\Bigg)V(r_j,x_j)\)
其中Q,K,V是查询(Query)、键(Key)、值(Value)的函数,通常是由位置和特征的联合表示经过线性变换得到的,\(d_{k}\)是缩放因子。
(5)距离加权:
可以简单地将权重设置为距离的函数,例如:
\(W(r_i-r_j)=\frac{1}{\|r_i-r_j\|^p+\epsilon}\)
其中p是一个超参数,\(\epsilon \)是一个小的常数,以避免除以零。
离散卷积和连续卷积的主要区别在于它们处理的数据结构和操作方式。在数据结构上,离散卷积处理的是排列在固定网格上的点,例如图像像素。而连续卷积处理的是不依赖于固定网格的数据。它可以应用于任意分布的点,例如不规则排列的原子或其他任意空间分布的数据点。
在操作方式上,离散卷积的局部性是由卷积核的尺寸强制定义的,并在图像数据上局部地应用卷积核,通常这个卷积核是一个小型的矩阵,比如3×3。这个大小的选择意味着每个像素点的新值是由其自身和周围8个像素点(在3×3邻域中)通过卷积核加权求和得到的。与离散卷积不同,连续卷积涉及一个动态生成的卷积核,这个核的值是根据截断变量C或其目标节点的邻居数量计算的。由于图数据的不规则性,即使截断变量C固定,每个图局部范围下的邻居数量大概率是不同的。
1.2 SchNet与GNN
最后,从GNN模型框架的角度来讲解一下SchNet。首先,回顾一下GNN 的核心计算流程,如图6所示。

其中的关键点有三方面,其一是对输入数据的向量化(Embedding);其二是对节点、边或全局特征进行消息传递与更新,最后是对模型顶层的设计来适应具体的任务,将SchNet对应到GNN的框架中如图7所示。

SchNet中对输入数据的向量化相对直接而简单。它通过将原子类型映射成高维空间中的向量来实现。这一过程通常涉及一个可学习的神经网络层,该层将每个原子类型的one-hot编码转换为一个固定长度的向量表示。这些表示可以通过模型训练捕捉原子之间潜在的相似性和差异性。
至于信息的传递,主要依靠连续滤波卷积来处理原子间的相互作用。这种方法直接利用原子位置(通过它们的三维坐标表示)计算原子对之间的距离,并基于这些距离通过径向模型计算范围,动态生成卷积核的权重值。这种方式允许模型更灵活地处理各种大小和形状的分子结构。
在顶层设计中,会基于更新后的节点特征向量对目标性质做预测,在SchNet中也是通过一些神经网络层和池化操作实现的。
作者
arwin.yu.98@gmail.com