| | |
| | |
| | |
| | |
| |
|
| | import math |
| | from typing import Dict, List, Optional |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from fairseq import search, utils |
| | from fairseq.data import data_utils |
| | from fairseq.models import FairseqIncrementalDecoder |
| | from fairseq.models.fairseq_encoder import EncoderOut |
| | from torch import Tensor |
| |
|
| |
|
| | class SequenceGenerator(nn.Module): |
| | def __init__( |
| | self, |
| | models, |
| | tgt_dict, |
| | beam_size=1, |
| | max_len_a=0, |
| | max_len_b=200, |
| | min_len=1, |
| | normalize_scores=True, |
| | len_penalty=1.0, |
| | unk_penalty=0.0, |
| | temperature=1.0, |
| | match_source_len=False, |
| | no_repeat_ngram_size=0, |
| | search_strategy=None, |
| | eos=None, |
| | symbols_to_strip_from_output=None, |
| | ): |
| | """Generates translations of a given source sentence. |
| | |
| | Args: |
| | models (List[~fairseq.models.FairseqModel]): ensemble of models, |
| | currently support fairseq.models.TransformerModel for scripting |
| | beam_size (int, optional): beam width (default: 1) |
| | max_len_a/b (int, optional): generate sequences of maximum length |
| | ax + b, where x is the source length |
| | min_len (int, optional): the minimum length of the generated output |
| | (not including end-of-sentence) |
| | normalize_scores (bool, optional): normalize scores by the length |
| | of the output (default: True) |
| | len_penalty (float, optional): length penalty, where <1.0 favors |
| | shorter, >1.0 favors longer sentences (default: 1.0) |
| | unk_penalty (float, optional): unknown word penalty, where <0 |
| | produces more unks, >0 produces fewer (default: 0.0) |
| | temperature (float, optional): temperature, where values |
| | >1.0 produce more uniform samples and values <1.0 produce |
| | sharper samples (default: 1.0) |
| | match_source_len (bool, optional): outputs should match the source |
| | length (default: False) |
| | """ |
| | super().__init__() |
| | if isinstance(models, EnsembleModel): |
| | self.model = models |
| | else: |
| | self.model = EnsembleModel(models) |
| | self.pad = tgt_dict.pad() |
| | self.unk = tgt_dict.unk() |
| | self.eos = tgt_dict.eos() if eos is None else eos |
| | self.symbols_to_strip_from_output = ( |
| | symbols_to_strip_from_output.union({self.eos}) |
| | if symbols_to_strip_from_output is not None else {self.eos}) |
| | self.vocab_size = len(tgt_dict) |
| | self.beam_size = beam_size |
| | |
| | self.beam_size = min(beam_size, self.vocab_size - 1) |
| | self.max_len_a = max_len_a |
| | self.max_len_b = max_len_b |
| | self.min_len = min_len |
| |
|
| | self.normalize_scores = normalize_scores |
| | self.len_penalty = len_penalty |
| | self.unk_penalty = unk_penalty |
| | self.temperature = temperature |
| | self.match_source_len = match_source_len |
| | self.no_repeat_ngram_size = no_repeat_ngram_size |
| | assert temperature > 0, "--temperature must be greater than 0" |
| |
|
| | self.search = ( |
| | search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy |
| | ) |
| | |
| | |
| | |
| | self.should_set_src_lengths = hasattr(self.search, 'needs_src_lengths') and self.search.needs_src_lengths |
| |
|
| | self.model.eval() |
| |
|
| | def cuda(self): |
| | self.model.cuda() |
| | return self |
| |
|
| | @torch.no_grad() |
| | def forward( |
| | self, |
| | sample: Dict[str, Dict[str, Tensor]], |
| | prefix_tokens: Optional[Tensor] = None, |
| | bos_token: Optional[int] = None, |
| | ): |
| | """Generate a batch of translations. |
| | |
| | Args: |
| | sample (dict): batch |
| | prefix_tokens (torch.LongTensor, optional): force decoder to begin |
| | with these tokens |
| | bos_token (int, optional): beginning of sentence token |
| | (default: self.eos) |
| | """ |
| | return self._generate(sample, prefix_tokens, bos_token) |
| |
|
| | |
| | def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None): |
| | """Iterate over a batched dataset and yield individual translations. |
| | Args: |
| | cuda (bool, optional): use GPU for generation |
| | timer (StopwatchMeter, optional): time generations |
| | """ |
| | for sample in data_itr: |
| | s = utils.move_to_cuda(sample) if cuda else sample |
| | if "net_input" not in s: |
| | continue |
| | input = s["net_input"] |
| | |
| | |
| | encoder_input = { |
| | k: v for k, v in input.items() if k != "prev_output_tokens" |
| | } |
| | if timer is not None: |
| | timer.start() |
| | with torch.no_grad(): |
| | hypos = self.generate(encoder_input) |
| | if timer is not None: |
| | timer.stop(sum(len(h[0]["tokens"]) for h in hypos)) |
| | for i, id in enumerate(s["id"].data): |
| | |
| | src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad) |
| | ref = ( |
| | utils.strip_pad(s["target"].data[i, :], self.pad) |
| | if s["target"] is not None |
| | else None |
| | ) |
| | yield id, src, ref, hypos[i] |
| |
|
| | @torch.no_grad() |
| | def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs): |
| | """Generate translations. Match the api of other fairseq generators. |
| | |
| | Args: |
| | models (List[~fairseq.models.FairseqModel]): ensemble of models |
| | sample (dict): batch |
| | prefix_tokens (torch.LongTensor, optional): force decoder to begin |
| | with these tokens |
| | bos_token (int, optional): beginning of sentence token |
| | (default: self.eos) |
| | """ |
| | return self._generate(sample, **kwargs) |
| |
|
| | def _generate( |
| | self, |
| | sample: Dict[str, Dict[str, Tensor]], |
| | prefix_tokens: Optional[Tensor] = None, |
| | bos_token: Optional[int] = None, |
| | ): |
| | incremental_states = torch.jit.annotate( |
| | List[Dict[str, Dict[str, Optional[Tensor]]]], |
| | [ |
| | torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) |
| | for i in range(self.model.models_size) |
| | ], |
| | ) |
| | net_input = sample["net_input"] |
| |
|
| | if 'src_tokens' in net_input: |
| | src_tokens = net_input['src_tokens'] |
| | |
| | src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) |
| | elif 'source' in net_input: |
| | src_tokens = net_input['source'] |
| | src_lengths = net_input['padding_mask'].size(-1) - net_input['padding_mask'].sum(-1) if net_input['padding_mask'] is not None else torch.tensor(src_tokens.size(-1)) |
| | else: |
| | raise Exception('expected src_tokens or source in net input') |
| |
|
| | |
| | input_size = src_tokens.size() |
| | bsz, src_len = input_size[0], input_size[1] |
| | beam_size = self.beam_size |
| |
|
| | max_len: int = -1 |
| | if self.match_source_len: |
| | max_len = src_lengths.max().item() |
| | else: |
| | max_len = min( |
| | int(self.max_len_a * src_len + self.max_len_b), |
| | |
| | self.model.max_decoder_positions() - 1, |
| | ) |
| | assert ( |
| | self.min_len <= max_len |
| | ), "min_len cannot be larger than max_len, please adjust these!" |
| | |
| | encoder_outs = self.model.forward_encoder(net_input) |
| |
|
| | |
| | new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) |
| | new_order = new_order.to(src_tokens.device).long() |
| | encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order) |
| | |
| | assert encoder_outs is not None |
| |
|
| | |
| | scores = ( |
| | torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float() |
| | ) |
| | tokens = ( |
| | torch.zeros(bsz * beam_size, max_len + 2) |
| | .to(src_tokens) |
| | .long() |
| | .fill_(self.pad) |
| | ) |
| | tokens[:, 0] = self.eos if bos_token is None else bos_token |
| | attn: Optional[Tensor] = None |
| |
|
| | |
| | |
| | |
| | |
| | cands_to_ignore = ( |
| | torch.zeros(bsz, beam_size).to(src_tokens).eq(-1) |
| | ) |
| |
|
| | |
| | finalized = torch.jit.annotate( |
| | List[List[Dict[str, Tensor]]], |
| | [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)], |
| | ) |
| |
|
| | finished = [ |
| | False for i in range(bsz) |
| | ] |
| | num_remaining_sent = bsz |
| |
|
| | |
| | cand_size = 2 * beam_size |
| |
|
| | |
| | bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens) |
| | cand_offsets = torch.arange(0, cand_size).type_as(tokens) |
| |
|
| | reorder_state: Optional[Tensor] = None |
| | batch_idxs: Optional[Tensor] = None |
| | for step in range(max_len + 1): |
| | |
| | |
| | if reorder_state is not None: |
| | if batch_idxs is not None: |
| | |
| | corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as( |
| | batch_idxs |
| | ) |
| | reorder_state.view(-1, beam_size).add_( |
| | corr.unsqueeze(-1) * beam_size |
| | ) |
| | self.model.reorder_incremental_state(incremental_states, reorder_state) |
| | encoder_outs = self.model.reorder_encoder_out( |
| | encoder_outs, reorder_state |
| | ) |
| |
|
| | lprobs, avg_attn_scores = self.model.forward_decoder( |
| | tokens[:, : step + 1], |
| | encoder_outs, |
| | incremental_states, |
| | self.temperature, |
| | ) |
| | lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) |
| |
|
| | lprobs[:, self.pad] = -math.inf |
| | lprobs[:, self.unk] -= self.unk_penalty |
| |
|
| | |
| | if step >= max_len: |
| | lprobs[:, : self.eos] = -math.inf |
| | lprobs[:, self.eos + 1 :] = -math.inf |
| |
|
| | |
| | if ( |
| | prefix_tokens is not None |
| | and step < prefix_tokens.size(1) |
| | and step < max_len |
| | ): |
| | lprobs, tokens, scores = self._prefix_tokens( |
| | step, lprobs, scores, tokens, prefix_tokens, beam_size |
| | ) |
| | elif step < self.min_len: |
| | |
| | lprobs[:, self.eos] = -math.inf |
| |
|
| | |
| | if avg_attn_scores is not None: |
| | if attn is None: |
| | attn = torch.empty( |
| | bsz * beam_size, avg_attn_scores.size(1), max_len + 2 |
| | ).to(scores) |
| | attn[:, :, step + 1].copy_(avg_attn_scores) |
| |
|
| | scores = scores.type_as(lprobs) |
| | eos_bbsz_idx = torch.empty(0).to( |
| | tokens |
| | ) |
| | eos_scores = torch.empty(0).to( |
| | scores |
| | ) |
| |
|
| | if self.should_set_src_lengths: |
| | self.search.set_src_lengths(src_lengths) |
| |
|
| | if self.no_repeat_ngram_size > 0: |
| | lprobs = self._no_repeat_ngram(tokens, lprobs, bsz, beam_size, step) |
| |
|
| | cand_scores, cand_indices, cand_beams = self.search.step( |
| | step, |
| | lprobs.view(bsz, -1, self.vocab_size), |
| | scores.view(bsz, beam_size, -1)[:, :, :step], |
| | ) |
| |
|
| | |
| | |
| | |
| | cand_bbsz_idx = cand_beams.add(bbsz_offsets) |
| |
|
| | |
| | eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) |
| | eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask) |
| |
|
| | |
| | eos_bbsz_idx = torch.masked_select( |
| | cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size] |
| | ) |
| |
|
| | finalized_sents: List[int] = [] |
| | if eos_bbsz_idx.numel() > 0: |
| | eos_scores = torch.masked_select( |
| | cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size] |
| | ) |
| | finalized_sents = self.finalize_hypos( |
| | step, |
| | eos_bbsz_idx, |
| | eos_scores, |
| | tokens, |
| | scores, |
| | finalized, |
| | finished, |
| | beam_size, |
| | attn, |
| | src_lengths, |
| | max_len, |
| | ) |
| | num_remaining_sent -= len(finalized_sents) |
| |
|
| | assert num_remaining_sent >= 0 |
| | if num_remaining_sent == 0: |
| | break |
| | assert step < max_len |
| |
|
| | if len(finalized_sents) > 0: |
| | new_bsz = bsz - len(finalized_sents) |
| |
|
| | |
| | batch_mask = torch.ones(bsz).to(cand_indices) |
| | batch_mask[ |
| | torch.tensor(finalized_sents).to(cand_indices) |
| | ] = torch.tensor(0).to(batch_mask) |
| | batch_idxs = batch_mask.nonzero().squeeze(-1) |
| |
|
| | eos_mask = eos_mask[batch_idxs] |
| | cand_beams = cand_beams[batch_idxs] |
| | bbsz_offsets.resize_(new_bsz, 1) |
| | cand_bbsz_idx = cand_beams.add(bbsz_offsets) |
| | cand_scores = cand_scores[batch_idxs] |
| | cand_indices = cand_indices[batch_idxs] |
| |
|
| | if prefix_tokens is not None: |
| | prefix_tokens = prefix_tokens[batch_idxs] |
| | src_lengths = src_lengths[batch_idxs] |
| | cands_to_ignore = cands_to_ignore[batch_idxs] |
| |
|
| | scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) |
| | tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) |
| | if attn is not None: |
| | attn = attn.view(bsz, -1)[batch_idxs].view( |
| | new_bsz * beam_size, attn.size(1), -1 |
| | ) |
| | bsz = new_bsz |
| | else: |
| | batch_idxs = None |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size])) |
| | active_mask = torch.add( |
| | eos_mask.type_as(cand_offsets) * cand_size, |
| | cand_offsets[: eos_mask.size(1)], |
| | ) |
| |
|
| | |
| | |
| | new_cands_to_ignore, active_hypos = torch.topk( |
| | active_mask, k=beam_size, dim=1, largest=False |
| | ) |
| |
|
| | |
| | cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] |
| | assert (~cands_to_ignore).any(dim=1).all() |
| |
|
| | active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos) |
| | active_scores = torch.gather(cand_scores, dim=1, index=active_hypos) |
| |
|
| | active_bbsz_idx = active_bbsz_idx.view(-1) |
| | active_scores = active_scores.view(-1) |
| |
|
| | |
| | tokens[:, : step + 1] = torch.index_select( |
| | tokens[:, : step + 1], dim=0, index=active_bbsz_idx |
| | ) |
| | tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( |
| | cand_indices, dim=1, index=active_hypos |
| | ) |
| | if step > 0: |
| | scores[:, :step] = torch.index_select( |
| | scores[:, :step], dim=0, index=active_bbsz_idx |
| | ) |
| | scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather( |
| | cand_scores, dim=1, index=active_hypos |
| | ) |
| |
|
| | |
| | if attn is not None: |
| | attn[:, :, : step + 2] = torch.index_select( |
| | attn[:, :, : step + 2], dim=0, index=active_bbsz_idx |
| | ) |
| |
|
| | |
| | reorder_state = active_bbsz_idx |
| |
|
| | |
| | for sent in range(len(finalized)): |
| | |
| | BCList = [ |
| | BeamContainer(elem["score"].item(), elem) for elem in finalized[sent] |
| | ] |
| | BCList.sort() |
| | BCList.reverse() |
| | finalized[sent] = torch.jit.annotate( |
| | List[Dict[str, Tensor]], [x.elem for x in BCList] |
| | ) |
| |
|
| | return finalized |
| |
|
| | def _prefix_tokens( |
| | self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int |
| | ): |
| | """Handle prefix tokens""" |
| | prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) |
| | prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) |
| | prefix_mask = prefix_toks.ne(self.pad) |
| | lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs) |
| | lprobs[prefix_mask] = lprobs[prefix_mask].scatter( |
| | -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] |
| | ) |
| | |
| | |
| | eos_mask = prefix_toks.eq(self.eos) |
| | if eos_mask.any(): |
| | |
| | first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[ |
| | :, 0, 1 : step + 1 |
| | ] |
| | eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] |
| | target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] |
| | assert (first_beam == target_prefix).all() |
| |
|
| | |
| | tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size) |
| | scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size) |
| | lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size) |
| | return lprobs, tokens, scores |
| |
|
| | def replicate_first_beam(self, tensor, mask, beam_size: int): |
| | tensor = tensor.view(-1, beam_size, tensor.size(-1)) |
| | tensor[mask] = tensor[mask][:, :1, :] |
| | return tensor.view(-1, tensor.size(-1)) |
| |
|
| | def finalize_hypos( |
| | self, |
| | step: int, |
| | bbsz_idx, |
| | eos_scores, |
| | tokens, |
| | scores, |
| | finalized: List[List[Dict[str, Tensor]]], |
| | finished: List[bool], |
| | beam_size: int, |
| | attn: Optional[Tensor], |
| | src_lengths, |
| | max_len: int, |
| | ): |
| | """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. |
| | Returns number of sentences being finalized. |
| | Args: |
| | bbsz_idx (Tensor): |
| | """ |
| | assert bbsz_idx.numel() == eos_scores.numel() |
| |
|
| | |
| | tokens_clone = tokens.index_select(0, bbsz_idx)[ |
| | :, 1 : step + 2 |
| | ] |
| |
|
| | tokens_clone[:, step] = self.eos |
| | attn_clone = ( |
| | attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2] |
| | if attn is not None |
| | else None |
| | ) |
| |
|
| | |
| | pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1] |
| | pos_scores[:, step] = eos_scores |
| | |
| | pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] |
| |
|
| | |
| | if self.normalize_scores: |
| | eos_scores /= (step + 1) ** self.len_penalty |
| |
|
| | cum_unfin: List[int] = [] |
| | prev = 0 |
| | for f in finished: |
| | if f: |
| | prev += 1 |
| | else: |
| | cum_unfin.append(prev) |
| |
|
| | |
| | sents_seen: Dict[str, Optional[Tensor]] = {} |
| | for i in range(bbsz_idx.size()[0]): |
| | idx = bbsz_idx[i] |
| | score = eos_scores[i] |
| | unfin_idx = idx // beam_size |
| | sent = unfin_idx + cum_unfin[unfin_idx] |
| | |
| | |
| | seen = str(sent.item()) + "_" + str(unfin_idx.item()) |
| | if seen not in sents_seen: |
| | sents_seen[seen] = None |
| |
|
| | if self.match_source_len and step > src_lengths[unfin_idx]: |
| | score = torch.tensor(-math.inf).to(score) |
| |
|
| | if len(finalized[sent]) < beam_size: |
| | if attn_clone is not None: |
| | |
| | hypo_attn = attn_clone[i] |
| | else: |
| | hypo_attn = torch.empty(0) |
| | finalized[sent].append( |
| | { |
| | "tokens": tokens_clone[i], |
| | "score": score, |
| | "attention": hypo_attn, |
| | "alignment": torch.empty(0), |
| | "positional_scores": pos_scores[i], |
| | } |
| | ) |
| |
|
| | newly_finished: List[int] = [] |
| | for seen in sents_seen.keys(): |
| | |
| | sent: int = int(float(seen.split("_")[0])) |
| | unfin_idx: int = int(float(seen.split("_")[1])) |
| | if not finished[sent] and self.is_finished( |
| | step, unfin_idx, max_len, len(finalized[sent]), beam_size |
| | ): |
| | finished[sent] = True |
| | newly_finished.append(unfin_idx) |
| | return newly_finished |
| |
|
| | def is_finished( |
| | self, |
| | step: int, |
| | unfin_idx: int, |
| | max_len: int, |
| | finalized_sent_len: int, |
| | beam_size: int, |
| | ): |
| | """ |
| | Check whether we've finished generation for a given sentence, by |
| | comparing the worst score among finalized hypotheses to the best |
| | possible score among unfinalized hypotheses. |
| | """ |
| | assert finalized_sent_len <= beam_size |
| | if finalized_sent_len == beam_size or step == max_len: |
| | return True |
| | return False |
| |
|
| | def calculate_banned_tokens( |
| | self, |
| | tokens, |
| | step: int, |
| | gen_ngrams: List[Dict[str, List[int]]], |
| | no_repeat_ngram_size: int, |
| | bbsz_idx: int, |
| | ): |
| | tokens_list: List[int] = tokens[ |
| | bbsz_idx, step + 2 - no_repeat_ngram_size : step + 1 |
| | ].tolist() |
| | |
| | ngram_index = ",".join([str(x) for x in tokens_list]) |
| | return gen_ngrams[bbsz_idx].get(ngram_index, torch.jit.annotate(List[int], [])) |
| |
|
| | def transpose_list(self, l: List[List[int]]): |
| | |
| | min_len = min([len(x) for x in l]) |
| | l2 = [[row[i] for row in l] for i in range(min_len)] |
| | return l2 |
| |
|
| | def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int): |
| | |
| | gen_ngrams: List[Dict[str, List[int]]] = [ |
| | torch.jit.annotate(Dict[str, List[int]], {}) |
| | for bbsz_idx in range(bsz * beam_size) |
| | ] |
| | cpu_tokens = tokens.cpu() |
| | for bbsz_idx in range(bsz * beam_size): |
| | gen_tokens: List[int] = cpu_tokens[bbsz_idx].tolist() |
| | for ngram in self.transpose_list( |
| | [gen_tokens[i:] for i in range(self.no_repeat_ngram_size)] |
| | ): |
| | key = ",".join([str(x) for x in ngram[:-1]]) |
| | gen_ngrams[bbsz_idx][key] = gen_ngrams[bbsz_idx].get( |
| | key, torch.jit.annotate(List[int], []) |
| | ) + [ngram[-1]] |
| |
|
| | if step + 2 - self.no_repeat_ngram_size >= 0: |
| | |
| | banned_tokens = [ |
| | self.calculate_banned_tokens( |
| | tokens, step, gen_ngrams, self.no_repeat_ngram_size, bbsz_idx |
| | ) |
| | for bbsz_idx in range(bsz * beam_size) |
| | ] |
| | else: |
| | banned_tokens = [ |
| | torch.jit.annotate(List[int], []) for bbsz_idx in range(bsz * beam_size) |
| | ] |
| | for bbsz_idx in range(bsz * beam_size): |
| | lprobs[bbsz_idx][ |
| | torch.tensor(banned_tokens[bbsz_idx]).long() |
| | ] = torch.tensor(-math.inf, dtype=torch.float) |
| | return lprobs |
| |
|
| |
|
| | class EnsembleModel(nn.Module): |
| | """A wrapper around an ensemble of models.""" |
| |
|
| | def __init__(self, models): |
| | super().__init__() |
| | self.models_size = len(models) |
| | |
| | self.single_model = models[0] |
| | self.models = nn.ModuleList(models) |
| |
|
| | self.has_incremental: bool = False |
| | if all( |
| | hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) |
| | for m in models |
| | ): |
| | self.has_incremental = True |
| |
|
| | def forward(self): |
| | pass |
| |
|
| | def has_encoder(self): |
| | return hasattr(self.single_model, "encoder") |
| |
|
| | def has_incremental_states(self): |
| | return self.has_incremental |
| |
|
| | def max_decoder_positions(self): |
| | return min([m.max_decoder_positions() for m in self.models]) |
| |
|
| | @torch.jit.export |
| | def forward_encoder(self, net_input: Dict[str, Tensor]): |
| | if not self.has_encoder(): |
| | return None |
| | return [ |
| | model.encoder.forward_torchscript(net_input) |
| | for model in self.models |
| | ] |
| |
|
| | @torch.jit.export |
| | def forward_decoder( |
| | self, |
| | tokens, |
| | encoder_outs: List[EncoderOut], |
| | incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], |
| | temperature: float = 1.0, |
| | ): |
| | log_probs = [] |
| | avg_attn: Optional[Tensor] = None |
| | encoder_out: Optional[EncoderOut] = None |
| | for i, model in enumerate(self.models): |
| | if self.has_encoder(): |
| | encoder_out = encoder_outs[i] |
| | |
| | if self.has_incremental_states(): |
| | decoder_out = model.decoder.forward( |
| | tokens, |
| | encoder_out=encoder_out, |
| | incremental_state=incremental_states[i], |
| | ) |
| | else: |
| | decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) |
| |
|
| | attn: Optional[Tensor] = None |
| | decoder_len = len(decoder_out) |
| | if decoder_len > 1 and decoder_out[1] is not None: |
| | if isinstance(decoder_out[1], Tensor): |
| | attn = decoder_out[1] |
| | else: |
| | attn_holder = decoder_out[1]["attn"] |
| | if isinstance(attn_holder, Tensor): |
| | attn = attn_holder |
| | elif attn_holder is not None: |
| | attn = attn_holder[0] |
| | if attn is not None: |
| | attn = attn[:, -1, :] |
| |
|
| | decoder_out_tuple = ( |
| | decoder_out[0][:, -1:, :].div_(temperature), |
| | None if decoder_len <= 1 else decoder_out[1], |
| | ) |
| |
|
| | probs = model.get_normalized_probs( |
| | decoder_out_tuple, log_probs=True, sample=None |
| | ) |
| | probs = probs[:, -1, :] |
| | if self.models_size == 1: |
| | return probs, attn |
| |
|
| | log_probs.append(probs) |
| | if attn is not None: |
| | if avg_attn is None: |
| | avg_attn = attn |
| | else: |
| | avg_attn.add_(attn) |
| | avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log( |
| | self.models_size |
| | ) |
| | if avg_attn is not None: |
| | avg_attn.div_(self.models_size) |
| | return avg_probs, avg_attn |
| |
|
| | @torch.jit.export |
| | def reorder_encoder_out(self, encoder_outs: Optional[List[EncoderOut]], new_order): |
| | """ |
| | Reorder encoder output according to *new_order*. |
| | |
| | Args: |
| | encoder_out: output from the ``forward()`` method |
| | new_order (LongTensor): desired order |
| | |
| | Returns: |
| | *encoder_out* rearranged according to *new_order* |
| | """ |
| | new_outs: List[EncoderOut] = [] |
| | if not self.has_encoder(): |
| | return new_outs |
| | for i, model in enumerate(self.models): |
| | assert encoder_outs is not None |
| | new_outs.append( |
| | model.encoder.reorder_encoder_out(encoder_outs[i], new_order) |
| | ) |
| | return new_outs |
| |
|
| | @torch.jit.export |
| | def reorder_incremental_state( |
| | self, |
| | incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], |
| | new_order, |
| | ): |
| | if not self.has_incremental_states(): |
| | return |
| | for i, model in enumerate(self.models): |
| | model.decoder.reorder_incremental_state_scripting( |
| | incremental_states[i], new_order |
| | ) |
| |
|
| |
|
| | class SequenceGeneratorWithAlignment(SequenceGenerator): |
| | def __init__(self, models, tgt_dict, left_pad_target=False, **kwargs): |
| | """Generates translations of a given source sentence. |
| | |
| | Produces alignments following "Jointly Learning to Align and |
| | Translate with Transformer Models" (Garg et al., EMNLP 2019). |
| | |
| | Args: |
| | left_pad_target (bool, optional): Whether or not the |
| | hypothesis should be left padded or not when they are |
| | teacher forced for generating alignments. |
| | """ |
| | super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs) |
| | self.left_pad_target = left_pad_target |
| |
|
| | @torch.no_grad() |
| | def generate(self, models, sample, **kwargs): |
| | finalized = super()._generate(sample, **kwargs) |
| |
|
| | src_tokens = sample["net_input"]["src_tokens"] |
| | bsz = src_tokens.shape[0] |
| | beam_size = self.beam_size |
| | src_tokens, src_lengths, prev_output_tokens, tgt_tokens = self._prepare_batch_for_alignment( |
| | sample, finalized |
| | ) |
| | if any(getattr(m, "full_context_alignment", False) for m in self.model.models): |
| | attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens) |
| | else: |
| | attn = [ |
| | finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0) |
| | for i in range(bsz * beam_size) |
| | ] |
| |
|
| | if src_tokens.device != "cpu": |
| | src_tokens = src_tokens.to('cpu') |
| | tgt_tokens = tgt_tokens.to('cpu') |
| | attn = [i.to('cpu') for i in attn] |
| |
|
| | |
| | for i in range(bsz * beam_size): |
| | alignment = utils.extract_hard_alignment( |
| | attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos |
| | ) |
| | finalized[i // beam_size][i % beam_size]["alignment"] = alignment |
| | return finalized |
| |
|
| | def _prepare_batch_for_alignment(self, sample, hypothesis): |
| | src_tokens = sample["net_input"]["src_tokens"] |
| | bsz = src_tokens.shape[0] |
| | src_tokens = ( |
| | src_tokens[:, None, :] |
| | .expand(-1, self.beam_size, -1) |
| | .contiguous() |
| | .view(bsz * self.beam_size, -1) |
| | ) |
| | src_lengths = sample["net_input"]["src_lengths"] |
| | src_lengths = ( |
| | src_lengths[:, None] |
| | .expand(-1, self.beam_size) |
| | .contiguous() |
| | .view(bsz * self.beam_size) |
| | ) |
| | prev_output_tokens = data_utils.collate_tokens( |
| | [beam["tokens"] for example in hypothesis for beam in example], |
| | self.pad, |
| | self.eos, |
| | self.left_pad_target, |
| | move_eos_to_beginning=True, |
| | ) |
| | tgt_tokens = data_utils.collate_tokens( |
| | [beam["tokens"] for example in hypothesis for beam in example], |
| | self.pad, |
| | self.eos, |
| | self.left_pad_target, |
| | move_eos_to_beginning=False, |
| | ) |
| | return src_tokens, src_lengths, prev_output_tokens, tgt_tokens |
| |
|
| |
|
| | class EnsembleModelWithAlignment(EnsembleModel): |
| | """A wrapper around an ensemble of models.""" |
| |
|
| | def __init__(self, models): |
| | super().__init__(models) |
| |
|
| | def forward_align(self, src_tokens, src_lengths, prev_output_tokens): |
| | avg_attn = None |
| | for model in self.models: |
| | decoder_out = model(src_tokens, src_lengths, prev_output_tokens) |
| | attn = decoder_out[1]["attn"] |
| | if avg_attn is None: |
| | avg_attn = attn |
| | else: |
| | avg_attn.add_(attn) |
| | if len(self.models) > 1: |
| | avg_attn.div_(len(self.models)) |
| | return avg_attn |
| |
|
| |
|
| | @torch.jit.script |
| | class BeamContainer(object): |
| | def __init__(self, score: float, elem: Dict[str, Tensor]): |
| | self.score = score |
| | self.elem = elem |
| |
|
| | def __lt__(self, other): |
| | |
| | |
| | |
| | |
| | return self.score <= other.score |
| |
|