点击上方“视学算法”,选择加"星标"或“置顶”重磅干货,第一时间送达作者丨林小平@知乎(已授权)来源丨https://zhuanlan.zhihu.com/p/353365423编辑丨极市平台pytorch也自己实现了transformer的模型,不同于huggingface或者其他地方,pytorch的mask参数要更难理解一些(即便是有文档的情况下),这里做一些补充和说明。(顺带提一句,这里的transformer是需要自己实现position embedding的,别乐呵乐呵的就直接去跑数据了)>>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)>>> src = torch.rand((10, 32, 512))>>> tgt = torch.rand((20, 32, 512))>>> out = transformer_model(src, tgt) # 没有实现position embedding ,也需要自己实现mask机制。否则不是你想象的transformer首先看一下官网的参数src – the sequence to the encoder
………………………………