最近火到出圈的一篇数学论文,到底说了什么?为什么能掀起波澜?
moboyou 2025-05-18 14:31 40 浏览
四月,arXiv上出现了一篇题为《KAN: Kolmogorov-Arnold Networks》的论文。该论文获得约5000个赞,对于一篇学术论文来说,可谓是相当火爆。随附的GitHub库已有7600多个星标,且数字还在持续增长。
Kolmogorov-Arnold 网络(KAN)是一种全新的神经网络构建块。它比多层感知器(MLP)更具表达力、更不易过拟合且更易于解释。多层感知器在深度学习模型中无处不在。例如,我们知道它们被用于GPT-2、3以及(可能的)4等模型的Transformer模块之间。对MLP的改进将对机器学习世界产生广泛的影响。
MLP
MLP实际上是一种非常古老的架构,可以追溯到50年代。其设计初衷是模仿大脑结构;由许多互联的神经元组成,这些神经元将信息向前传递,因此得名前馈网络(feed-forward network)。
MLP通常通过类似上图的示意图来展示。对于外行来说,这很有用,但在我看来,它并没有传达出真正正在发生的事情的深刻理解。用数学来表示它要容易得多。
假设有一些输入x和一些输出y。一个两层的MLP将如下所示:
其中W是可学习权重的矩阵,b是偏差向量。函数f是一个非线性函数。看到这些方程,很明显,一个MLP是一系列带有非线性间隔的线性回归模型。这是一个非常基本的设置。
尽管基本,但它表达力极强。有数学保证,MLP是通用逼近器,即:它们可以逼近任何函数,类似于所有函数都可以用泰勒级数来表示。
为了训练模型的权重,我们使用了反向传播(backpropagation),这要归功于自动微分(autodiff)。我不会在这里深入讨论,但重要的是要注意自动微分可以对任何可微函数起作用,这在后面会很重要。
MLP的问题
MLP在广泛的用例中被使用,但存在一些严重的缺点。
- 因为它们作为模型极其灵活,可以很好地适应任何数据。结果,它们很可能过拟合。
- 模型中往往包含大量的权重,解释这些权重以从数据中得出结论变得非常困难。我们常说深度学习模型是“黑盒”。
- 拥有大量的权重还意味着它们的训练可能会很长,GPT-3的大部分参数都在MLP层中。
Kolmogorov-Arnold 网络
Kolmogorov-Arnold 表示定理
Kolmogorov-Arnold 表示定理的目标类似于支撑MLP的通用逼近定理,但前提不同。它本质上说,任何多变量函数都可以用1维非线性函数的加法来表示。例如:向量v=(x1, x2)的除法运算可以用对数和指数代替:
为什么这会有用呢?这究竟实现了什么?
这为我们提供了一种不同但简单的范式来开始构建神经网络架构。作者声称,这种架构比使用多层感知器(MLP)更易于解释、更高效地使用参数,并且具有更好的泛化能力。在MLP中,非线性函数是固定的,在训练过程中从未改变。而在KAN中,不再有权重矩阵或偏差,只有适应数据的一维非线性函数。然后将这些非线性函数相加。我们可以堆叠越来越多的层来创建更复杂的函数。
B样条(B-splines)
在KAN中表示非线性的方式中有一点重要的是需要注意的。与MLP中明确定义的非线性函数(如ReLU()、Tanh()、silu()等)不同,KAN的作者使用样条。这些基本上是分段多项式。它们源自计算机图形领域,在该领域中,过度参数化并不是一个问题。
样条解决了在多个点之间平滑插值的问题。如果你熟悉机器学习理论,你会知道要在n个数据点之间完美插值,需要一个n-1阶的多项式。问题是高阶多项式可能变得非常曲折,看起来不平滑。
- 10个数据点被一个9阶多项式完美拟合
通过将分段多项式函数适应于数据点之间的部分,样条解决了这个问题。这里我们使用三次样条。
- 三次样条插值更好,但不能泛化
对于三次样条(样条的一种类型),为了确保平滑,需要在数据点(或结点)的位置对一阶和二阶导数设置约束。数据点两侧的曲线必须在数据点处具有匹配的一阶导数和二阶导数。
KAN使用的是B样条,另一种类型的样条,具有局部性(移动一个点不会影响曲线的整体形状)和匹配的二阶导数(也称为C2连续性)的特性。这样做的代价是实际上不会通过这些点(除了在极端情况下)。
- 3条B样条对应5个数据点。注意曲线是如何不通过数据点的。
在机器学习中,特别是在应用于物理学时,不经过每一个数据点是可以接受的,因为我们预计测量会有噪声。
这就是在KAN的计算图的每一个边缘发生的事情。一维数据用一组B样条进行拟合。
进入KAN
因此,现在我们在计算图的每个边缘都有一个分段的参数曲线。在每个节点,这些曲线被求和:我们之前看到,可以通过这种方式逼近任何函数。
为了训练这样的模型,我们可以使用标准的反向传播。在这种情况下,作者使用的是LBFGS(Limited-memory
Broyden-Fletcher-Goldfarb-Shanno),这是一种二阶优化方法(与Adam这种一阶方法相比)。另一个需要注意的细节是:在每个代表一维函数的边上,有一个B样条,但作者还增加了一个非线性函数:silu函数。
对此的解释不是很清楚,但很可能是由于梯度消失(这是我的猜测)。
我们来试用一下
我打算使用作者提供的代码,它运行得非常出色,有许多示例可以帮助我们更好地理解它。
他们使用由以下函数生成的合成数据:
定义模型
model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
这里定义了三个参数:
- 宽度,其定义方式与多层感知器(MLP)类似:一个列表,其中每个元素对应一个层,元素值是该层的宽度。在这种情况下,有三层;输入维度为2,有5个隐藏维度,输出维度为1
- 网格与B样条相关,它描述了数据点之间的网格可以有多细致。增加这个参数可以创建更多曲折的函数。
- k是B样条的多项式阶数,一般来说,三次曲线是个不错的选择,因为三次曲线对样条有很好的属性。
- seed,随机种子:样条的权重用高斯噪声随机初始化(就像在常规MLP中一样)。
训练
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.0)
该库的API非常直观,我们可以看到我们正在使用LBFGS优化器,训练20步。接下来的两个参数与网络的正则化相关。
训练后的下一步是修剪模型,这会移除低于相关性阈值的边和节点,完成后建议重新训练一下。然后将每个样条边转换为符号函数(log、exp、sin等)。这可以手动或自动完成。库提供了一个极好的工具,借助model.plot()方法可以看到模型内部的情况。
# Code to fit symbolic functions to the fitted splines
if mode == "manual":
# manual mode
model.fix_symbolic(0, 0, 0, "sin")
model.fix_symbolic(0, 1, 0, "x^2")
model.fix_symbolic(1, 0, 0, "exp")
elif mode == "auto":
# automatic mode
lib = ["x", "x^2", "x^3", "x^4", "exp", "log", "sqrt", "sin", "abs"]
model.auto_symbolic(lib=lib)
一旦在每个边上设置了符号函数,就会进行最终的再训练,以确保每个边的仿射参数是合理的。
整个训练过程在下面的图表中总结。
- 使用KAN进行符号回归的示例。图片来自论文。
完整的训练代码如下所示:
# Define the model
model = KAN(width=[2, 5, 1], grid=5, k=3, seed=0)
# First training
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.0)
# Prune edges that have low importance
model = model.prune()
# Retrain the pruned model with no regularisation
model.train(dataset, opt="LBFGS", steps=50)
# Find the symbolic functions
model.auto_symbolic(lib=["x", "x^2", "x^3", "x^4", "exp", "log", "sqrt", "sin", "abs"])
# Find the afine parameters of the fitted functions without regularisation
model.train(dataset, opt="LBFGS", steps=50)
# Display the resultant equation
model.symbolic_formula()[0][0] # Print the resultant symbolic function
一些思考
模型中有相当多的超参数可以调整。这些可以产生非常不同的结果。例如,在上面的示例中:将隐藏神经元的数量从5改为6意味着KAN找不到正确的函数。
在机器学习中,“超参数”(hyperparameters)是指那些在学习过程开始之前需要设置的参数。这些参数控制着训练过程的各个方面,但它们并不是通过训练数据自动学习得到的。超参数的设置对模型的性能和效率有着重要的影响。
- 由KAN[2,6,1]找到的结果函数
这种变化性是预期的,因为这种架构是全新的。花了几十年时间,人们才找到了调整MLP超参数(如学习率、批大小、初始化等)的最佳方式。
结论
MLP已经存在很长时间了,早该升级了。我们知道这种改变是可能的,大约6年前,LSTMs在序列建模中无处不在,后来被transformers作为标准的语言模型架构构建块所取代。如果MLP也能发生这种变化,那将是令人兴奋的。另一方面,这种架构仍然不稳定,而且运行效果并不是非常出色。时间将告诉我们,否能找到一种方法来绕过这种不稳定性并释放KAN的真正潜力,或者KAN是否会被遗忘,成为机器学习的一个小知识点。
我对这种新架构感到非常兴奋,但我也持怀疑态度。
相关推荐
- Linux集群自动化监控系统Zabbix集群搭建到实战
-
自动化监控系统Cacti特点:将监控到的数据,绘制成各种图形基于SNMP协议(网络管理协议)的监控软件,强大的绘图能力Nagios特点:状态检查和报警机制(例如:内存不足或CPU负载高时,及时的...
- 快速掌握Kafka系列《三》配置项总结
-
往期系列文章:1.快速掌握Kafka系列《一》基本概念入门2.快速掌握Kafka系列《二》常用操作命令汇总目录一、前言二、broker配置2.1三个基本配置2.2其它配置2.3...
- 8.mxGraph 命名空间与 Hello World 示例实践.md
-
2.2.2GeneralJavaScriptDevelopment常规JavaScript开发2.2.2.1JavaScriptObfuscation/JavaScript混淆[翻...
- 英特尔 i9-12900KS 最新爆料:基础功耗 150W,790 美元
-
IT之家2月14日消息,据爆料者@momomo_us的消息,现在已有海外经销商列出了i9-12900KS的商品信息。i9-12900KS的产品代码为BX8071512900KS,基...
- Spring Boot集成OAuth2:实现安全认证与授权的详细指南
-
SpringBoot集成OAuth2:实现安全认证与授权的详细指南引言在当今数字化时代,Web应用的安全认证和授权至关重要。OAuth2作为一种广泛应用的开放标准协议,为第三方应用提供了安全、便捷的...
- DNF人造神团本男气功加点攻略(dnf男气功用什么神话)
-
SP方面:加点从下往上点起,大技能全部点满,剩余sp在雷霆踏和念雷轰之间根据个人喜好二选一。加点代码:eJwNzTEKglAAx+Hf35D0pU8bImxpkSgHt47QFNRSi2cIkkJ...
- Python连接Mysql数据库的几种方式以及问题排查方法
-
一、使用pymysql连接Mysql数据库连接示例:conn=pymysql.connect(host=host,user=user,password=passwd,db=db,port=int(...
- 37【源码】数据可视化:基于 Echarts + Python 动态实时大屏
-
效果图展示1.动态效果演示2.静态切片效果图一、确定需求方案1.确定产品上线部署的屏幕LED分辨率本案例基于16:9屏宽比,F11全屏显示。2.部署方式浏览器打开播放,Chrome浏览器、360浏览...
- 36【源码】数据可视化:基于 Echarts + Python 动态实时大屏
-
效果图展示动态效果演示2.静态切片效果图一、确定需求方案1.确定产品上线部署的屏幕LED分辨率本案例于16:9屏宽比,F11全屏显示。2.部署方式浏览器打开播放,Chrome浏览器、360浏览器等。...
- Jsp Servlet Mysql实现的在线商城项目源码附带视频指导运行教程
-
今天给大家演示一款由jspservletMySQL实现的在线商城系统,系统项目源码在【猿来入此】获取!本系统实现了管理员管理用户、商品(商品分类)、订单、留言、新闻等功能,前台会员注册登录,查看商...
- MySQL大数据表处理策略,原来一直都用错了……
-
场景当我们业务数据库表中的数据越来越多,如果你也和我遇到了以下类似场景,那让我们一起来解决这个问题。数据的插入,查询时长较长后续业务需求的扩展,在表中新增字段,影响较大表中的数据并不是所有的都为有效数...
- 基于SpringBoot 的CMS系统,拿去开发企业官网真香(附源码)
-
前言推荐这个项目是因为使用手册部署手册非常完善,项目也有开发教程视频对小白非常贴心,接私活可以直接拿去二开非常舒服开源说明系统100%开源模块化开发模式,铭飞所开发的模块都发布到了maven中央库。可...
- 「Qt入门第22篇」 数据库(二)编译MySQL数据库驱动
-
导语在上一节的末尾我们已经看到,现在可用的数据库驱动只有两类3种,那么怎样使用其他的数据库呢?在Qt中,我们需要自己编译其他数据库驱动的源码,然后当做插件来使用。下面就以现在比较流行的MySQL数据库...
- 基于SpringBoot从0到1编写一个图书管理系统(附源码)
-
项目源码地址:https://muzidong.com/productDetail/8ff44c71db6b4b6aa30c71e646b1c557需求分析基于SSM+MySql+LayUI...
- Jsp+Ssm+Mysql实现的投票管理系统源码附带视频指导配置运行教程
-
今天给大家演示的是一款由jsp+ssm框架+mysql实现的投票管理系统,系统分为前端和后台管理模块,系统项目源码在【猿来入此】获取!前端用户可以登录注册、查看投票信息,登录后可以进行投票,也可以查看...
- 一周热门
- 最近发表
-
- Linux集群自动化监控系统Zabbix集群搭建到实战
- 快速掌握Kafka系列《三》配置项总结
- 8.mxGraph 命名空间与 Hello World 示例实践.md
- 英特尔 i9-12900KS 最新爆料:基础功耗 150W,790 美元
- Spring Boot集成OAuth2:实现安全认证与授权的详细指南
- DNF人造神团本男气功加点攻略(dnf男气功用什么神话)
- Python连接Mysql数据库的几种方式以及问题排查方法
- 37【源码】数据可视化:基于 Echarts + Python 动态实时大屏
- 36【源码】数据可视化:基于 Echarts + Python 动态实时大屏
- Jsp Servlet Mysql实现的在线商城项目源码附带视频指导运行教程
- 标签列表
-
- 外键约束 oracle (36)
- oracle的row number (32)
- 唯一索引 oracle (34)
- oracle in 表变量 (28)
- oracle导出dmp导出 (28)
- oracle两个表 (20)
- oracle 数据库 字符集 (20)
- oracle安装补丁 (19)
- matlab化简多项式 (20)
- 多线程的创建方式 (29)
- 多线程 python (30)
- java多线程并发处理 (32)
- 宏程序代码一览表 (35)
- c++需要学多久 (25)
- css class选择器用法 (25)
- css样式引入 (30)
- css教程文字移动 (33)
- php简单源码 (36)
- php个人中心源码 (25)
- php小说爬取源码 (23)
- 云电脑app源码 (22)
- html画折线图 (24)
- docker好玩的应用 (28)
- linux有没有pe工具 (34)
- mysql数据库源码 (21)