๐Ÿ“š Study/Paper Review

NeRF ๊ฐ„๋‹จ ์„ค๋ช… with ์•ฝ๊ฐ„์˜ ์ฝ”๋“œ

์œฐ๊ฐฑ 2024. 5. 12. 06:26

[Paper Review] NeRF : Representing Scenes as Neural Radiance Fields for View Synthesis (ECCV2020)

NeRF ๋ชจ๋ธ์€ ๋งŽ์€ ๋ธ”๋กœ๊ทธ์™€ ์œ ํŠœ๋ธŒ ์ž๋ฃŒ๋ฅผ ์ฐพ์•„๋ณด๋ฉฐ ์ดํ•ดํ•˜๋Š” ์ˆ˜์ค€์— ๊ทธ์ณค๋Š”๋ฐ ๋…ผ๋ฌธ์„ ์ •๋…ํ•˜๋‹ˆ ํ›จ์”ฌ ๋” ์ดํ•ด ์ •๋„๊ฐ€ ๊นŠ์–ด์ง„ ๊ธฐ๋ถ„์ด๋‹ค. ์ง์ ‘ ๊ธ€์„ ์จ๋ณด๋ฉฐ ์™„๋ฒฝํžˆ ๋‚ด ๊ฒƒ์œผ๋กœ ๋งŒ๋“ค์ž! ๋‹ค์Œ๊ณผ ๊ฐ™์€

dusruddl2.tistory.com

 โ†‘  ๋…ผ๋ฌธ ๋ฆฌ๋ทฐ๋ฅผ ํ–ˆ์—ˆ๋Š”๋ฐ
์ •๋ง ๋‚ด๊ฐ€ NeRF ๋ชจ๋ธ์„ ์ œ๋Œ€๋กœ ์ดํ•ดํ•˜๊ณ  ์žˆ๋‚˜? ์˜๋ฌธ์ด ๋“ค์–ด์„œ ์“ฐ๊ฒŒ ๋œ ํฌ์ŠคํŠธ
ํ—ท๊ฐˆ๋ ธ๋˜ ๋ถ€๋ถ„ ์œ„์ฃผ๋กœ ๊ฐ„๋‹จ ๋ฆฌ๋ทฐํ•  ์˜ˆ์ •์ด๋‹ค.
 

nerf/tiny_nerf.ipynb at master ยท bmild/nerf

Code release for NeRF (Neural Radiance Fields). Contribute to bmild/nerf development by creating an account on GitHub.

github.com

์ฝ”๋“œ๋Š” NeRF ์ฝ”๋“œ ์ค‘์—์„œ tiny_nerf.ipynb๋ฅผ ์ฐธ๊ณ ํ•˜์˜€๋‹ค.
 


NeRF

NeRF๋ชจ๋ธ์€ ์—ฌ๋Ÿฌ ๋ทฐ์—์„œ ์ดฌ์˜ํ•œ ์ด๋ฏธ์ง€๋“ค์„ input์œผ๋กœ ๋ฐ›๊ณ 
์ด๋ฅผ ํ†ตํ•ด ๊ฐ€์ƒ์˜ 3D ๊ณต๊ฐ„์„ ๋งŒ๋“œ๋Š” ๋ชจ๋ธ์ด๋‹ค.
 
์—ฌ๊ธฐ์„œ ๊ฐ€์ƒ์˜ 3D ๊ณต๊ฐ„์ด๋ผ๊ณ  ํ‘œํ˜„์„ ํ•œ ์ด์œ ๋Š”,
point cloud, voxel, mesh์™€ ๊ฐ™์ด ์‹ค์ œ๋กœ ์กด์žฌํ•˜๋Š” 3D ๊ณต๊ฐ„๊ณผ ๋‹ฌ๋ฆฌ
NeRF๋ชจ๋ธ์—์„œ๋Š” ์ž„์˜์˜ ์‹œ์ ์—์„œ ์ฐ์€ 2D image๋ฅผ ๋ Œ๋”๋งํ•  ์ˆ˜ ์žˆ๋‹ค๋ฉด 3D ๊ณต๊ฐ„์„ ์ƒ์„ฑํ–ˆ๋‹ค๋ผ๊ณ  ์ •์˜ํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.
(์œ„์˜ ๊ทธ๋ฆผ optimize NeRF์—์„œ ๋“œ๋Ÿผ์ด ๋ฐ˜ํˆฌ๋ช…์œผ๋กœ ํ‘œ์‹œ๋œ ์ด์œ ๊ฐ€ ์‹ค์ œ๋กœ ์กด์žฌํ•˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— ๊ทธ๋ ‡๋‹ค)
 
