Understand the basic 6 functions in PyTorch
squeeze(), unsqueeze(), tensor[None],max(),argmax(), and view()
When I started to learn PyTorch, I found that there are various functions which seem vague to understand for me. Today I would like to summarize them with examples, which I think helpful greatly.
The functions are :
- squeeze(),
- unsqueeze() and a[None],
- max()
- argmax()
- view()
You can also check the explanations from the official website one by one, but summarizing them together helps me.
1. squeeze()
squeeze(i): it is kind of dimension reduction: if the original dimension is 1, then it can be reduced. Let’s check an example:
In the example, the original tensor shape is [2,1,4]. Because the dimension of index 0 is 2, it can’t be reduced and the size keeps the same, but the dimension of index 1 is 1, it can be reduced and the tensor from 3 dimensions into 2 dimension matrix.