# This code tries to improve TalkNet's performance on long phonemes by avoiding # the use of tokens. # To use it, modify `nemo/collections/tts/modules/talknet.py` by replacing # `GaussianEmbedding` (no changes to imports). # v1 of this fix: https://u.smutty.horse/mfyczyuesfn.py # v2 of this fix: https://u.smutty.horse/mjogmjdqoyq.py # This version gets rid of blank tokens like the previous versions, but uses # the standard Gaussian upsampling instead of LengthRegulator. Hopefully, this # gets the benefits of the standard method (works well for fast speech/short # phonemes) without the drawbacks of blank tokens (long phonemes change type). class GaussianEmbedding(nn.Module): """Gaussian embedding layer..""" EPS = 1e-6 def __init__( self, vocab, d_emb, sigma_c=2.0, merge_blanks=False, ): super().__init__() self.embed = nn.Embedding(len(vocab.labels), d_emb) self.pad = vocab.pad self.sigma_c = sigma_c self.merge_blanks = merge_blanks def forward(self, text, durs): # Remove tokens. We keep the first so that the model # knows if there's silence at the beginning of the clip. text = torch.cat( ( text[:, 0].unsqueeze(1), text[:, 1::2], ), 1 ) # Add the duration of each token to the preceeding token # (again, except for the first ). durs = torch.cat( ( durs[:, 0].unsqueeze(1), durs[:, 1::2] + durs[:, 2::2], ), 1 ) # Everything below is exactly the same as in standard TalkNet. """See base class.""" # Fake padding text = F.pad(text, [0, 2, 0, 0], value=self.pad) durs = F.pad(durs, [0, 2, 0, 0], value=0) repeats = AudioToCharWithDursF0Dataset.repeat_merge(text, durs, self.pad) total_time = repeats.shape[-1] # Centroids: [B,T,N] c = (durs / 2.0) + F.pad(torch.cumsum(durs, dim=-1)[:, :-1], [1, 0, 0, 0], value=0) c = c.unsqueeze(1).repeat(1, total_time, 1) # Sigmas: [B,T,N] sigmas = durs sigmas = sigmas.float() / self.sigma_c sigmas = sigmas.unsqueeze(1).repeat(1, total_time, 1) + self.EPS assert c.shape == sigmas.shape # Times at indexes t = torch.arange(total_time, device=c.device).view(1, -1, 1).repeat(durs.shape[0], 1, durs.shape[-1]).float() t = t + 0.5 ns = slice(None) if self.merge_blanks: ns = slice(1, None, 2) # Weights: [B,T,N] d = torch.distributions.normal.Normal(c, sigmas) w = d.log_prob(t).exp()[:, :, ns] # [B,T,N] pad_mask = (text == self.pad)[:, ns].unsqueeze(1).repeat(1, total_time, 1) w.masked_fill_(pad_mask, 0.0) # noqa w = w / (w.sum(-1, keepdim=True) + self.EPS) pad_mask = (repeats == self.pad).unsqueeze(-1).repeat(1, 1, text[:, ns].size(1)) # noqa w.masked_fill_(pad_mask, 0.0) # noqa pad_mask[:, :, :-1] = False w.masked_fill_(pad_mask, 1.0) # noqa # Embeds u = torch.bmm(w, self.embed(text)[:, ns, :]) # [B,T,E] return u