์˜ˆ๋ฅผ ๋“ค์–ด ์„ค๋ช…ํ•˜๋ฉด,
A, B, C ๋ฐฉํ–ฅ์œผ๋กœ ๋ฌผ์ฒด๋ฅผ ๋ฐ”๋ผ๋ณธ ์ด๋ฏธ์ง€๋“ค์„ input์œผ๋กœ ๋„ฃ์–ด์„œ
NeRF๋ฅผ ํ•™์Šต์‹œํ‚ค๋ฉด
๊ทธ ์–ด๋– ํ•œ ์ƒˆ๋กœ์šด ๋ฐฉํ–ฅ(D,E,F,...)์—์„œ๋„ ์›ํ•˜๋Š” ์ด๋ฏธ์ง€๋„ ์–ป์–ด๋‚ผ ์ˆ˜ ์žˆ๋‹ค!
์šฐ๋ฆฌ๋Š” ์ด๊ฒƒ์„ 3D ๊ณต๊ฐ„์„ ์ƒ์„ฑํ–ˆ๋‹ค๊ณ  ๋งํ•œ๋‹ค. (explicit ๋ฐฉ๋ฒ•๊ณผ ๋‹ฌ๋ฆฌ ์‹ค์ œ๋กœ ์กด์žฌํ•˜๋Š”๊ฒŒ x)
 
 


NeRF Overview

 

(a)

input image๊ฐ€ ๋ฌผ์ฒด๋ฅผ ์–ด๋А ๋ฐฉํ–ฅ์—์„œ ๋ณด๊ณ  ์žˆ๋Š”์ง€ (Direction)  ๊ตฌํ•œ๋‹ค. (def get_rays)
(์ด ๊ณผ์ •์€ ์ด๋ฏธ์ง€์™€ ์นด๋ฉ”๋ผ ํŒŒ๋ผ๋ฏธํ„ฐ์ธ ํฌ์ฆˆ๊ฐ’์„ ํ•จ๊ป˜ ์ด์šฉํ•œ๋‹ค. -- ๋‘˜๋‹ค ์ฒ˜์Œ์— ์ฃผ์–ด์ง)

 
๋ฐฉํ–ฅ์ •๋ณด๋ฅผ ๊ตฌํ–ˆ๋‹ค๋Š”๊ฑด ์นด๋ฉ”๋ผ ์ขŒํ‘œ๊ณ„(2D)์—์„œ ์‹ค์ œ๊ณต๊ฐ„(3D)๋กœ ์ด๋™ํ–ˆ๋‹ค๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•˜๊ณ ,
์œ„์˜ ๊ทธ๋ฆผ์ฒ˜๋Ÿผ ๊ด‘์„ ์˜ ์œ„์น˜์™€ ๊ด‘์„ ์˜ ๋ฐฉํ–ฅ์„ ์•Œ์•„๋ƒˆ๋‹ค๊ณ  ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋‹ค.
(get_rays, rays_o, rays_d๋Š” ์ฝ”๋“œ์—์„œ ๋‚˜์˜ด)
 

 
3D ๊ณต๊ฐ„์œผ๋กœ ์˜ฎ๊ฒจ์ง€๋ฉด, ์ด์ œ ๊ด‘์„ ์˜ ์œ„์น˜๋ฅผ ๊ธฐ์ค€์œผ๋กœ N๊ฐœ์˜ ์ ๋“ค์„ samplingํ•œ๋‹ค.
๊ทธ๋ฆฌ๊ณ  ๋ชจ๋“  ์ ๋“ค์— ๋Œ€ํ•ด ๊ฐ๊ฐ  (x,y,z)๋กœ ์–ด๋–ป๊ฒŒ ํ‘œํ˜„ํ•˜๋Š”์ง€, ์œ„์น˜ ์ •๋ณด(Position)์„ ๊ตฌํ•œ๋‹ค.
 

