From 9ce761b6df4d3931c91197cca1b3f5d281df9e96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fernando=20P=C3=A9rez-Garc=C3=ADa?= Date: Tue, 26 Sep 2023 23:52:01 +0100 Subject: [PATCH] Test 1D --- unet/encoding.py | 6 +++--- unet/unet.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/unet/encoding.py b/unet/encoding.py index f91a221..917161c 100644 --- a/unet/encoding.py +++ b/unet/encoding.py @@ -108,9 +108,9 @@ def __init__( dropout=dropout, ) - if dimensions == 2: - out_channels_second = out_channels_first - elif dimensions == 3: + if dimensions == 3: + out_channels_second = 2 * out_channels_first + else: out_channels_second = 2 * out_channels_first self.conv2 = ConvolutionalBlock( dimensions, diff --git a/unet/unet.py b/unet/unet.py index f97293f..ecc8f1d 100644 --- a/unet/unet.py +++ b/unet/unet.py @@ -78,10 +78,10 @@ def __init__( ) # Decoder - if dimensions == 2: - power = depth - 1 - elif dimensions == 3: + if dimensions == 3: power = depth + else: + power = depth - 1 in_channels = self.bottom_block.out_channels in_channels_skip_connection = out_channels_first_layer * 2**power num_decoding_blocks = depth