Pytorch框架中一些细节的理解和记录
tensor索引中的的broadcasting
在Stackflow的回答中看到的。
在tensor的索引中使用两个列表的时候,pytorch会将这两个列表进行broadcasting。
比如你现在对一个tensor进行这样的操作:
x = [0, 2, 1, 2, 1]
y = torch.arange(x.shape[0]) // [0, 1, 2, 3, 4]
M = torch.zeros(3, 5)
M[x, y] = 1
那么x,y之间会进行一个broadcasting,形成一个这样的索引: [[0, 0], [2, 1], [1, 2], [2, 3], [1, 4]]
。效果类似于zip
,但是用于多维索引的时候可能会有其他妙用。
model.eval()
以及with torch.no_grad()
model.eval()
会将模型切换到测试模式,届时dropout层以及batchnorm层会被禁用,并且停止反向传播梯度,但是仍然会计算梯度。
with torch.no_grad()
则是用于停止autograd模块的工作,即停止gradient计算,从而节省了GPU算力和显存,但是并不会影响dropout和batchnorm层的行为。
模型类内变量类型变换
在GPU上对模型进行训练时,我们首先会使用model.to(device)
将模型的参数类型转换成cuda使用的类型。但是如果在model.__init__()
中定义了一些运算时需要的变量(常量)或超参数,model.to(device)
并不会将模型类内变量的类型一起转换。
Google以后有以下两种做法:
- 使用Buffers:
- Parameters
- Parameters are tensors that are to be trained and will be returned by model.parameters(). They are easy to register, all you need to do is wrap the tensor in the nn.Parameter type and it will be automatically registered. Note that only floating point tensors can be parameters.
- Buffers
- Buffers are tensors that will be registered in the module so methods like .cuda() will affect them but they will not be returned by model.parameters(). Buffers are not restricted to a particular data type.
通过在模型中使用register_buffer(name: str, tensor: torch.Tensor, persistent: bool = True)
可以将变量注册到model中,之后对model的变量类型进行操作时就会一起处理了。persistent = True
默认为真,将该buffer当做模型state_dict
的一部分。有一点需要注意的是,register_buffer
只接受Tensor
类型的变量,list
等其它类型的变量无法注册。
- 重写模型的
_apply()
函数
model.float()
, model.cuda()
等等的类型变换其实都用到了_apply()
函数。查看torch源码发现实现并不复杂。
def _apply(self, fn):
super(VariationalGenerator, self)._apply(fn)
self._train_noise = fn(self._train_noise)
return self
torch.nn.DataParallel
对模型的列表的修改
最近使用了torch.nn.DataParallel
对模型在多卡上进行训练,中途中断了训练,恢复训练时发现模型参数加载出错,通过报错的信息可以发现保存模型的state_dict
中每一个layer
都加了一个module
的前缀。例如:vgg.0.weight
变成了module.vgg.0.weight
。
这是因为torch.nn.DataParallel
对模型进行了wrap
,在原来模型的基础上包裹了一层用于数据并行的功能,所以在使用了torch.nn.DataParallel
之后保存模型参数,就会发现state_dict
中的名称被加了前缀。
这里分享一个加载参数时忽略这个前缀的代码:
def load_weights(self, weights):
RE_MODULE = re.compile(r"^module\.")
weights = OrderedDict(
[(RE_MODULE.sub("", key), value) for key, value in weights.items()]
)
self.load_state_dict(weights)
同时,你也可以在保存参数时使用torch.save(mode.module.state_dict(), path_to_file)
来避免保存的参数名称带有前缀。