MultiTaskLoss
Bases: ModuleDict
__init__ ¶
__init__(head_losses: Dict[str, JointLoss], loss_weights: Dict[str, float] = None, **kwargs) -> None
Multi-task loss wrapper.
Combines losses from different heades to one loss function.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
head_losses
|
Dict[str, Module]
|
Dictionary of head names mapped to a loss module. e.g. {"inst": JointLoss(MSE(), Dice()), "type": Dice()}. |
required |
loss_weights
|
Dict[str, float], default=None
|
Dictionary of head names mapped to the weight used for that head loss. |
None
|
Raises:
Type | Description |
---|---|
ValueError
|
If the input arguments have different lengths. If the input arguments have mismatching keys. |
forward ¶
forward(yhats: Dict[str, Tensor], targets: Dict[str, Tensor], mask: Tensor = None, **kwargs) -> torch.Tensor
Compute the joint loss of the multi-task network.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
yhats
|
Dict[str, Tensor]
|
Dictionary of head names mapped to the predicted masks. e.g. {"inst": (B, C, H, W), "type": (B, C, H, W)}. |
required |
targets
|
Dict[str, Tensor]
|
Dictionary of head names mapped to the GT masks. e.g. {"inst": (B, C, H, W), "type": (B, C, H, W)}. |
required |
mask
|
torch.Tensor, default=None
|
The mask for masked losses. Shape (B, H, W). |
None
|
Returns:
Type | Description |
---|---|
Tensor
|
torch.Tensor: Computed multi-task loss (Scalar). |