Skip to content

Commit 4b78a1c

Browse files
authored
feat: add SiLUT activation (#4647)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new activation function called "SiLUT" that combines benefits from sigmoid and tanh behaviors. - Expanded the range of valid activation functions recognized by the module. - Added a configuration option for just-in-time compilation for custom operations. - **Refactor** - Standardized the naming of activation functions across all supported frameworks for improved consistency. These improvements offer users a streamlined experience with enhanced activation handling and performance optimizations. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent d040d08 commit 4b78a1c

File tree

8 files changed

+335
-0
lines changed

8 files changed

+335
-0
lines changed

deepmd/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"gelu",
5555
"gelu_tf",
5656
"silu",
57+
"silut",
5758
"none",
5859
"linear",
5960
]

deepmd/dpmodel/utils/network.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,37 @@ def fn(x):
325325
# generated by GitHub Copilot
326326
return x / (1 + xp.exp(-x))
327327

328+
return fn
329+
elif activation_function.startswith("silut") or activation_function.startswith(
330+
"custom_silu"
331+
):
332+
333+
def sigmoid(x):
334+
return 1 / (1 + np.exp(-x))
335+
336+
def silu(x):
337+
return x * sigmoid(x)
338+
339+
def silu_grad(x):
340+
sig = sigmoid(x)
341+
return sig + x * sig * (1 - sig)
342+
343+
threshold = (
344+
float(activation_function.split(":")[-1])
345+
if ":" in activation_function
346+
else 3.0
347+
)
348+
slope = float(silu_grad(threshold))
349+
const = float(silu(threshold))
350+
351+
def fn(x):
352+
xp = array_api_compat.array_namespace(x)
353+
return xp.where(
354+
x < threshold,
355+
x * (1 / (1 + xp.exp(-x))),
356+
xp.tanh(slope * (x - threshold)) + const,
357+
)
358+
328359
return fn
329360
elif activation_function.lower() in ("none", "linear"):
330361

deepmd/pd/utils/utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,47 @@
3232
)
3333

3434

35+
class SiLUT(paddle.nn.Layer):
36+
def __init__(self, threshold=3.0):
37+
super().__init__()
38+
39+
def sigmoid(x):
40+
return 1 / (1 + np.exp(-x))
41+
42+
def silu(x):
43+
return x * sigmoid(x)
44+
45+
def silu_grad(x):
46+
sig = sigmoid(x)
47+
return sig + x * sig * (1 - sig)
48+
49+
self.threshold = threshold
50+
self.slope = float(silu_grad(threshold))
51+
self.const = float(silu(threshold))
52+
53+
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
54+
silu_part = F.silu(x)
55+
mask = x >= self.threshold
56+
if paddle.any(mask):
57+
tanh_part = paddle.tanh(self.slope * (x - self.threshold)) + self.const
58+
return paddle.where(x < self.threshold, silu_part, tanh_part)
59+
else:
60+
return silu_part
61+
62+
3563
class ActivationFn(paddle.nn.Layer):
3664
def __init__(self, activation: str | None):
3765
super().__init__()
3866
self.activation: str = activation if activation is not None else "linear"
67+
if self.activation.lower().startswith(
68+
"silut"
69+
) or self.activation.lower().startswith("custom_silu"):
70+
threshold = (
71+
float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0
72+
)
73+
self.silut = SiLUT(threshold=threshold)
74+
else:
75+
self.silut = None
3976

4077
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
4178
"""Returns the tensor after applying activation function corresponding to `activation`."""
@@ -53,6 +90,11 @@ def forward(self, x: paddle.Tensor) -> paddle.Tensor:
5390
return F.sigmoid(x)
5491
elif self.activation.lower() == "silu":
5592
return F.silu(x)
93+
elif self.activation.lower().startswith(
94+
"silut"
95+
) or self.activation.lower().startswith("custom_silu"):
96+
assert self.silut is not None
97+
return self.silut(x)
5698
elif self.activation.lower() == "linear" or self.activation.lower() == "none":
5799
return x
58100
else:

deepmd/pt/entrypoints/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def train(
249249
output: str = "out.json",
250250
) -> None:
251251
log.info("Configuration path: %s", input_file)
252+
env.CUSTOM_OP_USE_JIT = True
252253
if LOCAL_RANK == 0:
253254
SummaryPrinter()()
254255
with open(input_file) as fin:

deepmd/pt/utils/env.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
JIT = False
3535
CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory
3636
ENERGY_BIAS_TRAINABLE = True
37+
CUSTOM_OP_USE_JIT = False
3738