(b)

(a)์—์„œ ๊ตฌํ•œ Direction๊ณผ Position์„
MLP๋ฅผ ํ†ต๊ณผ์‹œ์ผœ Color(RGB)์™€ ๋ฐ€๋„๊ฐ’(Density)์„ ๊ตฌํ•œ๋‹ค.
 
 

(c)

์šฐ๋ฆฌ๋Š” (b)์—์„œ ๊ตฌํ•œ Color์™€ Density๊ฐ’์ด input image๋ฅผ ์ž˜ ํ‘œํ˜„ํ–ˆ๋Š”์ง€ ํ™•์ธํ•ด๋ด์•ผ ํ•œ๋‹ค.
๋”ฐ๋ผ์„œ, Color์™€ Density๋ฅผ ์ด์šฉํ•ด 2D image๋ฅผ ๋งŒ๋“ค์–ด๋‚ด๋Š” ๊ณผ์ •์ธ Volume Rendering์„ ์ง„ํ–‰ํ•œ๋‹ค.
 

(โ˜…ํ—ท๊ฐˆ๋ ธ๋˜ ๋ถ€๋ถ„: volume rendering์ด ๋ฌด์Šจ ์—ญํ• ์„ ํ•˜๋Š”์ง€ โ˜…)

์œ„์˜ overview (a)์—์„œ input image ์‚ฌ์ง„์ด ๋‚˜์˜ค๋‹ค ๋ณด๋‹ˆ๊นŒ
์ด๋ฅผ MLP์— ๋„ฃ์–ด ํ•™์Šตํ–ˆ๋‹ค๊ณ  ์ž ์‹œ ํ˜ผ๋ž€์ด ์™”์—ˆ๋‹ค.

๊ทธ๊ฒŒ ์•„๋‹ˆ๋ผ ์šฐ๋ฆฌ๋Š” input image์˜ ๋ฐฉํ–ฅ์ •๋ณด๋งŒ์„ ์•Œ๊ณ  ์‹œ์ž‘ํ•œ๋‹ค.
์ดํ›„์— ์ƒ˜ํ”Œ๋ง์„ ํ†ตํ•ด ray์œ„ ์ ๋“ค์˜ ์œ„์น˜์ •๋ณด๋ฅผ ๊ตฌํ•˜๊ณ 
์ด๋ฅผ MLP์— ํ†ต๊ณผ์‹œ์ผœ ์ปฌ๋Ÿฌ๊ฐ’๊ณผ ๋ฐ€๋„๊ฐ’์„ ๊ตฌํ•œ ๊ฒƒ์ด๋‹ค.

