1111import torch .nn .functional as F
1212
1313from deepmd .dpmodel .common import PRECISION_DICT as NP_PRECISION_DICT
14+ from deepmd .pt .utils import (
15+ env ,
16+ )
1417
1518from .env import (
1619 DEVICE ,
1720)
1821from .env import PRECISION_DICT as PT_PRECISION_DICT
1922
2023
24+ def silut_forward (
25+ x : torch .Tensor , threshold : float , slope : float , const_val : float
26+ ) -> torch .Tensor :
27+ sig = torch .sigmoid (x )
28+ silu = x * sig
29+ tanh_part = torch .tanh (slope * (x - threshold )) + const_val
30+ return torch .where (x >= threshold , tanh_part , silu )
31+
32+
33+ def silut_backward (
34+ x : torch .Tensor , grad_output : torch .Tensor , threshold : float , slope : float
35+ ):
36+ sig = torch .sigmoid (x )
37+ grad_silu = sig * (1 + x * (1 - sig ))
38+
39+ tanh_term = torch .tanh (slope * (x - threshold ))
40+ grad_tanh = slope * (1 - tanh_term .pow (2 ))
41+
42+ grad = torch .where (x >= threshold , grad_tanh , grad_silu )
43+ return grad * grad_output , grad
44+
45+
46+ def silut_double_backward (
47+ x : torch .Tensor ,
48+ grad_grad_output : torch .Tensor ,
49+ grad_output : torch .Tensor ,
50+ threshold : float ,
51+ slope : float ,
52+ ) -> torch .Tensor :
53+ # Tanh branch
54+ tanh_term = torch .tanh (slope * (x - threshold ))
55+ grad_grad = - 2 * slope * slope * tanh_term * (1 - tanh_term * tanh_term )
56+
57+ # SiLU branch
58+ sig = 1.0 / (1.0 + torch .exp (- x ))
59+ sig_prime = sig * (1 - sig )
60+ silu_term = sig_prime * (2 + x * (1 - 2 * sig ))
61+
62+ grad_grad = torch .where (x >= threshold , grad_grad , silu_term )
63+
64+ return grad_output * grad_grad * grad_grad_output
65+
66+
67+ class SiLUTScript (torch .nn .Module ):
68+ def __init__ (self , threshold : float = 3.0 ):
69+ super ().__init__ ()
70+ self .threshold = threshold
71+
72+ # Precompute parameters for the tanh replacement
73+ sigmoid_threshold = 1 / (1 + np .exp (- threshold ))
74+ self .slope = float (
75+ sigmoid_threshold + threshold * sigmoid_threshold * (1 - sigmoid_threshold )
76+ )
77+ self .const_val = float (threshold * sigmoid_threshold )
78+ self .get_script_code ()
79+
80+ def get_script_code (self ):
81+ silut_forward_script = torch .jit .script (silut_forward )
82+ silut_backward_script = torch .jit .script (silut_backward )
83+ silut_double_backward_script = torch .jit .script (silut_double_backward )
84+
85+ class SiLUTFunction (torch .autograd .Function ):
86+ @staticmethod
87+ def forward (ctx , x , threshold , slope , const_val ):
88+ ctx .save_for_backward (x )
89+ ctx .threshold = threshold
90+ ctx .slope = slope
91+ ctx .const_val = const_val
92+ return silut_forward_script (x , threshold , slope , const_val )
93+
94+ @staticmethod
95+ def backward (ctx , grad_output ):
96+ (x ,) = ctx .saved_tensors
97+ threshold = ctx .threshold
98+ slope = ctx .slope
99+
100+ grad_input = SiLUTGradFunction .apply (x , grad_output , threshold , slope )
101+ return grad_input , None , None , None
102+
103+ class SiLUTGradFunction (torch .autograd .Function ):
104+ @staticmethod
105+ def forward (ctx , x , grad_output , threshold , slope ):
106+ ctx .threshold = threshold
107+ ctx .slope = slope
108+ grad_input , grad = silut_backward_script (
109+ x , grad_output , threshold , slope
110+ )
111+ ctx .save_for_backward (x , grad_output , grad )
112+ return grad_input
113+
114+ @staticmethod
115+ def backward (ctx , grad_grad_output ):
116+ (x , grad_output , grad ) = ctx .saved_tensors
117+ threshold = ctx .threshold
118+ slope = ctx .slope
119+
120+ grad_input = silut_double_backward_script (
121+ x , grad_grad_output , grad_output , threshold , slope
122+ )
123+ return grad_input , grad * grad_grad_output , None , None
124+
125+ self .SiLUTFunction = SiLUTFunction
126+
127+ def forward (self , x ):
128+ return self .SiLUTFunction .apply (x , self .threshold , self .slope , self .const_val )
129+
130+
131+ class SiLUT (torch .nn .Module ):
132+ def __init__ (self , threshold = 3.0 ):
133+ super ().__init__ ()
134+
135+ def sigmoid (x ):
136+ return 1 / (1 + np .exp (- x ))
137+
138+ def silu (x ):
139+ return x * sigmoid (x )
140+
141+ def silu_grad (x ):
142+ sig = sigmoid (x )
143+ return sig + x * sig * (1 - sig )
144+
145+ self .threshold = threshold
146+ self .slope = float (silu_grad (threshold ))
147+ self .const = float (silu (threshold ))
148+
149+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
150+ silu_part = F .silu (x )
151+ mask = x >= self .threshold
152+ if torch .any (mask ):
153+ tanh_part = torch .tanh (self .slope * (x - self .threshold )) + self .const
154+ return torch .where (x < self .threshold , silu_part , tanh_part )
155+ else :
156+ return silu_part
157+
158+
21159class ActivationFn (torch .nn .Module ):
22160 def __init__ (self , activation : Optional [str ]) -> None :
23161 super ().__init__ ()
24162 self .activation : str = activation if activation is not None else "linear"
163+ if self .activation .lower ().startswith (
164+ "silut"
165+ ) or self .activation .lower ().startswith ("custom_silu" ):
166+ threshold = (
167+ float (self .activation .split (":" )[- 1 ]) if ":" in self .activation else 3.0
168+ )
169+ if env .CUSTOM_OP_USE_JIT :
170+ # for efficient training but can not be jit
171+ self .silut = SiLUTScript (threshold = threshold )
172+ else :
173+ # for jit freeze
174+ self .silut = SiLUT (threshold = threshold )
175+ else :
176+ self .silut = None
25177
26178 def forward (self , x : torch .Tensor ) -> torch .Tensor :
27179 """Returns the tensor after applying activation function corresponding to `activation`."""
@@ -41,6 +193,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
41193 return torch .sigmoid (x )
42194 elif self .activation .lower () == "silu" :
43195 return F .silu (x )
196+ elif self .activation .lower ().startswith (
197+ "silut"
198+ ) or self .activation .lower ().startswith ("custom_silu" ):
199+ assert self .silut is not None
200+ return self .silut (x )
44201 elif self .activation .lower () == "linear" or self .activation .lower () == "none" :
45202 return x
46203 else :
0 commit comments