Python 深度学习
上QQ阅读APP看书,第一时间看更新

4.3 Tensor的变换、拼接和拆分

PyTorch提供了大量对Tensor进行操作的函数或方法,这些函数内部使用指针实现对矩阵的形状变换、拼接和拆分等操作,使得我们无须关心Tensor在内存的物理结构或者管理指针就可以方便快速地执行这些操作。Tensor.nelement()、Tensor.ndimension()或ndimension.size()可分别用于查看矩阵元素的个数、轴的个数以及维度,属性Tensor.shape也可以用于查看Tensor的维度。

在PyTorch中,Tensor.reshape和Tensor.view都能被用于更改Tensor的维度。它们的区别在于:Tensor.view要求Tensor的物理存储必须是连续的,否则将报错,而Tensor.reshape则没有这种要求。但是,Tensor.view返回的一定是一个索引,若更改返回值,则原始值同样被更改,Tensor.reshape返回的是引用还是拷贝是不确定的。它们的相同之处是都接收要输出的维度作为参数,且输出的矩阵元素个数不能改变,若在维度中输入-1,PyTorch会自动推断它的数值。

torch.squeeze和torch.unsqueeze用于给Tensor去掉和添加轴。torch.squeeze可以去掉维度为1的轴,而torch.unsqueeze用于给Tensor的指定位置添加一个维度为1的轴。

torch.t和torch.transpose用于转置二维矩阵,同时只接收二维Tensor。值得注意的是,torch.t是torch.transpose的简化版。

对于高维度Tensor,可以使用permute方法来变换维度。

PyTorch提供了torch.cat和torch.stack用于拼接矩阵,区别在于:torch.cat在已有的轴dim上拼接矩阵,给定轴的维度可以不同,而其他轴的维度必须相同。torch.stack在新的轴上拼接,同时它要求被拼接矩阵的所有维度都相同。下面的例子中可以很清楚地表明它们的使用方式和区别。

除了拼接矩阵外,PyTorch还提供了torch.split和torch.chunk用于拆分矩阵。它们的不同之处在于:torch.split传入的是拆分后每个矩阵的大小,既可以传入list,也可以传入整数,而torch.chunk传入的是拆分的矩阵个数。