torch.compile doesn't play nicely with amp autocasting and occasionally there are issues when exporting to onnx or other formats. Would explicit fasting to float and back be preferrable. This appears to be the torchtune approach: https://pytorch.org/torchtune/stable/_modules/torchtune/modules/position_embeddings.html