3839
PRECISION_DICT = {
3940
"float16": torch.float16,
@@ -76,6 +77,7 @@
7677

7778
__all__ = [
7879
"CACHE_PER_SYS",
80+
"CUSTOM_OP_USE_JIT",
7981
"DEFAULT_PRECISION",
8082
"DEVICE",
8183
"ENERGY_BIAS_TRAINABLE",

deepmd/pt/utils/utils.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,169 @@
1111
import torch.nn.functional as F
1212

1313
from deepmd.dpmodel.common import PRECISION_DICT as NP_PRECISION_DICT
14+
from deepmd.pt.utils import (
15+
env,
16+
)
1417

1518
from .env import (
1619
DEVICE,
1720
)
1821
from .env import PRECISION_DICT as PT_PRECISION_DICT
1922

2023

24+
def silut_forward(
25+
x: torch.Tensor, threshold: float, slope: float, const_val: float
26+
) -> torch.Tensor:
27+
sig = torch.sigmoid(x)
28+
silu = x * sig
29+
tanh_part = torch.tanh(slope * (x - threshold)) + const_val
30+
return torch.where(x >= threshold, tanh_part, silu)
31+
32+
33+
def silut_backward(
34+
x: torch.Tensor, grad_output: torch.Tensor, threshold: float, slope: float
35+
):
36+
sig = torch.sigmoid(x)
37+
grad_silu = sig * (1 + x * (1 - sig))
38+
39+
tanh_term = torch.tanh(slope * (x - threshold))
40+
grad_tanh = slope * (1 - tanh_term.pow(2))
41+
42+
grad = torch.where(x >= threshold, grad_tanh, grad_silu)
43+
return grad * grad_output, grad
44+
45+
46+
def silut_double_backward(
47+
x: torch.Tensor,
48+
grad_grad_output: torch.Tensor,
49+
grad_output: torch.Tensor,
50+
threshold: float,
51+
slope: float,
52+
) -> torch.Tensor:
53+
# Tanh branch
54+
tanh_term = torch.tanh(slope * (x - threshold))
55+
grad_grad = -2 * slope * slope * tanh_term * (1 - tanh_term * tanh_term)
56+
57+
# SiLU branch
58+
sig = 1.0 / (1.0 + torch.exp(-x))
59+
sig_prime = sig * (1 - sig)
60+
silu_term = sig_prime * (2 + x * (1 - 2 * sig))
61+
62+
grad_grad = torch.where(x >= threshold, grad_grad, silu_term)
63+
64+
return grad_output * grad_grad * grad_grad_output
65+
66+
67+
class SiLUTScript(torch.nn.Module):
68+
def __init__(self, threshold: float = 3.0):
69+
super().__init__()
70+
self.threshold = threshold
71+
72+
# Precompute parameters for the tanh replacement
73+
sigmoid_threshold = 1 / (1 + np.exp(-threshold))
74+
self.slope = float(
75+
sigmoid_threshold + threshold * sigmoid_threshold * (1 - sigmoid_threshold)
76+
)
77+
self.const_val = float(threshold * sigmoid_threshold)
78+
self.get_script_code()
79+
80+
def get_script_code(self):
81+
silut_forward_script = torch.jit.script(silut_forward)
82+
silut_backward_script = torch.jit.script(silut_backward)
83+
silut_double_backward_script = torch.jit.script(silut_double_backward)
84+
85+
class SiLUTFunction(torch.autograd.Function):
86+
@staticmethod
87+
def forward(ctx, x, threshold, slope, const_val):
88+
ctx.save_for_backward(x)
89+
ctx.threshold = threshold
90+
ctx.slope = slope
91+
ctx.const_val = const_val
92+
return silut_forward_script(x, threshold, slope, const_val)
93+
94+
@staticmethod
95+
def backward(ctx, grad_output):
96+
(x,) = ctx.saved_tensors
97+
threshold = ctx.threshold
98+
slope = ctx.slope
99+
100+
grad_input = SiLUTGradFunction.apply(x, grad_output, threshold, slope)
101+
return grad_input, None, None, None
102+
103+
class SiLUTGradFunction(torch.autograd.Function):
104+
@staticmethod
105+
def forward(ctx, x, grad_output, threshold, slope):
106+
ctx.threshold = threshold
107+
ctx.slope = slope
108+
grad_input, grad = silut_backward_script(
109+
x, grad_output, threshold, slope
110+
)
111+
ctx.save_for_backward(x, grad_output, grad)
112+
return grad_input
113+
114+
@staticmethod
115+
def backward(ctx, grad_grad_output):
116+
(x, grad_output, grad) = ctx.saved_tensors
117+
threshold = ctx.threshold
118+
slope = ctx.slope
119+
120+
grad_input = silut_double_backward_script(
121+
x, grad_grad_output, grad_output, threshold, slope
122+
)
123+
return grad_input, grad * grad_grad_output, None, None
124+
125+
self.SiLUTFunction = SiLUTFunction
126+
127+
def forward(self, x):
128+
return self.SiLUTFunction.apply(x, self.threshold, self.slope, self.const_val)
129+
130+
131+
class SiLUT(torch.nn.Module):
132+
def __init__(self, threshold=3.0):
133+
super().__init__()
134+
135+
def sigmoid(x):
136+
return 1 / (1 + np.exp(-x))
137+
138+
def silu(x):
139+
return x * sigmoid(x)
140+
141+
def silu_grad(x):
142+
sig = sigmoid(x)
143+
return sig + x * sig * (1 - sig)
144+
145+
self.threshold = threshold
146+
self.slope = float(silu_grad(threshold))
147+
self.const = float(silu(threshold))
148+
149+
def forward(self, x: torch.Tensor) -> torch.Tensor:
150+
silu_part = F.silu(x)
151+
mask = x >= self.threshold
152+
if torch.any(mask):
153+
tanh_part = torch.tanh(self.slope * (x - self.threshold)) + self.const
154+
return torch.where(x < self.threshold, silu_part, tanh_part)
155+
else:
156+
return silu_part
157+
158+
21159
class ActivationFn(torch.nn.Module):
22160
def __init__(self, activation: Optional[str]) -> None:
23161
super().__init__()
24162
self.activation: str = activation if activation is not None else "linear"
163+
if self.activation.lower().startswith(
164+
"silut"
165+
) or self.activation.lower().startswith("custom_silu"):
166+
threshold = (
167+
float(self.activation.split(":")[-1]) if ":" in self.activation else 3.0
168+
)
169+
if env.CUSTOM_OP_USE_JIT:
170+
# for efficient training but can not be jit
171+
self.silut = SiLUTScript(threshold=threshold)
172+
else:
173+
# for jit freeze
174+
self.silut = SiLUT(threshold=threshold)
175+
else:
176+
self.silut = None
25177

26178
def forward(self, x: torch.Tensor) -> torch.Tensor:
27179
"""Returns the tensor after applying activation function corresponding to `activation`."""
@@ -41,6 +193,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
41193
return torch.sigmoid(x)
42194
elif self.activation.lower() == "silu":
43195
return F.silu(x)
196+
elif self.activation.lower().startswith(
197+
"silut"
198+
) or self.activation.lower().startswith("custom_silu"):
199+
assert self.silut is not None
200+
return self.silut(x)
44201
elif self.activation.lower() == "linear" or self.activation.lower() == "none":
45202
return x
46203
else:

deepmd/tf/common.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,47 @@ def silu(x: tf.Tensor) -> tf.Tensor:
144144
return x * tf.sigmoid(x)
145145

146146

147+
def get_silut(activation_function: str = "silut"):
148+
import numpy as np
149+
150+
def sigmoid(x):
151+
return 1 / (1 + np.exp(-x))
152+
153+
def silu(x):
154+
return x * sigmoid(x)
155+
156+
def silu_grad(x):
157+
sig = sigmoid(x)
158+
return sig + x * sig * (1 - sig)
159+
160+
threshold = (
161+
float(activation_function.split(":")[-1]) if ":" in activation_function else 3.0
162+
)
163+
slope = float(silu_grad(threshold))
164+
const = float(silu(threshold))
165+
166+
def silut(x: tf.Tensor) -> tf.Tensor:
167+
"""The customized sigmoid-weighted linear unit with tanh.
168+
169+
Parameters
170+
----------
171+
x : tf.Tensor
172+
float Tensor to perform activation
173+
174+
Returns
175+
-------
176+
tf.Tensor
177+
`x` with the SiLUT activation applied
178+
"""
179+
return tf.where(
180+
x < threshold,
181+
x * tf.sigmoid(x),
182+
tf.nn.tanh(slope * (x - threshold)) + const,
183+
)
184+
185+
return silut
186+
187+
147188
ACTIVATION_FN_DICT = {
148189
"relu": tf.nn.relu,
149190
"relu6": tf.nn.relu6,
@@ -153,6 +194,7 @@ def silu(x: tf.Tensor) -> tf.Tensor:
153194
"gelu": gelu,
154195
"gelu_tf": gelu_tf,
155196
"silu": silu,
197+
"silut": get_silut("silut"),
156198
"linear": lambda x: x,
157199
"none": lambda x: x,
158200
}
@@ -182,6 +224,8 @@ def get_activation_func(
182224
if activation_fn is None:
183225
activation_fn = "none"
184226
assert activation_fn is not None
227+
if activation_fn.lower().startswith("silut"):
228+
ACTIVATION_FN_DICT[activation_fn.lower()] = get_silut(activation_fn.lower())
185229
if activation_fn.lower() not in ACTIVATION_FN_DICT:
186230
raise RuntimeError(f"{activation_fn} is not a valid activation function")
187231
return ACTIVATION_FN_DICT[activation_fn.lower()]

0 commit comments

Comments
 (0)