igtr_loss.py 265 B

123456789101112
  1. from torch import nn
  2. class IGTRLoss(nn.Module):
  3. def __init__(self, **kwargs):
  4. super(IGTRLoss, self).__init__()
  5. def forward(self, predicts, batch):
  6. if isinstance(predicts, list):
  7. predicts = predicts[0]
  8. return predicts