LoRA插件式微调!
LoRA插件式微调!
月伴飞鱼LoRA是高效FineTune方法的一种。
LoRA论文:https://arxiv.org/abs/2106.09685
LoRA代码: https://github.com/microsoft/LoRA
LoRA原理
大模型都是过参数化的, 当用于特定任务时, 其实只有一小部分参数起主要作用。
也就是参数矩阵维度很高, 但可以用低维矩阵分解近似。
其实这个思想与矩阵特征向量, 主成分分析, 压缩感知等有异曲同工之妙。
具体做法:
在网络中增加一个旁路结构,旁路是A和B两个矩阵相乘。
A矩阵的维度是dxr,B 矩阵的维度是rxd,其中r<<d,一般r取1,2,4,8就够了。
那么这个旁路的参数量将远远小于原来网络的参数W。
LoRA训练时, 我们冻结原来网络的参数W,只训练旁路参数A和B。
由于A和B的参数量远远小于W,那么训练时需要的显存开销就大约等于推理时的开销。
其实采用这种旁路相加的方式, 与ResNet的跳连方式也有异曲同工之妙:
原网络的参数不变, 在旁路上做些微小改变,适应特定新任务。
这样就可以让网络基本保持原来的能力,在特定任务上更精进了一步。
LoRA微调并没有改变原有的预训练参数:
只是针对特定任务微调出了新的少量参数, 新的这些参数要与原有的预训练参数配合使用。
实际使用时, 都是把旁路的参数和原来的参数直接合并, 也就是参数相加, 这样就完全不会增加推理时间。
针对不同的任务, 都可以训练出自己的LoRA参数, 然后与原本的预训练参数结合, 做成插件式的应用。
LoRA原理很简单, 代码实现也不复杂:
简单地说,在模型实现上, 要在特定的模块上加一个旁路, 这个旁路就是两个矩阵相乘的形式。
HuggingFace Pert库把各种FineTune方式都做了集成, 更加简单和方便。
HuggingFace Pert库代码: https://github.com/huggingface/peft