@@ -455,18 +455,42 @@ def _repr_html_(self):
455455def sample (
456456 compiled_model : CompiledModel ,
457457 * ,
458- draws : int | None ,
459- tune : int | None ,
460- chains : int ,
461- cores : Optional [int ],
462- seed : Optional [int ],
463- save_warmup : bool ,
464- progress_bar : bool ,
458+ draws : int | None = None ,
459+ tune : int | None = None ,
460+ chains : int | None = None ,
461+ cores : int | None = None ,
462+ seed : int | None = None ,
463+ save_warmup : bool = True ,
464+ progress_bar : bool = True ,
465+ low_rank_modified_mass_matrix : bool = False ,
466+ transform_adapt : bool = False ,
467+ init_mean : np .ndarray | None = None ,
468+ return_raw_trace : bool = False ,
469+ progress_template : str | None = None ,
470+ progress_style : str | None = None ,
471+ progress_rate : int = 100 ,
472+ ) -> arviz .InferenceData : ...
473+
474+
475+ @overload
476+ def sample (
477+ compiled_model : CompiledModel ,
478+ * ,
479+ draws : int | None = None ,
480+ tune : int | None = None ,
481+ chains : int | None = None ,
482+ cores : int | None = None ,
483+ seed : int | None = None ,
484+ save_warmup : bool = True ,
485+ progress_bar : bool = True ,
465486 low_rank_modified_mass_matrix : bool = False ,
466487 transform_adapt : bool = False ,
467- init_mean : Optional [ np .ndarray ] ,
468- return_raw_trace : bool ,
488+ init_mean : np .ndarray | None = None ,
489+ return_raw_trace : bool = False ,
469490 blocking : Literal [True ],
491+ progress_template : str | None = None ,
492+ progress_style : str | None = None ,
493+ progress_rate : int = 100 ,
470494 ** kwargs ,
471495) -> arviz .InferenceData : ...
472496
@@ -475,18 +499,21 @@ def sample(
475499def sample (
476500 compiled_model : CompiledModel ,
477501 * ,
478- draws : int | None ,
479- tune : int | None ,
480- chains : int ,
481- cores : Optional [ int ] ,
482- seed : Optional [ int ] ,
483- save_warmup : bool ,
484- progress_bar : bool ,
502+ draws : int | None = None ,
503+ tune : int | None = None ,
504+ chains : int | None = None ,
505+ cores : int | None = None ,
506+ seed : int | None = None ,
507+ save_warmup : bool = True ,
508+ progress_bar : bool = True ,
485509 low_rank_modified_mass_matrix : bool = False ,
486510 transform_adapt : bool = False ,
487- init_mean : Optional [ np .ndarray ] ,
488- return_raw_trace : bool ,
511+ init_mean : np .ndarray | None = None ,
512+ return_raw_trace : bool = False ,
489513 blocking : Literal [False ],
514+ progress_template : str | None = None ,
515+ progress_style : str | None = None ,
516+ progress_rate : int = 100 ,
490517 ** kwargs ,
491518) -> _BackgroundSampler : ...
492519
@@ -496,21 +523,21 @@ def sample(
496523 * ,
497524 draws : int | None = None ,
498525 tune : int | None = None ,
499- chains : int = 6 ,
500- cores : Optional [ int ] = None ,
501- seed : Optional [ int ] = None ,
526+ chains : int | None = None ,
527+ cores : int | None = None ,
528+ seed : int | None = None ,
502529 save_warmup : bool = True ,
503530 progress_bar : bool = True ,
504531 low_rank_modified_mass_matrix : bool = False ,
505532 transform_adapt : bool = False ,
506- init_mean : Optional [ np .ndarray ] = None ,
533+ init_mean : np .ndarray | None = None ,
507534 return_raw_trace : bool = False ,
508535 blocking : bool = True ,
509- progress_template : Optional [ str ] = None ,
510- progress_style : Optional [ str ] = None ,
536+ progress_template : str | None = None ,
537+ progress_style : str | None = None ,
511538 progress_rate : int = 100 ,
512539 ** kwargs ,
513- ) -> arviz .InferenceData :
540+ ) -> arviz .InferenceData | _BackgroundSampler :
514541 """Sample the posterior distribution for a compiled model.
515542
516543 Parameters
@@ -618,7 +645,8 @@ def sample(
618645 settings .num_tune = tune
619646 if draws is not None :
620647 settings .num_draws = draws
621- settings .num_chains = chains
648+ if chains is not None :
649+ settings .num_chains = chains
622650
623651 for name , val in kwargs .items ():
624652 setattr (settings , name , val )
@@ -629,7 +657,10 @@ def sample(
629657 available = os .process_cpu_count () # type: ignore
630658 except AttributeError :
631659 available = os .cpu_count ()
632- cores = min (chains , cast (int , available ))
660+ if chains is None :
661+ cores = available
662+ else :
663+ cores = min (chains , cast (int , available ))
633664
634665 if init_mean is None :
635666 init_mean = np .zeros (compiled_model .n_dim )
0 commit comments