[Paper Review] NeRF : Representing Scenes as Neural Radiance Fields for View Synthesis (ECCV2020)
NeRF ๋ชจ๋ธ์ ๋ง์ ๋ธ๋ก๊ทธ์ ์ ํ๋ธ ์๋ฃ๋ฅผ ์ฐพ์๋ณด๋ฉฐ ์ดํดํ๋ ์์ค์ ๊ทธ์ณค๋๋ฐ ๋ ผ๋ฌธ์ ์ ๋ ํ๋ ํจ์ฌ ๋ ์ดํด ์ ๋๊ฐ ๊น์ด์ง ๊ธฐ๋ถ์ด๋ค. ์ง์ ๊ธ์ ์จ๋ณด๋ฉฐ ์๋ฒฝํ ๋ด ๊ฒ์ผ๋ก ๋ง๋ค์! ๋ค์๊ณผ ๊ฐ์
dusruddl2.tistory.com
โ ๋
ผ๋ฌธ ๋ฆฌ๋ทฐ๋ฅผ ํ์๋๋ฐ
์ ๋ง ๋ด๊ฐ NeRF ๋ชจ๋ธ์ ์ ๋๋ก ์ดํดํ๊ณ ์๋? ์๋ฌธ์ด ๋ค์ด์ ์ฐ๊ฒ ๋ ํฌ์คํธ
ํท๊ฐ๋ ธ๋ ๋ถ๋ถ ์์ฃผ๋ก ๊ฐ๋จ ๋ฆฌ๋ทฐํ ์์ ์ด๋ค.
nerf/tiny_nerf.ipynb at master ยท bmild/nerf
Code release for NeRF (Neural Radiance Fields). Contribute to bmild/nerf development by creating an account on GitHub.
github.com
์ฝ๋๋ NeRF ์ฝ๋ ์ค์์ tiny_nerf.ipynb๋ฅผ ์ฐธ๊ณ ํ์๋ค.
NeRF

NeRF๋ชจ๋ธ์ ์ฌ๋ฌ ๋ทฐ์์ ์ดฌ์ํ ์ด๋ฏธ์ง๋ค์ input์ผ๋ก ๋ฐ๊ณ
์ด๋ฅผ ํตํด ๊ฐ์์ 3D ๊ณต๊ฐ์ ๋ง๋๋ ๋ชจ๋ธ์ด๋ค.
์ฌ๊ธฐ์ ๊ฐ์์ 3D ๊ณต๊ฐ์ด๋ผ๊ณ ํํ์ ํ ์ด์ ๋,
point cloud, voxel, mesh์ ๊ฐ์ด ์ค์ ๋ก ์กด์ฌํ๋ 3D ๊ณต๊ฐ๊ณผ ๋ฌ๋ฆฌ
NeRF๋ชจ๋ธ์์๋ ์์์ ์์ ์์ ์ฐ์ 2D image๋ฅผ ๋ ๋๋งํ ์ ์๋ค๋ฉด 3D ๊ณต๊ฐ์ ์์ฑํ๋ค๋ผ๊ณ ์ ์ํ๊ธฐ ๋๋ฌธ์ด๋ค.
(์์ ๊ทธ๋ฆผ optimize NeRF์์ ๋๋ผ์ด ๋ฐํฌ๋ช
์ผ๋ก ํ์๋ ์ด์ ๊ฐ ์ค์ ๋ก ์กด์ฌํ์ง ์๊ธฐ ๋๋ฌธ์ ๊ทธ๋ ๋ค)
์๋ฅผ ๋ค์ด ์ค๋ช
ํ๋ฉด,
A, B, C ๋ฐฉํฅ์ผ๋ก ๋ฌผ์ฒด๋ฅผ ๋ฐ๋ผ๋ณธ ์ด๋ฏธ์ง๋ค์ input์ผ๋ก ๋ฃ์ด์
NeRF๋ฅผ ํ์ต์ํค๋ฉด
๊ทธ ์ด๋ ํ ์๋ก์ด ๋ฐฉํฅ(D,E,F,...)์์๋ ์ํ๋ ์ด๋ฏธ์ง๋ ์ป์ด๋ผ ์ ์๋ค!
์ฐ๋ฆฌ๋ ์ด๊ฒ์ 3D ๊ณต๊ฐ์ ์์ฑํ๋ค๊ณ ๋งํ๋ค. (explicit ๋ฐฉ๋ฒ๊ณผ ๋ฌ๋ฆฌ ์ค์ ๋ก ์กด์ฌํ๋๊ฒ x)
NeRF Overview

