Source code for cached_historical_data_fetcher.cache.base

from __future__ import annotations

from abc import ABCMeta, abstractmethod
from logging import getLogger
from pathlib import Path
from typing import Any, Literal

from pandas import DataFrame, Timedelta, Timestamp
from slugify import slugify

from ..io import get_path, read, update

BASE_PATH = Path(f"~/.cache/{__name__.split('.')[0]}").expanduser()
LOG = getLogger(__name__)


[docs]class HistoricalDataCache(metaclass=ABCMeta): """Base class for historical data cache. Usage ----- 1. Override `self.get()` to implement the logic. 2. Override `self.to_update()` if the index is not Timestamp. 3. Call `self.update()` to get historical data. Examples -------- .. code-block:: python from cached_historical_data_fetcher import HistoricalDataCache from pandas import DataFrame, Timedelta, Timestamp, date_range class MyCache(HistoricalDataCache): interval: Timedelta = Timedelta(days=1) async def get(self, start: Timestamp | None, *args: Any, **kwargs: Any) -> DataFrame: start = start or Timestamp.utcnow().floor("10D") date_range_chunk = date_range(start, Timestamp.utcnow(), freq="D") return DataFrame({"day": [d.day for d in date_range_chunk]}, index=date_range_chunk) df = await MyCache().update() """ add_interval: bool = True """If True, `start` in `self.get()` is the last index of historical data + `self.interval`. If False, `start` in `self.get()` is the last index of historical data.""" folder: str """The folder name. By default, the class name.""" interval: Timedelta """The interval to update cache file.""" mismatch: Literal["warn", "raise"] | int | None = "warn" """The action when data mismatch. If int, log level. If None, do nothing.""" keep: Literal["first", "last"] = "last" """Which duplicated index to keep.""" compress: int | str | tuple[str, int] = ("lz4", 3) """The compression level.""" protocol: int | None = None """The pickle protocol.""" def __init__(self) -> None: """Initialize HistoricalDataCache.""" self.folder = self.__class__.__name__
[docs] def path(self, name: str) -> Path: return get_path(self.folder, name)
[docs] async def update( self, reload: bool = False, *args: Any, **kwargs: Any, ) -> DataFrame: """Update cache file with DataFrame. Parameters ---------- reload : bool, optional Whether to ignore cache file and reload, by default False *args : Any The arguments for `self.get()` and `self.to_update()`. **kwargs : Any The keyword arguments for `self.get()` and `self.to_update()`. Returns ------- DataFrame The DataFrame read from cache file. Raises ------ RuntimeError If unexpected type read from cache file or `self.get()` does not return DataFrame or `self.to_update()` does not return bool. """ # generate name name = "_".join([str(arg) for arg in args]) + "_".join( [f"{key}-{value}" for key, value in kwargs.items()] ) name = slugify(name) # read df = await read(self.path(name)) if not reload else DataFrame() if not isinstance(df, DataFrame): raise TypeError(f"Unexpected type read from {self.path(name)}: {type(df)}") # check if need to update start = None if not df.empty: start = df.index.max() if isinstance(start, tuple): start = start[0] to_update = self.to_update(start, *args, **kwargs) if not isinstance(to_update, bool): raise TypeError(f"to_update must return bool: {type(to_update)}") # update if to_update: df = await self.get( ((start + self.interval) if self.add_interval else start) if not df.empty else None, *args, **kwargs, ) if not isinstance(df, DataFrame): raise TypeError(f"get must return DataFrame: {type(df)}") old_len = len(df) df = await update( self.path(name), df, reload=reload, mismatch=self.mismatch, keep=self.keep, compress=self.compress, protocol=self.protocol, ) LOG.debug( f"Updated {name} from {self.path}, [{df.index.min()}~{df.index.max()}]" f" ({old_len}->{len(df)} rows)" ) else: LOG.debug( f"Loaded {name} from {self.path}, [{df.index.min()}~{df.index.max()}]" f" ({len(df)} rows)" ) return df
[docs] @abstractmethod async def get( self, start: Timestamp | Any | None, *args: Any, **kwargs: Any ) -> DataFrame: """Get historical data. Override this method to implement the logic. Parameters ---------- start : Timestamp | Any | None The last index of historical data. Returns ------- DataFrame The historical data. It is recommended to set index to Timestamp or unique incremental number. If the index is not Timestamp, override `self.to_update()` to implement the logic as well. Multiindex is supported. It is recommended to set the first level to Timestamp. """
[docs] def to_update(self, end: Timestamp | Any | None, *args: Any, **kwargs: Any) -> bool: """Check if need to update cache file. Override this method to implement the logic. By default, update if cache file is older than self.interval. Parameters ---------- end : Timestamp | Any | None The last index of historical data. If the cache file is empty, end is None. Returns ------- bool Whether to update cache file. """ return end is None or end + self.interval < Timestamp.utcnow().tz_convert( tz=end.tz )