Skip to content

Conversation

@Xreki
Copy link
Collaborator

@Xreki Xreki commented Dec 24, 2025

PR Category

Feature Enhancement

Description

初步实现反向图的抽取器。
以一个简单的示例为例,样本model.py如下:

import torch

class GraphModule(torch.nn.Module):

    def forward(self, w_0: torch.Tensor, w_1: torch.Tensor, w_2: torch.Tensor, w_3: torch.Tensor, w_4: torch.Tensor, in_0: torch.Tensor):
        tmp_0 = torch.conv2d(in_0, w_4, None, (1, 1), (1, 1), (1, 1), 1)
        in_0 = w_4 = None
        tmp_1 = torch.nn.functional.batch_norm(tmp_0, w_0, w_1, w_3, w_2, False, 0.1, 0.001)
        tmp_0 = w_0 = w_1 = w_3 = w_2 = None
        return (tmp_1,)

反向图的fx.Graph:

graph():
    %primals_1 : [num_users=1] = placeholder[target=primals_1]
    %primals_2 : [num_users=1] = placeholder[target=primals_2]
    %primals_4 : [num_users=1] = placeholder[target=primals_4]
    %primals_5 : [num_users=1] = placeholder[target=primals_5]
    %primals_6 : [num_users=1] = placeholder[target=primals_6]
    %convolution : [num_users=1] = placeholder[target=convolution]
    %getitem_1 : [num_users=1] = placeholder[target=getitem_1]
    %getitem_2 : [num_users=1] = placeholder[target=getitem_2]
    %tangents_1 : [num_users=1] = placeholder[target=tangents_1]
    %native_batch_norm_backward : [num_users=3] = call_function[target=torch.ops.aten.native_batch_norm_backward.default](args = (%tangents_1, %convolution, %primals_4, %primals_1, %primals_2, %getitem_1, %getitem_2, False, 0.001, [True, True, True]), kwargs = {})
    %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%native_batch_norm_backward, 0), kwargs = {})
    %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%native_batch_norm_backward, 1), kwargs = {})
    %getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%native_batch_norm_backward, 2), kwargs = {})
    %convolution_backward : [num_users=2] = call_function[target=torch.ops.aten.convolution_backward.default](args = (%getitem_3, %primals_6, %primals_5, [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]), kwargs = {})
    %getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%convolution_backward, 0), kwargs = {})
    %getitem_7 : [num_users=1] = call_function[target=operator.getitem](args = (%convolution_backward, 1), kwargs = {})
   return (getitem_5, getitem_4, getitem_7, getitem_6)

保存的反向图样本model.py:

import torch

class GraphModule(torch.nn.Module):
    
    
    
    def forward(self, primals_1, primals_2, primals_4, primals_5, primals_6, convolution, getitem_1, getitem_2, tangents_1):
        native_batch_norm_backward = torch.ops.aten.native_batch_norm_backward.default(tangents_1, convolution, primals_4, primals_1, primals_2, getitem_1, getitem_2, False, 0.001, [True, True, True]);  tangents_1 = convolution = primals_4 = primals_1 = primals_2 = getitem_1 = getitem_2 = None
        getitem_3 = native_batch_norm_backward[0]
        getitem_4 = native_batch_norm_backward[1]
        getitem_5 = native_batch_norm_backward[2];  native_batch_norm_backward = None
        convolution_backward = torch.ops.aten.convolution_backward.default(getitem_3, primals_6, primals_5, [0], [1, 1], [1, 1], [1, 1], False, [0, 0], 1, [True, True, False]);  getitem_3 = primals_6 = primals_5 = None
        getitem_6 = convolution_backward[0]
        getitem_7 = convolution_backward[1];  convolution_backward = None
        return (getitem_5, getitem_4, getitem_7, getitem_6)

@paddle-bot
Copy link

paddle-bot bot commented Dec 24, 2025

Thanks for your contribution!

@Xreki Xreki force-pushed the torch_backward_extractor branch from 3106ee1 to de2a081 Compare December 24, 2025 05:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant