def __init__(self, hidden_size):
self.linear = nn.Linear(hidden_size, hidden_size, bias=False)
self.softmax = nn.Softmax(dim=-1)
def forward(self, h_src, h_t_tgt, mask=None):
# |h_src| = (batch_size, length, hidden_size)
# |h_t_tgt| = (batch_size, 1, hidden_size)
# |mask| = (batch_size, length)
query = self.linear(h_t_tgt.squeeze(1)).unsqueeze(-1)
# |query| = (batch_size, hidden_size, 1)
weight = torch.bmm(h_src, query).squeeze(-1)
# |weight| = (batch_size, length)
# Set each weight as -inf, if the mask value equals to 1.
# Since the softmax operation makes -inf to 0,
# masked weights would be set to 0 after softmax operation.
# Thus, if the sample is shorter than other samples in mini-batch,
# the weight for empty time-step would be set to 0.
weight = self.softmax(weight)
context_vector = torch.bmm(weight.unsqueeze(1), h_src)
# |context_vector| = (batch_size, 1, hidden_size)