최근 pytorch의 nn.functional 모듈에 있는 grid_sample이라는 함수를 사용할일이 있었는데, 어떻게 동작하는지 이해가 잘 안되서 하나씩 정리해보려고 한다.
Parameters
- input : 변환할 이미지 - (N,C,H_in,W_in)
- grid : flow-field라고도 부르고, 어떤 포인트가 어디로 옮겨질지에 대한 부분을 나타낸다. output의 포인트들에 대한 x,y위치를 나타낸다. - (N,H_out,W_out,2)
- mode (default = 'bilinear') : 중간값에 대한 부분을 어떤 알고리즘을 활용해서 적용할지 나타낸다. 사용할 수 있는 모드는 ('bilinear' | 'nearest' | 'bicubic')가 있다.
- padding_mode(default = 'zeros') : input 이외의 영역을 어떻게 채울지를 나타낸다. 사용할 수 있는 패딩은 ('zeros' | 'border' | 'reflection')가 있다.
- align_corners(default = None) : boolean 값으로 True, False 값을 가진다.
동작 원리
<그림1>은 mode가 bilinear, padding_mode은 zeros, align_corners 는 False인 예시를 보여준다. Input은 4x4 사이즈로, grid는 8x8사이즈로 테스트를 진행해 보았다.
여기서 하늘색 부분이 grid를 나타내는데 이 부분이 실질적으로 output의 결과를 보여주는 위치를 의미하게 된다. 그리고 빨간색 라인은 input value를 의미하고, 이 위치는 align_corners에 의해서 위치가 바뀌게 된다. 마지막 초록색 점이 변환한 결과를 나타내게 된다.
(1,1)위치에 초록색 점에 대한 값을 하나를 예시로 계산해보면, 0,1,4,5의 숫자들을 거리 기준 비율로 bilinear interpolation하여 결과를 추출하게 된다. 초록색 점은 (-0.625, -0.625)인데, 이를 0,1,4,5 input 숫자들 기준으로 비율을 나타내면, 각각 1:3 비율에 위치하게 된다. 이를 bilinear interpolation 방식을 활용해 계산하면 손쉽게 값을 얻을 수 있게된다.
padding_mode
<그림2>에서 보면, Out(1,1) 점은 위 예시에서 보였던 것 처럼, 둘러쌓인 0,1,4,5 숫자들의 interpolation으로 결과를 추출하였다. 그렇다면 Out(0,0)은 어떤 숫자들로 계산할 수 있을지 정하는 것이 바로 padding_mode이다. 기본적으로는 zeros로 채워져 있어, 위 예시에서output의 0,0 값은 0으로 계산됨을 확인할 수 있다.
align_corners
align_corners는 input의 위치가 변화하는 걸 확인할 수 있다. 왼쪽은 align_corners가 True일때, 오른쪽은 align_corners가 False일때를 의미한다. 계산되는 방식이 바뀌기 때문에 결과 또한 바뀌게 된다.
실험 코드
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
input_tensor = torch.linspace(0,4*4-1,steps=4*4)
input_tensor =input_tensor.reshape(1,1,4,4)
def imshow(tensor):
plt.imshow(tensor[0, 0, :, :], cmap='gray')
plt.axis('off')
plt.show()
print("Original Image:")
imshow(input_tensor)
# 회전 및 이동 변환
theta = np.radians(0)
rotation_matrix = torch.tensor([[np.cos(theta), np.sin(theta),0],
[-np.sin(theta), np.cos(theta),0]], dtype=torch.float32)
# 변환을 grid로 변경
grid = F.affine_grid(rotation_matrix.unsqueeze(0), (1,1,8,8))
output_tensor = F.grid_sample(input_tensor, grid, align_corners=False) # align_corners=True
print("Transformed Image:")
imshow(output_tensor)