๐Ÿ“š Study/AI

[๋”ฅ๋Ÿฌ๋‹๊ณผ ์„ค๊ณ„] VAE(Variational AutoEncoder)

์œฐ๊ฐฑ 2024. 7. 10. 22:19

๋ณธ ๊ธ€์€ ์•„๋ž˜ ์˜์ƒ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์ž‘์„ฑํ•˜์˜€์Šต๋‹ˆ๋‹ค.

https://www.youtube.com/watch?v=GbCAwVVKaHY&list=PLQASD18hjBgyLqK3PgXZSp5FHmME7elWS&index=10

 


 

# Variational Autoencoders(VAE)

 

AE ๊ฐ™์€ ๊ฒฝ์šฐ์—๋Š” encoder๊ฐ€ ์ค‘์š”ํ•œ ๋ฐ˜๋ฉด, VAE๋Š” decoder๊ฐ€ ๋” ์ค‘์š”ํ•˜๋‹ค.

์ฆ‰, AE๋Š” ์ฐจ์›์„ ์ถ•์†Œํ•˜๋Š” ๊ฒŒ ์ค‘์š”ํ•˜๊ณ , VAE๋Š” ์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๊ฒƒ์ด ์ค‘์š”ํ•˜๋‹ค.

 

encoder๊ฐ™์€ ๊ฒฝ์šฐ์—๋Š” ๋ฐ”๋กœ $z$๋ฅผ ๊ตฌํ•˜๋Š”๊ฒŒ ์•„๋‹ˆ๋ผ

ํ‰๊ท  $\mu$์™€ ๋ถ„์‚ฐ $\sigma$๋ฅผ ๋ฝ‘์•„๋‚ธ ํ›„ ์ƒ˜ํ”Œ๋งํ•ด์„œ $z$๋ฅผ ๊ตฌํ•œ๋‹ค.

๊ทธ๋ฆฌ๊ณ  ์—ฌ๊ธฐ์„œ samplingํ•˜๋Š” ๊ณผ์ •์—์„œ Reparameterization Trick $e$์„ ์‚ฌ์šฉํ•ด์•ผ backpropagation์ด ๊ฐ€๋Šฅํ•˜๋‹ค.

 

 

loss function์€ 2๊ฐ€์ง€์˜ ํ•ฉ์œผ๋กœ ์ด๋ฃจ์–ด์ง„๋‹ค. (1) Reconstruction Error (2) Regularization

(1)

์ผ๋‹จ, input์ด ๊ทธ๋Œ€๋กœ ๋ณต์›๋  ์ˆ˜ ์žˆ๊ฒŒ ํ•˜๋Š” Reconstruction Error๊ฐ€ ์ตœ์†Œํ™”๋  ๋•Œ

$p_{i}$๊ฐ€ ๋ฒ ๋ฅด๋ˆ„์ด ๋ถ„ํฌ๋ฅผ ๋”ฐ๋ฅธ๋‹ค๊ณ  ๊ฐ€์ •ํ•˜๋ฉด Cross Entropy๋กœ ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ๋‹ค.

(2)

๋˜ํ•œ, encoder๋ฅผ ํ†ต๊ณผํ•œ ๊ฐ’์ด normal distribution์„ ๋”ฐ๋ผ์•ผ ํ•˜๊ธฐ ๋•Œ๋ฌธ์— KL divergence ์‹์„ ์“ฐ๋Š”๋ฐ

์ด๋Š” $KL(q_{\emptyset}(z|x_{i}) || p(z))$๋ฅผ ์˜๋ฏธํ•˜๊ณ  ๋‘˜ ์‚ฌ์ด์˜ ์ฐจ์ด๋ฅผ ์ตœ์†Œํ™”ํ•˜๊ธฐ ์œ„ํ•œ ์‹์ด๋‹ค.

 


# Loss Function

 