(a)
input image๊ฐ ๋ฌผ์ฒด๋ฅผ ์ด๋ ๋ฐฉํฅ์์ ๋ณด๊ณ ์๋์ง (Direction) ๊ตฌํ๋ค. (def get_rays)
(์ด ๊ณผ์ ์ ์ด๋ฏธ์ง์ ์นด๋ฉ๋ผ ํ๋ผ๋ฏธํฐ์ธ ํฌ์ฆ๊ฐ์ ํจ๊ป ์ด์ฉํ๋ค. -- ๋๋ค ์ฒ์์ ์ฃผ์ด์ง)

๋ฐฉํฅ์ ๋ณด๋ฅผ ๊ตฌํ๋ค๋๊ฑด ์นด๋ฉ๋ผ ์ขํ๊ณ(2D)์์ ์ค์ ๊ณต๊ฐ(3D)๋ก ์ด๋ํ๋ค๋ ๊ฒ์ ์๋ฏธํ๊ณ ,
์์ ๊ทธ๋ฆผ์ฒ๋ผ ๊ด์ ์ ์์น์ ๊ด์ ์ ๋ฐฉํฅ์ ์์๋๋ค๊ณ ์ดํดํ ์ ์๋ค.
(get_rays, rays_o, rays_d๋ ์ฝ๋์์ ๋์ด)

3D ๊ณต๊ฐ์ผ๋ก ์ฎ๊ฒจ์ง๋ฉด, ์ด์ ๊ด์ ์ ์์น๋ฅผ ๊ธฐ์ค์ผ๋ก N๊ฐ์ ์ ๋ค์ samplingํ๋ค.
๊ทธ๋ฆฌ๊ณ ๋ชจ๋ ์ ๋ค์ ๋ํด ๊ฐ๊ฐ (x,y,z)๋ก ์ด๋ป๊ฒ ํํํ๋์ง, ์์น ์ ๋ณด(Position)์ ๊ตฌํ๋ค.
(b)
(a)์์ ๊ตฌํ Direction๊ณผ Position์
MLP๋ฅผ ํต๊ณผ์์ผ Color(RGB)์ ๋ฐ๋๊ฐ(Density)์ ๊ตฌํ๋ค.
(c)
์ฐ๋ฆฌ๋ (b)์์ ๊ตฌํ Color์ Density๊ฐ์ด input image๋ฅผ ์ ํํํ๋์ง ํ์ธํด๋ด์ผ ํ๋ค.
๋ฐ๋ผ์, Color์ Density๋ฅผ ์ด์ฉํด 2D image๋ฅผ ๋ง๋ค์ด๋ด๋ ๊ณผ์ ์ธ Volume Rendering์ ์งํํ๋ค.
(โ ํท๊ฐ๋ ธ๋ ๋ถ๋ถ: volume rendering์ด ๋ฌด์จ ์ญํ ์ ํ๋์ง โ )
์์ overview (a)์์ input image ์ฌ์ง์ด ๋์ค๋ค ๋ณด๋๊น
์ด๋ฅผ MLP์ ๋ฃ์ด ํ์ตํ๋ค๊ณ ์ ์ ํผ๋์ด ์์๋ค.
๊ทธ๊ฒ ์๋๋ผ ์ฐ๋ฆฌ๋ input image์ ๋ฐฉํฅ์ ๋ณด๋ง์ ์๊ณ ์์ํ๋ค.
์ดํ์ ์ํ๋ง์ ํตํด ray์ ์ ๋ค์ ์์น์ ๋ณด๋ฅผ ๊ตฌํ๊ณ
์ด๋ฅผ MLP์ ํต๊ณผ์์ผ ์ปฌ๋ฌ๊ฐ๊ณผ ๋ฐ๋๊ฐ์ ๊ตฌํ ๊ฒ์ด๋ค.
๊ทธ๋ฌ๋๊น GT(Ground Truth) ์ด๋ฏธ์ง๋ input image๊ฐ ๋๋ ๊ฒ์ด๊ณ
์ฐ๋ฆฌ๋ MLP๋ฅผ ํตํด ๊ตฌํ ์ปฌ๋ฌ๊ฐ๊ณผ ๋ฐ๋๊ฐ์ ์ด์ฉํด 2D image๋ฅผ ๋ง๋ค์ด
๋์ ๋น๊ตํ๋ ๋ฐฉํฅ์ผ๋ก MLP์ weight๋ฅผ ์ ๋ฐ์ดํธํ๋ ๊ฒ์ด๋ค.
Volume Rendering์ ์๋ ์์์ ํตํด
๊ฐ ray์ ์๋ sampling๋ ์ ๋ค์ color๊ฐ๊ณผ density๊ฐ์ ๋์ ํฉํ์ฌ ํ๋์ color๊ฐ์ ๊ตฌํ๋ ๊ฒ์ด๋ค.
(์์์ ์์ธํ ์ค๋ช
์ ๋
ผ๋ฌธ๋ฆฌ๋ทฐ ์ฐธ๊ณ )

