@@ -278,7 +278,7 @@ def __init__(
278278 rescale : float = 20 ,
279279 threshold : float = 0.5 ,
280280 attribute : str = "cui" ,
281- span_getter : SpanGetterArg = { "ents" : True } ,
281+ span_getter : SpanGetterArg = None ,
282282 context_getter : Optional [SpanGetterArg ] = None ,
283283 reference_mode : Literal ["concept" , "synonym" ] = "concept" ,
284284 probability_mode : Literal ["softmax" , "sigmoid" ] = "sigmoid" ,
@@ -289,6 +289,7 @@ def __init__(
289289 sub_span_getter = getattr (embedding , "span_getter" , None )
290290 if sub_span_getter is not None and span_getter is None : # pragma: no cover
291291 span_getter = sub_span_getter
292+ span_getter = span_getter or {"ents" : True }
292293 sub_context_getter = getattr (embedding , "context_getter" , None )
293294 if (
294295 sub_context_getter is not None and context_getter is None
@@ -309,7 +310,7 @@ def __init__(
309310 self .reference_mode = reference_mode
310311 self .probability_mode = probability_mode
311312 self .init_weights = init_weights
312- self .context_getter : SpanGetter = context_getter or span_getter
313+ self .context_getter : SpanGetter = context_getter or self . span_getter
313314 with warnings .catch_warnings ():
314315 warnings .simplefilter ("ignore" , UserWarning )
315316 self .classifier = Metric (
0 commit comments