Skip to content

Commit

Permalink
Merge pull request #1705 from ranaroussi/fix/tz-cache-init
Browse files Browse the repository at this point in the history
Fix TZ cache exception blocking import
  • Loading branch information
ValueRaider authored Sep 30, 2023
2 parents 95ef486 + 62b2c25 commit 9581b8b
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 62 deletions.
107 changes: 54 additions & 53 deletions tests/ticker.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,22 +152,10 @@ def test_goodTicker_withProxy(self):
tkr = "IBM"
dat = yf.Ticker(tkr, session=self.session)

dat._fetch_ticker_tz(proxy=self.proxy, timeout=5, debug_mode=False, raise_errors=False)
dat._get_ticker_tz(proxy=self.proxy, timeout=5, debug_mode=False, raise_errors=False)
dat._fetch_ticker_tz(proxy=self.proxy, timeout=5)
dat._get_ticker_tz(proxy=self.proxy, timeout=5)
dat.history(period="1wk", proxy=self.proxy)

v = dat.stats(proxy=self.proxy)
self.assertIsNotNone(v)
self.assertTrue(len(v) > 0)

v = dat.get_recommendations(proxy=self.proxy)
self.assertIsNotNone(v)
self.assertFalse(v.empty)

v = dat.get_calendar(proxy=self.proxy)
self.assertIsNotNone(v)
self.assertFalse(v.empty)

v = dat.get_major_holders(proxy=self.proxy)
self.assertIsNotNone(v)
self.assertFalse(v.empty)
Expand All @@ -184,38 +172,6 @@ def test_goodTicker_withProxy(self):
self.assertIsNotNone(v)
self.assertTrue(len(v) > 0)

v = dat.get_sustainability(proxy=self.proxy)
self.assertIsNotNone(v)
self.assertFalse(v.empty)

v = dat.get_recommendations_summary(proxy=self.proxy)
self.assertIsNotNone(v)
self.assertFalse(v.empty)

v = dat.get_analyst_price_target(proxy=self.proxy)
self.assertIsNotNone(v)
self.assertFalse(v.empty)

v = dat.get_rev_forecast(proxy=self.proxy)
self.assertIsNotNone(v)
self.assertFalse(v.empty)

v = dat.get_earnings_forecast(proxy=self.proxy)
self.assertIsNotNone(v)
self.assertFalse(v.empty)

v = dat.get_trend_details(proxy=self.proxy)
self.assertIsNotNone(v)
self.assertFalse(v.empty)

v = dat.get_earnings_trend(proxy=self.proxy)
self.assertIsNotNone(v)
self.assertFalse(v.empty)

v = dat.get_earnings(proxy=self.proxy)
self.assertIsNotNone(v)
self.assertFalse(v.empty)

v = dat.get_income_stmt(proxy=self.proxy)
self.assertIsNotNone(v)
self.assertFalse(v.empty)
Expand Down Expand Up @@ -244,10 +200,6 @@ def test_goodTicker_withProxy(self):
self.assertIsNotNone(v)
self.assertFalse(v.empty)

v = dat.get_shares(proxy=self.proxy)
self.assertIsNotNone(v)
self.assertFalse(v.empty)

v = dat.get_shares_full(proxy=self.proxy)
self.assertIsNotNone(v)
self.assertFalse(v.empty)
Expand All @@ -264,11 +216,60 @@ def test_goodTicker_withProxy(self):
self.assertIsNotNone(v)
self.assertFalse(v.empty)

# TODO: enable after merge
# dat.get_history_metadata(proxy=self.proxy)
dat.get_history_metadata(proxy=self.proxy)
self.assertIsNotNone(v)
self.assertTrue(len(v) > 0)

# Below will fail because not ported to Yahoo API

# v = dat.stats(proxy=self.proxy)
# self.assertIsNotNone(v)
# self.assertTrue(len(v) > 0)

# v = dat.get_recommendations(proxy=self.proxy)
# self.assertIsNotNone(v)
# self.assertFalse(v.empty)

# v = dat.get_calendar(proxy=self.proxy)
# self.assertIsNotNone(v)
# self.assertFalse(v.empty)

# v = dat.get_sustainability(proxy=self.proxy)
# self.assertIsNotNone(v)
# self.assertFalse(v.empty)

# v = dat.get_recommendations_summary(proxy=self.proxy)
# self.assertIsNotNone(v)
# self.assertFalse(v.empty)

# v = dat.get_analyst_price_target(proxy=self.proxy)
# self.assertIsNotNone(v)
# self.assertFalse(v.empty)

# v = dat.get_rev_forecast(proxy=self.proxy)
# self.assertIsNotNone(v)
# self.assertFalse(v.empty)

