张量有以上几个属性
这是pytorch中的梯度计算图:y = a * b = (x + w) * (w + 1),若此时, x=2, w=1, y对x求导是2,y对w求导是5
叶子节点:x,w,也是用户创建的节点,所有梯度的计算都要依赖叶子节点,梯度反向传播结束之后,非叶子节点的梯度都会被释放掉
import torch
w = torch.tensor([1.], requires_grad=True) #由于需要计算梯度,所以requires_grad设置为True
x = torch.tensor([2.], requires_grad=True) #由于需要计算梯度,所以requires_grad设置为True
a = torch.add(w, x) # a = w + x
b = torch.add(w, 1) # b = w + 1
y = torch.mul(a, b) # y = a * b
y.backward() #对y进行反向传播
print(w.grad) #输出w的梯度
#查看叶子结点
print("is_leaf:\n", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)
#输出为True True False False False,只有前面两个是叶子节点
#查看梯度
print("gradient:\n", w.grad, x.grad, a.grad, b.grad, y.grad)
#输出为tensor([5.]) tensor([2.]) None None None,因为非叶子节点都被释放掉了
若想保存非叶子节点梯度使用 retain_grad()
a.retain_grad()
#保存非叶子结点a的梯度,输出为tensor([5.]) tensor([2.]) tensor([2.]) None None
grad_fn 的作用是记录创建该张量时所用的方法,例如y在反向传播的时候会记录y是用乘法得到的,所用在求解a和b的梯度的时候就会用到乘法的求导法则去求解a和b的梯度
# 查看 grad_fn
print("grad_fn:\n", w.grad_fn, x.grad_fn, a.grad_fn, b.grad_fn, y.grad_fn)
#上面代码的输出结果为
grad_fn:
None None <AddBackward0 object at 0x000001EEAA829308> <AddBackward0 object at 0x000001EE9C051548> <MulBackward0 object at 0x000001EE9C29F948>
pytorch中,每次前向传播都会动态构建计算图,如果有两个损失函数,涉及到同一个计算图,需要保留,例如Gan中 disc_loss.backward(retain_graph=True),两个损失都用到了disc_fake
反向传播前,梯度计算所需要的变量不能原地修改
# 判别器判别
disc_real = discriminator(obs_real)
disc_fake = discriminator(obs_fake)
# 判别器损失
disc_loss = discriminator_loss(disc_real, disc_fake)
# 更新判别器
discriminator_optimizer.zero_grad()
disc_loss.backward(retain_graph=True)
# 必须有,因为两个损失函数都需要用到disc_fake,由discriminator前向传播,如果不保留计算图,会默认清除,
# 导致生成器的损失再用到discriminator的计算图时会报错
discriminator_optimizer.step()
# 生成器损失
disc_fake = discriminator(obs_fake)
# 必须有,adam优化器discriminator_optimizer.step()导致梯度涉及到的变量原地操作,
#不然gen_loss.backward()这句会报错
real_action_probs = victim_model.actor_local(obs_real)
perturbed_action_probs = victim_model.actor_local(obs_fake)
gen_loss = generator_loss(disc_fake, real_action_probs, perturbed_action_probs, obs_real, obs_fake)
# 更新生成器
generator_optimizer.zero_grad()
gen_loss.backward()
generator_optimizer.step()