LoRA插件式微调!

LoRA是高效FineTune方法的一种。

LoRA论文:https://arxiv.org/abs/2106.09685

LoRA代码: https://github.com/microsoft/LoRA

LoRA原理

大模型都是过参数化的, 当用于特定任务时, 其实只有一小部分参数起主要作用。

也就是参数矩阵维度很高, 但可以用低维矩阵分解近似。

其实这个思想与矩阵特征向量, 主成分分析, 压缩感知等有异曲同工之妙。

具体做法:

在网络中增加一个旁路结构,旁路是A和B两个矩阵相乘。

A矩阵的维度是dxr,B 矩阵的维度是rxd,其中r<<d,一般r取1,2,4,8就够了。

那么这个旁路的参数量将远远小于原来网络的参数W。

LoRA训练时, 我们冻结原来网络的参数W,只训练旁路参数A和B。

由于A和B的参数量远远小于W,那么训练时需要的显存开销就大约等于推理时的开销。

其实采用这种旁路相加的方式, 与ResNet的跳连方式也有异曲同工之妙:

原网络的参数不变, 在旁路上做些微小改变,适应特定新任务。

这样就可以让网络基本保持原来的能力,在特定任务上更精进了一步。

LoRA微调并没有改变原有的预训练参数:

只是针对特定任务微调出了新的少量参数, 新的这些参数要与原有的预训练参数配合使用。

实际使用时, 都是把旁路的参数和原来的参数直接合并, 也就是参数相加, 这样就完全不会增加推理时间。

针对不同的任务, 都可以训练出自己的LoRA参数, 然后与原本的预训练参数结合, 做成插件式的应用。

LoRA原理很简单, 代码实现也不复杂:

简单地说,在模型实现上, 要在特定的模块上加一个旁路, 这个旁路就是两个矩阵相乘的形式。

HuggingFace Pert库把各种FineTune方式都做了集成, 更加简单和方便。

HuggingFace Pert库代码: https://github.com/huggingface/peft

官方博客:https://huggingface.co/blog/zh/peft