# v = dat.get_earnings_forecast(proxy=self.proxy)
# self.assertIsNotNone(v)
# self.assertFalse(v.empty)

# v = dat.get_trend_details(proxy=self.proxy)
# self.assertIsNotNone(v)
# self.assertFalse(v.empty)

# v = dat.get_earnings_trend(proxy=self.proxy)
# self.assertIsNotNone(v)
# self.assertFalse(v.empty)

# v = dat.get_earnings(proxy=self.proxy)
# self.assertIsNotNone(v)
# self.assertFalse(v.empty)

# v = dat.get_shares(proxy=self.proxy)
# self.assertIsNotNone(v)
# self.assertFalse(v.empty)


class TestTickerHistory(unittest.TestCase):
session = None
Expand Down Expand Up @@ -318,7 +319,7 @@ def test_no_expensive_calls_introduced(self):
actual_urls_called = tuple([r.url for r in session.cache.filter()])
session.close()
expected_urls = (
'https://query2.finance.yahoo.com/v8/finance/chart/GOOGL?events=div,splits,capitalGains&includePrePost=False&interval=1d&range=1y',
'https://query2.finance.yahoo.com/v8/finance/chart/GOOGL?events=div%2Csplits%2CcapitalGains&includePrePost=False&interval=1d&range=1y',
)
self.assertEqual(expected_urls, actual_urls_called, "Different than expected url used to fetch history.")

Expand Down
10 changes: 10 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest
# import requests_cache
import tempfile
import os


class TestUtils(unittest.TestCase):
Expand All @@ -40,6 +41,15 @@ def test_storeTzNoRaise(self):
cache.store(tkr, tz1)
cache.store(tkr, tz2)

def test_setTzCacheLocation(self):
self.assertEqual(yf.utils._DBManager.get_location(), self.tempCacheDir.name)

tkr = 'AMZN'
tz1 = "America/New_York"
cache = yf.utils.get_tz_cache()
cache.store(tkr, tz1)

self.assertTrue(os.path.exists(os.path.join(self.tempCacheDir.name, "tkr-tz.db")))

def suite():
suite = unittest.TestSuite()
Expand Down
34 changes: 25 additions & 9 deletions yfinance/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,8 @@ class _TzCacheManager:
@classmethod
def get_tz(cls):
if cls._tz_cache is None:
cls._initialise()
with _cache_init_lock:
cls._initialise()
return cls._tz_cache

@classmethod
Expand Down Expand Up @@ -976,7 +977,13 @@ def _initialise(cls, cache_dir=None):
cls._cache_dir = cache_dir

if not _os.path.isdir(cls._cache_dir):
_os.mkdir(cls._cache_dir)
try:
_os.makedirs(cls._cache_dir)
except OSError as err:
raise _TzCacheException(f"Error creating TzCache folder: '{cls._cache_dir}' reason: {err}")
elif not (_os.access(cls._cache_dir, _os.R_OK) and _os.access(cls._cache_dir, _os.W_OK)):
raise _TzCacheException(f"Cannot read and write in TzCache folder: '{cls._cache_dir}'")

cls._db = _peewee.SqliteDatabase(
_os.path.join(cls._cache_dir, 'tkr-tz.db'),
pragmas={'journal_mode': 'wal', 'cache_size': -64}
Expand All @@ -987,10 +994,16 @@ def _initialise(cls, cache_dir=None):
_os.remove(old_cache_file_path)

@classmethod
def change_location(cls, new_cache_dir):
cls._db.close()
cls._db = None
def set_location(cls, new_cache_dir):
if cls._db is not None:
cls._db.close()
cls._db = None
cls._cache_dir = new_cache_dir

@classmethod
def get_location(cls):
return cls._cache_dir

# close DB when Python exists
_atexit.register(_DBManager.close_db)

Expand All @@ -1000,7 +1013,11 @@ class _KV(_peewee.Model):
value = _peewee.CharField(null=True)

class Meta:
database = _DBManager.get_database()
try:
database = _DBManager.get_database()
except Exception:
# This code runs at import, so Logger won't be ready yet, so must discard exception.
database = None
without_rowid = True


Expand Down Expand Up @@ -1043,8 +1060,7 @@ def get_tz_cache():
dummy cache with same interface as real cash.
"""
# as this can be called from multiple threads, protect it.
with _cache_init_lock:
return _TzCacheManager.get_tz()
return _TzCacheManager.get_tz()


def set_tz_cache_location(cache_dir: str):
Expand All @@ -1055,4 +1071,4 @@ def set_tz_cache_location(cache_dir: str):
:param cache_dir: Path to use for caches
:return: None
"""
_DBManager.change_location(cache_dir)
_DBManager.set_location(cache_dir)

0 comments on commit 9581b8b

Please sign in to comment.