77
88from investing_algorithm_framework .domain import DataProvider , \
99 OperationalException , ImproperlyConfigured , DataSource , DataType , \
10- BacktestDateRange , tqdm , convert_polars_to_pandas
10+ BacktestDateRange , tqdm , convert_polars_to_pandas , TimeFrame
1111
1212logger = logging .getLogger ("investing_algorithm_framework" )
1313
@@ -26,6 +26,7 @@ def __init__(self, data_providers=[]):
2626 self .data_providers_lookup = defaultdict ()
2727 self .ohlcv_data_providers = defaultdict ()
2828 self .ohlcv_data_providers_no_market = defaultdict ()
29+ self .ohlcv_data_providers_with_timeframe = defaultdict ()
2930 self .ticker_data_providers = defaultdict ()
3031
3132 def add (self , data_provider : DataProvider ):
@@ -80,11 +81,15 @@ def register(self, data_source: DataSource) -> DataProvider:
8081
8182 symbol = data_source .symbol
8283 market = data_source .market
84+ time_frame = data_source .time_frame
8385
8486 if DataType .OHLCV .equals (data_source .data_type ):
8587 if symbol not in self .ohlcv_data_providers :
8688 self .ohlcv_data_providers [(symbol , market )] = best_provider
8789 self .ohlcv_data_providers_no_market [symbol ] = best_provider
90+ self .ohlcv_data_providers_with_timeframe [
91+ (symbol , market , time_frame )
92+ ] = best_provider
8893 else :
8994 try :
9095 # If the symbol already exists, we can update the provider
@@ -105,6 +110,11 @@ def register(self, data_source: DataSource) -> DataProvider:
105110 self .ohlcv_data_providers_no_market [symbol ] = \
106111 best_provider
107112
113+ time_frame_key = (symbol , market , time_frame )
114+ self .ohlcv_data_providers_with_timeframe [
115+ time_frame_key
116+ ] = best_provider
117+
108118 except Exception :
109119 # If the existing provider does not have a time_frame
110120 # attribute, we can safely ignore this
@@ -165,12 +175,16 @@ def register_backtest_data_source(
165175
166176 symbol = data_source .symbol
167177 market = data_source .market
178+ time_frame = data_source .time_frame
168179
169180 if DataType .OHLCV .equals (data_source .data_type ):
170181
171182 if symbol not in self .ohlcv_data_providers :
172183 self .ohlcv_data_providers [(symbol , market )] = best_provider
173184 self .ohlcv_data_providers_no_market [symbol ] = best_provider
185+ self .ohlcv_data_providers_with_timeframe [
186+ (symbol , market , time_frame )
187+ ] = best_provider
174188 else :
175189 try :
176190 # If the symbol already exists, we can update the provider
@@ -192,6 +206,11 @@ def register_backtest_data_source(
192206 self .ohlcv_data_providers_no_market [symbol ] = \
193207 best_provider
194208
209+ time_frame_key = (symbol , market , data_source .time_frame )
210+ self .ohlcv_data_providers_with_timeframe [
211+ time_frame_key
212+ ] = best_provider
213+
195214 except Exception :
196215 # If the existing provider does not have a time_frame
197216 # attribute, we can safely ignore this
@@ -264,20 +283,31 @@ def __len__(self):
264283 return len (self .data_providers_lookup )
265284
266285 def get_ohlcv_data_provider (
267- self , symbol : str , market : Optional [str ] = None
286+ self ,
287+ symbol : str ,
288+ market : Optional [str ] = None ,
289+ time_frame : Optional [str ] = None
268290 ) -> Optional [DataProvider ]:
269291 """
270292 Get the OHLCV data provider for a given symbol and market.
271293
272294 Args:
273295 symbol (str): The symbol to get the data provider for.
274296 market (Optional[str]): The market to get the data provider for.
297+ time_frame (Optional[str]): The time frame to get the
298+ data provider for.
275299
276300 Returns:
277301 DataProvider: The OHLCV data provider for the symbol and market,
278302 or None if no provider is found.
279303 """
280304
305+ if market is not None and time_frame is not None :
306+ time_frame = TimeFrame .from_value (time_frame )
307+ return self .ohlcv_data_providers_with_timeframe .get (
308+ (symbol , market , time_frame ), None
309+ )
310+
281311 if market is None :
282312 # If no market is specified
283313 return self .ohlcv_data_providers_no_market .get (symbol , None )
@@ -462,7 +492,8 @@ def get_ohlcv_data(
462492 start_date : Optional [datetime ] = None ,
463493 end_date : Optional [datetime ] = None ,
464494 window_size : Optional [int ] = None ,
465- pandas : bool = False
495+ pandas : bool = False ,
496+ add_pandas_index : bool = True ,
466497 ):
467498 """
468499 Function to get OHLCV data from the data provider.
@@ -483,6 +514,7 @@ def get_ohlcv_data(
483514 data_provider = self .data_provider_index .get_ohlcv_data_provider (
484515 symbol = symbol ,
485516 market = market ,
517+ time_frame = time_frame
486518 )
487519
488520 if data_provider is None :
@@ -511,7 +543,9 @@ def get_ohlcv_data(
511543
512544 if pandas :
513545 if isinstance (data , pl .DataFrame ):
514- return convert_polars_to_pandas (data )
546+ return convert_polars_to_pandas (
547+ data , add_index = add_pandas_index
548+ )
515549 else :
516550 return data
517551
0 commit comments