Skip to content

MultiTaskLoss

Bases: ModuleDict

__init__

__init__(head_losses: Dict[str, JointLoss], loss_weights: Dict[str, float] = None, **kwargs) -> None

Multi-task loss wrapper.

Combines losses from different heades to one loss function.

Parameters:

Name Type Description Default
head_losses Dict[str, Module]

Dictionary of head names mapped to a loss module. e.g. {"inst": JointLoss(MSE(), Dice()), "type": Dice()}.

required
loss_weights Dict[str, float], default=None

Dictionary of head names mapped to the weight used for that head loss.

None

Raises:

Type Description
ValueError

If the input arguments have different lengths. If the input arguments have mismatching keys.

forward

forward(yhats: Dict[str, Tensor], targets: Dict[str, Tensor], mask: Tensor = None, **kwargs) -> torch.Tensor

Compute the joint loss of the multi-task network.

Parameters:

Name Type Description Default
yhats Dict[str, Tensor]

Dictionary of head names mapped to the predicted masks. e.g. {"inst": (B, C, H, W), "type": (B, C, H, W)}.

required
targets Dict[str, Tensor]

Dictionary of head names mapped to the GT masks. e.g. {"inst": (B, C, H, W), "type": (B, C, H, W)}.

required
mask torch.Tensor, default=None

The mask for masked losses. Shape (B, H, W).

None

Returns:

Type Description
Tensor

torch.Tensor: Computed multi-task loss (Scalar).

extra_repr

extra_repr() -> str

Add info to print.