Diffusion Model ์ํ์ด ํฌํจ๋ tutorial (1/2)
๋ณธ ๊ธ์ ์๋ ์์์ ๋ณด๊ณ ์ ๋ฆฌํ์์ต๋๋ค.
https://www.youtube.com/watch?v=uFoGaIVHfoE
GAN์ ์ฑ๋ฅ์ ์ด๊ฒจ๋ฒ๋ฆฐ Diffusion
์ต์ด์ ์ฐ๊ธฐ๋ฅผ ์ฐพ์๋ณด๋ ๊ฒ์ด ํต์ฌ!
์ค์ ๋ฌผ๋ฆฌ์ ์ผ๋ก ๋ถ์๊ฐ ํ์ฐ๋ ๋,
๊ฐ์ฐ์์ ๋ถํฌ ์์ ๋ค์ ์์น๊ฐ ๊ฒฐ์ ๋๋ค.
์์ ์ด๋ฏธ์ง์์ noise๋ฅผ ์ถ๊ฐํด์ ์ ์ฒด๋ฅผ noise๋ก ๋ง๋ค์ด๋ฒ๋ฆฌ๋ forward์
noise๋ฅผ ์ค์ฌ๋๊ฐ๋ฉด์ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ reverse๊ฐ ์กด์ฌํ๋ค.
์ด๋, ์์์ ์ค๋ช ํ ๊ฒ์ฒ๋ผ ์ด๋ฏธ์ง์์ noise๋ฅผ ์ถ๊ฐํ๋ ๊ณผ์ , ์ฆ ์ฐ๊ธฐ๊ฐ ํผ์ ธ ๋๊ฐ๋ ๊ณผ์ ์ ๋งค์ฐ ์ฝ๋ค.
๊ทธ๋ฌ๋ ๊ทธ ๋ฐ๋์ ๊ณผ์ ์ธ reverse๋ ์ด๋ ต๋ค.
์์ ๋ถ์๊ฐ ํ์ฐ๋ ๋ ์ค๋ช ํ ๊ฒ์ฒ๋ผ, noise๊ฐ ์ถ๊ฐ๋๋ ํ์์ ๊ฐ์ฐ์์ ๋ถํฌ๋ฅผ ๋ฐ๋ฅด๋ค.
$q(x_{t} | x_{t-1})$ = $N(x_{t}: \sqrt{1-\beta _{t}} x_{t-1}, \beta_{t}I)$
- ์ด์ step์ธ $x_{t-1}$ ์ฃผ์ด์ก์ ๋ ๋ค์ step์ธ $x_{t}$๋ฅผ ๊ณ์ฐํ๋ ์์ด๋ค.
- ๋ step์ฌ์ด๋ ๊ฐ์ฐ์์ ๋ถํฌ๋ก ์ฐ๊ฒฐ๋์ด ์๋ค.
- ํ๊ท ์ $\sqrt{1-\beta _{t}} x_{t-1}$์ด๊ณ , ๋ถ์ฐ์ $\beta_{t}I$์ด๋ค.
- ์ฌ๊ธฐ์ $\beta_{t}$์ ๊ฐ์ 0.001 ์ ๋๋ก ๋งค์ฐ ์๋ค๊ณ ํ๋ค. ์ด ๊ฐ์ด ์์์๋ก ์ด์ step๊ณผ ๋น์ทํ๊ณ , ํด์๋ก ๋ณํ๊ฐ ํฌ๋ค.
$q(x_{1:T} | x_{0}) =\prod_{1}^{t}{q(x_{t}|x_{t-1})}$
- $x_{T}$๋ฒ์งธ ์ด๋ฏธ์ง๋, $x_{0}$์์ $q$๋ฅผ ๊ณ์ ๊ณฑํ๋ฉด ๋ง๋ ๋ค๋ ๋ป์ด๋ค.
$\beta_{t}$๋ฅผ ์ ์ํ๋ ๋ฐฉ์์ ๋ ผ๋ฌธ๋ง๋ค ๋ค๋ฅด๊ณ , ์๊ทผํ ์ฑ๋ฅ์ฐจ์ด๋ ๋ง์ด ๋๋ค.
timestep $t$์ ๋ฐ๋ผ์ ๋คํธ์ํฌ๊ฐ ํด์ผ ํ๋ ์ผ์ด ๋ฌ๋ผ์ง๋๋ฐ,
๊ฐ๊ฐ ์ผ๋งํผ์ noise๊ฐ ์ถ๊ฐ๋ ์ง ๊ฒฐ์ ํด์ฃผ๋ ๊ฒ์ด ๋ฐ๋ก $beta_{t}$์ด๋ค.
์ด๋ ๋ฌ๋ฆฌ ์๊ธฐํ๋ฉด, ๊ณง ๋คํธ์ํฌ์๊ฒ ์ด๋ค ์ผ์ ํด๋ผํ๊ณ ๊ฐ์ ์ ์ผ๋ก ๋งํ๋ ๊ฒ๊ณผ ๊ฐ์์
์ฑ๋ฅ์ด ์ฐจ์ด๊ฐ ๋ ์๋ฐ์ ์๋ ๊ฒ์ด๋ค.
์ ์์ด ์ด๋ ๊ฒ ๊ตฌ์ฑ์ด ๋๋์? ๋ผ๋ ์ง๋ฌธ์ ๋๋ต์ผ๋ก๋ ์ด๋ ๊ฒ ๋ตํ ์ ์๋ค.
๋ค์ ๋ถํฌ์ ๋ฐ๋ฅด๋ฉด, $x_{t} = \sqrt{1-\beta_{t}}x_{t-1} + \sqrt{\beta_{t}}e$์ด๊ณ , $e$ ~ $N(0,I)$์ด๋ค.
$x_{t}$์ ๋ถ์ฐ์ ๊ตฌํ๋ฉด,
$Var(x_{t}) = Var( \sqrt{1-\beta_{t}}x_{t-1} + \sqrt{\beta_{t}}e )$
$= (1-\beta_{t})Var(x_{t-1}) + \beta_{t}I$
๋ง์ฝ $Var(x_{t-1})$์ ๊ฐ์ด 1์ด๋ผ๊ณ ํ๋ค๋ฉด, $Var(x_{t})$ ๋ํ 1์ด ๋๊ธฐ ๋๋ฌธ์ ์์ ๊ฐ์ด ์ ์ํ๋ ๊ฒ์ด๋ค.
ํน์ timestep์ผ๋ก ํ๋ฒ์ ๋ํ ์ ์์ง ์์๊น?๊ฐ ddpm์ ์ฃผ๋ ์๋
ํ์์คํ ์ผ๋ก ๋์ด๊ฐ๋ ์์ $x_{t} = \sqrt{\alpha_{t}}x_{0} + \sqrt{1-\alpha_{t}}e$ ์ด๋ ๊ฒ ๋ ผ๋ฌธ์์ ์ ์ํด๋์๋ค.
์ฌ๊ธฐ์ $\sqrt{\alpha_{t}}$๋ ์ด๊ธฐ์ํ $x_{0}$์ ์ ๋ณด๋ฅผ ์ผ๋ง๋ ์ ์งํ ์ง๋ฅผ ๊ฒฐ์ ํ๊ณ ,
$\sqrt{(1-\alpha_{t})}$๋ ๋ ธ์ด์ฆ์ ํฌ๊ธฐ๋ฅผ ๊ฒฐ์ ํ๋ค.
๋ฐ๋ผ์ $\alpha_{t}$๊ฐ ํฌ๋ฉด $x_{t}$๋ ์ฃผ๋ก $x_{0}$์ ์ํฅ์ ๋ฐ๊ณ , ์์ผ๋ฉด ์ฃผ๋ก noise์ ์ํฅ์ ๋ฐ์
$\alpha$ ์์ ๊ดํ ์ค๋ช ์ด๋ค.
๊ฒฐ๊ณผ์ ์ผ๋ก noise๋ฅผ ์ถ๊ฐํด๊ฐ๋ฉด์ $q(x_{T})$์ธ gaussian distribution์ ์ฐพ์๋ด๋ ๊ณผ์ ์ธ ๊ฒ์ด๋ค.
denoisingํ๋ ๊ณผ์ ๋ํ gaussian์ด๋ผ๊ณ ๊ฐ์ ํ๊ณ 'ํ๊ท '๊ณผ '๋ถ์ฐ'์ ์ฐพ๋ ๊ฒ์ด ๋ชฉํ์ธ ๊ฒ์ด๋ค.
์์ ์์์๋ $\mu_{\theta}(x_{t},t)$์ $\sigma_{t}^2I$์ธ ๊ฒ์ด๋ค.
์ฉ์ด๋ฅผ ํท๊ฐ๋ฆฌ์ง ๋ง์!!
$q$ํจ์์ ๊ฒฝ์ฐ์๋ '์ฐธ ๋ถํฌ', ์ด์ฉ๋ฉด ์ ๋ต์ด ๋๋ ๋ถํฌ๋ฅผ ์๋ฏธํ๊ณ
$p$ํจ์์ ๊ฒฝ์ฐ์๋ '๋ชจ๋ธ์ด ์์ธกํ ๋ถํฌ', ํ์ตํ ๋ถํฌ๋ฅผ ์๋ฏธํ๋ค.
๋ฐ๋ผ์, ๋ชจ๋ธ์ training์ ํตํด
์ด ๋ถํฌ์ 'ํ๊ท '๊ณผ '๋ถ์ฐ'์ ์ฐพ์๋ธ๋ค.
์ฌ๊ธฐ์ $\theta$๋ฅผ ํฌํจํ๋๋ก ํ์ฌ, neural network๋ก ํ์ต๋๋ ํ๋ฅ ๋ชจ๋ธ์์ ๋ช ์ํ๋ค.
๊ธฐ์กด ๋ ผ๋ฌธ์์ ์ ์ํ Loss ์ฌ์ฉ
$L_{T}$์์
$q(x_{T}|x_{0})๊ณผ p(x_{T})$๋ ๋ชจ๋ gaussian distribution์ด๊ธฐ ๋๋ฌธ์ KL-div๋ฅผ ๊ตฌํ๋๊ฒ ์๋ฏธ๊ฐ ์์ด์ ์ฌ์ฉ ์๋จ
$q(x_{t-1} | x_{t}, x_{0})$ ๋ํ gaussian distribution์ด๊ณ , $p_{\theta}(x_{t-1}|x_{t})$ ๋ํ ๊ทธ๋ ๊ฒ ๊ฐ์ ํ๋ฏ๋ก
๊ฐ์ฐ์์ ์ฌ์ด์ KL-div๋ก ์ ์ํ ์ ์๋ค.
์์ ์ ๊ฐํ๋ค๋ณด๋, ๋งจ ์๋ ์๊ณผ ๊ฐ์ด noise prediction๋ง ์๋ฉด loss๋ฅผ ๊ณ์ฐํ ์ ์์์ ์์๋
ddpm ๋ ผ๋ฌธ์์๋ coefficient $\mu_{t} = 1$์ผ๋ก ๊ฐ์ ํ์ง๋ง, ๋ ผ๋ฌธ์ ๋ฐ๋ผ ๋ค๋ฅด๊ฒ ์ค์ ํ์๋ผ๊ณ ์ฃผ์ฅํ๋ ๊ฒ๋ ๋ง๋ค.
(hyperparameter ๊ฐ์ ์กด์ฌ์ผ๊น..?)
๊ฐ์ฐ์์ ๋ชจ๋ธ์ ํ๊ท ์ ์์ธกํ๊ธฐ ์ํด ์์ ์ฐ๋ค๋ณด๋,
๊ฐ step ์ฌ์ด์ noise๋ฅผ ์์ธกํ๋ ๋ชจ๋ธ์ ๋ง๋ค๋ฉด ๋๋ค๋ ๊ฒฐ๋ก ์ด ๋์๋ค.
U-Net์ ๊ธฐ๋ณธ ์์ด๋์ด๋ ์ ์ฐจ์ ๋ฟ๋ง ์๋๋ผ ๊ณ ์ฐจ์ ์ ๋ณด๋ ์ด์ฉํ์ฌ ์ด๋ฏธ์ง์ ํน์ง์ ์ถ์ถํจ๊ณผ ๋์์ ์ ํํ ์์น ํ์ ๋ ๊ฐ๋ฅํ๊ฒ ํ์๋ ๊ฒ์ด๋ค. ์ด๋ฅผ ์ํด์ ์ธ์ฝ๋ฉ ๋จ๊ณ์ ๊ฐ ๋ ์ด์ด์์ ์ป์ ํน์ง์ ๋์ฝ๋ฉ ๋จ๊ณ์ ๊ฐ ๋ ์ด์ด์ ํฉ์น๋(concatenation) ๋ฐฉ๋ฒ์ ์ฌ์ฉํ๋ค. ์ธ์ฝ๋ ๋ ์ด์ด์ ๋์ฝ๋ ๋ ์ด์ด์ ์ง์ ์ฐ๊ฒฐ์ ์คํต ์ฐ๊ฒฐ(skip connection)์ด๋ผ๊ณ ํ๋ค.
https://velog.io/@lighthouse97/UNet%EC%9D%98-%EC%9D%B4%ED%95%B4
$x_{0}$์์ ์์ํ์ง๋ง noise๋ฅผ ํ๋ฒ์ ์ ํ์ $x_{t}$๋ฅผ ๋ง๋ค๊ณ ,
์ด๋ฅผ timestep $t$์ ํจ๊ป U-net์ ๋ฃ์ด์ ์ด๋ค noise๊ฐ ๋ํด์ง๊ฑด์ง predictionํ๋ ๊ตฌ์กฐ์ด๋ค.
Q. ์ $t$๋ฅผ ๊ฐ์ด ๋ฃ์ด์ฃผ๋์?
A. $x_{t}$๊ฐ ์ผ๋งํผ์ noise๊ฐ ์ถ๊ฐ๋ ์ํ์ธ์ง ์์์ผ ํ๊ธฐ ๋๋ฌธ์
+ $t$๋ฅผ ๊ทธ๋ฅ ์ซ์๋ก ๋ฃ์ด์ฃผ๋ฉด ์๋๋ฏ๋ก, embedding์ ํด์ค๋ค (sin, cos) -- like positional encoding
+ ์ผ๋ฐ์ ์ผ๋ก $t$๋ 1,000์ด์์ธ๋ฐ ๋ง์ฝ $t$๋ฅผ ๋ฐ๋ก ๋ฃ์ด์ฃผ์ง ์์ผ๋ฉด 1,000๊ฐ์ ํจ์๋ฅผ ๋ง๋ค์ด์ผ ํ๋ค๋ ์๋ฆฌ๋ค.
์ด๋ฅผ ๋ฐฉ์งํ๊ธฐ ์ํด ๋ฃ์ด์ฃผ๋ ๊ฒ์ด๋ค
ddpm์์๋ $\beta_{i}$์ $\sigma_{i}^2$๋ฅผ ๋์ผํ๊ฒ ๋ง๋ค์๋ค.
ํ์ง๋ง $\beta_{i}$๋, $\sigma_{i}^2$๋ฅผ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ผ๋ก ์ ์ํ์๋ ๋ ผ๋ฌธ ๋ํ ๋ง๋ค.
์ด๋ฐ์๋ noise์์ ์์ํด์ low-frequency content๋ฅผ ๋ง๋ค์ด๋๊ฐ๋ ๊ณผ์ ์ด๋ค.
์ด๋ ์ด๋ค content๊ฐ ์์ฑํ ์ง์ ๋ํ ๋ด์ฉ์ด๋ค. (์์ ํ๋ฆฟ -> ์ฝ๊ฐ์ ์ค๋ฃจ์ฃ)
์ผ์ ์์ค์ด ๋์ด๊ฐ๋ฉด,
detailํ ์ ๋ณด๋ค๋ง ์ถ๊ฐ๋๋ high-frequency content๋ฅผ ๋ง๋ค์ด๋๊ฐ๋ ๊ณผ์ ์ด ์ด์ด์ง๋ค.
timestep $t$์ ๋ฐ๋ผ์ ๋ด๋นํ๋ ๋ด์ฉ์ด ๋ค๋ฅด๋ค๋ ์ฌ์ค!
(์กฐ๊ธ ๋ ์์ธํ)
์ง๊ธ๊น์ง ์ฐ๋ฆฌ๋ diffusion ๋ชจ๋ธ์ ๋ชฉ์ ์ ์ดํด๋ดค๋ค.
๊ฐ์ฐ์์์ด๋ผ๊ณ ๊ฐ์ ํ reverse process์ ๋ถํฌ๋ฅผ ์ฐพ๊ธฐ ์ํด, 'ํ๊ท '์ ๊ตฌํ๋๊ฒ ์ฐ๋ฆฌ์ ๋ชฉ์ ์ด๋ค.
('๋ถ์ฐ'์ forward๋ ๊ฐ๋ค๊ณ ๊ฐ์ ํจ)
์ด๋, ๊ตฌํ๋ ๋ฐฉ๋ฒ์ noise prediction์ ์์ผ์ sampling์ ํ๋ค๋ ๊ฒ
์กฐ๊ธ ๋ ์์ธํ ์ดํด๋ด ์๋ค :)
# forward
์ $q(x_{1:T} | x_{0}) = \prod_{t=1}^{T} q(x_{t}|x_{t-1})$์ ์์ด ์ฑ๋ฆฝํ๋์ง์ ๊ดํ ์ฆ๋ช ์ด๋ค.
(markov chain)
$P(x_{t+1} | x_{0}, ..., x_{t}) = P(x_{t+1} | x_{t})$
$0$๋ฒ์งธ๋ถํฐ $t$๋ฒ์งธ๊น์ง์ ๋ชจ๋ data๊ฐ ์์ ๋ (๊ณผ๊ฑฐ),
$(t+1)$๋ฒ์งธ ๋ฐ์ดํฐ(๋ฏธ๋)๋ $t$๋ฒ์งธ ๋ฐ์ดํฐ์๋ง ์ํฅ์ ๋ฐ๋๋ค๋ ๋ป!
์ฆ, ๋ง์ฝํ ์ฒด์ธ์ ํ๋ฅ ๊ณผ์ ์ ๋ฏธ๋๊ฐ ๊ณผ๊ฑฐ์๋ ๋ ๋ฆฝ์ด๊ณ ์ค๋ก์ง ์ง์ ์์ ์๋ง ์ํฅ์ ๋ฐ๋๋ค๋ ๊ฒ
https://blog.naver.com/jinis_stat/221686989847
์์ ๊ฐ์ด ํ๋์ step์ฉ ๊ตฌํ ์ ์๋ $q(x_{t}|x_{t-1})$๋ฅผ ํตํด
ํ๋ฒ์ $x_{t}$๊น์ง ๊ตฌํ ์ ์๋ $q(x_{t}|x_{0})$๊น์ง ์ ๋ํ๋ค.
# reverse
reverse ์ ๋ํ, markov๋ฅผ ์ด์ฉํด์ ํํํ ์ ์๋ค.
# Learning Denoising Model
# VAE -> DDPM
$E_{x_{T}~g(x_{T} | x_{0})} ([-log (g(x_{1:T}|x_{0}) / p_{\theta}(x_{1}, x_{2}, ..., x_{T} | x_{0})])$
$= D_{KL}( g(x_{1:T}|x_{0}) || p_{\theta}(x_{1}, x_{2}, ..., x_{T} | x_{0})) >= 0$ ์ด๋ฏ๋ก ๋ถ๋ฑํธ๊ฐ ์ฑ๋ฆฝํ๋ค.
# DDPM Loss
๋งจ ์๋ ์๊ณผ ๊ฐ์ด, $q$๊ฐ ์ ์ ๋ฌํ ๋ถํฌ๋ฅผ ๊ฐ๊ณ ์๋์ง์ ๋ํ ์ค๋ช ์ ์๋์ ๊ฐ๋ค.
$q$์ ๋ถํฌ๋ฅผ gaussian contribution์ผ๋ก ํ๊ธฐ ๋๋ฌธ์, $exp$๋ก ํ์ด์ ์์ ์ธ ์ ์๋ ๊ฒ์ด๋ค.
(์ค๊ฐ ์ ์ ์ดํด ๋ชปํ์..)
* network architecture -- u-net ์ฌ์ฉ
* objective weighting -- (loss์์ ๋ถ์ด์๋ ๊ฐ) ddpm์์๋ 1์ ์ฌ์ฉํ์ง๋ง, ๋ ผ๋ฌธ์ ๋ฐ๋ผ ๋ค๋ฆ
์ค์ ์ฝ๋์์๋ $L_{simple}$๋ง ์๋๊ฑธ ํ์ธํ ์ ์๋ค.

์ฌ๊ธฐ๊น์ง ddpm์ ์ค๋ช ์ ๋๋ค :)