ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [Pytorch] pytorch Contiguous
    💫 Computer Science/Python & AI Framework 2022. 8. 3. 21:52

     

    Contiguous

    Contiguous(인접한, 근접한)는 단어의 뜻처럼 Tensor의 각 값들이 메모리에도 순차적으로 저장되어 있는지 여부를 의미한다.

     

    [0, 1, 2, 3, 4]라는 Tensor 가 있을 때, 메모리에 저장된 모양이

    이와 같으면 contiguous한 것이고

     

    이런식으로 요소들이 메모리에 연속적으로 저장되어 있지 않으면 contiguous하지 않은 것이다.

     

    >>> t = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
    >>> t.is_contiguous()
    True
    >>> t.stride()
    (4, 1)

     

    t라는 Tensor는 처음에는 위와 같이 메모리에 저장되어 있을 것인데,

    >>> t = t.transpose()
    >>> t.stride()
    (1, 4)
    >>> t.is_contiguous()
    False

    transpose()연산을 해주면

     

    이렇게 된다.

    t.stride()에서 결과가 (1, 4)라는 것은 t(0, 0)에서 t(1, 0)으로갈 때 요소 1개만큼 메모리 주소가 이동하고 t(0, 0)에서 t(0, 1)로 이동할 때 요소 4만큼 메모리 주소가 바뀐다는 뜻이다.

     

     

    Pytorch에서 자주 사용하는 view()나  reshape()같은 함수를 보면 비슷한 역할을 하는데 굳이 나눠져 있는 것을 볼 수 있다.

    이는 Contiguous한 Tensor를 Input으로 받을 수 있는지, output이 Contiguous Tensor인지에 대한 차이 때문이라고 볼 수 있다.

     

    예를 들어 view()는 Contiguous한 Tensor를 받고 Contiguous한 Tensor를 리턴하고,

    transpose()는 Non-contiguous Tensor를 받을 순 있지만 리턴하는 Tensor는 Non-contiguous하다.

     

     

    pytorch method 복사본과 메모리 공유
    (새로운 Tensor 생성 X)
    Contiguous tensor에서 동작 Non-Contiguous
    tensor에서 동작
    반환 Tensor
    view() O O X Contiguous
    reshape() O O O Contiguous
    same as contiguous().view()
    transpose() O O O Non-Contiguous
    permute() O O O Non-Contiguous
    narrow() O O O Non-Contiguous
    expand() O O O Non-Contiguous

     

    bbb = aaa.transpose(0,1)
    print(bbb.stride())
    print(bbb.is_contiguous())
    
    #(1, 3)
    #False
    
    
    ccc = aaa.narrow(1,1,2)   ## equivalent to matrix slicing aaa[:,1:3]
    print(ccc.stride())
    print(ccc.is_contiguous())
    
    #(3, 1)
    #False
    
    
    ddd = aaa.repeat(2,1)   # The first dimension repeat once, the second dimension repeat twice
    print(ddd.stride())
    print(ddd.is_contiguous())
    
    #(3, 1)
    #True
    
    
    ## expand is different from repeat.
    ## if a tensor has a shape [d1,d2,1], it can only be expanded using "expand(d1,d2,d3)", which
    ## means the singleton dimension is repeated d3 times
    eee = aaa.unsqueeze(2).expand(2,3,3)
    print(eee.stride())
    print(eee.is_contiguous())
    
    #(3, 1, 0)
    #False
    
    
    fff = aaa.unsqueeze(2).repeat(1,1,8).view(2,-1,2)
    print(fff.stride())
    print(fff.is_contiguous())

     

     

    복사본과 메모리를 공유하지 않는 다는 것은 연산 후 결과를 새로운 Tensor로 생성하지 않는 다는 것을 의미한다. 즉, Tensor object의 새로운 shape을 설명하기 위한 offset과 stride가 포함된 meta information 만 modify 함을 의미한다. 따라서 반환된 값은 원본과 data(memory)를 공유하기 때문에 둘 중 하나만 수정해도 두개 모두 수정된다.

     

    x = torch.randn(3,2)
    y = torch.transpose(x, 0, 1)
    x[0, 0] = 42
    print(y[0,0])
    # prints 42
    
    z = y.contiguous()

    x는 contiguous하지만 y는 contiguous하지 않고 x와 메모리를 공유한다.  copy를 만들어서 메모리에 요소들이 순서대로 저장되게 하고 싶다면 contiguous()를 이용하면 된다.

     

     

     

     

    References

     

    https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html

    https://stackoverflow.com/questions/48915810/what-does-contiguous-do-in-pytorch

     

     

    댓글

Designed by Tistory.