์ผ๋‹จ ์šฐ๋ฆฌ๊ฐ€ ์•Œ๊ณ  ์‹ถ์€ ๊ฒƒ์€ ์ตœ์ ํ™”๋œ $\theta$์ธ $\theta*$์ด๋‹ค

์ด๋ฅผ ์–ด๋–ป๊ฒŒ ํ›ˆ๋ จ์‹œํ‚ฌ ์ˆ˜ ์žˆ์„๊นŒ?

 

 

 

p_{\theta}(x|z)$๋Š” decoder NN์ด๋ผ๊ณ  ํ–ˆ์„ ๋•Œ

z๋Š” ๋ฌดํ•œํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋ชจ๋“  z์— ๋Œ€ํ•ด ์ ๋ถ„ํ•  ์ˆ˜ ์—†์–ด์„œ $p_{\theta}(x)$๋ฅผ ๊ตฌํ•  ์ˆ˜ ์—†๋‹ค.

๊ทธ๋ž˜์„œ ๋‹ค๋ฅธ ์‹์œผ๋กœ ๊ตฌํ•ด๋ณด๊ธฐ๋กœ ํ–ˆ๋‹ค.

๋งŒ์•ฝ $p_{\theta}(z|x)$๋ฅผ ๊ตฌํ•ด๋ณด๋ฉด ์–ด๋–จ๊นŒ?

์ด ์‹์„ ๋‹ค์‹œ ์จ๋ณด๋ฉด, $p_{\theta}(x|z) * p_{\theta}(z) / p_{\theta}(x)$์ธ๋ฐ

์—ฌ๊ธฐ์„œ ๋‹ค์‹œ $p_{\theta}(z|x)$๋ฅผ ๋ชจ๋ฅด๊ธฐ ๋•Œ๋ฌธ์— ๊ตฌํ•  ์ˆ˜ ์—†๋‹ค.

 

๊ทธ๋Ÿผ ์–ด๋–ป๊ฒŒ ํ•ด๊ฒฐํ• ๊นŒ? decoder๋ฅผ ๊ตฌํ•˜๊ธฐ ์œ„ํ•ด encoder๋ฅผ ์ถ”๊ฐ€๋กœ ์ •์˜ํ•˜๊ธฐ๋กœ ํ–ˆ๋‹ค.

์ฆ‰, ์—ฌ๊ธฐ์„œ๋Š” decoder๋ฅผ modelingํ•˜๊ธฐ ์œ„ํ•ด์„œ encoder $q_{\emptyset}(z|x)$๋ฅผ ์ •์˜ํ•œ ๊ฒƒ์ด๋‹ค.

๊ทธ๋ฆฌ๊ณ  ์ด $q_{\emptyset}(z|x)$๋Š” $p_{\emptyset}(z|x)$์— ๊ทผ์‚ฌํ•˜๋Š” ๊ฐ’์ด๋‹ค.

 

ํ•ต์‹ฌ์€ ์›๋ž˜ decoder๋งŒ ์žˆ์œผ๋ฉด ๋˜๋Š”๋ฐ ์ด๋ฅผ ํ•™์Šต์‹œํ‚ค๊ธฐ ์œ„ํ•ด์„œ encoder์˜ ๋„์›€์„ ๋ฐ›์€ ๊ฒƒ์ด๋‹ค.

 

 

 

๊ตฌ์ฒด์ ์ธ ํ’€์ด๊ณผ์ •์€ ์œ„์˜ ์‹๊ณผ ๊ฐ™๋‹ค.

 

 

๋”ฐ๋ผ์„œ, $log(p_{\theta}(x^(i)))$๋ฅผ ์ตœ๋Œ€ํ™”ํ•  ๋•Œ, $L(x^(i), \theta, \emptyset)$๋ฅผ lower bound๋กœ ์„ค์ •ํ•˜๋Š” ๊ฒƒ์ด๋‹ค.

 

 

 

