Skip to content

JABS Vision Modules

Modules for jabs.vision models.

decoders

Decoder modules for vision models.

ConvBNReLU

Bases: Sequential

Conv2d + BatchNorm + ReLU block.

DecoderBlock

Bases: Module

Single U-Net decoder block: upsample, concat skip, double conv.

__init__(in_ch, skip_ch, out_ch)

Initialize decoder block.

Parameters:

Name Type Description Default
in_ch int

Input channels from previous decoder stage.

required
skip_ch int

Channels from skip connection (0 if no skip).

required
out_ch int

Output channels.

required

forward(x, skip=None)

Apply decoder block.

Parameters:

Name Type Description Default
x Tensor

Input tensor from previous stage.

required
skip Tensor | None

Optional skip connection tensor.

None

Returns:

Type Description
Tensor

Decoded tensor.

UNetDecoder

Bases: Module

U-Net style decoder with skip connections.

Takes multi-scale encoder features and progressively upsamples while fusing with skip connections.

__init__(encoder_channels, decoder_channels=None)

Initialize U-Net decoder.

Parameters:

Name Type Description Default
encoder_channels list[int]

Channel dims from encoder, low-to-high stride (e.g., [16, 24, 40, 112, 960] for MobileNetV3).

required
decoder_channels list[int] | None

Output channels per decoder stage. Defaults to [256, 128, 64, 32].

None

forward(features)

Run decoder forward pass.

Parameters:

Name Type Description Default
features list[Tensor]

Encoder features ordered low-to-high stride (e.g., [stride2, stride4, stride8, stride16, stride32]).

required

Returns:

Type Description
Tensor

Decoded feature map at stride 2.

heads

Task-specific head modules.

HeatmapHead

Bases: Module

Head for heatmap regression (keypoint detection).

forward(x)

Forward computation.