编程技术网

关注微信公众号,定时推送前沿、专业、深度的编程技术资料。

 找回密码
 立即注册

QQ登录

只需一步,快速开始

极客时间

如何在 PyTorch 中更新神经网络的参数?:How can I update the parameters of a neural network in PyTorch?

Ken-B CNN 2022-5-7 12:40 15人围观

腾讯云服务器
如何在 PyTorch 中更新神经网络的参数?的处理方法

假设我想在 PyTorch(继承自 torch.nn.Module) 通过 0.9.我该怎么做?

Let's say I wanted to multiply all parameters of a neural network in PyTorch (an instance of a class inheriting from torch.nn.Module) by 0.9. How would I do that?

问题解答

net 成为你的神经网络类的一个实例.然后你可以做

Let net an instance of your neural network class. You can then do

state_dict = net.state_dict() for name, param in state_dict.items(): # Transform the parameter as required. transformed_param = param * 0.9 # Update the parameter. state_dict[name].copy_(transformed_param) 

将所有参数乘以0.9.

如果你只想更新权重而不是所有参数,你可以这样做

If you ever only want to update weights instead of all parameters, you can do

state_dict = net.state_dict() for name, param in state_dict.items(): # Don't update if this is not a weight. if not "weight" in name: continue # Transform the parameter as required. transformed_param = param * 0.9 # Update the parameter. state_dict[name].copy_(transformed_param) 

这篇关于如何在 PyTorch 中更新神经网络的参数?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程技术网(www.editcode.net)!

腾讯云服务器 阿里云服务器
关注微信
^