๊ฒฐ๋ก ์ ์œผ๋กœ๋Š”,

$L(x^(i), \theta, \emptyset)$๋ฅผ ์ตœ๋Œ€ํ™”ํ•ด์ฃผ๋Š” $\theta*$์™€ $\emptyset*$์„ ์ฐพ์œผ๋ฉด ๋œ๋‹ค.

์ด๋•Œ $\theta$๋Š” $\emptyset$์€ encoder ์†์˜ W,b์ด๊ณ , decoder ์†์˜ W,b์ด๋‹ค.

 

 

 

 

 


# Optimization

 

 Regularization์„ ๊ตฌํ•˜๋Š” ์‹์„ ๊ตฌํ•ด๋ณด์ž.

์ด๋•Œ ์œ„์™€ ๊ฐ™์€ ๊ฐ€์ • 2๊ฐœ๋ฅผ ์„ค์ •ํ•œ๋‹ค.

 

100% ์ดํ•ด๋œ๊ฑด ์•„๋‹Œ๊ฑฐ ๊ฐ™์€๋ฐ ์ผ๋‹จ ๋„˜์–ด๊ฐˆ๊ฒŒ..

 

 

L์„ 1์ด๋ผ๊ณ  ๊ฐ€์ •์„ ํ•ด๋ฒ„๋ฆผ. randomํ•˜๊ฒŒ ๋”ฑ ํ•˜๋‚˜์˜ sampling์„ ํ–ˆ๋‹ค๊ณ  ๊ฐ€์ •

 

 

 

 

ํ•ต์‹ฌ์€, ํ™•๋ฅ  ๋ถ„ํฌ๋ฅผ ๋ฒ ๋ฅด๋ˆ„์ด๋ผ๊ณ  ๊ฐ€์ •ํ•ด์„œ loss ๊ตฌํ•  ์ˆ˜ ์žˆ์Œ

 

 

ํ™•๋ฅ  ๋ถ„ํฌ๋ฅผ ๊ฐ€์šฐ์‹œ์•ˆ์œผ๋กœ ๊ฐ€์ •ํ•˜๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์€ loss๋ฅผ ๊ตฌํ•  ์ˆ˜ ์žˆ์Œ

 

 

์œ„์™€ ๊ฐ™์ด Decoder์˜ ๋ถ„ํฌ๋ฅผ ๋ฒ ๋ฅด๋ˆ„์ด๊ฐ€ ์•„๋‹Œ ๊ฐ€์šฐ์‹œ์•ˆ์œผ๋กœ๋„ ํ‘œํ˜„ํ•  ์ˆ˜๋„ ์žˆ์Œ

 

 

๋‹น์—ฐํžˆ latent space์˜ ์ฐจ์›์„ ๋Š˜๋ฆด์ˆ˜๋ก ์ด๋ฏธ์ง€๊ฐ€ ์ž˜ ๋ณต์›๋˜๋Š” ๊ฒƒ์„ ๋ณผ ์ˆ˜ ์žˆ์Œ

 

 

 

๋‹ค์Œ ์ฑ•ํ„ฐ์˜ GAN์ด VAE์— ๋น„ํ•ด ์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ๋“ค์„ ๋” ์ž˜ ์ƒ์„ฑํ•œ๋‹ค.

๋”ฐ๋ผ์„œ, VAE๋Š” decoder ๋•Œ๋ฌธ์— ์ œ์•ˆ๋œ๊ฑด๋ฐ ์‹ค์ œ ์—ฐ๊ตฌ์—์„œ๋Š” encoder๋ฅผ ๋งŽ์ด ์‚ฌ์šฉํ•จ

 

 

๊ทธ๋ฆฌ๊ณ  ์ฝ”๋“œ๋Š” regularization๊ณผ reconstruction error๋ฅผ ๊ฐœ์„ ํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ์งค ์ˆ˜ ์žˆ์Œ