(โ ํท๊ฐ๋ ธ๋ ๋ถ๋ถ: ์์ ์์์์๋ color๊ฐ์ด ํ๋๋ง ๋์ค๋๋ฐ ์ด๊ฑธ๋ก ์ด๋ป๊ฒ 2D ์ด๋ฏธ์ง๋ฅผ ๋ง๋ค์ด? โ )
(a)์ (b)์ ๋จ๊ณ๋ ๋ค ํฝ์ ๋จ์๋ก ์ด๋ฃจ์ด์ง๋ค.
๋ฐ๋ผ์, ๊ฐ๊ฐ์ ํฝ์ ์ ๋ํด์ volume rendering์ ํตํด ์ต์ข color๊ฐ์ ๊ตฌํ๊ฒ ๋๋ ๊ฒ์ด๊ณ
์ด๋ฅผ ๋ค ๋ชจ์ผ๋ฉด ํ๋์ 2D ์ด๋ฏธ์ง๊ฐ ๋๋ ๊ฒ์ด๋ค.
(์ฝ๋์์ ray_o์ ray_d์ shape์ ํตํด ์ด๋ฅผ ์ฆ๋ช ํ ์ ์๋ค.)
(d)
GT์ด๋ฏธ์ง์
Volume Rendering๋ 2D ์ด๋ฏธ์ง ์ฌ์ด์ loss๊ฐ์ ๊ตฌํ๊ณ
์ด๋ฅผ ์ค์ด๋ ๋ฐฉํฅ์ผ๋ก weight๋ฅผ ์
๋ฐ์ดํธํ๋ค.
์ฝ๋๋ก ์ดํดํ๊ธฐ
Load Input Images and Poses
data = np.load('tiny_nerf_data.npz')
images = data['images']
poses = data['poses']
focal = data['focal']
H, W = images.shape[1:3]
print(images.shape, poses.shape, focal)
testimg, testpose = images[101], poses[101]
images = images[:100,...,:3]
poses = poses[:100]
plt.imshow(testimg)
plt.show()

