我有一个嵌入层
self.embedding = nn.Embedding(n,m)
我需要所有的嵌入权来加入计算,
logits = torch.einsum('bd,nd->bn', [over\_states, self.embedding.weight.half()])
如果删除上面一行中的.half(),在优化使用apex时,我将得到数据类型错误,
Original Traceback (most recent call last):
File "/home/mhtan/anaconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
output = module(*input, **kwargs)
File "/home/mhtan/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/mhtan/gitee/modeling_bert.py", line 551, in forward
over_logits = self.vocab(self.over_linear(sentiment_states))
File "/home/mhtan/gitee/modeling_bert.py", line 520, in vocab
c_mo_logits = torch.einsum('bd,nd->bn', [over_states, self.embedding.weight]) # (b, 256, 10)
File "/home/mhtan/anaconda3/lib/python3.7/site-packages/torch/functional.py", line 201, in einsum
return torch._C._VariableFunctions.einsum(equation, operands)
RuntimeError: Expected object of scalar type Half but got scalar type Float for argument #2 'mat2' in call to _th_bmm我用过
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level=args.fp16_opt_level)和
if self.args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()在我的训练过程中。
显式使用float16 .half()很奇怪,这个错误可能与哪个包有关,pytorch还是apex
发布于 2020-01-04 09:40:20
您应该使用amp.initialize调用初始化您的模型。
引用文档:
用户不应该手动将他们的模型或数据转换到
.half().
在你的例子中,它将是这样的东西:
model = YourModel().cuda() # includes your embedding layer
optimizer = ... # any optimizer you want
# Usually you want O1 or O2 for mixed precision
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
with amp.scale(loss, optimizer) as scaled_loss:
scaled_loss.backward()以上应该为混合精度培训适当地转换您的模型和更新。
https://stackoverflow.com/questions/59588524
复制相似问题