cd ..
Dense Blocks
import torch
import torch.nn as nn
class DenseBlock(nn.Module):
def __init__(self, in_channels, kernel_size=3, stride=1, padding=1, activation='relu'):
super(DenseBlock, self).__init__()
self.feature_maps = torch.linspace(16, 128, 8).int()
self.layers = nn.ModuleList()
for i in range(len(self.feature_maps)):
self.layers.append(nn.Sequential(
nn.Conv2d(in_channels + (i * 16), 16, kernel_size=kernel_size, stride=stride, padding=padding),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True) if activation == 'relu' else nn.GELU(),
))
def forward(self, x):
outputs = [x]
for layer in self.layers:
x = layer(torch.cat(outputs, dim=1))
outputs.append(x)
return torch.cat(outputs, dim=1)