QLoRA: 16-bit์ ์ฑ๋ฅ์ ์ ์งํ๋ฉด์ 65B๊ฐ์ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐ์ง ๋ชจ๋ธ์ Single 48GB GPU์ ์ฌ๋ ค finetuning ํ ์ ์๊ฒ ํ๋ค.
# Contribution
QLoRA ๋ฐฉ๋ฒ๋ก
1. 4-bit NormalFloat(NF4): ์ ๊ท๋ถํฌ๋ ๊ฐ์ค์น์ ๋ํด ์ ๋ณด ์ด๋ก ์ ์ผ๋ก ์ต์ ์ธ ์๋ก์ด ๋ฐ์ดํฐ ํ์
2. Double Quantization: ์์ํ ์์๋ฅผ ๋ค์ ์์ํํจ์ผ๋ก์จ ํ๊ท ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ์ ๊ฐ
3. Paged Optimizers: ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ด ๊ธ์ฆํ๋ ์ํฉ์ ํจ๊ณผ์ ์ผ๋ก ์ ์ด
# Introduction
LLM์ Finetuning ํ๋๊ฑด ํน์ ๋๋ฉ์ธ์์์ ์ฑ๋ฅ์ ํฅ์์ํค๊ธฐ ์ํด ํ์ํ ๊ณผ์ ์ด๋ค.
๊ธฐ์กด์๋ 16-bit finetuning์ ํ๊ธฐ ์ํด์๋ LLaMA 65B ๊ธฐ์ค์ผ๋ก, 780GB ํฌ๊ธฐ์ GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ ํ์ํ์๋ค.
๊ทธ๋ฌ๋ QLoRA๋ runtime์ด๋ predictive performance์ ์ฑ๋ฅ ์ ํ ์์ด๋ ์ค์ง 48GB ํฌ๊ธฐ์ ๋ฉ๋ชจ๋ฆฌ๋ง ํ์ํ๋ค.
์์ contribution์์ ์๊ธฐํ๋ฏ์ด ํต์ฌ์ ์ธ๊ฐ์ง๋ค.
1. 4-bit NormalFloat
๊ธฐ์กด์ 4-bit quantizaiton ๋ฐฉ์์ธ 4-bit Integer, 4-bit Float quantization ๋ฐฉ์์ ๊ท ๋ฑํ๊ฒ ๊ฐ๊ฒฉ์ ๋๋์ง๋ง, ๋๋ถ๋ถ weight ๋ ์ ๊ท๋ถํฌ (normal distribution, N(0, σ²)) ๋ฅผ ๋ฐ๋ฅธ๋ค.
๋นํธ๋ง๋ค ํํํ๋ ๊ฐ์ ๋ฒ์๋ฅผ "๊ท ๋ฑ"ํ๊ฒ ๋๋๊ธฐ๋ณด๋ค๋, ์ค์ ๋ถํฌ์ ๋ง๊ฒ ๊ฐ์ด ๋ง์ด ๋ชฐ๋ฆฐ ๋ถ๋ถ์ ๋นํธ๋ฅผ ๋ ํ ๋นํ๋ ๋ฐฉ์์ด ๋ฐ๋ก 4-bit NormalFloat์ด๋ค. ์ ๊ท ๋ถํฌ์ ์ต์ ํ๋์๊ธฐ์ ์ ์ ๋นํธ๋ก๋ ๋์ ์ ํ๋๋ฅผ ์ ์งํ๋ค.
2. Double Quantization
์๋ ์์ํ(quantization)๋ ํ๋ผ๋ฏธํฐ ๊ฐ์ ์์ถํ๋ ๋ฐฉ๋ฒ์ธ๋ฐ,
QLoRA ๋ ํ ๋จ๊ณ ๋ ๋์๊ฐ์ ์์ํ์ ์ฌ์ฉ๋๋ ์์๋ค์กฐ์ฐจ ๋ค์ ์์ํํ๋ค.
์ด๋ฅผ ํตํด ํ๋ผ๋ฏธํฐ๋น ํ๊ท 0.37bit ์ ๋๋ฅผ ์ ๊ฐํ ์ ์์ผ๋ฉฐ, ๋๊ท๋ชจ ๋ชจ๋ธ (์: 65B ๋ชจ๋ธ) ์์๋ ์ฝ 3GB ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์๋ ์ ์๋ค.
3. Paged Optimizers
๊ธฐ์กด Optimizer๋ mini-batch ํฌ๊ธฐ๊ฐ ์ปค์ง๊ฑฐ๋ ์ํ์ค ๊ธธ์ด๊ฐ ๊ธธ์ด์ง๋ฉด ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ์ด ํญ๋ฐํ๋ค.
ํนํ gradient checkpointing ๊ฐ์ ๊ธฐ์ ์ ์ฌ์ฉํ ๋ ๋ฉ๋ชจ๋ฆฌ spike ํ์์ด ์๊ธด๋ค.
๋ฐ๋ผ์ NVIDIA์ Unified Memory ๊ธฐ๋ฅ์ ํ์ฉํด์ optimizer state ๋ฅผ CPU ๋ฉ๋ชจ๋ฆฌ๋ก ๋๊ฒจ ์ฌ์ฉํ๋ค.
์๋ก์ด ๋ฐ๊ฒฌ
1. ๋ฐ์ดํฐ ํ์ง์ด ์๋ณด๋ค ํจ์ฌ ์ค์ํ๋ค
9k sample dataset(OASST1)์ 450k sample dataset(FLAN v2, subsampled)๋ณด๋ค ์ฑ๋ด ์ฑ๋ฅ์์ ๋ ์ข์๋ค.
2. MMLU ์ฑ๋ฅ์ด ์ฑ๋ด ์ฑ๋ฅ์ ๋ณด์ฅํ์ง ์๋๋ค
MMLU (Massive Multitask Language Understanding) ๋ฒค์น๋งํฌ์์ ์ ์๊ฐ ๋๋ค๊ณ ํด์ Vicuna ๊ฐ์ ์ฑ๋ด ๋ฒค์น๋งํฌ์์ ๋ฌด์กฐ๊ฑด ์ํ๋ ๊ฒ์ ์๋์๋ค.
์ถ๊ฐ ๋ถ์
์ฌ๋ ํ๊ฐ์ + GPT-4 ๋ฅผ ํจ๊ป ์ฌ์ฉํด์ ํ ๋๋จผํธ ๋ฐฉ์์ผ๋ก ๋ชจ๋ธ๋ค์ ์๋ก ๋๊ฒฐ์์ผ ํ๊ฐํ๋ค. (์ฃผ์ด์ง ํ๋กฌํํธ์ ๋ํด ์ด๋ ๋ชจ๋ธ์ด ๋ ๋์ ๋ต๋ณ์ ์์ฑํ๋์ง)
ํ ๋๋จผํธ ๊ฒฐ๊ณผ๋ Elo ์ ์๋ก ์ง๊ณ๋์ด ์ฑ๋ด์ ์ฑ๋ฅ ์์๊ฐ ๋งค๊ฒจ์ง๋ค.
๊ฒฐ๊ณผ๋ ๋์ฒด๋ก GPT-4์ ์ฌ๋ ํ๊ฐ๊ฐ ์ผ์นํ์์ง๋ง, ๊ทธ๋ ์ง ์์ ๊ฒฝ์ฐ๋ ์์๋ค.
# Background
Block-wise k-bit Quantization
๋ฐ์ดํฐ๋ฅผ ๋ ์ ์ ๋นํธ ์๋ก ํํํ๋ ๋ฐฉ๋ฒ์ด๋ค.
๋ฐ์ดํฐ ์ ์ฒด์์ ๊ฐ์ฅ ํฐ ๊ฐ์ผ๋ก ์ ๊ทํ(normalize) ํ๊ณ ์ค์ผ์ผ๋ง(scale) ํด์ 8๋นํธ๋ก ํํ
๋จ์
๋ง์ฝ์ ์ ๋ ฅ ๋ฐ์ดํฐ ์ค์ ๋๋ฌด ํฐ ๊ฐ (outlier, ์ด์์น) ์ด ์์ผ๋ฉด,
์ ์ฒด ๋ฐ์ดํฐ๋ฅผ ์ต๋๊ฐ์ ๋ง์ถฐ ์ ๊ทํํ๊ธฐ ๋๋ฌธ์:
- ๋๋ถ๋ถ์ "ํ๋ฒํ ๊ฐ๋ค" ์ ์์ํ ๋ฒ์์ ์์ฃผ ์ข์ ๋ถ๋ถ์ ๋ชฐ๋ฆฌ๊ฒ ๋๋ค.
- ๋ฐ๋ฉด์ ํฐ ๊ฐ (outlier) ๋๋ฌธ์ ์์ํ ๋นํธ (quantization bins) ๊ฐ ์ ๋๋ก ํ์ฉ๋์ง ์๊ฒ ๋๋ค.
ex.
$X^{FP32} = [2.0,-1.0,0.0,8.0]$
$absmax(X^{FP32}) = max(|2.0|,|-1.0|,|0.0|,|8.0|) = 8.0$
ํด๊ฒฐ์ฑ
๋ฐ์ดํฐ๋ฅผ ๋ธ๋ก(block) ์ผ๋ก ๋๋๊ณ , ๊ฐ ๋ธ๋ก๋ง๋ค ๋ฐ๋ก ์์ํ ํ๋ ๋ฐฉ์์ ์ฌ์ฉํ๋ค.
๋ธ๋ก๋ง๋ค ๊ฐ๊ฐ ์ต๋๊ฐ์ ๊ธฐ์ค์ผ๋ก ์ ๊ทํํ๋ฏ๋ก outlier์ ์ํฅ์ ์ค์ด๊ณ ๋นํธ ์กฐํฉ์ ๋ ์ ํ์ฉํ ์ ์๋ค.
Low-rank Adapters
๊ธฐ์กด ํฐ ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ๋ค์ ๊ทธ๋๋ก ๋๊ณ , ์๊ณ ํ์ต ๊ฐ๋ฅํ "Adapter (์ ์ฐจ์ ํ๋ ฌ)" ๋ง ์ถ๊ฐ๋ก ํ์ตํ๋ ๋ฐฉ๋ฒ์ด๋ค.
- $X$: ์ ๋ ฅํ ์ ($X \in R^{b*h}$)
- $W$: ์ ์ฒด ํ๋ผ๋ฏธํฐ ($W \in R^{h*o}$)
- $L_1$: ์ฐจ์์ ์ถ์ํ๋ projection (down-projection) ($ L_1 \in R^{h*r}$)
- $L_2$: ์ฐจ์์ ํ์ฅํ๋ projection (down-projection) ($ L_2 \in R^{r*o}$)
Memory Requirement of Parameter-Efficient Finetuning
LoRA ๋ ํ์ตํด์ผ ํ ํ๋ผ๋ฏธํฐ ์๋ฅผ ์ค์ธ ํ์ ์ ์ธ PEFT ๋ฐฉ์์ด์ง๋ง, ์ค์ ํ์ธํ๋์์๋ ํ๋ผ๋ฏธํฐ ์์ฒด๋ณด๋ค ํ์ต ์ค ์์ฑ๋๋ ์ค๊ฐ ๊ณ์ฐ๊ฐ(activation gradients)์ด ํจ์ฌ ๋ ๋ง์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฐจ์งํ๋ค. ์๋ฅผ ๋ค์ด, FLAN v2 ๋ฐ์ดํฐ์ ์ผ๋ก ํ์ตํ๋ 7B LLaMA ๋ชจ๋ธ์์๋ LoRA ํ๋ผ๋ฏธํฐ๊ฐ ์๋ ๋ชจ๋ธ์ ์ฝ 0.2% ์์ค์ธ 26MB ์ ๋ถ๊ณผํ์ง๋ง, input gradients ๋ 567MB ๋ฅผ ์ฐจ์งํ๋ค. ์ค๊ฐ activation ๊ฐ์ ์ ์ฅํ์ง ์๊ณ ํ์ํ ๋๋ง๋ค ๋ค์ ๊ณ์ฐํ๋ gradient checkpointing ๊ธฐ๋ฒ์ ์ ์ฉํ๋ฉด, input gradients ๋ฅผ 567MB ์์ 18MB ๋ก ํฌ๊ฒ ์ค์ผ ์ ์๋ค.
์ด๋ฅผ ํตํด, LoRA parameter ๋ฅผ ๋ ์ค์ด๋ ๊ฒ์ ์ ์ฒด ๋ฉ๋ชจ๋ฆฌ์์ ํฐ ํจ๊ณผ๊ฐ ์๊ณ , ๊ทธ ๋์ adapter ์๋ฅผ ๋๋ ค ์ฑ๋ฅ์ ๋์ด๋๋ผ๋ ๋ฉ๋ชจ๋ฆฌ ๋ถ๋ด์ ํฌ์ง ์๋ค๋ ๊ฒ์ ์ ์ ์๋ค. ์ด๋ฐ ์ค๊ณ๊ฐ full 16-bit precision ์ฑ๋ฅ ๋ณต์์ ํต์ฌ์ ์ธ ์ญํ ์ ํ๋ค.
# QLoRA Finetuning
1. 4-bit NormalFloat Quantization
๋ชจ๋ธ์ weight ๋ถํฌ ํน์ฑ์ ํ์ฉํด์ ํจ์จ์ ์ด๊ณ ์ ํํ๊ฒ 4-bit ์์ํ ๋ฅผ ์ํํ๋ ๋ฐฉ๋ฒ์ด๋ค.
- ์ผ๋ฐ์ ์ธ ์์ํ๋ ๊ท ์ผํ๊ฒ ๊ฐ์ ๋๋์ง๋ง, ๋ชจ๋ธ weight ๋ ๋๋ถ๋ถ ์ ๊ท๋ถํฌ(Normal Distribution) ๋ฅผ ๋ฐ๋ฅด๋ฏ๋ก ๋นํจ์จ์ ์ด๋ค.
- NormalFloat ์ ๋ถํฌ์ ๋ง์ถฐ ๋นํธ๋ฅผ ๋ฐฐ๋ถํ์ฌ ์์ฃผ ๋ฑ์ฅํ๋ ๊ฐ ๊ทผ์ฒ๋ ๋ ์ด์ดํ๊ฒ, ํฌ๊ทํ ๊ฐ์ ์ ๊ฒ ๋นํธ๋ฅผ ์ฌ์ฉํ๋ค.
์์ํ ๋จ๊ณ
1. ์ด๋ก ์ ์ธ ์ ๊ท๋ถํฌ $N(0,1)$์ ๋ํด $2^k+1$๊ฐ์ quantile์ ๋ฏธ๋ฆฌ ๊ณ์ฐํ๋ค. -> k-bit ์์ํ์ฉ ๋ฐ์ดํฐ ํ์ ์์ฑ
: ๋งค๋ฒ ๋ชจ๋ธ weight์ ๋ฐ๋ผ ๊ณ์ฐํ์ง ์์๋ ๋๋ฏ๋ก ๋น ๋ฅด๊ณ ํจ์จ์ ์ด๋ค
2. ์ด quantile๋ค์ [-1,1] ๋ฒ์๋ก ์ ๊ทํํ๋ค.
3. ์ ๋ ฅ weight tensor ๋ absolute max rescaling์ ํตํด [-1,1] ๋ฒ์๋ก ๋ง์ถ๋ค.
๋์นญ์ ์ธ k-bit quantization ๋ฐฉ์์์๋ 0 ๊ฐ์ ์ ํํ ํํํ ์ ์๋ ๋ฌธ์ ๊ฐ ์๋ค
ํ์ง๋ง ํจ๋ฉ(padding)์ด๋ ๋ค๋ฅธ 0 ๊ฐ ์์๋ค์ ์ค์ฐจ ์์ด ์์ํํ๋ ค๋ฉด 0 ๊ฐ์ ์ ํํ ํํ์ด ๋งค์ฐ ์ค์ํ๋ค.
์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์ฐ๋ฆฌ๋ ๋น๋์นญ(asymmetric) ๋ฐ์ดํฐ ํ์
์ ์ฌ์ฉํ๋ค.
๊ตฌ์ฒด์ ์ผ๋ก, ์์ ๊ตฌ๊ฐ์๋ $2^{k-1}$ ๊ฐ, ์์ ๊ตฌ๊ฐ์๋ $2^{k-1} + 1$๊ฐ์ ๋ถ์์๋ฅผ ์ถ์ ํ๋ค.
๊ทธ ํ ์ด ๋ ๋ถ์์ ์งํฉ์ ํตํฉํ๊ณ , ์์์ ์์ ๊ตฌ๊ฐ์์ ์ค๋ณต๋ 0 ์ ํ๋ ์ ๊ฑฐํ๋ค.
๊ทธ ๊ฒฐ๊ณผ๋ก, ๋ชจ๋ $2^{k}$ ๋นํธ๋ฅผ ํ์ฉํ๋ฉด์๋ bin ๋ง๋ค ๊ธฐ๋๊ฐ์ด ๋์ผํ๋๋ก ๋ฐฐ๋ถ๋ ๋ฐ์ดํฐ ํ์ ์ ๋ง๋ค ์ ์๊ณ , ์ด๋ฌํ ๋ฐ์ดํฐ ํ์ ์ k-bit NormalFloat (NFk) ๋ผ๊ณ ๋ถ๋ฅธ๋ค.
2. Double Quantization
Double Quantization (DQ) ์ ์์ํ ์์ ์์ฒด๋ฅผ ๋ค์ ํ ๋ฒ ์์ํํด์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ถ๊ฐ๋ก ์ค์ด๋ ๊ธฐ๋ฒ์ด๋ค,
๋ฐฐ๊ฒฝ
์๋ weight ์์ํ ์, ๊ฐ ๋ธ๋ก๋ง๋ค ์์ํ ์์ (scale factor) ๊ฐ ํ์ํ๋ค.
ํ์ง๋ง ๋ธ๋ก ์ฌ์ด์ฆ๊ฐ ์์์๋ก ์ ๋ฐ๋๋ ์ข์์ง์ง๋ง, ์์ํ ์์ ๊ฐ์๊ฐ ๋ง์์ ธ์ ๋ฉ๋ชจ๋ฆฌ ์ค๋ฒํค๋๊ฐ ๋ฐ์ํ๋ค.
(ex. 32-bit ์์ ์ฌ์ฉ, ๋ธ๋ก ์ฌ์ด์ฆ 64 → ํ๊ท ํ๋ผ๋ฏธํฐ๋น 0.5 bit ์๋ชจ)
1. 1๋จ๊ณ: ๊ธฐ๋ณธ weight ์์ํ (1์ฐจ ์์ํ)
์๋ float weight๋ฅผ quantizationํ๋ ค๋ฉด scale factor๊ฐ ํ์ํ๋ค
$w_{int} = round(w_{fp32}/c_2)$
- $c_2$: block๋ง๋ค ์กด์ฌํ๋ scale factor(quantization constant)
- $c_2$ ๋ค์ ์ ์ฅํด์ผ ํ๋๋ฐ, ๋ณดํต FP32 (32-bit float) ๋ก ์ ์ฅํ๋ค.
block size๊ฐ 64๋ผ๊ณ ํ ๋,
$32 bits / 64 parameters = 0.5bits/parameter$
2. 2๋จ๊ณ: scale factor ๋ ์์ํ (2์ฐจ ์์ํ)
$c_2$ ๋ํ block์ผ๋ก ๋ฌถ์ด ์์ํํ๋๊ฒ ๋ฐ๋ก double quantization > ์ ์ฅ ๋น์ฉ์ ๋ ์ค์ผ ์ ์๋ค.
$c_{2}^{int8}= round(c_2 -\mu_{c_2}/c_1)$
- $c_{2}^{int8}$: ์์ํ๋ $c_2$, 8-bit๋ก ์ ์ฅ
- $\mu_{c_2}$: $c_2$๋ค์ ํ๊ท (mean clustering)
: ํ๊ท ์ ๋นผ๋ ์ด์ ) $c_2$ ๊ฐ ์์ ๊ฐ์ด๋ฏ๋ก, ๋์นญ์ ์ผ๋ก ๋ง๋ค์ด์ quantization ํจ์จ ๋์ - $c_1$: ๋ ๋ฒ์งธ quantization์ scale factor
ํจ๊ณผ
block size๊ฐ 64๋ผ๊ณ ํ ๋,
๋ฉ๋ชจ๋ฆฌ ์ ๊ฐ: $32/64 = 0.5 bits$ -> $8/64 + 32/(64 · 256) = 0.127 bits$
$c_2$๋ฅผ ๋ block์ผ๋ก ๋ฌถ์ด์ ์์ํํจ.
$c_2$๋ค์ 256๊ฐ์ฉ ๋ฌถ์ด์
scale factor $c_1$ 1๊ฐ
์์ํ๋ $c_2^{int8}๊ฐ 256๊ฐ
ํ๊ท ์ ์ผ๋ก ํ๋ผ๋ฏธํฐ๋น 0.373 bit ์ ์ฝํ ์ ์๋ค.
3. Paged Optimizers
Paged Optimizer ๋ NVIDIA Unified Memory ๊ธฐ๋ฅ์ ํ์ฉํ๋ค. ์ด ๊ธฐ๋ฅ์ GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ ๋ถ์กฑํ ๋, CPU ๋ฉ๋ชจ๋ฆฌ(RAM) ์ GPU ๋ฉ๋ชจ๋ฆฌ ๊ฐ์ ์๋์ผ๋ก "ํ์ด์ง ๋จ์ ์ ์ก(page-to-page transfer)" ์ ํด์ฃผ๋ ๊ธฐ๋ฅ์ด๋ค.
์ฝ๊ฒ ๋งํ๋ฉด:
- ์ฐ๋ฆฌ๊ฐ ๋ณดํต PC ์์ RAM ์ด ๋ถ์กฑํ๋ฉด, ๋์คํฌ(HDD/SSD) ๋ก ์ค์ํ์ด ์ผ์ด๋๋ฏ์ด
- GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ ๋ถ์กฑํ๋ฉด, CPU RAM ์ผ๋ก ์๋์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ์ฎ๊ฒผ๋ค๊ฐ ํ์ํ ๋ ๋ค์ ๊ฐ์ ธ์ค๋ ๋ฐฉ์์ด๋ค.
# QLoRA
QLoRA Linear Layer ์๋ฆฌ
- ์ ์ฅ ์:
- Weight: 4-bit NormalFloat (NF4)
- Scale factor: Double Quantization (block size 256)
- ๊ณ์ฐ ์:
- ๋ณต์๋ BF16 precision ์ผ๋ก ์ฐ์ฐ
- LoRA adapter (BF16) ๋ฅผ ํตํด fine-tuning
- gradient:
- base weight ์ ๋ํด gradient ์ ์ฅ ์ ํจ
- adapter weight ์ ๋ํด์๋ง gradient ์ ์ฅ
# Evaluation
QLoRA๋ฅผ ํตํด ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ง์ด ์ค์๋๋ฐ๋ ์ ์๊ฐ ๊ฑฐ์ ์ฐจ์ด๊ฐ ์๋ค๋ ๊ฒ์ ๋ณผ ์ ์์๋ค.
Reference
์๊ฐ์ ์ผ๋ก ๋์์ ๋ฐ์ ์์
https://www.youtube.com/watch?v=6l8GZDPbFn8
https://www.youtube.com/watch?v=aZPAqBov3tQ
https://www.youtube.com/watch?v=XpoKB3usmKc