Skip to content

JointLoss

Bases: ModuleDict

__init__

__init__(losses: List[Module], weights: List[float] = None) -> None

Joint loss function.

Takes in a list of nn.Module losses and computes the loss for each loss in the list and at the end sums the outputs together as one joint loss.

Parameters:

Name Type Description Default
losses List[Module]

List of initialized nn.Module losses.

required
weights List[float], default=None

List of weights for each loss.

None

Raises:

Type Description
ValueError

If more than 4 losses are given as input. If given weights are not between [0, 1].

forward

forward(**kwargs) -> torch.Tensor

Compute the joint-loss.

Returns:

Type Description
Tensor

torch.Tensor: The computed joint loss.

extra_repr

extra_repr() -> str

Add info to print.