Skip to content

Question regarding MOSAIKS Implementation #307

@alexlopezcifuentes

Description

@alexlopezcifuentes

Discrepancy between MOSAIK paper and MPC RCF implementation

Hi!

Given the recent open-sourcing of AlphaEarth by DeepMind and its comparison with MOSAIK [1], I was taking a closer look at the implementation from @calebrob6's PR (#70).

Expected behavior (from paper)

Based on the MOSAIK paper, the Random Convolution Features (RCF) process works as follows:

  1. Given an input image I, features are computed by randomly sampling K patches from across N images from the training set
  2. These K patches are then convolved over I to obtain K feature maps
  3. The feature maps are averaged over pixels to produce a K-dimensional feature vector X_i

The following image from the paper illustrates this process:

RCF process from MOSAIK paper

Actual implementation

However, the MPC implementation uses a different approach:

class RCF(nn.Module):
    """A model for extracting Random Convolution Features (RCF) from input imagery."""
    def __init__(self, num_features=16, kernel_size=3, num_input_channels=3):
        super(RCF, self).__init__()
        # We create `num_features / 2` filters so require `num_features` to be divisible by 2
        assert num_features % 2 == 0
        self.conv1 = nn.Conv2d(
            num_input_channels,
            num_features // 2,
            kernel_size=kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            bias=True,
        )
        nn.init.normal_(self.conv1.weight, mean=0.0, std=1.0)
        nn.init.constant_(self.conv1.bias, -1.0)
    
    def forward(self, x):
        x1a = F.relu(self.conv1(x), inplace=True)
        x1b = F.relu(-self.conv1(x), inplace=True)
        x1a = F.adaptive_avg_pool2d(x1a, (1, 1)).squeeze()
        x1b = F.adaptive_avg_pool2d(x1b, (1, 1)).squeeze()
        if len(x1a.shape) == 1:  # case where we passed a single input
            return torch.cat((x1a, x1b), dim=0)
        elif len(x1a.shape) == 2:  # case where we passed a batch of > 1 inputs
            return torch.cat((x1a, x1b), dim=1)

As I understand it, instead of using K kernels extracted from N training images (as described in the paper), this implementation creates K randomly initialized kernels. While randomly initialized kernels can be powerful feature extractors, this approach differs significantly from the paper's methodology.

Am I missing something that explains why this implementation choice was made? Is there a specific reason for deviating from the paper's approach of sampling patches from training images?

Thank you!

References
[1] Rolf, Esther, et al. "A generalizable and accessible approach to machine learning with global satellite imagery." Nature communications 12.1 (2021): 4392.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions