@@ -92,7 +92,7 @@ def create_basis_matrix(df_events: pd.DataFrame, model_dates: np.ndarray):
9292
9393"""
9494
95- from typing import cast
95+ from typing import Literal , cast
9696
9797import numpy as np
9898import numpy .typing as npt
@@ -270,6 +270,183 @@ def function(self, x: pt.TensorLike, sigma: pt.TensorLike) -> TensorVariable:
270270 }
271271
272272
273+ class HalfGaussianBasis (Basis ):
274+ R"""One-sided Gaussian basis transformation.
275+
276+ .. plot::
277+ :context: close-figs
278+
279+ import matplotlib.pyplot as plt
280+ from pymc_marketing.mmm.events import HalfGaussianBasis
281+ from pymc_extras.prior import Prior
282+ half_gaussian = HalfGaussianBasis(
283+ priors={
284+ "sigma": Prior("Gamma", mu=[3, 4], sigma=1, dims="event"),
285+ }
286+ )
287+ coords = {"event": ["PyData-Berlin", "PyCon-Finland"]}
288+ prior = half_gaussian.sample_prior(coords=coords)
289+ curve = half_gaussian.sample_curve(prior)
290+ fig, axes = half_gaussian.plot_curve(
291+ curve, subplot_kwargs={"figsize": (6, 3), "sharey": True}
292+ )
293+ for ax in axes:
294+ ax.set_xlabel("")
295+ plt.show()
296+
297+ Parameters
298+ ----------
299+ mode : Literal["after", "before"]
300+ Whether the basis is located before or after the event.
301+ include_event : bool
302+ Whether to include the event days in the basis.
303+ priors : dict[str, Prior]
304+ Prior for the sigma parameter.
305+ prefix : str
306+ Prefix for the parameter names.
307+ """
308+
309+ lookup_name = "half_gaussian"
310+
311+ def __init__ (
312+ self ,
313+ mode : Literal ["after" , "before" ] = "after" ,
314+ include_event : bool = True ,
315+ ** kwargs ,
316+ ):
317+ super ().__init__ (** kwargs )
318+ self .mode = mode
319+ self .include_event = include_event
320+
321+ def function (self , x : pt .TensorLike , sigma : pt .TensorLike ) -> TensorVariable :
322+ """One-sided Gaussian bump function."""
323+ rv = pm .Normal .dist (mu = 0.0 , sigma = sigma )
324+ out = pm .math .exp (pm .logp (rv , x ))
325+ # Sign determines if the zeroing happens after or before the event.
326+ sign = 1 if self .mode == "after" else - 1
327+ # Build boolean mask(s) in x's shape and broadcast to out's shape.
328+ pre_mask = sign * x < 0
329+ if not self .include_event :
330+ pre_mask = pm .math .or_ (pre_mask , sign * x == 0 )
331+
332+ # Ensure mask matches output shape for elementwise switch
333+ pre_mask = pt .broadcast_to (pre_mask , out .shape )
334+
335+ return pt .switch (pre_mask , 0 , out )
336+
337+ def to_dict (self ) -> dict :
338+ """Convert the half Gaussian basis to a dictionary."""
339+ return {
340+ ** super ().to_dict (),
341+ "mode" : self .mode ,
342+ "include_event" : self .include_event ,
343+ }
344+
345+ default_priors = {
346+ "sigma" : Prior ("Gamma" , mu = 7 , sigma = 1 ),
347+ }
348+
349+
350+ class AsymmetricGaussianBasis (Basis ):
351+ R"""Asymmetric Gaussian bump basis transformation.
352+
353+ Allows different widths (sigma_before, sigma_after) and amplitudes (a_after)
354+ after the event.
355+
356+ .. plot::
357+ :context: close-figs
358+
359+ import matplotlib.pyplot as plt
360+ from pymc_marketing.mmm.events import AsymmetricGaussianBasis
361+ from pymc_extras.prior import Prior
362+ asy_gaussian = AsymmetricGaussianBasis(
363+ priors={
364+ "sigma_before": Prior("Gamma", mu=[3, 4], sigma=1, dims="event"),
365+ "a_after": Prior("Normal", mu=[-.75, .5], sigma=.2, dims="event"),
366+ }
367+ )
368+ coords = {"event": ["PyData-Berlin", "PyCon-Finland"]}
369+ prior = asy_gaussian.sample_prior(coords=coords)
370+ curve = asy_gaussian.sample_curve(prior)
371+ fig, axes = asy_gaussian.plot_curve(
372+ curve, subplot_kwargs={"figsize": (6, 3), "sharey": True}
373+ )
374+ for ax in axes:
375+ ax.set_xlabel("")
376+ plt.show()
377+
378+ Parameters
379+ ----------
380+ event_in : Literal["before", "after", "exclude"]
381+ Whether to include the event in the before or after part of the basis,
382+ or leave it out entirely. Default is "after".
383+ priors : dict[str, Prior]
384+ Prior for the sigma_before, sigma_after, a_before, and a_after parameters.
385+ prefix : str
386+ Prefix for the parameters.
387+ """
388+
389+ lookup_name = "asymmetric_gaussian"
390+
391+ def __init__ (
392+ self ,
393+ event_in : Literal ["before" , "after" , "exclude" ] = "after" ,
394+ ** kwargs ,
395+ ):
396+ super ().__init__ (** kwargs )
397+ self .event_in = event_in
398+
399+ def function (
400+ self ,
401+ x : pt .TensorLike ,
402+ sigma_before : pt .TensorLike ,
403+ sigma_after : pt .TensorLike ,
404+ a_after : pt .TensorLike ,
405+ ) -> pt .TensorVariable :
406+ """Asymmetric Gaussian bump function."""
407+ match self .event_in :
408+ case "before" :
409+ indicator_before = pt .cast (x <= 0 , "float32" )
410+ indicator_after = pt .cast (x > 0 , "float32" )
411+ case "after" :
412+ indicator_before = pt .cast (x < 0 , "float32" )
413+ indicator_after = pt .cast (x >= 0 , "float32" )
414+ case "exclude" :
415+ indicator_before = pt .cast (x < 0 , "float32" )
416+ indicator_after = pt .cast (x > 0 , "float32" )
417+ case _:
418+ raise ValueError (f"Invalid event_in: { self .event_in } " )
419+
420+ rv_before = pm .Normal .dist (mu = 0.0 , sigma = sigma_before )
421+ rv_after = pm .Normal .dist (mu = 0.0 , sigma = sigma_after )
422+
423+ y_before = pt .switch (
424+ indicator_before ,
425+ pm .math .exp (pm .logp (rv_before , x )),
426+ 0 ,
427+ )
428+ y_after = pt .switch (
429+ indicator_after ,
430+ pm .math .exp (pm .logp (rv_after , x )) * a_after ,
431+ 0 ,
432+ )
433+
434+ return y_before + y_after
435+
436+ def to_dict (self ) -> dict :
437+ """Convert the asymmetric Gaussian basis to a dictionary."""
438+ return {
439+ ** super ().to_dict (),
440+ "event_in" : self .event_in ,
441+ }
442+
443+ default_priors = {
444+ "sigma_before" : Prior ("Gamma" , mu = 3 , sigma = 1 ),
445+ "sigma_after" : Prior ("Gamma" , mu = 7 , sigma = 2 ),
446+ "a_after" : Prior ("Normal" , mu = 1 , sigma = 0.5 ),
447+ }
448+
449+
273450def days_from_reference (
274451 dates : pd .Series | pd .DatetimeIndex ,
275452 reference_date : str | pd .Timestamp ,
0 commit comments