pytorchの関数gatherの処理内容について
pytorchの関数gather
の処理内容が公式ドキュメントを読んでもよく分かりません。
例えばExampleのt
が、どのような計算をした結果、出力のようなテンソルになるのか、
具体的に教えていただけないでしょうか。
Example:
>>> t = torch.tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
tensor([[ 1, 1],
[ 4, 3]])
dim = 0
だと、上記の入力tは下記のような出力になります。
tensor([[1, 2],
[3, 2]])
公式ドキュメント:
https://pytorch.org/docs/stable/torch.html#torch.gather
ご回答、何卒宜しくお願い致します。