💫 Computer Science/Python & AI Framework

[Pytorch] pytorch Contiguous

minkyung 2022. 8. 3. 21:52

 

Contiguous

Contiguous(인접한, 근접한)는 단어의 뜻처럼 Tensor의 각 값들이 메모리에도 순차적으로 저장되어 있는지 여부를 의미한다.

 

[0, 1, 2, 3, 4]라는 Tensor 가 있을 때, 메모리에 저장된 모양이

이와 같으면 contiguous한 것이고

 

이런식으로 요소들이 메모리에 연속적으로 저장되어 있지 않으면 contiguous하지 않은 것이다.

 

>>> t = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
>>> t.is_contiguous()
True
>>> t.stride()
(4, 1)

 

t라는 Tensor는 처음에는 위와 같이 메모리에 저장되어 있을 것인데,

>>> t = t.transpose()
>>> t.stride()
(1, 4)
>>> t.is_contiguous()
False

transpose()연산을 해주면

 

이렇게 된다.

t.stride()에서 결과가 (1, 4)라는 것은 t(0, 0)에서 t(1, 0)으로갈 때 요소 1개만큼 메모리 주소가 이동하고 t(0, 0)에서 t(0, 1)로 이동할 때 요소 4만큼 메모리 주소가 바뀐다는 뜻이다.

 

 

Pytorch에서 자주 사용하는 view()나  reshape()같은 함수를 보면 비슷한 역할을 하는데 굳이 나눠져 있는 것을 볼 수 있다.

이는 Contiguous한 Tensor를 Input으로 받을 수 있는지, output이 Contiguous Tensor인지에 대한 차이 때문이라고 볼 수 있다.

 

예를 들어 view()는 Contiguous한 Tensor를 받고 Contiguous한 Tensor를 리턴하고,

transpose()는 Non-contiguous Tensor를 받을 순 있지만 리턴하는 Tensor는 Non-contiguous하다.

 

 

pytorch method 복사본과 메모리 공유
(새로운 Tensor 생성 X)
Contiguous tensor에서 동작 Non-Contiguous
tensor에서 동작
반환 Tensor
view() O O X Contiguous
reshape() O O O Contiguous
same as contiguous().view()
transpose() O O O Non-Contiguous
permute() O O O Non-Contiguous
narrow() O O O Non-Contiguous
expand() O O O Non-Contiguous

 

bbb = aaa.transpose(0,1)
print(bbb.stride())
print(bbb.is_contiguous())

#(1, 3)
#False


ccc = aaa.narrow(1,1,2)   ## equivalent to matrix slicing aaa[:,1:3]
print(ccc.stride())
print(ccc.is_contiguous())

#(3, 1)
#False


ddd = aaa.repeat(2,1)   # The first dimension repeat once, the second dimension repeat twice
print(ddd.stride())
print(ddd.is_contiguous())

#(3, 1)
#True


## expand is different from repeat.
## if a tensor has a shape [d1,d2,1], it can only be expanded using "expand(d1,d2,d3)", which
## means the singleton dimension is repeated d3 times
eee = aaa.unsqueeze(2).expand(2,3,3)
print(eee.stride())
print(eee.is_contiguous())

#(3, 1, 0)
#False


fff = aaa.unsqueeze(2).repeat(1,1,8).view(2,-1,2)
print(fff.stride())
print(fff.is_contiguous())

 

 

복사본과 메모리를 공유하지 않는 다는 것은 연산 후 결과를 새로운 Tensor로 생성하지 않는 다는 것을 의미한다. 즉, Tensor object의 새로운 shape을 설명하기 위한 offset과 stride가 포함된 meta information 만 modify 함을 의미한다. 따라서 반환된 값은 원본과 data(memory)를 공유하기 때문에 둘 중 하나만 수정해도 두개 모두 수정된다.

 

x = torch.randn(3,2)
y = torch.transpose(x, 0, 1)
x[0, 0] = 42
print(y[0,0])
# prints 42

z = y.contiguous()

x는 contiguous하지만 y는 contiguous하지 않고 x와 메모리를 공유한다.  copy를 만들어서 메모리에 요소들이 순서대로 저장되게 하고 싶다면 contiguous()를 이용하면 된다.

 

 

 

 

References

 

https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html

https://stackoverflow.com/questions/48915810/what-does-contiguous-do-in-pytorch