How Neural Networks memorize data
Introduction
A simple 4 layers MLP with 384 hidden units per layer is used to memorize a single image with positional encoding. The MLP is trained with a MSE loss function and Adam optimizer.
class MLP(nn.Module):
def __init__(self, in_dims, out_dims):
self.in_dims = in_dims
self.out_dims = out_dims
super(MLP, self).__init__()
self.fc1 = nn.Linear(in_dims, 384)
self.fc2 = nn.Linear(384, 384)
self.fc3 = nn.Linear(384, 384)
self.fc4 = nn.Linear(384, out_dims)
self.act = nn.LeakyReLU(0.01)
self.dropout = nn.Dropout(0.2)
def forward(self, x):
x = self.act(self.fc1(x))
x = self.act(self.fc2(x))
x = self.act(self.fc3(x))
x = self.fc4(x)
return x
Image is encoded with a positional encoding function that maps the coordinates with a set of sinusoidal functions with different frequencies. The positional encoding function is defined as:
class InputEncoding:
def __init__(self, max_intervals=8, img_size=[256, 256]):
self.max_intervals = max_intervals
self.img_size = img_size
self.n_freqs = max_intervals
freq_bands = torch.linspace(
0.5*np.pi*2.**0., 0.5*np.pi*2.**(self.n_freqs - 1), self.n_freqs)
self.embed_fns = [lambda x: x]
# Alternate sin and cos
for freq in freq_bands:
self.embed_fns.append(lambda x, freq=freq: torch.sin(x * freq))
self.embed_fns.append(lambda x, freq=freq: torch.cos(x * freq))
self.out_dim = len(self.embed_fns) * 2
def encode(self, x):
return torch.concat([fn(x) for fn in self.embed_fns], dim=-1)
Visualize the memorized image: