JointLoss
Bases: ModuleDict
__init__ ¶
Joint loss function.
Takes in a list of nn.Module losses and computes the loss for each loss in the list and at the end sums the outputs together as one joint loss.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
losses
|
List[Module]
|
List of initialized nn.Module losses. |
required |
weights
|
List[float], default=None
|
List of weights for each loss. |
None
|
Raises:
Type | Description |
---|---|
ValueError
|
If more than 4 losses are given as input. If given weights are not between [0, 1]. |
forward ¶
Compute the joint-loss.
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: The computed joint loss. |