In the torch example provided here https://github.com/pytorch/examples/tree/main/word_language_model, tansformer only uses torch.TransformerEncoder and torch.TransformerDecoder is overwritten with a simple Linear layer. I wanted to implement torch's Decoder in my code but I'm not sure if I'm doing it correctly.
Here is the original code:
class TransformerModel(nn.Transformer):
"""Container module with an encoder, a recurrent or transformer module, and a decoder."""
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
super(TransformerModel, self).__init__(d_model=ninp, nhead=nhead, dim_feedforward=nhid, num_encoder_layers=nlayers)
self.model_type = 'Transformer'
self.src_mask = None
self.pos_encoder = PositionalEncoding(ninp, dropout)
self.input_emb = nn.Embedding(ntoken, ninp)
self.ninp = ninp
self.decoder = nn.Linear(ninp, ntoken)
self.init_weights()
def _generate_square_subsequent_mask(self, sz):
return torch.log(torch.tril(torch.ones(sz,sz)))
def init_weights(self):
initrange = 0.1
nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
nn.init.zeros_(self.decoder.bias)
nn.init.uniform_(self.decoder.weight, -initrange, initrange)
def forward(self, src, has_mask=True):
if has_mask:
device = src.device
if self.src_mask is None or self.src_mask.size(0) != len(src):
mask = self._generate_square_subsequent_mask(len(src)).to(device)
self.src_mask = mask
else:
self.src_mask = None
src = self.input_emb(src) * math.sqrt(self.ninp)
src = self.pos_encoder(src)
output = self.encoder(src, mask=self.src_mask)
output = self.decoder(output)
return F.log_softmax(output, dim=-1)
Here I moved decoder to a variable called _decoder and I use original decoder and then use it's output in _decoder to generate final logits.
Torch's decoder has 4 arguments which are important. tgt, which is decoder's output, memory, which is encoder's last output and masks for these two inputs, namely, tgt_mask and memory_mask.
Since, initially, there's no decoder output, I take the start_token index from the vocabulary, put it into an embedding layer and use it as tgt in decoder's output. This is the point I'm having difficulty. Since seq_len is different in original source sentence, should I repeat the start_token for seq_len? If I don't, the model outputs only 1 word, which is not compatible with the original idea of the example.
The model learns until loss reaches to about 1.00, but it learns very slowly, like in 100 epochs. The original example can learn very quickly.
I wonder what I'm doing wrong, any help is appreciated.
Here is my code:
class TransformerModel(nn.Transformer):
"""Container module with an encoder, a recurrent or transformer module, and a decoder."""
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
super(TransformerModel, self).__init__(d_model=ninp, nhead=nhead, dim_feedforward=nhid, num_encoder_layers=nlayers)
self.model_type = 'Transformer'
self.src_mask = None
self.pos_encoder = PositionalEncoding(ninp, dropout)
self.input_emb = nn.Embedding(ntoken, ninp)
self.output_emb = nn.Embedding(ntoken, ninp)
self.ninp = ninp
self._decoder = nn.Linear(ninp, ntoken)
self.init_weights()
def _generate_square_subsequent_mask(self, sz):
return torch.log(torch.tril(torch.ones(sz,sz)))
def init_weights(self):
initrange = 0.1
nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
# nn.init.zeros_(self.decoder.bias)
# nn.init.uniform_(self.decoder.weight, -initrange, initrange)
def forward(self, src, tgt, has_mask=True):
if has_mask:
device = src.device
if self.src_mask is None or self.src_mask.size(0) != len(src):
mask = self._generate_square_subsequent_mask(len(src)).to(device)
self.src_mask = mask
else:
self.src_mask = None
src = self.input_emb(src) * math.sqrt(self.ninp)
src = self.pos_encoder(src)
output = self.encoder(src, mask=self.src_mask)
tgt = self.output_emb(tgt)
tgt_mask = torch.triu(torch.ones((src.shape[0], src.shape[0])), diagonal=1).bool()
output = self.decoder(tgt, output, tgt_mask=tgt_mask, memory_mask=self.src_mask)
output = self._decoder(output)
return F.log_softmax(output, dim=-1)
This is how I call model during training:
start = [[int(corpus.dictionary.word2idx['<empty>'])]]
start = torch.LongTensor(start).repeat(data.shape[0], 1).to(device)
output = model(data, start)
EDIT: I changed target to get shifted output position with start token in the beginning, still no chance.