๊ทธ๋Ÿฌ๋‹ˆ๊นŒ GT(Ground Truth) ์ด๋ฏธ์ง€๋Š” input image๊ฐ€ ๋˜๋Š” ๊ฒƒ์ด๊ณ 
์šฐ๋ฆฌ๋Š” MLP๋ฅผ ํ†ตํ•ด ๊ตฌํ•œ ์ปฌ๋Ÿฌ๊ฐ’๊ณผ ๋ฐ€๋„๊ฐ’์„ ์ด์šฉํ•ด 2D image๋ฅผ ๋งŒ๋“ค์–ด
๋‘˜์„ ๋น„๊ตํ•˜๋Š” ๋ฐฉํ–ฅ์œผ๋กœ MLP์˜ weight๋ฅผ ์—…๋ฐ์ดํŠธํ•˜๋Š” ๊ฒƒ์ด๋‹ค.

 
Volume Rendering์€ ์•„๋ž˜ ์ˆ˜์‹์„ ํ†ตํ•ด
๊ฐ ray์— ์žˆ๋Š” sampling๋œ ์ ๋“ค์˜ color๊ฐ’๊ณผ density๊ฐ’์„ ๋ˆ„์ ํ•ฉํ•˜์—ฌ ํ•˜๋‚˜์˜ color๊ฐ’์„ ๊ตฌํ•˜๋Š” ๊ฒƒ์ด๋‹ค.
(์ˆ˜์‹์˜ ์ž์„ธํ•œ ์„ค๋ช…์€ ๋…ผ๋ฌธ๋ฆฌ๋ทฐ ์ฐธ๊ณ )
 

 

(โ˜…ํ—ท๊ฐˆ๋ ธ๋˜ ๋ถ€๋ถ„: ์œ„์˜ ์ˆ˜์‹์—์„œ๋Š” color๊ฐ’์ด ํ•˜๋‚˜๋งŒ ๋‚˜์˜ค๋Š”๋ฐ ์ด๊ฑธ๋กœ ์–ด๋–ป๊ฒŒ 2D ์ด๋ฏธ์ง€๋ฅผ ๋งŒ๋“ค์–ด? โ˜…)

(a)์™€ (b)์˜ ๋‹จ๊ณ„๋Š” ๋‹ค ํ”ฝ์…€ ๋‹จ์œ„๋กœ ์ด๋ฃจ์–ด์ง„๋‹ค.
๋”ฐ๋ผ์„œ, ๊ฐ๊ฐ์˜ ํ”ฝ์…€์— ๋Œ€ํ•ด์„œ volume rendering์„ ํ†ตํ•ด ์ตœ์ข… color๊ฐ’์„ ๊ตฌํ•˜๊ฒŒ ๋˜๋Š” ๊ฒƒ์ด๊ณ 
์ด๋ฅผ ๋‹ค ๋ชจ์œผ๋ฉด ํ•˜๋‚˜์˜ 2D ์ด๋ฏธ์ง€๊ฐ€ ๋˜๋Š” ๊ฒƒ์ด๋‹ค.

(์ฝ”๋“œ์—์„œ ray_o์™€ ray_d์˜ shape์„ ํ†ตํ•ด ์ด๋ฅผ ์ฆ๋ช…ํ•  ์ˆ˜ ์žˆ๋‹ค.)

 

(d)

GT์ด๋ฏธ์ง€์™€
Volume Rendering๋œ 2D ์ด๋ฏธ์ง€ ์‚ฌ์ด์˜ loss๊ฐ’์„ ๊ตฌํ•˜๊ณ 
์ด๋ฅผ ์ค„์ด๋Š” ๋ฐฉํ–ฅ์œผ๋กœ weight๋ฅผ ์—…๋ฐ์ดํŠธํ•œ๋‹ค.
 




์ฝ”๋“œ๋กœ ์ดํ•ดํ•˜๊ธฐ

Load Input Images and Poses

data = np.load('tiny_nerf_data.npz')
images = data['images']
poses = data['poses']
focal = data['focal']
H, W = images.shape[1:3]
print(images.shape, poses.shape, focal)

testimg, testpose = images[101], poses[101]
images = images[:100,...,:3]
poses = poses[:100]

plt.imshow(testimg)
plt.show()

 
input image๋“ค ์ค‘ ํ•˜๋‚˜์ธ ์œ„์˜ ์ด๋ฏธ์ง€๋ฅผ ๊ธฐ์ค€์œผ๋กœ ์„ค๋ช…ํ•˜๊ฒ ๋‹ค.

์ด๋ฏธ์ง€์˜ ์ฐจ์›(shape): B = 106, H = 100, W = 100, C = 3

 

