1234567891011121314 |
- from torch import nn
- class LISTERLoss(nn.Module):
- def __init__(self, **kwargs):
- super(LISTERLoss, self).__init__()
- def forward(self, predicts, batch):
- # predicts = predicts['res']
- # loss = predicts
- if isinstance(predicts, list):
- predicts = predicts[0]
- return predicts
|