How Neural Networks memorize data

Introduction

A simple 4 layers MLP with 384 hidden units per layer is used to memorize a single image with positional encoding. The MLP is trained with a MSE loss function and Adam optimizer.

class MLP(nn.Module):
    def __init__(self, in_dims, out_dims):
        self.in_dims = in_dims
        self.out_dims = out_dims
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(in_dims, 384)
        self.fc2 = nn.Linear(384, 384)
        self.fc3 = nn.Linear(384, 384)
        self.fc4 = nn.Linear(384, out_dims)
        self.act = nn.LeakyReLU(0.01)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        x = self.act(self.fc3(x))
        x = self.fc4(x)
        return x

Image is encoded with a positional encoding function that maps the coordinates with a set of sinusoidal functions with different frequencies. The positional encoding function is defined as:


class InputEncoding:
    def __init__(self, max_intervals=8, img_size=[256, 256]):
        self.max_intervals = max_intervals
        self.img_size = img_size

        self.n_freqs = max_intervals
        freq_bands = torch.linspace(
            0.5*np.pi*2.**0., 0.5*np.pi*2.**(self.n_freqs - 1), self.n_freqs)
        self.embed_fns = [lambda x: x]
        # Alternate sin and cos
        for freq in freq_bands:
            self.embed_fns.append(lambda x, freq=freq: torch.sin(x * freq))
            self.embed_fns.append(lambda x, freq=freq: torch.cos(x * freq))
        self.out_dim = len(self.embed_fns) * 2

    def encode(self, x):
        return torch.concat([fn(x) for fn in self.embed_fns], dim=-1)

Target Image

Visualize the memorized image:

Chengqi (William) Li
Chengqi (William) Li

My research interests include 3D perception, computer vision, and machine learning.