์นด๋ฉ”๋ผ ํŒŒ๋ผ๋ฏธํ„ฐ pose์˜ ์ฐจ์›(shape): B = 106, 4x4 ๋ณ€ํ™˜ ํ–‰๋ ฌ

> ๋ณ€ํ™˜ํ–‰๋ ฌ์€ ์นด๋ฉ”๋ผ์˜ ์œ„์น˜์™€ ๋ฐฉํ–ฅ์„ ์ •์˜

 

(a)

def get_rays(H, W, focal, c2w):
    i, j = tf.meshgrid(tf.range(W, dtype=tf.float32), tf.range(H, dtype=tf.float32), indexing='xy')
    dirs = tf.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -tf.ones_like(i)], -1)
    rays_d = tf.reduce_sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
    rays_o = tf.broadcast_to(c2w[:3,-1], tf.shape(rays_d))
    return rays_o, rays_d

ํ•จ์ˆ˜ get_rays๋Š” ์ด๋ฏธ์ง€๋ฅผ ํ†ตํ•ด Direction๋ฅผ ์–ป๋Š” ์ฝ”๋“œ์ด๋‹ค.
 
+ ํŠธ๋Ÿญ ์ด๋ฏธ์ง€๋ฅผ get_rays์— ํ†ต๊ณผ์‹œํ‚จ ๊ฒฐ๊ณผ์˜ shape์„ ์ƒ๊ฐํ•ด๋ณด๋ฉด,

rays_o, rays_d = get_rays(H,W,focal,pose)
rays_o์˜ ์ฐจ์›(shape): (106, H, W, 3)
rays_d์˜ ์ฐจ์›(shape):  (106, H, W, 3)

>> RGB ๊ฐ๊ฐ์— ๋Œ€ํ•ด์„œ๋Š” (106, H, W)์ด ๋™์ผํ•  ๊ฒƒ

์ด๋Š” ๋ชจ๋‘ ํ”ฝ์…€ ๋‹จ์œ„๋กœ ์ด๋ฃจ์–ด์ง„๋‹ค๋Š” ๊ฒƒ์„ ํ™•์ธํ•ด๋ณผ ์ˆ˜ ์žˆ๋‹ค.
 

(a) & (b) & (c)

def render_rays(network_fn, rays_o, rays_d, near, far, N_samples, rand=False):

    def batchify(fn, chunk=1024*32):
        return lambda inputs : tf.concat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
    
    #-----------(a) ์ƒ˜ํ”Œ๋ง์„ ํ†ตํ•ด Position ๊ตฌํ•˜๊ธฐ
    # Compute 3D query points
    z_vals = tf.linspace(near, far, N_samples) # N_samples๊ฐœ์˜ ์ ๋“ค์„ ์ƒ˜ํ”Œ๋ง
    if rand:
      z_vals += tf.random.uniform(list(rays_o.shape[:-1]) + [N_samples]) * (far-near)/N_samples
    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] #๊ฐ ์ƒ˜ํ”Œ ํฌ์ธํŠธ์˜ ์œ„์น˜์ •๋ณด๋ฅผ ์–ป์Œ
    
    
    #-----------(b) MLP๋ฅผ ๊ฑฐ์ณ Color & Density ๊ตฌํ•˜๊ธฐ
    # Run network
    pts_flat = tf.reshape(pts, [-1,3])
    pts_flat = embed_fn(pts_flat)
    raw = batchify(network_fn)(pts_flat)
    raw = tf.reshape(raw, list(pts.shape[:-1]) + [4])
    
    # Compute opacities and colors
    sigma_a = tf.nn.relu(raw[...,3])
    rgb = tf.math.sigmoid(raw[...,:3]) 
    
    #-----------(c) ๋ณผ๋ฅจ๋ Œ๋”๋ง Volume Rendering
    # Do volume rendering
    dists = tf.concat([z_vals[..., 1:] - z_vals[..., :-1], tf.broadcast_to([1e10], z_vals[...,:1].shape)], -1) 
    alpha = 1.-tf.exp(-sigma_a * dists)  
    weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
    
    rgb_map = tf.reduce_sum(weights[...,None] * rgb, -2) 
    depth_map = tf.reduce_sum(weights * z_vals, -1) 
    acc_map = tf.reduce_sum(weights, -1)

    return rgb_map, depth_map, acc_map

ํ•จ์ˆ˜ render_rays์—๋Š”
(a)์—์„œ์˜ Position์„ ๊ตฌํ•˜๋Š” ๊ณผ์ •๊ณผ
(b)์—์„œ์˜ Color & Density๋ฅผ ๊ตฌํ•˜๋Š” ๊ณผ์ •๊ณผ
(c) ๋ณผ๋ฅจ๋ Œ๋”๋ง(Volume Rendering)์ด ๋ชจ๋‘ ํฌํ•จ๋˜์–ด ์žˆ๋‹ค.
 

(d)

model = init_model()
optimizer = tf.keras.optimizers.Adam(5e-4)

N_samples = 64
N_iters = 1000
psnrs = []
iternums = []
i_plot = 25

import time
t = time.time()
for i in range(N_iters+1):
    
    img_i = np.random.randint(images.shape[0])
    target = images[img_i]
    pose = poses[img_i]
    rays_o, rays_d = get_rays(H, W, focal, pose)
  #-----------(d) loss๋ฅผ ๊ตฌํ•ด gradient ์—…๋ฐ์ดํŠธ
    with tf.GradientTape() as tape:
        rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples, rand=True)
        loss = tf.reduce_mean(tf.square(rgb - target))
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    if i%i_plot==0:
        print(i, (time.time() - t) / i_plot, 'secs per iter')
        t = time.time()
        
        # Render the holdout view for logging
        rays_o, rays_d = get_rays(H, W, focal, testpose)
        rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples)
        loss = tf.reduce_mean(tf.square(rgb - testimg))
        psnr = -10. * tf.math.log(loss) / tf.math.log(10.)

        psnrs.append(psnr.numpy())
        iternums.append(i)
        
        plt.figure(figsize=(10,4))
        plt.subplot(121)
        plt.imshow(rgb)
        plt.title(f'Iteration: {i}')
        plt.subplot(122)
        plt.plot(iternums, psnrs)
        plt.title('PSNR')
        plt.show()

print('Done')

 
์œ„์˜ ์ฝ”๋“œ๋ฅผ ํ†ตํ•ด epoch๋ฅผ 25 ๊ฐ„๊ฒฉ์œผ๋กœ ๊ฒฐ๊ณผ๋ฅผ ๊ณ„์† ์ถœ๋ ฅํ•ด๋ณด๋ฉด
์•„๋ž˜ ์‚ฌ์ง„๊ณผ ๊ฐ™์ด ํ›ˆ๋ จ์„ ๋” ๋งŽ์ด ํ•จ์— ๋”ฐ๋ผ ํ€„๋ฆฌํ‹ฐ๊ฐ€ ์ข‹์•„์ง์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.
(ํ•™์Šต ๊ณผ์ • ์ค‘ ์ผ๋ถ€ ๋ฐœ์ทŒ)

 
 


 
 

 
 ๋ฌผ๋ก  ์ด ๊ธ€์„ ์“ฐ๋ฉฐ ์ฐธ๊ณ ํ•œ ์ฝ”๋“œ๋Š” tiny_nerf.ipynb๋กœ
์ตœ์ข… nerf๋ชจ๋ธ๊ณผ๋Š” ์ผ๋ถ€ ์ฐจ์ด๊ฐ€ ์žˆ์ง€๋งŒ
๊ทธ๋ž˜๋„ ํ—ท๊ฐˆ๋ ธ๋˜ ๋ถ€๋ถ„์„ ์ดํ•ดํ•˜๋Š”๋ฐ๋Š” ํฐ ๋„์›€์ด ๋˜์—ˆ๋‹ค :)
 
(2024.05.12 ์ž‘์„ฑ)