input image๋ค ์ค ํ๋์ธ ์์ ์ด๋ฏธ์ง๋ฅผ ๊ธฐ์ค์ผ๋ก ์ค๋ช
ํ๊ฒ ๋ค.
์ด๋ฏธ์ง์ ์ฐจ์(shape): B = 106, H = 100, W = 100, C = 3
์นด๋ฉ๋ผ ํ๋ผ๋ฏธํฐ pose์ ์ฐจ์(shape): B = 106, 4x4 ๋ณํ ํ๋ ฌ
> ๋ณํํ๋ ฌ์ ์นด๋ฉ๋ผ์ ์์น์ ๋ฐฉํฅ์ ์ ์
(a)
def get_rays(H, W, focal, c2w):
i, j = tf.meshgrid(tf.range(W, dtype=tf.float32), tf.range(H, dtype=tf.float32), indexing='xy')
dirs = tf.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -tf.ones_like(i)], -1)
rays_d = tf.reduce_sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
rays_o = tf.broadcast_to(c2w[:3,-1], tf.shape(rays_d))
return rays_o, rays_d
ํจ์ get_rays๋ ์ด๋ฏธ์ง๋ฅผ ํตํด Direction๋ฅผ ์ป๋ ์ฝ๋์ด๋ค.
+ ํธ๋ญ ์ด๋ฏธ์ง๋ฅผ get_rays์ ํต๊ณผ์ํจ ๊ฒฐ๊ณผ์ shape์ ์๊ฐํด๋ณด๋ฉด,
rays_o, rays_d = get_rays(H,W,focal,pose)
rays_o์ ์ฐจ์(shape): (106, H, W, 3)
rays_d์ ์ฐจ์(shape): (106, H, W, 3)
>> RGB ๊ฐ๊ฐ์ ๋ํด์๋ (106, H, W)์ด ๋์ผํ ๊ฒ
์ด๋ ๋ชจ๋ ํฝ์
๋จ์๋ก ์ด๋ฃจ์ด์ง๋ค๋ ๊ฒ์ ํ์ธํด๋ณผ ์ ์๋ค.
(a) & (b) & (c)
def render_rays(network_fn, rays_o, rays_d, near, far, N_samples, rand=False):
def batchify(fn, chunk=1024*32):
return lambda inputs : tf.concat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
#-----------(a) ์ํ๋ง์ ํตํด Position ๊ตฌํ๊ธฐ
# Compute 3D query points
z_vals = tf.linspace(near, far, N_samples) # N_samples๊ฐ์ ์ ๋ค์ ์ํ๋ง
if rand:
z_vals += tf.random.uniform(list(rays_o.shape[:-1]) + [N_samples]) * (far-near)/N_samples
pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] #๊ฐ ์ํ ํฌ์ธํธ์ ์์น์ ๋ณด๋ฅผ ์ป์
#-----------(b) MLP๋ฅผ ๊ฑฐ์ณ Color & Density ๊ตฌํ๊ธฐ
# Run network
pts_flat = tf.reshape(pts, [-1,3])
pts_flat = embed_fn(pts_flat)
raw = batchify(network_fn)(pts_flat)
raw = tf.reshape(raw, list(pts.shape[:-1]) + [4])
# Compute opacities and colors
sigma_a = tf.nn.relu(raw[...,3])
rgb = tf.math.sigmoid(raw[...,:3])
#-----------(c) ๋ณผ๋ฅจ๋ ๋๋ง Volume Rendering
# Do volume rendering
dists = tf.concat([z_vals[..., 1:] - z_vals[..., :-1], tf.broadcast_to([1e10], z_vals[...,:1].shape)], -1)
alpha = 1.-tf.exp(-sigma_a * dists)
weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
rgb_map = tf.reduce_sum(weights[...,None] * rgb, -2)
depth_map = tf.reduce_sum(weights * z_vals, -1)
acc_map = tf.reduce_sum(weights, -1)
return rgb_map, depth_map, acc_map
ํจ์ render_rays์๋
(a)์์์ Position์ ๊ตฌํ๋ ๊ณผ์ ๊ณผ
(b)์์์ Color & Density๋ฅผ ๊ตฌํ๋ ๊ณผ์ ๊ณผ
(c) ๋ณผ๋ฅจ๋ ๋๋ง(Volume Rendering)์ด ๋ชจ๋ ํฌํจ๋์ด ์๋ค.
(d)
model = init_model()
optimizer = tf.keras.optimizers.Adam(5e-4)
N_samples = 64
N_iters = 1000
psnrs = []
iternums = []
i_plot = 25
import time
t = time.time()
for i in range(N_iters+1):
img_i = np.random.randint(images.shape[0])
target = images[img_i]
pose = poses[img_i]
rays_o, rays_d = get_rays(H, W, focal, pose)
#-----------(d) loss๋ฅผ ๊ตฌํด gradient ์
๋ฐ์ดํธ
with tf.GradientTape() as tape:
rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples, rand=True)
loss = tf.reduce_mean(tf.square(rgb - target))
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
if i%i_plot==0:
print(i, (time.time() - t) / i_plot, 'secs per iter')
t = time.time()
# Render the holdout view for logging
rays_o, rays_d = get_rays(H, W, focal, testpose)
rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples)
loss = tf.reduce_mean(tf.square(rgb - testimg))
psnr = -10. * tf.math.log(loss) / tf.math.log(10.)
psnrs.append(psnr.numpy())
iternums.append(i)
plt.figure(figsize=(10,4))
plt.subplot(121)
plt.imshow(rgb)
plt.title(f'Iteration: {i}')
plt.subplot(122)
plt.plot(iternums, psnrs)
plt.title('PSNR')
plt.show()
print('Done')
์์ ์ฝ๋๋ฅผ ํตํด epoch๋ฅผ 25 ๊ฐ๊ฒฉ์ผ๋ก ๊ฒฐ๊ณผ๋ฅผ ๊ณ์ ์ถ๋ ฅํด๋ณด๋ฉด
์๋ ์ฌ์ง๊ณผ ๊ฐ์ด ํ๋ จ์ ๋ ๋ง์ด ํจ์ ๋ฐ๋ผ ํ๋ฆฌํฐ๊ฐ ์ข์์ง์ ํ์ธํ ์ ์๋ค.
(ํ์ต ๊ณผ์ ์ค ์ผ๋ถ ๋ฐ์ท)




๋ฌผ๋ก ์ด ๊ธ์ ์ฐ๋ฉฐ ์ฐธ๊ณ ํ ์ฝ๋๋ tiny_nerf.ipynb๋ก
์ต์ข
nerf๋ชจ๋ธ๊ณผ๋ ์ผ๋ถ ์ฐจ์ด๊ฐ ์์ง๋ง
๊ทธ๋๋ ํท๊ฐ๋ ธ๋ ๋ถ๋ถ์ ์ดํดํ๋๋ฐ๋ ํฐ ๋์์ด ๋์๋ค :)
(2024.05.12 ์์ฑ)