Understand the basic 6 functions in PyTorch

squeeze(), unsqueeze(), tensor[None],max(),argmax(), and view()

Annie Wang
3 min readMay 5, 2021
Pic from github

When I started to learn PyTorch, I found that there are various functions which seem vague to understand for me. Today I would like to summarize them with examples, which I think helpful greatly.

The functions are :

  • squeeze(),
  • unsqueeze() and a[None],
  • max()
  • argmax()
  • view()

You can also check the explanations from the official website one by one, but summarizing them together helps me.

1. squeeze()

squeeze(i): it is kind of dimension reduction: if the original dimension is 1, then it can be reduced. Let’s check an example:

In the example, the original tensor shape is [2,1,4]. Because the dimension of index 0 is 2, it can’t be reduced and the size keeps the same, but the dimension of index 1 is 1, it can be reduced and the tensor from 3 dimensions into 2 dimension matrix.

--

--