parseq_loss.py 265 B

123456789101112
  1. from torch import nn
  2. class PARSeqLoss(nn.Module):
  3. def __init__(self, **kwargs):
  4. super(PARSeqLoss, self).__init__()
  5. def forward(self, predicts, batch):
  6. # predicts = predicts['res']
  7. loss, _ = predicts
  8. return {'loss': loss}