-
Notifications
You must be signed in to change notification settings - Fork 210
Description
Discrepancy between MOSAIK paper and MPC RCF implementation
Hi!
Given the recent open-sourcing of AlphaEarth by DeepMind and its comparison with MOSAIK [1], I was taking a closer look at the implementation from @calebrob6's PR (#70).
Expected behavior (from paper)
Based on the MOSAIK paper, the Random Convolution Features (RCF) process works as follows:
- Given an input image
I, features are computed by randomly samplingKpatches from acrossNimages from the training set - These
Kpatches are then convolved overIto obtainKfeature maps - The feature maps are averaged over pixels to produce a
K-dimensional feature vectorX_i
The following image from the paper illustrates this process:
Actual implementation
However, the MPC implementation uses a different approach:
class RCF(nn.Module):
"""A model for extracting Random Convolution Features (RCF) from input imagery."""
def __init__(self, num_features=16, kernel_size=3, num_input_channels=3):
super(RCF, self).__init__()
# We create `num_features / 2` filters so require `num_features` to be divisible by 2
assert num_features % 2 == 0
self.conv1 = nn.Conv2d(
num_input_channels,
num_features // 2,
kernel_size=kernel_size,
stride=1,
padding=0,
dilation=1,
bias=True,
)
nn.init.normal_(self.conv1.weight, mean=0.0, std=1.0)
nn.init.constant_(self.conv1.bias, -1.0)
def forward(self, x):
x1a = F.relu(self.conv1(x), inplace=True)
x1b = F.relu(-self.conv1(x), inplace=True)
x1a = F.adaptive_avg_pool2d(x1a, (1, 1)).squeeze()
x1b = F.adaptive_avg_pool2d(x1b, (1, 1)).squeeze()
if len(x1a.shape) == 1: # case where we passed a single input
return torch.cat((x1a, x1b), dim=0)
elif len(x1a.shape) == 2: # case where we passed a batch of > 1 inputs
return torch.cat((x1a, x1b), dim=1)As I understand it, instead of using K kernels extracted from N training images (as described in the paper), this implementation creates K randomly initialized kernels. While randomly initialized kernels can be powerful feature extractors, this approach differs significantly from the paper's methodology.
Am I missing something that explains why this implementation choice was made? Is there a specific reason for deviating from the paper's approach of sampling patches from training images?
Thank you!
References
[1] Rolf, Esther, et al. "A generalizable and accessible approach to machine learning with global satellite imagery." Nature communications 12.1 (2021): 4392.