文章作者:Tyan 博客:noahsnail.com | CSDN | 简书
0. 测试环境Python 3.6.9, Pytorch 1.5.0
1. 基本概念Tensor是一个多维矩阵,其中包含所有的元素为同一数据类型。默认数据类型为torch.float32。
示例一>>> a = torch.tensor([1.0]) >>> a.data tensor([1.]) >>> a.grad >>> a.requires_grad False >>> a.dtype torch.float32 >>> a.item() 1.0 >>> type(a.item()) <class 'float'>Tensor中只有一个数字时,使用torch.Tensor.item()可以得到一个Python数字。requires_grad为True时,表示需要计算Tensor的梯度。requires_grad=False可以用来冻结部分网络,只更新另一部分网络的参数。
示例二>>> a = torch.tensor([1.0, 2.0]) >>> b = a.data >>> id(b) 139808984381768 >>> id(a) 139811772112328 >>> b.grad >>> a.grad >>> b[0] = 5.0 >>> b tensor([5., 2.]) >>> a tensor([5., 2.])a.data返回的是一个新的Tensor对象b,a, b的id不同,说明二者不是同一个Tensor,但b与a共享数据的存储空间,即二者的数据部分指向同一块内存,因此修改b的元素时,a的元素也对应修改。
2. requires_grad_()与detach()>>> a = torch.tensor([1.0, 2.0]) >>> a.data tensor([1., 2.]) >>> a.grad >>> a.requires_grad False >>> a.requires_grad_() tensor([1., 2.], requires_grad=True) >>> c = a.pow(2).sum() >>> c.backward() >>> a.grad tensor([2., 4.]) >>> b = a.detach() >>> b.grad >>> b.requires_grad False >>> b tensor([1., 2.]) >>> b[0] = 6 >>> b tensor([6., 2.]) >>> a tensor([6., 2.], requires_grad=True)requires_grad_()requires_grad_()函数会改变Tensor的requires_grad属性并返回Tensor,修改requires_grad的操作是原位操作(in place)。其默认参数为requires_grad=True。requires_grad=True时,自动求导会记录对Tensor的操作,requires_grad_()的主要用途是告诉自动求导开始记录对Tensor的操作。
detach()detach()函数会返回一个新的Tensor对象b,并且新Tensor是与当前的计算图分离的,其requires_grad属性为False,反向传播时不会计算其梯度。b与a共享数据的存储空间,二者指向同一块内存。
注:共享内存空间只是共享的数据部分,a.grad与b.grad是不同的。
3. torch.no_grad()torch.no_grad()是一个上下文管理器,用来禁止梯度的计算,通常用来网络推断中,它可以减少计算内存的使用量。
>>> a = torch.tensor([1.0, 2.0], requires_grad=True) >>> with torch.no_grad(): ... b = n.pow(2).sum() ... >>> b tensor(5.) >>> b.requires_grad False >>> c = a.pow(2).sum() >>> c.requires_grad True上面的例子中,当a的requires_grad=True时,不使用torch.no_grad(),c.requires_grad为True,使用torch.no_grad()时,b.requires_grad为False,当不需要进行反向传播时(推断)或不需要计算梯度(网络输入)时,requires_grad=True会占用更多的计算资源及存储资源。
4. 总结requires_grad_()会修改Tensor的requires_grad属性。
detach()会返回一个与计算图分离的新Tensor,新Tensor不会在反向传播中计算梯度,会在特定场合使用。
torch.no_grad()更节省计算资源和存储资源,其作用域范围内的操作不会构建计算图,常用在网络推断中。
Referenceshttps://pytorch.org/docs/stable/tensors.htmlhttps://pytorch.org/docs/stable/tensors.html#torch.Tensor.requires_grad_https://pytorch.org/docs/stable/autograd.html#torch.Tensor.detachhttps://pytorch.org/docs/master/generated/torch.no_grad.html ---来自腾讯云社区的---Tyan
微信扫一扫打赏
支付宝扫一扫打赏