lister_loss.py 332 B

1234567891011121314
  1. from torch import nn
  2. class LISTERLoss(nn.Module):
  3. def __init__(self, **kwargs):
  4. super(LISTERLoss, self).__init__()
  5. def forward(self, predicts, batch):
  6. # predicts = predicts['res']
  7. # loss = predicts
  8. if isinstance(predicts, list):
  9. predicts = predicts[0]
  10. return predicts