From 41a4bc1cc1ff96b36ff702d27203f023eb999c1f Mon Sep 17 00:00:00 2001 From: Linlang Date: Wed, 6 Sep 2023 17:00:50 +0800 Subject: [PATCH 01/25] add_baostock_collector --- .../data_collector/baostock_5min/README.md | 82 +++ .../data_collector/baostock_5min/collector.py | 529 ++++++++++++++++++ .../baostock_5min/requirements.txt | 13 + 3 files changed, 624 insertions(+) create mode 100644 scripts/data_collector/baostock_5min/README.md create mode 100644 scripts/data_collector/baostock_5min/collector.py create mode 100644 scripts/data_collector/baostock_5min/requirements.txt diff --git a/scripts/data_collector/baostock_5min/README.md b/scripts/data_collector/baostock_5min/README.md new file mode 100644 index 0000000000..a2739b9413 --- /dev/null +++ b/scripts/data_collector/baostock_5min/README.md @@ -0,0 +1,82 @@ +## Collector Data + +### Get Qlib data(`bin file`) + + - get data: `python scripts/get_data.py qlib_data` + - parameters: + - `target_dir`: save dir, by default *~/.qlib/qlib_data/cn_data_5min* + - `version`: dataset version, value from [`v2`], by default `v2` + - `v2` end date is *2022-12* + - `interval`: `5min` + - `region`: `hs300` + - `delete_old`: delete existing data from `target_dir`(*features, calendars, instruments, dataset_cache, features_cache*), value from [`True`, `False`], by default `True` + - `exists_skip`: traget_dir data already exists, skip `get_data`, value from [`True`, `False`], by default `False` + - examples: + ```bash + # hs300 5min + python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/hs300_data_5min --region hs300 --interval 5min + ``` + +### Collector *Baostock high frequency* data to qlib +> collector *Baostock high frequency* data and *dump* into `qlib` format. +> If the above ready-made data can't meet users' requirements, users can follow this section to crawl the latest data and convert it to qlib-data. + 1. download data to csv: `python scripts/data_collector/baostock_5min/collector.py download_data` + + This will download the raw data such as date, symbol, open, high, low, close, volume, amount, adjustflag from baostock to a local directory. One file per symbol. + - parameters: + - `source_dir`: save the directory + - `interval`: `5min` + - `region`: `HS300` + - `start`: start datetime, by default *None* + - `end`: end datetime, by default *None* + - examples: + ```bash + # cn 5min data + python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_data_5min --start 2022-01-01 --end 2022-01-30 --interval 5min --region HS300 + ``` + 2. normalize data: `python scripts/data_collector/baostock_5min/collector.py normalize_data` + + This will: + 1. Normalize high, low, close, open price using adjclose. + 2. Normalize the high, low, close, open price so that the first valid trading date's close price is 1. + - parameters: + - `source_dir`: csv directory + - `normalize_dir`: result directory + - `interval`: `5min` + > if **`interval == 5min`**, `qlib_data_1d_dir` cannot be `None` + - `region`: `HS300` + - `date_field_name`: column *name* identifying time in csv files, by default `date` + - `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol` + - `end_date`: if not `None`, normalize the last date saved (*including end_date*); if `None`, it will ignore this parameter; by default `None` + - `qlib_data_1d_dir`: qlib directory(1d data) + ``` + if interval==5min, qlib_data_1d_dir cannot be None, normalize 5min needs to use 1d data; + + qlib_data_1d can be obtained like this: + $ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn --version v3 + ``` + - examples: + ```bash + # normalize 5min cn + python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --normalize_dir ~/.qlib/stock_data/source/hs300_5min_nor --region HS300 --interval 5min + ``` + 3. dump data: `python scripts/dump_bin.py dump_all` + + This will convert the normalized csv in `feature` directory as numpy array and store the normalized data one file per column and one symbol per directory. + + - parameters: + - `csv_path`: stock data path or directory, **normalize result(normalize_dir)** + - `qlib_dir`: qlib(dump) data director + - `freq`: transaction frequency, by default `day` + > `freq_map = {1d:day, 5mih: 5min}` + - `max_workers`: number of threads, by default *16* + - `include_fields`: dump fields, by default `""` + - `exclude_fields`: fields not dumped, by default `""" + > dump_fields = `include_fields if include_fields else set(symbol_df.columns) - set(exclude_fields) exclude_fields else symbol_df.columns` + - `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol` + - `date_field_name`: column *name* identifying time in csv files, by default `date` + - examples: + ```bash + # dump 5min cn + python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/hs300_5min_nor --qlib_dir ~/.qlib/qlib_data/hs300_5min_bin --freq 5min --exclude_fields date,symbol + ``` \ No newline at end of file diff --git a/scripts/data_collector/baostock_5min/collector.py b/scripts/data_collector/baostock_5min/collector.py new file mode 100644 index 0000000000..d443484603 --- /dev/null +++ b/scripts/data_collector/baostock_5min/collector.py @@ -0,0 +1,529 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os +import abc +from re import I +from typing import List +from tqdm import tqdm +import sys +import copy +import time +import datetime +import baostock as bs +from abc import ABC +from pathlib import Path +from typing import Iterable + +import fire +import numpy as np +import pandas as pd +from loguru import logger + +import qlib +from qlib.data import D +from qlib.constant import REG_CN as REGION_CN + +CUR_DIR = Path(__file__).resolve().parent +sys.path.append(str(CUR_DIR.parent.parent)) + +from data_collector.base import BaseCollector, BaseNormalize, BaseRun, Normalize + +from data_collector.utils import generate_minutes_calendar_from_daily + + +class BaostockCollectorHS3005min(BaseCollector): + def __init__( + self, + save_dir: [str, Path], + start=None, + end=None, + interval="5min", + max_workers=4, + max_collector_count=2, + delay=0, + check_data_length: int = None, + limit_nums: int = None, + ): + """ + + Parameters + ---------- + save_dir: str + stock save dir + max_workers: int + workers, default 4 + max_collector_count: int + default 2 + delay: float + time.sleep(delay), default 0 + interval: str + freq, value from [1min, 1d], default 1min + start: str + start datetime, default None + end: str + end datetime, default None + check_data_length: int + check data length, by default None + limit_nums: int + using for debug, by default None + """ + bs.login() + super(BaostockCollectorHS3005min, self).__init__( + save_dir=save_dir, + start=start, + end=end, + interval=interval, + max_workers=max_workers, + max_collector_count=max_collector_count, + delay=delay, + check_data_length=check_data_length, + limit_nums=limit_nums, + ) + + def get_trade_calendar(self): + _format = "%Y-%m-%d" + start = self.start_datetime.strftime(_format) + end = self.end_datetime.strftime(_format) + rs = bs.query_trade_dates(start_date=start, end_date=end) + calendar_list = [] + while (rs.error_code == "0") & rs.next(): + calendar_list.append(rs.get_row_data()) + calendar_df = pd.DataFrame(calendar_list, columns=rs.fields) + trade_calendar_df = calendar_df[~calendar_df["is_trading_day"].isin(["0"])] + # bs.logout() + return trade_calendar_df["calendar_date"].values + + @staticmethod + def process_interval(interval: str): + if interval == "1d": + return {"interval": "d", "fields": "date,code,open,high,low,close,volume,amount,adjustflag"} + if interval == "5min": + return {"interval": "5", "fields": "date,time,code,open,high,low,close,volume,amount,adjustflag"} + + def get_data( + self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + ) -> pd.DataFrame: + df = self.get_data_from_remote( + symbol=symbol, interval=interval, start_datetime=start_datetime, end_datetime=end_datetime + ) + df.columns = ["date", "time", "symbol", "open", "high", "low", "close", "volume", "amount", "adjustflag"] + df["time"] = pd.to_datetime(df["time"], format="%Y%m%d%H%M%S%f") + df["date"] = df["time"].dt.strftime("%Y-%m-%d %H:%M:%S") + df["date"] = df["date"].map(lambda x: pd.Timestamp(x) - pd.Timedelta(minutes=5)) + df.drop(["time"], axis=1, inplace=True) + df["symbol"] = df["symbol"].map(lambda x: str(x).replace(".", "").upper()) + return df + + @staticmethod + def get_data_from_remote( + symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + ) -> pd.DataFrame: + df = pd.DataFrame() + rs = bs.query_history_k_data_plus( + symbol, + BaostockCollectorHS3005min.process_interval(interval=interval)["fields"], + start_date=str(start_datetime.strftime("%Y-%m-%d")), + end_date=str(end_datetime.strftime("%Y-%m-%d")), + frequency=BaostockCollectorHS3005min.process_interval(interval=interval)["interval"], + adjustflag="3", + ) + if rs.error_code == "0" and len(rs.data) > 0: + data_list = rs.data + columns = rs.fields + df = pd.DataFrame(data_list, columns=columns) + return df + + def get_hs300_symbols(self) -> List[str]: + hs300_stocks = [] + trade_calendar = self.get_trade_calendar() + with tqdm(total=len(trade_calendar)) as p_bar: + for date in trade_calendar: + rs = bs.query_hs300_stocks(date=date) + while rs.error_code == "0" and rs.next(): + hs300_stocks.append(rs.get_row_data()) + p_bar.update() + return sorted(set([e[1] for e in hs300_stocks])) + + def get_instrument_list(self): + logger.info("get HS stock symbols......") + symbols = self.get_hs300_symbols() + logger.info(f"get {len(symbols)} symbols.") + return symbols + + def normalize_symbol(self, symbol: str): + return str(symbol).replace(".", "").upper() + + +class BaostockNormalizeHS3005min(BaseNormalize): + COLUMNS = ["open", "close", "high", "low", "volume"] + DAILY_FORMAT = "%Y-%m-%d" + AM_RANGE = ("09:30:00", "11:29:00") + PM_RANGE = ("13:00:00", "14:59:00") + # Whether the trading day of 1min data is consistent with 1d + CONSISTENT_1d = True + CALC_PAUSED_NUM = True + + def __init__( + self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs + ): + """ + + Parameters + ---------- + qlib_data_1d_dir: str, Path + the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data + date_field_name: str + date field name, default is date + symbol_field_name: str + symbol field name, default is symbol + """ + bs.login() + qlib.init(provider_uri=qlib_data_1d_dir) + # self.qlib_data_1d_dir = qlib_data_1d_dir + super(BaostockNormalizeHS3005min, self).__init__(date_field_name, symbol_field_name) + self._all_1d_data = self._get_all_1d_data() + + @staticmethod + def calc_change(df: pd.DataFrame, last_close: float) -> pd.Series: + df = df.copy() + _tmp_series = df["close"].fillna(method="ffill") + _tmp_shift_series = _tmp_series.shift(1) + if last_close is not None: + _tmp_shift_series.iloc[0] = float(last_close) + change_series = _tmp_series / _tmp_shift_series - 1 + return change_series + + def _get_calendar_list(self) -> Iterable[pd.Timestamp]: + # return list(D.calendar(freq="day")) + return self.generate_5min_from_daily(self.calendar_list_1d) + + def _get_all_1d_data(self): + df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") + df.reset_index(inplace=True) + df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True) + df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns)) + return df + + @property + def calendar_list_1d(self): + calendar_list_1d = getattr(self, "_calendar_list_1d", None) + if calendar_list_1d is None: + calendar_list_1d = self._get_1d_calendar_list() + setattr(self, "_calendar_list_1d", calendar_list_1d) + return calendar_list_1d + + @staticmethod + def normalize_baostock( + df: pd.DataFrame, + calendar_list: list = None, + date_field_name: str = "date", + symbol_field_name: str = "symbol", + last_close: float = None, + ): + if df.empty: + return df + symbol = df.loc[df[symbol_field_name].first_valid_index(), symbol_field_name] + columns = copy.deepcopy(BaostockNormalizeHS3005min.COLUMNS) + df = df.copy() + df.set_index(date_field_name, inplace=True) + df.index = pd.to_datetime(df.index) + df = df[~df.index.duplicated(keep="first")] + if calendar_list is not None: + df = df.reindex( + pd.DataFrame(index=calendar_list) + .loc[pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + pd.Timedelta(days=1)] + .index + ) + df.sort_index(inplace=True) + # df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), list(set(df.columns) - {symbol_field_name})] = np.nan + df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), list(set(df.columns) - {symbol_field_name})] = np.nan + + change_series = BaostockNormalizeHS3005min.calc_change(df, last_close) + # NOTE: The data obtained by Yahoo finance sometimes has exceptions + # WARNING: If it is normal for a `symbol(exchange)` to differ by a factor of *89* to *111* for consecutive trading days, + # WARNING: the logic in the following line needs to be modified + _count = 0 + while True: + # NOTE: may appear unusual for many days in a row + change_series = BaostockNormalizeHS3005min.calc_change(df, last_close) + _mask = (change_series >= 89) & (change_series <= 111) + if not _mask.any(): + break + _tmp_cols = ["high", "close", "low", "open"] + df.loc[_mask, _tmp_cols] = df.loc[_mask, _tmp_cols] / 100 + _count += 1 + if _count >= 10: + _symbol = df.loc[df[symbol_field_name].first_valid_index()]["symbol"] + logger.warning( + f"{_symbol} `change` is abnormal for {_count} consecutive days, please check the specific data file carefully" + ) + + df["change"] = BaostockNormalizeHS3005min.calc_change(df, last_close) + + columns += ["change"] + df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), columns] = np.nan + + df[symbol_field_name] = symbol + df.index.names = [date_field_name] + return df.reset_index() + + def generate_5min_from_daily(self, calendars: Iterable) -> pd.Index: + return generate_minutes_calendar_from_daily( + calendars, freq="5min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE + ) + + def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame: + """get 1d data + + Returns + ------ + data_1d: pd.DataFrame + data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"] + + """ + return self._all_1d_data[ + (self._all_1d_data[self._symbol_field_name] == symbol.upper()) + & (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start)) + & (self._all_1d_data[self._date_field_name] < pd.Timestamp(end)) + ] + + def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame: + # TODO: using daily data factor + if df.empty: + return df + df = df.copy() + df = df.sort_values(self._date_field_name) + symbol = df.iloc[0][self._symbol_field_name] + # get 1d data from baostock + _start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT) + _end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT) + data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end) + data_1d = data_1d.copy() + if data_1d is None or data_1d.empty: + df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"] + # TODO: np.nan or 1 or 0 + df["paused"] = np.nan + else: + # NOTE: volume is np.nan or volume <= 0, paused = 1 + # FIXME: find a more accurate data source + data_1d["paused"] = 0 + data_1d.loc[(data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused"] = 1 + data_1d = data_1d.set_index(self._date_field_name) + + # add factor from 1d data + # NOTE: yahoo 1d data info: + # - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits. + # - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits. + # - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)` + def _calc_factor(df_1d: pd.DataFrame): + try: + _date = pd.Timestamp(pd.Timestamp(df_1d[self._date_field_name].iloc[0]).date()) + df_1d["factor"] = ( + data_1d.loc[_date]["close"] / df_1d.loc[df_1d["close"].last_valid_index()]["close"] + ) + df_1d["paused"] = data_1d.loc[_date]["paused"] + except Exception: + df_1d["factor"] = np.nan + df_1d["paused"] = np.nan + return df_1d + + df = df.groupby([df[self._date_field_name].dt.date]).apply(_calc_factor) + + if self.CONSISTENT_1d: + # the date sequence is consistent with 1d + df.set_index(self._date_field_name, inplace=True) + df = df.reindex( + self.generate_5min_from_daily( + pd.to_datetime(data_1d.reset_index()[self._date_field_name].drop_duplicates()) + ) + ) + df[self._symbol_field_name] = df.loc[df[self._symbol_field_name].first_valid_index()][ + self._symbol_field_name + ] + df.index.names = [self._date_field_name] + df.reset_index(inplace=True) + for _col in self.COLUMNS: + if _col not in df.columns: + continue + if _col == "volume": + df[_col] = df[_col] / df["factor"] + else: + df[_col] = df[_col] * df["factor"] + + if self.CALC_PAUSED_NUM: + df = self.calc_paused_num(df) + return df + + def calc_paused_num(self, df: pd.DataFrame): + _symbol = df.iloc[0][self._symbol_field_name] + df = df.copy() + df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) + # remove data that starts and ends with `np.nan` all day + all_data = [] + # Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan + all_nan_nums = 0 + # Record the number of consecutive occurrences of trading days that are not nan throughout the day + not_nan_nums = 0 + for _date, _df in df.groupby("_tmp_date"): + _df["paused"] = 0 + if not _df.loc[_df["volume"] < 0].empty: + logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}") + _df.loc[_df["volume"] < 0, "volume"] = np.nan + + check_fields = set(_df.columns) - { + "_tmp_date", + "paused", + "factor", + self._date_field_name, + self._symbol_field_name, + } + if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all(): + all_nan_nums += 1 + not_nan_nums = 0 + _df["paused"] = 1 + if all_data: + _df["paused_num"] = not_nan_nums + all_data.append(_df) + else: + all_nan_nums = 0 + not_nan_nums += 1 + _df["paused_num"] = not_nan_nums + all_data.append(_df) + all_data = all_data[: len(all_data) - all_nan_nums] + if all_data: + df = pd.concat(all_data, sort=False) + else: + logger.warning(f"data is empty: {_symbol}") + df = pd.DataFrame() + return df + del df["_tmp_date"] + return df + + def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: + return list(D.calendar(freq="day")) + + def normalize(self, df: pd.DataFrame) -> pd.DataFrame: + # normalize + df = self.normalize_baostock(df, self._calendar_list, self._date_field_name, self._symbol_field_name) + # adjusted price + df = self.adjusted_price(df) + return df + + +class Run(BaseRun): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN): + """ + + Parameters + ---------- + source_dir: str + The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" + normalize_dir: str + Directory for normalize data, default "Path(__file__).parent/normalize" + max_workers: int + Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1 + interval: str + freq, value from [1min, 1d], default 1d + region: str + region, value from ["CN", "US", "BR"], default "CN" + """ + super().__init__(source_dir, normalize_dir, max_workers, interval) + self.region = region + + @property + def collector_class_name(self): + return f"BaostockCollector{self.region.upper()}{self.interval}" + + @property + def normalize_class_name(self): + return f"BaostockNormalize{self.region.upper()}{self.interval}" + + @property + def default_base_dir(self) -> [Path, str]: + return CUR_DIR + + def download_data( + self, + max_collector_count=2, + delay=0.5, + start=None, + end=None, + check_data_length=None, + limit_nums=None, + ): + """download data from Internet + + Parameters + ---------- + max_collector_count: int + default 2 + delay: float + time.sleep(delay), default 0.5 + start: str + start datetime, default "2000-01-01"; closed interval(including start) + end: str + end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``; open interval(excluding end) + check_data_length: int + check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None. + limit_nums: int + using for debug, by default None + + Notes + ----- + check_data_length, example: + daily, one year: 252 // 4 + us 1min, a week: 6.5 * 60 * 5 + cn 1min, a week: 4 * 60 * 5 + + Examples + --------- + # get daily data + $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d + # get 1m data + $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m + """ + super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums) + + def normalize_data( + self, + date_field_name: str = "date", + symbol_field_name: str = "symbol", + end_date: str = None, + qlib_data_1d_dir: str = None, + ): + """normalize data + + Parameters + ---------- + date_field_name: str + date field name, default date + symbol_field_name: str + symbol field name, default symbol + end_date: str + if not None, normalize the last date saved (including end_date); if None, it will ignore this parameter; by default None + qlib_data_1d_dir: str + if interval==1min, qlib_data_1d_dir cannot be None, normalize 1min needs to use 1d data; + + qlib_data_1d can be obtained like this: + $ python scripts/get_data.py qlib_data --target_dir --interval 1d + $ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir --trading_date 2021-06-01 + or: + download 1d data, reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#1d-from-yahoo + + Examples + --------- + $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d + $ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min + """ + if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists(): + raise ValueError( + "If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir , Reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance" + ) + super(Run, self).normalize_data( + date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir + ) + + +if __name__ == "__main__": + fire.Fire(Run) diff --git a/scripts/data_collector/baostock_5min/requirements.txt b/scripts/data_collector/baostock_5min/requirements.txt new file mode 100644 index 0000000000..97802ced86 --- /dev/null +++ b/scripts/data_collector/baostock_5min/requirements.txt @@ -0,0 +1,13 @@ +loguru +fire +requests +numpy +pandas +tqdm +lxml +yahooquery +joblib +beautifulsoup4 +bs4 +soupsieve +baostock \ No newline at end of file From b17afdfda4ed692ea690234533b7476069ec7cc4 Mon Sep 17 00:00:00 2001 From: Linlang Date: Wed, 6 Sep 2023 20:35:25 +0800 Subject: [PATCH 02/25] modify_comments --- .../data_collector/baostock_5min/README.md | 2 +- .../data_collector/baostock_5min/collector.py | 33 ++++++++----------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/scripts/data_collector/baostock_5min/README.md b/scripts/data_collector/baostock_5min/README.md index a2739b9413..cf6b7789c9 100644 --- a/scripts/data_collector/baostock_5min/README.md +++ b/scripts/data_collector/baostock_5min/README.md @@ -32,7 +32,7 @@ - examples: ```bash # cn 5min data - python collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_data_5min --start 2022-01-01 --end 2022-01-30 --interval 5min --region HS300 + python collector.py download_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --start 2022-01-01 --end 2022-01-30 --interval 5min --region HS300 ``` 2. normalize data: `python scripts/data_collector/baostock_5min/collector.py normalize_data` diff --git a/scripts/data_collector/baostock_5min/collector.py b/scripts/data_collector/baostock_5min/collector.py index d443484603..60a536f5db 100644 --- a/scripts/data_collector/baostock_5min/collector.py +++ b/scripts/data_collector/baostock_5min/collector.py @@ -58,7 +58,7 @@ def __init__( delay: float time.sleep(delay), default 0 interval: str - freq, value from [1min, 1d], default 1min + freq, value from [5min], default 5min start: str start datetime, default None end: str @@ -69,6 +69,7 @@ def __init__( using for debug, by default None """ bs.login() + interval="5min" super(BaostockCollectorHS3005min, self).__init__( save_dir=save_dir, start=start, @@ -160,7 +161,7 @@ class BaostockNormalizeHS3005min(BaseNormalize): DAILY_FORMAT = "%Y-%m-%d" AM_RANGE = ("09:30:00", "11:29:00") PM_RANGE = ("13:00:00", "14:59:00") - # Whether the trading day of 1min data is consistent with 1d + # Whether the trading day of 5min data is consistent with 1d CONSISTENT_1d = True CALC_PAUSED_NUM = True @@ -172,7 +173,7 @@ def __init__( Parameters ---------- qlib_data_1d_dir: str, Path - the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data + the qlib data to be updated for yahoo, usually from: Normalised to 5min using local 1d data date_field_name: str date field name, default is date symbol_field_name: str @@ -412,7 +413,7 @@ def normalize(self, df: pd.DataFrame) -> pd.DataFrame: class Run(BaseRun): - def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="5min", region="HS300"): """ Parameters @@ -424,9 +425,9 @@ def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval= max_workers: int Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1 interval: str - freq, value from [1min, 1d], default 1d + freq, value from [5min, default 5min region: str - region, value from ["CN", "US", "BR"], default "CN" + region, value from ["HS300"], default "HS300" """ super().__init__(source_dir, normalize_dir, max_workers, interval) self.region = region @@ -472,16 +473,12 @@ def download_data( Notes ----- check_data_length, example: - daily, one year: 252 // 4 - us 1min, a week: 6.5 * 60 * 5 - cn 1min, a week: 4 * 60 * 5 + hs300 5min, a week: 4 * 60 * 5 Examples --------- - # get daily data - $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d - # get 1m data - $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m + # get hs300 5min data + $ python collector.py download_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --start 2022-01-01 --end 2022-01-30 --interval 5min --region HS300 """ super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums) @@ -503,22 +500,20 @@ def normalize_data( end_date: str if not None, normalize the last date saved (including end_date); if None, it will ignore this parameter; by default None qlib_data_1d_dir: str - if interval==1min, qlib_data_1d_dir cannot be None, normalize 1min needs to use 1d data; + if interval==5min, qlib_data_1d_dir cannot be None, normalize 5min needs to use 1d data; qlib_data_1d can be obtained like this: - $ python scripts/get_data.py qlib_data --target_dir --interval 1d - $ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir --trading_date 2021-06-01 + $ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn --version v3 or: download 1d data, reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#1d-from-yahoo Examples --------- - $ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region cn --interval 1d - $ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min + $ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --normalize_dir ~/.qlib/stock_data/source/hs300_5min_nor --region HS300 --interval 5min """ if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists(): raise ValueError( - "If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir , Reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance" + "If normalize 5min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir , Reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance" ) super(Run, self).normalize_data( date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir From b02d7893d86248ac2d59414b7357238b01e58725 Mon Sep 17 00:00:00 2001 From: Linlang Date: Thu, 7 Sep 2023 13:52:02 +0800 Subject: [PATCH 03/25] fix_pylint_error --- .github/workflows/test_qlib_from_source.yml | 1 + .../data_collector/baostock_5min/collector.py | 25 ++++------ scripts/data_collector/base.py | 4 +- scripts/data_collector/br_index/collector.py | 3 +- scripts/data_collector/cn_index/collector.py | 3 +- scripts/data_collector/crypto/collector.py | 13 +++--- scripts/data_collector/fund/collector.py | 10 ++-- .../future_calendar_collector.py | 2 +- scripts/data_collector/us_index/collector.py | 5 +- scripts/data_collector/utils.py | 46 ++++++++++--------- scripts/data_collector/yahoo/collector.py | 22 +++------ scripts/dump_bin.py | 2 +- scripts/dump_pit.py | 9 ++-- 13 files changed, 61 insertions(+), 84 deletions(-) diff --git a/.github/workflows/test_qlib_from_source.yml b/.github/workflows/test_qlib_from_source.yml index acf37208fd..3d72961a91 100644 --- a/.github/workflows/test_qlib_from_source.yml +++ b/.github/workflows/test_qlib_from_source.yml @@ -102,6 +102,7 @@ jobs: - name: Check Qlib with pylint run: | pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" + pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' scripts --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" # The following flake8 error codes were ignored: # E501 line too long diff --git a/scripts/data_collector/baostock_5min/collector.py b/scripts/data_collector/baostock_5min/collector.py index 60a536f5db..d8f7c3ad4e 100644 --- a/scripts/data_collector/baostock_5min/collector.py +++ b/scripts/data_collector/baostock_5min/collector.py @@ -1,34 +1,25 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import os -import abc -from re import I -from typing import List -from tqdm import tqdm + import sys import copy -import time -import datetime -import baostock as bs -from abc import ABC -from pathlib import Path -from typing import Iterable - import fire import numpy as np import pandas as pd +import baostock as bs +from tqdm import tqdm +from pathlib import Path from loguru import logger +from typing import Iterable, List import qlib from qlib.data import D -from qlib.constant import REG_CN as REGION_CN CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) -from data_collector.base import BaseCollector, BaseNormalize, BaseRun, Normalize - +from data_collector.base import BaseCollector, BaseNormalize, BaseRun from data_collector.utils import generate_minutes_calendar_from_daily @@ -69,7 +60,7 @@ def __init__( using for debug, by default None """ bs.login() - interval="5min" + interval = "5min" super(BaostockCollectorHS3005min, self).__init__( save_dir=save_dir, start=start, @@ -144,7 +135,7 @@ def get_hs300_symbols(self) -> List[str]: while rs.error_code == "0" and rs.next(): hs300_stocks.append(rs.get_row_data()) p_bar.update() - return sorted(set([e[1] for e in hs300_stocks])) + return sorted({e[1] for e in hs300_stocks}) def get_instrument_list(self): logger.info("get HS stock symbols......") diff --git a/scripts/data_collector/base.py b/scripts/data_collector/base.py index 386bb1b2c0..2517e9bce8 100644 --- a/scripts/data_collector/base.py +++ b/scripts/data_collector/base.py @@ -8,7 +8,7 @@ import importlib from pathlib import Path from typing import Type, Iterable -from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor +from concurrent.futures import ProcessPoolExecutor import pandas as pd from tqdm import tqdm @@ -290,7 +290,7 @@ def _executor(self, file_path: Path): # some symbol_field values such as TRUE, NA are decoded as True(bool), NaN(np.float) by pandas default csv parsing. # manually defines dtype and na_values of the symbol_field. - default_na = pd._libs.parsers.STR_NA_VALUES + default_na = pd._libs.parsers.STR_NA_VALUES # pylint: disable=I1101 symbol_na = default_na.copy() symbol_na.remove("NA") columns = pd.read_csv(file_path, nrows=0).columns diff --git a/scripts/data_collector/br_index/collector.py b/scripts/data_collector/br_index/collector.py index 7d32170f06..04b2f96d9f 100644 --- a/scripts/data_collector/br_index/collector.py +++ b/scripts/data_collector/br_index/collector.py @@ -3,7 +3,6 @@ from functools import partial import sys from pathlib import Path -import importlib import datetime import fire @@ -98,7 +97,7 @@ def get_four_month_period(self): now = datetime.datetime.now() current_year = now.year current_month = now.month - for year in [item for item in range(init_year, current_year)]: + for year in [item for item in range(init_year, current_year)]: # pylint: disable=R1721 for el in four_months_period: self.years_4_month_periods.append(str(year) + "_" + el) # For current year the logic must be a little different diff --git a/scripts/data_collector/cn_index/collector.py b/scripts/data_collector/cn_index/collector.py index 40fbe4c9a4..96f68ef9cd 100644 --- a/scripts/data_collector/cn_index/collector.py +++ b/scripts/data_collector/cn_index/collector.py @@ -4,7 +4,6 @@ import re import abc import sys -import datetime from io import BytesIO from typing import List, Iterable from pathlib import Path @@ -39,7 +38,7 @@ def retry_request(url: str, method: str = "get", exclude_status: List = None): if exclude_status is None: exclude_status = [] method_func = getattr(requests, method) - _resp = method_func(url, headers=REQ_HEADERS) + _resp = method_func(url, headers=REQ_HEADERS, timeout=None) _status = _resp.status_code if _status not in exclude_status and _status != 200: raise ValueError(f"response status: {_status}, url={url}") diff --git a/scripts/data_collector/crypto/collector.py b/scripts/data_collector/crypto/collector.py index d1f7c16d9e..283517da9c 100644 --- a/scripts/data_collector/crypto/collector.py +++ b/scripts/data_collector/crypto/collector.py @@ -5,7 +5,6 @@ from pathlib import Path import fire -import requests import pandas as pd from loguru import logger from dateutil.tz import tzlocal @@ -31,15 +30,15 @@ def get_cg_crypto_symbols(qlib_data_path: [str, Path] = None) -> list: ------- crypto symbols in given exchanges list of coingecko """ - global _CG_CRYPTO_SYMBOLS + global _CG_CRYPTO_SYMBOLS # pylint: disable=W0603 @deco_retry def _get_coingecko(): try: cg = CoinGeckoAPI() resp = pd.DataFrame(cg.get_coins_markets(vs_currency="usd")) - except: - raise ValueError("request error") + except Exception as e: + raise ValueError("request error") from e try: _symbols = resp["id"].to_list() except Exception as e: @@ -226,7 +225,7 @@ def _get_calendar_list(self): class Run(BaseRun): - def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"): + def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"): # pylint: disable=W0246 """ Parameters @@ -254,7 +253,7 @@ def normalize_class_name(self): def default_base_dir(self) -> [Path, str]: return CUR_DIR - def download_data( + def download_data( # pylint: disable=W0246 self, max_collector_count=2, delay=0, @@ -290,7 +289,7 @@ def download_data( super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums) - def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): + def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): # pylint: disable=W0246 """normalize data Parameters diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index 9d6c82c815..de375bf07e 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -107,7 +107,7 @@ def get_data_from_remote(symbol, interval, start, end): url = INDEX_BENCH_URL.format( index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end ) - resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}) + resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}, timeout=None) if resp.status_code != 200: raise ValueError("request error") @@ -116,8 +116,8 @@ def get_data_from_remote(symbol, interval, start, end): # Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html SYType = data["Data"]["SYType"] - if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"): - raise Exception("The fund contains 每*份收益") + if SYType in {"每万份收益", "每百份收益", "每百万份收益"}: + raise ValueError("The fund contains 每*份收益") # TODO: should we sort the value by datetime? _resp = pd.DataFrame(data["Data"]["LSJZList"]) @@ -247,7 +247,7 @@ def normalize_class_name(self): def default_base_dir(self) -> [Path, str]: return CUR_DIR - def download_data( + def download_data( # pylint: disable=W0246 self, max_collector_count=2, delay=0, @@ -283,7 +283,7 @@ def download_data( super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums) - def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): + def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): # pylint: disable=W0246 """normalize data Parameters diff --git a/scripts/data_collector/future_calendar_collector.py b/scripts/data_collector/future_calendar_collector.py index e5b1189268..4dfd24e4b7 100644 --- a/scripts/data_collector/future_calendar_collector.py +++ b/scripts/data_collector/future_calendar_collector.py @@ -53,7 +53,7 @@ def _format_datetime(self, datetime_d: [str, pd.Timestamp]): return datetime_d.strftime(self.calendar_format) def write_calendar(self, calendar: Iterable): - calendars_list = list(map(lambda x: self._format_datetime(x), sorted(set(self.calendar_list + calendar)))) + calendars_list = [self._format_datetime(x) for x in sorted(set(self.calendar_list + calendar))] np.savetxt(self.future_path, calendars_list, fmt="%s", encoding="utf-8") @abc.abstractmethod diff --git a/scripts/data_collector/us_index/collector.py b/scripts/data_collector/us_index/collector.py index cb0c3fc955..50278d11ee 100644 --- a/scripts/data_collector/us_index/collector.py +++ b/scripts/data_collector/us_index/collector.py @@ -4,7 +4,6 @@ import abc from functools import partial import sys -import importlib from pathlib import Path from concurrent.futures import ThreadPoolExecutor from typing import List @@ -113,7 +112,7 @@ def calendar_list(self) -> List[pd.Timestamp]: return _calendar_list def _request_new_companies(self) -> requests.Response: - resp = requests.get(self._target_url) + resp = requests.get(self._target_url, timeout=None) if resp.status_code != 200: raise ValueError(f"request error: {self._target_url}") @@ -164,7 +163,7 @@ def _request_history_companies(self, trade_date: pd.Timestamp, use_cache: bool = df = pd.read_pickle(cache_path) else: url = self.HISTORY_COMPANIES_URL.format(trade_date=trade_date) - resp = requests.post(url) + resp = requests.post(url, timeout=None) if resp.status_code != 200: raise ValueError(f"request error: {url}") df = pd.DataFrame(resp.json()["aaData"]) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 74ecb541ea..0dbe432422 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -68,7 +68,7 @@ def get_calendar_list(bench_code="CSI300") -> List[pd.Timestamp]: logger.info(f"get calendar list: {bench_code}......") def _get_calendar(url): - _value_list = requests.get(url).json()["data"]["klines"] + _value_list = requests.get(url, timeout=None).json()["data"]["klines"] return sorted(map(lambda x: pd.Timestamp(x.split(",")[0]), _value_list)) calendar = _CALENDAR_MAP.get(bench_code, None) @@ -85,12 +85,14 @@ def _get_calendar(url): def _get_calendar(month): _cal = [] try: - resp = requests.get(SZSE_CALENDAR_URL.format(month=month, random=random.random)).json() + resp = requests.get( + SZSE_CALENDAR_URL.format(month=month, random=random.random), timeout=None + ).json() for _r in resp["data"]: if int(_r["jybz"]): _cal.append(pd.Timestamp(_r["jyrq"])) except Exception as e: - raise ValueError(f"{month}-->{e}") + raise ValueError(f"{month}-->{e}") from e return _cal month_range = pd.date_range(start="2000-01", end=pd.Timestamp.now() + pd.Timedelta(days=31), freq="M") @@ -109,7 +111,7 @@ def _get_calendar(month): def return_date_list(date_field_name: str, file_path: Path): date_list = pd.read_csv(file_path, sep=",", index_col=0)[date_field_name].to_list() - return sorted(map(lambda x: pd.Timestamp(x), date_list)) + return sorted([pd.Timestamp(x) for x in date_list]) def get_calendar_list_by_ratio( @@ -155,7 +157,7 @@ def get_calendar_list_by_ratio( if date_list: all_oldest_list.append(date_list[0]) for date in date_list: - if date not in _dict_count_trade.keys(): + if date not in _dict_count_trade: _dict_count_trade[date] = 0 _dict_count_trade[date] += 1 @@ -163,7 +165,7 @@ def get_calendar_list_by_ratio( p_bar.update() logger.info(f"count how many funds have founded in this day......") - _dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count} + _dict_count_founding = {date: _number_all_funds for date in _dict_count_trade} # dict{date:count} with tqdm(total=_number_all_funds) as p_bar: for oldest_date in all_oldest_list: for date in _dict_count_founding.keys(): @@ -171,9 +173,7 @@ def get_calendar_list_by_ratio( _dict_count_founding[date] -= 1 calendar = [ - date - for date in _dict_count_trade - if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count) + date for date, count in _dict_count_trade.items() if count >= max(int(count * threshold), minimum_count) ] return calendar @@ -186,16 +186,18 @@ def get_hs_stock_symbols() -> list: ------- stock symbols """ - global _HS_SYMBOLS + global _HS_SYMBOLS # pylint: disable=W0603 def _get_symbol(): _res = set() for _k, _v in (("ha", "ss"), ("sa", "sz"), ("gem", "sz")): - resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k)) + resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k), timeout=None) _res |= set( map( - lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v), - etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"), + [ + "{}.{}".format(re.findall(r"\d+", x)[0], _v) + for x in etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()") + ] ) ) time.sleep(3) @@ -230,12 +232,12 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list: ------- stock symbols """ - global _US_SYMBOLS + global _US_SYMBOLS # pylint: disable=W0603 @deco_retry def _get_eastmoney(): url = "http://4.push2.eastmoney.com/api/qt/clist/get?pn=1&pz=10000&fs=m:105,m:106,m:107&fields=f12" - resp = requests.get(url) + resp = requests.get(url, timeout=None) if resp.status_code != 200: raise ValueError("request error") @@ -277,7 +279,7 @@ def _get_nyse(): "maxResultsPerPage": 10000, "filterToken": "", } - resp = requests.post(url, json=_parms) + resp = requests.post(url, json=_parms, timeout=None) if resp.status_code != 200: raise ValueError("request error") @@ -317,7 +319,7 @@ def get_in_stock_symbols(qlib_data_path: [str, Path] = None) -> list: ------- stock symbols """ - global _IN_SYMBOLS + global _IN_SYMBOLS # pylint: disable=W0603 @deco_retry def _get_nifty(): @@ -358,7 +360,7 @@ def get_br_stock_symbols(qlib_data_path: [str, Path] = None) -> list: ------- B3 stock symbols """ - global _BR_SYMBOLS + global _BR_SYMBOLS # pylint: disable=W0603 @deco_retry def _get_ibovespa(): @@ -367,7 +369,7 @@ def _get_ibovespa(): # Request agent = {"User-Agent": "Mozilla/5.0"} - page = requests.get(url, headers=agent) + page = requests.get(url, headers=agent, timeout=None) # BeautifulSoup soup = BeautifulSoup(page.content, "html.parser") @@ -375,7 +377,7 @@ def _get_ibovespa(): children = tbody.findChildren("a", recursive=True) for child in children: - _symbols.append(str(child).split('"')[-1].split(">")[1].split("<")[0]) + _symbols.append(str(child).rsplit('"', maxsplit=1)[-1].split(">")[1].split("<")[0]) return _symbols @@ -409,12 +411,12 @@ def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list: ------- fund symbols in China """ - global _EN_FUND_SYMBOLS + global _EN_FUND_SYMBOLS # pylint: disable=W0603 @deco_retry def _get_eastmoney(): url = "http://fund.eastmoney.com/js/fundcode_search.js" - resp = requests.get(url) + resp = requests.get(url, timeout=None) if resp.status_code != 200: raise ValueError("request error") try: diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 143fa12ac2..20a791b35d 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. import abc -from re import I import sys import copy import time @@ -21,6 +20,8 @@ from yahooquery import Ticker from dateutil.tz import tzlocal +import qlib +from qlib.data import D from qlib.tests.data import GetData from qlib.utils import code_to_fname, fname_to_code, exists_qlib_data from qlib.constant import REG_CN as REGION_CN @@ -229,9 +230,9 @@ def download_index_data(self): df = pd.DataFrame( map( lambda x: x.split(","), - requests.get(INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end)).json()[ - "data" - ]["klines"], + requests.get( + INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end), timeout=None + ).json()["data"]["klines"], ) ) except Exception as e: @@ -316,7 +317,7 @@ class YahooCollectorIN1min(YahooCollectorIN): class YahooCollectorBR(YahooCollector, ABC): - def retry(cls): + def retry(cls): # pylint: disable=E0213 """ The reason to use retry=2 is due to the fact that Yahoo Finance unfortunately does not keep track of some @@ -356,12 +357,10 @@ def _timezone(self): class YahooCollectorBR1d(YahooCollectorBR): retry = 2 - pass class YahooCollectorBR1min(YahooCollectorBR): retry = 2 - pass class YahooNormalize(BaseNormalize): @@ -527,9 +526,6 @@ def __init__( self.old_qlib_data = self._get_old_data(old_qlib_data_dir) def _get_old_data(self, qlib_data_dir: [str, Path]): - import qlib - from qlib.data import D - qlib_data_dir = str(Path(qlib_data_dir).expanduser().resolve()) qlib.init(provider_uri=qlib_data_dir, expression_cache=None, dataset_cache=None) df = D.features(D.instruments("all"), ["$close/$factor", "$adjclose/$close"]) @@ -774,16 +770,10 @@ def __init__( self._all_1d_data = self._get_all_1d_data() def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: - import qlib - from qlib.data import D - qlib.init(provider_uri=self.qlib_data_1d_dir) return list(D.calendar(freq="day")) def _get_all_1d_data(self): - import qlib - from qlib.data import D - qlib.init(provider_uri=self.qlib_data_1d_dir) df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") df.reset_index(inplace=True) diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py index 269366f75e..317546e982 100644 --- a/scripts/dump_bin.py +++ b/scripts/dump_bin.py @@ -176,7 +176,7 @@ def _read_instruments(self, instrument_path: Path) -> pd.DataFrame: def save_calendars(self, calendars_data: list): self._calendars_dir.mkdir(parents=True, exist_ok=True) calendars_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve()) - result_calendars_list = list(map(lambda x: self._format_datetime(x), calendars_data)) + result_calendars_list = [self._format_datetime(x) for x in calendars_data] np.savetxt(calendars_path, result_calendars_list, fmt="%s", encoding="utf-8") def save_instruments(self, instruments_data: Union[list, pd.DataFrame]): diff --git a/scripts/dump_pit.py b/scripts/dump_pit.py index c328eb67a8..34d304ed78 100644 --- a/scripts/dump_pit.py +++ b/scripts/dump_pit.py @@ -6,21 +6,18 @@ - seperated insert, delete, update, query operations are required. """ -import abc import shutil import struct -import traceback from pathlib import Path -from typing import Iterable, List, Union +from typing import Iterable from functools import partial -from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor +from concurrent.futures import ProcessPoolExecutor import fire -import numpy as np import pandas as pd from tqdm import tqdm from loguru import logger -from qlib.utils import fname_to_code, code_to_fname, get_period_offset +from qlib.utils import fname_to_code, get_period_offset from qlib.config import C From 4574d055d14f91f3ccb0e23cb964de23e2c37575 Mon Sep 17 00:00:00 2001 From: Linlang Date: Wed, 13 Sep 2023 14:50:30 +0800 Subject: [PATCH 04/25] solve_duplication_methods --- .../data_collector/baostock_5min/collector.py | 204 ++--------------- scripts/data_collector/utils.py | 156 ++++++++++++- scripts/data_collector/yahoo/collector.py | 216 +++--------------- 3 files changed, 208 insertions(+), 368 deletions(-) diff --git a/scripts/data_collector/baostock_5min/collector.py b/scripts/data_collector/baostock_5min/collector.py index d8f7c3ad4e..2d966cffb0 100644 --- a/scripts/data_collector/baostock_5min/collector.py +++ b/scripts/data_collector/baostock_5min/collector.py @@ -20,7 +20,7 @@ sys.path.append(str(CUR_DIR.parent.parent)) from data_collector.base import BaseCollector, BaseNormalize, BaseRun -from data_collector.utils import generate_minutes_calendar_from_daily +from data_collector.utils import generate_minutes_calendar_from_daily, calc_adjusted_price class BaostockCollectorHS3005min(BaseCollector): @@ -83,7 +83,6 @@ def get_trade_calendar(self): calendar_list.append(rs.get_row_data()) calendar_df = pd.DataFrame(calendar_list, columns=rs.fields) trade_calendar_df = calendar_df[~calendar_df["is_trading_day"].isin(["0"])] - # bs.logout() return trade_calendar_df["calendar_date"].values @staticmethod @@ -149,12 +148,8 @@ def normalize_symbol(self, symbol: str): class BaostockNormalizeHS3005min(BaseNormalize): COLUMNS = ["open", "close", "high", "low", "volume"] - DAILY_FORMAT = "%Y-%m-%d" AM_RANGE = ("09:30:00", "11:29:00") PM_RANGE = ("13:00:00", "14:59:00") - # Whether the trading day of 5min data is consistent with 1d - CONSISTENT_1d = True - CALC_PAUSED_NUM = True def __init__( self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs @@ -172,9 +167,8 @@ def __init__( """ bs.login() qlib.init(provider_uri=qlib_data_1d_dir) - # self.qlib_data_1d_dir = qlib_data_1d_dir + self.qlib_data_1d_dir = qlib_data_1d_dir super(BaostockNormalizeHS3005min, self).__init__(date_field_name, symbol_field_name) - self._all_1d_data = self._get_all_1d_data() @staticmethod def calc_change(df: pd.DataFrame, last_close: float) -> pd.Series: @@ -187,16 +181,8 @@ def calc_change(df: pd.DataFrame, last_close: float) -> pd.Series: return change_series def _get_calendar_list(self) -> Iterable[pd.Timestamp]: - # return list(D.calendar(freq="day")) return self.generate_5min_from_daily(self.calendar_list_1d) - def _get_all_1d_data(self): - df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") - df.reset_index(inplace=True) - df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True) - df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns)) - return df - @property def calendar_list_1d(self): calendar_list_1d = getattr(self, "_calendar_list_1d", None) @@ -228,7 +214,6 @@ def normalize_baostock( .index ) df.sort_index(inplace=True) - # df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), list(set(df.columns) - {symbol_field_name})] = np.nan df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), list(set(df.columns) - {symbol_field_name})] = np.nan change_series = BaostockNormalizeHS3005min.calc_change(df, last_close) @@ -265,131 +250,14 @@ def generate_5min_from_daily(self, calendars: Iterable) -> pd.Index: calendars, freq="5min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE ) - def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame: - """get 1d data - - Returns - ------ - data_1d: pd.DataFrame - data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"] - - """ - return self._all_1d_data[ - (self._all_1d_data[self._symbol_field_name] == symbol.upper()) - & (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start)) - & (self._all_1d_data[self._date_field_name] < pd.Timestamp(end)) - ] - def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame: - # TODO: using daily data factor - if df.empty: - return df - df = df.copy() - df = df.sort_values(self._date_field_name) - symbol = df.iloc[0][self._symbol_field_name] - # get 1d data from baostock - _start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT) - _end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT) - data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end) - data_1d = data_1d.copy() - if data_1d is None or data_1d.empty: - df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"] - # TODO: np.nan or 1 or 0 - df["paused"] = np.nan - else: - # NOTE: volume is np.nan or volume <= 0, paused = 1 - # FIXME: find a more accurate data source - data_1d["paused"] = 0 - data_1d.loc[(data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused"] = 1 - data_1d = data_1d.set_index(self._date_field_name) - - # add factor from 1d data - # NOTE: yahoo 1d data info: - # - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits. - # - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits. - # - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)` - def _calc_factor(df_1d: pd.DataFrame): - try: - _date = pd.Timestamp(pd.Timestamp(df_1d[self._date_field_name].iloc[0]).date()) - df_1d["factor"] = ( - data_1d.loc[_date]["close"] / df_1d.loc[df_1d["close"].last_valid_index()]["close"] - ) - df_1d["paused"] = data_1d.loc[_date]["paused"] - except Exception: - df_1d["factor"] = np.nan - df_1d["paused"] = np.nan - return df_1d - - df = df.groupby([df[self._date_field_name].dt.date]).apply(_calc_factor) - - if self.CONSISTENT_1d: - # the date sequence is consistent with 1d - df.set_index(self._date_field_name, inplace=True) - df = df.reindex( - self.generate_5min_from_daily( - pd.to_datetime(data_1d.reset_index()[self._date_field_name].drop_duplicates()) - ) - ) - df[self._symbol_field_name] = df.loc[df[self._symbol_field_name].first_valid_index()][ - self._symbol_field_name - ] - df.index.names = [self._date_field_name] - df.reset_index(inplace=True) - for _col in self.COLUMNS: - if _col not in df.columns: - continue - if _col == "volume": - df[_col] = df[_col] / df["factor"] - else: - df[_col] = df[_col] * df["factor"] - - if self.CALC_PAUSED_NUM: - df = self.calc_paused_num(df) - return df - - def calc_paused_num(self, df: pd.DataFrame): - _symbol = df.iloc[0][self._symbol_field_name] - df = df.copy() - df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) - # remove data that starts and ends with `np.nan` all day - all_data = [] - # Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan - all_nan_nums = 0 - # Record the number of consecutive occurrences of trading days that are not nan throughout the day - not_nan_nums = 0 - for _date, _df in df.groupby("_tmp_date"): - _df["paused"] = 0 - if not _df.loc[_df["volume"] < 0].empty: - logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}") - _df.loc[_df["volume"] < 0, "volume"] = np.nan - - check_fields = set(_df.columns) - { - "_tmp_date", - "paused", - "factor", - self._date_field_name, - self._symbol_field_name, - } - if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all(): - all_nan_nums += 1 - not_nan_nums = 0 - _df["paused"] = 1 - if all_data: - _df["paused_num"] = not_nan_nums - all_data.append(_df) - else: - all_nan_nums = 0 - not_nan_nums += 1 - _df["paused_num"] = not_nan_nums - all_data.append(_df) - all_data = all_data[: len(all_data) - all_nan_nums] - if all_data: - df = pd.concat(all_data, sort=False) - else: - logger.warning(f"data is empty: {_symbol}") - df = pd.DataFrame() - return df - del df["_tmp_date"] + df = calc_adjusted_price( + df=df, + qlib_data_1d_dir=self.qlib_data_1d_dir, + _date_field_name=self._date_field_name, + _symbol_field_name=self._symbol_field_name, + frequence="5min", + ) return df def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: @@ -406,19 +274,7 @@ def normalize(self, df: pd.DataFrame) -> pd.DataFrame: class Run(BaseRun): def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="5min", region="HS300"): """ - - Parameters - ---------- - source_dir: str - The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" - normalize_dir: str - Directory for normalize data, default "Path(__file__).parent/normalize" - max_workers: int - Concurrent number, default is 1; when collecting data, it is recommended that max_workers be set to 1 - interval: str - freq, value from [5min, default 5min - region: str - region, value from ["HS300"], default "HS300" + Changed the default value of: scripts.data_collector.base.BaseRun. """ super().__init__(source_dir, normalize_dir, max_workers, interval) self.region = region @@ -444,22 +300,7 @@ def download_data( check_data_length=None, limit_nums=None, ): - """download data from Internet - - Parameters - ---------- - max_collector_count: int - default 2 - delay: float - time.sleep(delay), default 0.5 - start: str - start datetime, default "2000-01-01"; closed interval(including start) - end: str - end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``; open interval(excluding end) - check_data_length: int - check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None. - limit_nums: int - using for debug, by default None + """download data from Baostock Notes ----- @@ -482,21 +323,14 @@ def normalize_data( ): """normalize data - Parameters - ---------- - date_field_name: str - date field name, default date - symbol_field_name: str - symbol field name, default symbol - end_date: str - if not None, normalize the last date saved (including end_date); if None, it will ignore this parameter; by default None - qlib_data_1d_dir: str - if interval==5min, qlib_data_1d_dir cannot be None, normalize 5min needs to use 1d data; - - qlib_data_1d can be obtained like this: - $ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn --version v3 - or: - download 1d data, reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#1d-from-yahoo + Attention + --------- + qlib_data_1d_dir cannot be None, normalize 5min needs to use 1d data; + + qlib_data_1d can be obtained like this: + $ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn --version v3 + or: + download 1d data, reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#1d-from-yahoo Examples --------- diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 0dbe432422..bff812fea0 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -21,6 +21,7 @@ from functools import partial from concurrent.futures import ProcessPoolExecutor from bs4 import BeautifulSoup +from qlib.data import D HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}" @@ -194,10 +195,8 @@ def _get_symbol(): resp = requests.get(HS_SYMBOLS_URL.format(s_type=_k), timeout=None) _res |= set( map( - [ - "{}.{}".format(re.findall(r"\d+", x)[0], _v) - for x in etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()") - ] + lambda x: "{}.{}".format(re.findall(r"\d+", x)[0], _v), # pylint: disable=W0640 + etree.HTML(resp.text).xpath("//div[@class='result']/ul//li/a/text()"), # pylint: disable=I1101 ) ) time.sleep(3) @@ -607,5 +606,154 @@ def get_instruments( getattr(obj, method)() +def _get_all_1d_data(qlib_data_1d_dir: str, _date_field_name: str, _symbol_field_name: str): + # qlib.init(provider_uri=qlib_data_1d_dir) + df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") + df.reset_index(inplace=True) + df.rename(columns={"datetime": _date_field_name, "instrument": _symbol_field_name}, inplace=True) + df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns)) + return df + + +def get_1d_data( + qlib_data_1d_dir: str, _date_field_name: str, _symbol_field_name: str, symbol: str, start: str, end: str +) -> pd.DataFrame: + """get 1d data + + Returns + ------ + data_1d: pd.DataFrame + data_1d.columns = [_date_field_name, _symbol_field_name, "paused", "volume", "factor", "close"] + + """ + _all_1d_data = _get_all_1d_data(qlib_data_1d_dir, _date_field_name, _symbol_field_name) + return _all_1d_data[ + (_all_1d_data[_symbol_field_name] == symbol.upper()) + & (_all_1d_data[_date_field_name] >= pd.Timestamp(start)) + & (_all_1d_data[_date_field_name] < pd.Timestamp(end)) + ] + + +def calc_adjusted_price( + df: pd.DataFrame, + qlib_data_1d_dir: str, + _date_field_name: str, + _symbol_field_name: str, + frequence: str, + consistent_1d: bool = True, + calc_paused: bool = True, +) -> pd.DataFrame: + # TODO: using daily data factor + if df.empty: + return df + df = df.copy() + df.drop_duplicates(subset=_date_field_name, inplace=True) + df.sort_values(_date_field_name, inplace=True) + symbol = df.iloc[0][_symbol_field_name] + df[_date_field_name] = pd.to_datetime(df[_date_field_name]) + # get 1d data from qlib + _start = pd.Timestamp(df[_date_field_name].min()).strftime("%Y-%m-%d") + _end = (pd.Timestamp(df[_date_field_name].max()) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") + data_1d: pd.DataFrame = get_1d_data(qlib_data_1d_dir, _date_field_name, _symbol_field_name, symbol, _start, _end) + data_1d = data_1d.copy() + if data_1d is None or data_1d.empty: + df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"] + # TODO: np.nan or 1 or 0 + df["paused"] = np.nan + else: + # NOTE: volume is np.nan or volume <= 0, paused = 1 + # FIXME: find a more accurate data source + data_1d["paused"] = 0 + data_1d.loc[(data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused"] = 1 + data_1d = data_1d.set_index(_date_field_name) + + # add factor from 1d data + # NOTE: 1d data info: + # - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits. + # - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits. + # - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)` + def _calc_factor(df_1d: pd.DataFrame): + try: + _date = pd.Timestamp(pd.Timestamp(df_1d[_date_field_name].iloc[0]).date()) + df_1d["factor"] = data_1d.loc[_date]["close"] / df_1d.loc[df_1d["close"].last_valid_index()]["close"] + df_1d["paused"] = data_1d.loc[_date]["paused"] + except Exception: + df_1d["factor"] = np.nan + df_1d["paused"] = np.nan + return df_1d + + df = df.groupby([df[_date_field_name].dt.date]).apply(_calc_factor) + if consistent_1d: + # the date sequence is consistent with 1d + df.set_index(_date_field_name, inplace=True) + df = df.reindex( + generate_minutes_calendar_from_daily( + calendars=pd.to_datetime(data_1d.reset_index()[_date_field_name].drop_duplicates()), + freq=frequence, + am_range=("09:30:00", "11:29:00"), + pm_range=("13:00:00", "14:59:00"), + ) + ) + df[_symbol_field_name] = df.loc[df[_symbol_field_name].first_valid_index()][_symbol_field_name] + df.index.names = [_date_field_name] + df.reset_index(inplace=True) + for _col in ["open", "close", "high", "low", "volume"]: + if _col not in df.columns: + continue + if _col == "volume": + df[_col] = df[_col] / df["factor"] + else: + df[_col] = df[_col] * df["factor"] + if calc_paused: + df = calc_paused_num(df, _date_field_name, _symbol_field_name) + return df + + +def calc_paused_num(df: pd.DataFrame, _date_field_name, _symbol_field_name): + _symbol = df.iloc[0][_symbol_field_name] + df = df.copy() + df["_tmp_date"] = df[_date_field_name].apply(lambda x: pd.Timestamp(x).date()) + # remove data that starts and ends with `np.nan` all day + all_data = [] + # Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan + all_nan_nums = 0 + # Record the number of consecutive occurrences of trading days that are not nan throughout the day + not_nan_nums = 0 + for _date, _df in df.groupby("_tmp_date"): + _df["paused"] = 0 + if not _df.loc[_df["volume"] < 0].empty: + logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}") + _df.loc[_df["volume"] < 0, "volume"] = np.nan + + check_fields = set(_df.columns) - { + "_tmp_date", + "paused", + "factor", + _date_field_name, + _symbol_field_name, + } + if _df.loc[:, list(check_fields)].isna().values.all() or (_df["volume"] == 0).all(): + all_nan_nums += 1 + not_nan_nums = 0 + _df["paused"] = 1 + if all_data: + _df["paused_num"] = not_nan_nums + all_data.append(_df) + else: + all_nan_nums = 0 + not_nan_nums += 1 + _df["paused_num"] = not_nan_nums + all_data.append(_df) + all_data = all_data[: len(all_data) - all_nan_nums] + if all_data: + df = pd.concat(all_data, sort=False) + else: + logger.warning(f"data is empty: {_symbol}") + df = pd.DataFrame() + return df + del df["_tmp_date"] + return df + + if __name__ == "__main__": assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 20a791b35d..41d3265841 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -39,6 +39,7 @@ get_in_stock_symbols, get_br_stock_symbols, generate_minutes_calendar_from_daily, + calc_adjusted_price, ) INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}" @@ -590,6 +591,8 @@ def normalize(self, df: pd.DataFrame) -> pd.DataFrame: class YahooNormalize1min(YahooNormalize, ABC): + """Normalised to 1min using local 1d data""" + AM_RANGE = None # type: tuple # eg: ("09:30:00", "11:29:00") PM_RANGE = None # type: tuple # eg: ("13:00:00", "14:59:00") @@ -597,6 +600,27 @@ class YahooNormalize1min(YahooNormalize, ABC): CONSISTENT_1d = True CALC_PAUSED_NUM = True + def __init__( + self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs + ): + """ + + Parameters + ---------- + qlib_data_1d_dir: str, Path + the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data + date_field_name: str + date field name, default is date + symbol_field_name: str + symbol field name, default is symbol + """ + super(YahooNormalize1min, self).__init__(date_field_name, symbol_field_name) + self.qlib_data_1d_dir = qlib_data_1d_dir + qlib.init(provider_uri=self.qlib_data_1d_dir) + + def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: + return list(D.calendar(freq="day")) + @property def calendar_list_1d(self): calendar_list_1d = getattr(self, "_calendar_list_1d", None) @@ -610,133 +634,16 @@ def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index: calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE ) - def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame: - """get 1d data - - Returns - ------ - data_1d: pd.DataFrame - data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"] - - """ - data_1d = YahooCollector.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=start, end=end) - if not (data_1d is None or data_1d.empty): - _class_name = self.__class__.__name__.replace("min", "d") - _class: type(YahooNormalize) = getattr(importlib.import_module("collector"), _class_name) - data_1d_obj = _class(self._date_field_name, self._symbol_field_name) - data_1d = data_1d_obj.normalize(data_1d) - return data_1d - def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame: - # TODO: using daily data factor - if df.empty: - return df - df = df.copy() - df = df.sort_values(self._date_field_name) - symbol = df.iloc[0][self._symbol_field_name] - # get 1d data from yahoo - _start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT) - _end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT) - data_1d: pd.DataFrame = self.get_1d_data(symbol, _start, _end) - data_1d = data_1d.copy() - if data_1d is None or data_1d.empty: - df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"] - # TODO: np.nan or 1 or 0 - df["paused"] = np.nan - else: - # NOTE: volume is np.nan or volume <= 0, paused = 1 - # FIXME: find a more accurate data source - data_1d["paused"] = 0 - data_1d.loc[(data_1d["volume"].isna()) | (data_1d["volume"] <= 0), "paused"] = 1 - data_1d = data_1d.set_index(self._date_field_name) - - # add factor from 1d data - # NOTE: yahoo 1d data info: - # - Close price adjusted for splits. Adjusted close price adjusted for both dividends and splits. - # - data_1d.adjclose: Adjusted close price adjusted for both dividends and splits. - # - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)` - def _calc_factor(df_1d: pd.DataFrame): - try: - _date = pd.Timestamp(pd.Timestamp(df_1d[self._date_field_name].iloc[0]).date()) - df_1d["factor"] = ( - data_1d.loc[_date]["close"] / df_1d.loc[df_1d["close"].last_valid_index()]["close"] - ) - df_1d["paused"] = data_1d.loc[_date]["paused"] - except Exception: - df_1d["factor"] = np.nan - df_1d["paused"] = np.nan - return df_1d - - df = df.groupby([df[self._date_field_name].dt.date]).apply(_calc_factor) - - if self.CONSISTENT_1d: - # the date sequence is consistent with 1d - df.set_index(self._date_field_name, inplace=True) - df = df.reindex( - self.generate_1min_from_daily( - pd.to_datetime(data_1d.reset_index()[self._date_field_name].drop_duplicates()) - ) - ) - df[self._symbol_field_name] = df.loc[df[self._symbol_field_name].first_valid_index()][ - self._symbol_field_name - ] - df.index.names = [self._date_field_name] - df.reset_index(inplace=True) - for _col in self.COLUMNS: - if _col not in df.columns: - continue - if _col == "volume": - df[_col] = df[_col] / df["factor"] - else: - df[_col] = df[_col] * df["factor"] - - if self.CALC_PAUSED_NUM: - df = self.calc_paused_num(df) - return df - - def calc_paused_num(self, df: pd.DataFrame): - _symbol = df.iloc[0][self._symbol_field_name] - df = df.copy() - df["_tmp_date"] = df[self._date_field_name].apply(lambda x: pd.Timestamp(x).date()) - # remove data that starts and ends with `np.nan` all day - all_data = [] - # Record the number of consecutive trading days where the whole day is nan, to remove the last trading day where the whole day is nan - all_nan_nums = 0 - # Record the number of consecutive occurrences of trading days that are not nan throughout the day - not_nan_nums = 0 - for _date, _df in df.groupby("_tmp_date"): - _df["paused"] = 0 - if not _df.loc[_df["volume"] < 0].empty: - logger.warning(f"volume < 0, will fill np.nan: {_date} {_symbol}") - _df.loc[_df["volume"] < 0, "volume"] = np.nan - - check_fields = set(_df.columns) - { - "_tmp_date", - "paused", - "factor", - self._date_field_name, - self._symbol_field_name, - } - if _df.loc[:, check_fields].isna().values.all() or (_df["volume"] == 0).all(): - all_nan_nums += 1 - not_nan_nums = 0 - _df["paused"] = 1 - if all_data: - _df["paused_num"] = not_nan_nums - all_data.append(_df) - else: - all_nan_nums = 0 - not_nan_nums += 1 - _df["paused_num"] = not_nan_nums - all_data.append(_df) - all_data = all_data[: len(all_data) - all_nan_nums] - if all_data: - df = pd.concat(all_data, sort=False) - else: - logger.warning(f"data is empty: {_symbol}") - df = pd.DataFrame() - return df - del df["_tmp_date"] + df = calc_adjusted_price( + df=df, + qlib_data_1d_dir=self.qlib_data_1d_dir, + _date_field_name=self._date_field_name, + _symbol_field_name=self._symbol_field_name, + frequence="1min", + consistent_1d=self.CONSISTENT_1d, + calc_paused=self.CALC_PAUSED_NUM, + ) return df @abc.abstractmethod @@ -748,55 +655,6 @@ def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: raise NotImplementedError("rewrite _get_1d_calendar_list") -class YahooNormalize1minOffline(YahooNormalize1min): - """Normalised to 1min using local 1d data""" - - def __init__( - self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs - ): - """ - - Parameters - ---------- - qlib_data_1d_dir: str, Path - the qlib data to be updated for yahoo, usually from: Normalised to 1min using local 1d data - date_field_name: str - date field name, default is date - symbol_field_name: str - symbol field name, default is symbol - """ - self.qlib_data_1d_dir = qlib_data_1d_dir - super(YahooNormalize1minOffline, self).__init__(date_field_name, symbol_field_name) - self._all_1d_data = self._get_all_1d_data() - - def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: - qlib.init(provider_uri=self.qlib_data_1d_dir) - return list(D.calendar(freq="day")) - - def _get_all_1d_data(self): - qlib.init(provider_uri=self.qlib_data_1d_dir) - df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") - df.reset_index(inplace=True) - df.rename(columns={"datetime": self._date_field_name, "instrument": self._symbol_field_name}, inplace=True) - df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns)) - return df - - def get_1d_data(self, symbol: str, start: str, end: str) -> pd.DataFrame: - """get 1d data - - Returns - ------ - data_1d: pd.DataFrame - data_1d.columns = [self._date_field_name, self._symbol_field_name, "paused", "volume", "factor", "close"] - - """ - return self._all_1d_data[ - (self._all_1d_data[self._symbol_field_name] == symbol.upper()) - & (self._all_1d_data[self._date_field_name] >= pd.Timestamp(start)) - & (self._all_1d_data[self._date_field_name] < pd.Timestamp(end)) - ] - - class YahooNormalizeUS: def _get_calendar_list(self) -> Iterable[pd.Timestamp]: # TODO: from MSN @@ -811,7 +669,7 @@ class YahooNormalizeUS1dExtend(YahooNormalizeUS, YahooNormalize1dExtend): pass -class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1minOffline): +class YahooNormalizeUS1min(YahooNormalizeUS, YahooNormalize1min): CALC_PAUSED_NUM = False def _get_calendar_list(self) -> Iterable[pd.Timestamp]: @@ -834,7 +692,7 @@ class YahooNormalizeIN1d(YahooNormalizeIN, YahooNormalize1d): pass -class YahooNormalizeIN1min(YahooNormalizeIN, YahooNormalize1minOffline): +class YahooNormalizeIN1min(YahooNormalizeIN, YahooNormalize1min): CALC_PAUSED_NUM = False def _get_calendar_list(self) -> Iterable[pd.Timestamp]: @@ -862,7 +720,7 @@ class YahooNormalizeCN1dExtend(YahooNormalizeCN, YahooNormalize1dExtend): pass -class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1minOffline): +class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min): AM_RANGE = ("09:30:00", "11:29:00") PM_RANGE = ("13:00:00", "14:59:00") @@ -889,7 +747,7 @@ class YahooNormalizeBR1d(YahooNormalizeBR, YahooNormalize1d): pass -class YahooNormalizeBR1min(YahooNormalizeBR, YahooNormalize1minOffline): +class YahooNormalizeBR1min(YahooNormalizeBR, YahooNormalize1min): CALC_PAUSED_NUM = False def _get_calendar_list(self) -> Iterable[pd.Timestamp]: From ab3e6d3caed2fe28ec459a54f66ce028eaa603e1 Mon Sep 17 00:00:00 2001 From: Linlang Date: Wed, 27 Sep 2023 15:12:20 +0800 Subject: [PATCH 05/25] modified the logic of update_data_to_bin --- scripts/data_collector/yahoo/collector.py | 76 ++++++----------------- 1 file changed, 19 insertions(+), 57 deletions(-) diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 41d3265841..996266a59f 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -524,70 +524,31 @@ def __init__( super(YahooNormalize1dExtend, self).__init__(date_field_name, symbol_field_name) self._first_close_field = "first_close" self._ori_close_field = "ori_close" + self.column_list = ["open", "high", "low", "close", "volume", "factor", "change"] self.old_qlib_data = self._get_old_data(old_qlib_data_dir) def _get_old_data(self, qlib_data_dir: [str, Path]): qlib_data_dir = str(Path(qlib_data_dir).expanduser().resolve()) qlib.init(provider_uri=qlib_data_dir, expression_cache=None, dataset_cache=None) - df = D.features(D.instruments("all"), ["$close/$factor", "$adjclose/$close"]) - df.columns = [self._ori_close_field, self._first_close_field] + df = D.features(D.instruments("all"), ["$" + col for col in self.column_list]) + df.columns = self.column_list return df - def _get_close(self, df: pd.DataFrame, field_name: str): - _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper() - _df = self.old_qlib_data.loc(axis=0)[_symbol] - _close = _df.loc[_df.last_valid_index()][field_name] - return _close - - def _get_first_close(self, df: pd.DataFrame) -> float: - try: - _close = self._get_close(df, field_name=self._first_close_field) - except KeyError: - _close = super(YahooNormalize1dExtend, self)._get_first_close(df) - return _close - - def _get_last_close(self, df: pd.DataFrame) -> float: - try: - _close = self._get_close(df, field_name=self._ori_close_field) - except KeyError: - _close = None - return _close - - def _get_last_date(self, df: pd.DataFrame) -> pd.Timestamp: - _symbol = df.loc[df[self._symbol_field_name].first_valid_index()][self._symbol_field_name].upper() - try: - _df = self.old_qlib_data.loc(axis=0)[_symbol] - _date = _df.index.max() - except KeyError: - _date = None - return _date - def normalize(self, df: pd.DataFrame) -> pd.DataFrame: - _last_close = self._get_last_close(df) - # reindex - _last_date = self._get_last_date(df) - if _last_date is not None: - df = df.set_index(self._date_field_name) - df.index = pd.to_datetime(df.index) - df = df[~df.index.duplicated(keep="first")] - _max_date = df.index.max() - df = df.reindex(self._calendar_list).loc[:_max_date].reset_index() - df = df[df[self._date_field_name] > _last_date] - if df.empty: - return pd.DataFrame() - _si = df["close"].first_valid_index() - if _si > df.index[0]: - logger.warning( - f"{df.loc[_si][self._symbol_field_name]} missing data: {df.loc[:_si - 1][self._date_field_name].to_list()}" - ) - # normalize - df = self.normalize_yahoo( - df, self._calendar_list, self._date_field_name, self._symbol_field_name, last_close=_last_close - ) - # adjusted price - df = self.adjusted_price(df) - df = self._manual_adj_data(df) - return df + df = super(YahooNormalize1dExtend, self).normalize(df) + df.set_index(self._date_field_name, inplace=True) + symbol_name = df[self._symbol_field_name].iloc[0] + old_df = self.old_qlib_data.loc[str(symbol_name).upper()] + latest_date = old_df.index[-1] + new_latest_data = df.loc[latest_date] + old_latest_data = old_df.loc[latest_date] + for col in self.column_list[:-1]: + if col == "volume": + df[col] = df[col] / (new_latest_data[col] / old_latest_data[col]) + else: + df[col] = df[col] * (old_latest_data[col] / new_latest_data[col]) + df = df.loc[self._calendar_list[self._calendar_list.index(latest_date) + 1]:] + return df.reset_index() class YahooNormalize1min(YahooNormalize, ABC): @@ -1019,7 +980,8 @@ def update_data_to_bin( # download data from yahoo # NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1 - self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length) + trading_date = (pd.Timestamp(trading_date) - pd.Timedelta(days=2)).strftime("%Y-%m-%d") + # self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length) # NOTE: a larger max_workers setting here would be faster self.max_workers = ( max(multiprocessing.cpu_count() - 2, 1) From 065479e2ec886d474139169a6886fb9d2989d2d5 Mon Sep 17 00:00:00 2001 From: Linlang Date: Wed, 27 Sep 2023 22:03:28 +0800 Subject: [PATCH 06/25] modified the logic of update_data_to_bin --- scripts/data_collector/yahoo/collector.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 996266a59f..4eb92a7de0 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -540,15 +540,15 @@ def normalize(self, df: pd.DataFrame) -> pd.DataFrame: symbol_name = df[self._symbol_field_name].iloc[0] old_df = self.old_qlib_data.loc[str(symbol_name).upper()] latest_date = old_df.index[-1] - new_latest_data = df.loc[latest_date] + df = df.loc[latest_date:] + new_latest_data = df.iloc[0] old_latest_data = old_df.loc[latest_date] for col in self.column_list[:-1]: if col == "volume": df[col] = df[col] / (new_latest_data[col] / old_latest_data[col]) else: df[col] = df[col] * (old_latest_data[col] / new_latest_data[col]) - df = df.loc[self._calendar_list[self._calendar_list.index(latest_date) + 1]:] - return df.reset_index() + return df.drop(df.index[0]).reset_index() class YahooNormalize1min(YahooNormalize, ABC): @@ -981,7 +981,7 @@ def update_data_to_bin( # download data from yahoo # NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1 trading_date = (pd.Timestamp(trading_date) - pd.Timedelta(days=2)).strftime("%Y-%m-%d") - # self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length) + self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length) # NOTE: a larger max_workers setting here would be faster self.max_workers = ( max(multiprocessing.cpu_count() - 2, 1) From bbf47df1a8857fa6a43e6f69748e7550efda68b4 Mon Sep 17 00:00:00 2001 From: Linlang Date: Tue, 10 Oct 2023 16:52:59 +0800 Subject: [PATCH 07/25] optimize code --- .../data_collector/baostock_5min/collector.py | 2 ++ scripts/data_collector/utils.py | 22 +++++++++++++------ scripts/data_collector/yahoo/collector.py | 3 ++- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/scripts/data_collector/baostock_5min/collector.py b/scripts/data_collector/baostock_5min/collector.py index 2d966cffb0..3db027e6a0 100644 --- a/scripts/data_collector/baostock_5min/collector.py +++ b/scripts/data_collector/baostock_5min/collector.py @@ -167,6 +167,7 @@ def __init__( """ bs.login() qlib.init(provider_uri=qlib_data_1d_dir) + self.all_1d_data = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") self.qlib_data_1d_dir = qlib_data_1d_dir super(BaostockNormalizeHS3005min, self).__init__(date_field_name, symbol_field_name) @@ -257,6 +258,7 @@ def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame: _date_field_name=self._date_field_name, _symbol_field_name=self._symbol_field_name, frequence="5min", + _1d_data_all=self.all_1d_data, ) return df diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index bff812fea0..3d591cd0ae 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import re +import copy import importlib import time import bisect @@ -21,7 +22,6 @@ from functools import partial from concurrent.futures import ProcessPoolExecutor from bs4 import BeautifulSoup -from qlib.data import D HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}" @@ -606,9 +606,8 @@ def get_instruments( getattr(obj, method)() -def _get_all_1d_data(qlib_data_1d_dir: str, _date_field_name: str, _symbol_field_name: str): - # qlib.init(provider_uri=qlib_data_1d_dir) - df = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") +def _get_all_1d_data(qlib_data_1d_dir: str, _date_field_name: str, _symbol_field_name: str, _1d_data_all: pd.DataFrame): + df = copy.deepcopy(_1d_data_all) df.reset_index(inplace=True) df.rename(columns={"datetime": _date_field_name, "instrument": _symbol_field_name}, inplace=True) df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns)) @@ -616,7 +615,13 @@ def _get_all_1d_data(qlib_data_1d_dir: str, _date_field_name: str, _symbol_field def get_1d_data( - qlib_data_1d_dir: str, _date_field_name: str, _symbol_field_name: str, symbol: str, start: str, end: str + qlib_data_1d_dir: str, + _date_field_name: str, + _symbol_field_name: str, + symbol: str, + start: str, + end: str, + _1d_data_all: pd.DataFrame, ) -> pd.DataFrame: """get 1d data @@ -626,7 +631,7 @@ def get_1d_data( data_1d.columns = [_date_field_name, _symbol_field_name, "paused", "volume", "factor", "close"] """ - _all_1d_data = _get_all_1d_data(qlib_data_1d_dir, _date_field_name, _symbol_field_name) + _all_1d_data = _get_all_1d_data(qlib_data_1d_dir, _date_field_name, _symbol_field_name, _1d_data_all) return _all_1d_data[ (_all_1d_data[_symbol_field_name] == symbol.upper()) & (_all_1d_data[_date_field_name] >= pd.Timestamp(start)) @@ -636,6 +641,7 @@ def get_1d_data( def calc_adjusted_price( df: pd.DataFrame, + _1d_data_all: pd.DataFrame, qlib_data_1d_dir: str, _date_field_name: str, _symbol_field_name: str, @@ -654,7 +660,9 @@ def calc_adjusted_price( # get 1d data from qlib _start = pd.Timestamp(df[_date_field_name].min()).strftime("%Y-%m-%d") _end = (pd.Timestamp(df[_date_field_name].max()) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") - data_1d: pd.DataFrame = get_1d_data(qlib_data_1d_dir, _date_field_name, _symbol_field_name, symbol, _start, _end) + data_1d: pd.DataFrame = get_1d_data( + qlib_data_1d_dir, _date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all + ) data_1d = data_1d.copy() if data_1d is None or data_1d.empty: df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"] diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 4eb92a7de0..4890e3a765 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -578,6 +578,7 @@ def __init__( super(YahooNormalize1min, self).__init__(date_field_name, symbol_field_name) self.qlib_data_1d_dir = qlib_data_1d_dir qlib.init(provider_uri=self.qlib_data_1d_dir) + self.all_1d_data = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: return list(D.calendar(freq="day")) @@ -604,6 +605,7 @@ def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame: frequence="1min", consistent_1d=self.CONSISTENT_1d, calc_paused=self.CALC_PAUSED_NUM, + _1d_data_all=self.all_1d_data, ) return df @@ -959,7 +961,6 @@ def update_data_to_bin( Examples ------- $ python collector.py update_data_to_bin --qlib_data_1d_dir --trading_date --end_date - # get 1m data """ if self.interval.lower() != "1d": From 0401298387be0c76891fd620db8c116d36537513 Mon Sep 17 00:00:00 2001 From: Linlang Date: Thu, 12 Oct 2023 10:15:37 +0800 Subject: [PATCH 08/25] optimize pylint issue --- .github/workflows/test_qlib_from_source.yml | 4 ++-- scripts/data_collector/crypto/collector.py | 6 +++--- scripts/data_collector/fund/collector.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test_qlib_from_source.yml b/.github/workflows/test_qlib_from_source.yml index 3d72961a91..d65dc46935 100644 --- a/.github/workflows/test_qlib_from_source.yml +++ b/.github/workflows/test_qlib_from_source.yml @@ -101,8 +101,8 @@ jobs: # We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000). - name: Check Qlib with pylint run: | - pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" - pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' scripts --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" + pylint --disable=F0002,C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" + pylint --disable=F0002,C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' scripts --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" # The following flake8 error codes were ignored: # E501 line too long diff --git a/scripts/data_collector/crypto/collector.py b/scripts/data_collector/crypto/collector.py index 283517da9c..302b89e200 100644 --- a/scripts/data_collector/crypto/collector.py +++ b/scripts/data_collector/crypto/collector.py @@ -225,7 +225,7 @@ def _get_calendar_list(self): class Run(BaseRun): - def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"): # pylint: disable=W0246 + def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"): """ Parameters @@ -253,7 +253,7 @@ def normalize_class_name(self): def default_base_dir(self) -> [Path, str]: return CUR_DIR - def download_data( # pylint: disable=W0246 + def download_data( self, max_collector_count=2, delay=0, @@ -289,7 +289,7 @@ def download_data( # pylint: disable=W0246 super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums) - def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): # pylint: disable=W0246 + def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): """normalize data Parameters diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index de375bf07e..937d3931db 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -247,7 +247,7 @@ def normalize_class_name(self): def default_base_dir(self) -> [Path, str]: return CUR_DIR - def download_data( # pylint: disable=W0246 + def download_data( self, max_collector_count=2, delay=0, @@ -283,7 +283,7 @@ def download_data( # pylint: disable=W0246 super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums) - def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): # pylint: disable=W0246 + def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): """normalize data Parameters From bd678fcda090775b4a0acf8e7e92447d15626378 Mon Sep 17 00:00:00 2001 From: Linlang Date: Thu, 19 Oct 2023 18:36:56 +0800 Subject: [PATCH 09/25] fix pylint error --- .github/workflows/test_qlib_from_source.yml | 4 ++-- setup.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_qlib_from_source.yml b/.github/workflows/test_qlib_from_source.yml index d65dc46935..9205a13641 100644 --- a/.github/workflows/test_qlib_from_source.yml +++ b/.github/workflows/test_qlib_from_source.yml @@ -101,8 +101,8 @@ jobs: # We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000). - name: Check Qlib with pylint run: | - pylint --disable=F0002,C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" - pylint --disable=F0002,C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' scripts --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" + pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" + pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' scripts --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" # The following flake8 error codes were ignored: # E501 line too long diff --git a/setup.py b/setup.py index 9d7c185ab9..e9c2387aad 100644 --- a/setup.py +++ b/setup.py @@ -140,7 +140,8 @@ def get_version(rel_path: str) -> str: "wheel", "setuptools", "black", - "pylint", + # Version 3.0 of pylint had problems with the build process, so we limited the version of pylint. + "pylint<=2.17.6", # Using the latest versions(0.981 and 0.982) of mypy, # the error "multiprocessing.Value()" is detected in the file "qlib/rl/utils/data_queue.py", # If this is fixed in a subsequent version of mypy, then we will revert to the latest version of mypy. From 48de114f364466770b6d931ff0b87c866ca6b780 Mon Sep 17 00:00:00 2001 From: Linlang Date: Wed, 25 Oct 2023 17:39:40 +0800 Subject: [PATCH 10/25] changes suggested by the review --- scripts/data_collector/baostock_5min/README.md | 7 +++---- scripts/data_collector/baostock_5min/collector.py | 2 -- scripts/data_collector/utils.py | 8 +++----- scripts/data_collector/yahoo/collector.py | 4 ---- 4 files changed, 6 insertions(+), 15 deletions(-) diff --git a/scripts/data_collector/baostock_5min/README.md b/scripts/data_collector/baostock_5min/README.md index cf6b7789c9..e593ea2e49 100644 --- a/scripts/data_collector/baostock_5min/README.md +++ b/scripts/data_collector/baostock_5min/README.md @@ -49,11 +49,10 @@ - `symbol_field_name`: column *name* identifying symbol in csv files, by default `symbol` - `end_date`: if not `None`, normalize the last date saved (*including end_date*); if `None`, it will ignore this parameter; by default `None` - `qlib_data_1d_dir`: qlib directory(1d data) - ``` if interval==5min, qlib_data_1d_dir cannot be None, normalize 5min needs to use 1d data; - - qlib_data_1d can be obtained like this: - $ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn --version v3 + ``` + # qlib_data_1d can be obtained like this: + python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn --version v3 ``` - examples: ```bash diff --git a/scripts/data_collector/baostock_5min/collector.py b/scripts/data_collector/baostock_5min/collector.py index 3db027e6a0..e2ad10dd59 100644 --- a/scripts/data_collector/baostock_5min/collector.py +++ b/scripts/data_collector/baostock_5min/collector.py @@ -168,7 +168,6 @@ def __init__( bs.login() qlib.init(provider_uri=qlib_data_1d_dir) self.all_1d_data = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") - self.qlib_data_1d_dir = qlib_data_1d_dir super(BaostockNormalizeHS3005min, self).__init__(date_field_name, symbol_field_name) @staticmethod @@ -254,7 +253,6 @@ def generate_5min_from_daily(self, calendars: Iterable) -> pd.Index: def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame: df = calc_adjusted_price( df=df, - qlib_data_1d_dir=self.qlib_data_1d_dir, _date_field_name=self._date_field_name, _symbol_field_name=self._symbol_field_name, frequence="5min", diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 3d591cd0ae..8f02d8cf21 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -606,7 +606,7 @@ def get_instruments( getattr(obj, method)() -def _get_all_1d_data(qlib_data_1d_dir: str, _date_field_name: str, _symbol_field_name: str, _1d_data_all: pd.DataFrame): +def _get_all_1d_data(_date_field_name: str, _symbol_field_name: str, _1d_data_all: pd.DataFrame): df = copy.deepcopy(_1d_data_all) df.reset_index(inplace=True) df.rename(columns={"datetime": _date_field_name, "instrument": _symbol_field_name}, inplace=True) @@ -615,7 +615,6 @@ def _get_all_1d_data(qlib_data_1d_dir: str, _date_field_name: str, _symbol_field def get_1d_data( - qlib_data_1d_dir: str, _date_field_name: str, _symbol_field_name: str, symbol: str, @@ -631,7 +630,7 @@ def get_1d_data( data_1d.columns = [_date_field_name, _symbol_field_name, "paused", "volume", "factor", "close"] """ - _all_1d_data = _get_all_1d_data(qlib_data_1d_dir, _date_field_name, _symbol_field_name, _1d_data_all) + _all_1d_data = _get_all_1d_data(_date_field_name, _symbol_field_name, _1d_data_all) return _all_1d_data[ (_all_1d_data[_symbol_field_name] == symbol.upper()) & (_all_1d_data[_date_field_name] >= pd.Timestamp(start)) @@ -642,7 +641,6 @@ def get_1d_data( def calc_adjusted_price( df: pd.DataFrame, _1d_data_all: pd.DataFrame, - qlib_data_1d_dir: str, _date_field_name: str, _symbol_field_name: str, frequence: str, @@ -661,7 +659,7 @@ def calc_adjusted_price( _start = pd.Timestamp(df[_date_field_name].min()).strftime("%Y-%m-%d") _end = (pd.Timestamp(df[_date_field_name].max()) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") data_1d: pd.DataFrame = get_1d_data( - qlib_data_1d_dir, _date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all + _date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all ) data_1d = data_1d.copy() if data_1d is None or data_1d.empty: diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 4890e3a765..ea9225d140 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -522,8 +522,6 @@ def __init__( symbol field name, default is symbol """ super(YahooNormalize1dExtend, self).__init__(date_field_name, symbol_field_name) - self._first_close_field = "first_close" - self._ori_close_field = "ori_close" self.column_list = ["open", "high", "low", "close", "volume", "factor", "change"] self.old_qlib_data = self._get_old_data(old_qlib_data_dir) @@ -576,7 +574,6 @@ def __init__( symbol field name, default is symbol """ super(YahooNormalize1min, self).__init__(date_field_name, symbol_field_name) - self.qlib_data_1d_dir = qlib_data_1d_dir qlib.init(provider_uri=self.qlib_data_1d_dir) self.all_1d_data = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") @@ -599,7 +596,6 @@ def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index: def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame: df = calc_adjusted_price( df=df, - qlib_data_1d_dir=self.qlib_data_1d_dir, _date_field_name=self._date_field_name, _symbol_field_name=self._symbol_field_name, frequence="1min", From bc5fe98829180ef320b3cf0c5ffb829f5a789d5b Mon Sep 17 00:00:00 2001 From: Linlang Date: Wed, 25 Oct 2023 20:06:56 +0800 Subject: [PATCH 11/25] fix CI faild --- scripts/data_collector/utils.py | 4 +--- setup.py | 6 +++++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 8f02d8cf21..31170de7d0 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -658,9 +658,7 @@ def calc_adjusted_price( # get 1d data from qlib _start = pd.Timestamp(df[_date_field_name].min()).strftime("%Y-%m-%d") _end = (pd.Timestamp(df[_date_field_name].max()) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") - data_1d: pd.DataFrame = get_1d_data( - _date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all - ) + data_1d: pd.DataFrame = get_1d_data(_date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all) data_1d = data_1d.copy() if data_1d is None or data_1d.empty: df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"] diff --git a/setup.py b/setup.py index e9c2387aad..508fd8c3a4 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,11 @@ def get_version(rel_path: str) -> str: "lightgbm>=3.3.0", "tornado", "joblib>=0.17.0", - "ruamel.yaml>=0.16.12", + # With the upgrading of ruamel.yaml to 0.18, the safe_load method was deprecated, + # which would cause qlib.workflow.cli to not work properly, + # and no good replacement has been found, so the version of ruamel.yaml has been restricted for now. + # Refs: https://pypi.org/project/ruamel.yaml/ + "ruamel.yaml<=0.17.36", "pymongo==3.7.2", # For task management "scikit-learn>=0.22", "dill", From df94d6c31a8be7881b1241cdb0c77d3936d2194c Mon Sep 17 00:00:00 2001 From: Linlang Date: Wed, 25 Oct 2023 20:43:37 +0800 Subject: [PATCH 12/25] fix CI faild --- scripts/data_collector/yahoo/collector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index ea9225d140..28acff4fc0 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -574,7 +574,7 @@ def __init__( symbol field name, default is symbol """ super(YahooNormalize1min, self).__init__(date_field_name, symbol_field_name) - qlib.init(provider_uri=self.qlib_data_1d_dir) + qlib.init(provider_uri=qlib_data_1d_dir) self.all_1d_data = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: From e87a54ff4cfc007633e1df418b4504c27d64e844 Mon Sep 17 00:00:00 2001 From: Linlang Date: Thu, 26 Oct 2023 19:59:37 +0800 Subject: [PATCH 13/25] fix issue 1121 --- scripts/data_collector/yahoo/collector.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 28acff4fc0..a333e38422 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -934,6 +934,7 @@ def update_data_to_bin( end_date: str = None, check_data_length: int = None, delay: float = 1, + exists_skip: bool = False, ): """update yahoo data to bin @@ -950,6 +951,8 @@ def update_data_to_bin( check data length, if not None and greater than 0, each symbol will be considered complete if its data length is greater than or equal to this value, otherwise it will be fetched again, the maximum number of fetches being (max_collector_count). By default None. delay: float time.sleep(delay), default 1 + exists_skip: bool + exists skip, by default False Notes ----- If the data in qlib_data_dir is incomplete, np.nan will be populated to trading_date for the previous trading day @@ -973,7 +976,7 @@ def update_data_to_bin( # download qlib 1d data qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve()) if not exists_qlib_data(qlib_data_1d_dir): - GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region) + GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region, exists_skip=exists_skip) # download data from yahoo # NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1 From b168a39a8ce25fb2c7254421621fd27e4ef44b10 Mon Sep 17 00:00:00 2001 From: Linlang Date: Thu, 26 Oct 2023 21:46:38 +0800 Subject: [PATCH 14/25] format with black --- scripts/data_collector/yahoo/collector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index a333e38422..fa59dcfeb0 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -976,7 +976,9 @@ def update_data_to_bin( # download qlib 1d data qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve()) if not exists_qlib_data(qlib_data_1d_dir): - GetData().qlib_data(target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region, exists_skip=exists_skip) + GetData().qlib_data( + target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region, exists_skip=exists_skip + ) # download data from yahoo # NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1 From 530bd088274174598268a1e1ee866980034fe2ac Mon Sep 17 00:00:00 2001 From: Linlang Date: Fri, 27 Oct 2023 14:04:53 +0800 Subject: [PATCH 15/25] optimize code logic --- .../data_collector/baostock_5min/collector.py | 20 ------------------- scripts/data_collector/yahoo/README.md | 8 +++----- scripts/data_collector/yahoo/collector.py | 19 +++++++----------- 3 files changed, 10 insertions(+), 37 deletions(-) diff --git a/scripts/data_collector/baostock_5min/collector.py b/scripts/data_collector/baostock_5min/collector.py index e2ad10dd59..e188ec024f 100644 --- a/scripts/data_collector/baostock_5min/collector.py +++ b/scripts/data_collector/baostock_5min/collector.py @@ -216,26 +216,6 @@ def normalize_baostock( df.sort_index(inplace=True) df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), list(set(df.columns) - {symbol_field_name})] = np.nan - change_series = BaostockNormalizeHS3005min.calc_change(df, last_close) - # NOTE: The data obtained by Yahoo finance sometimes has exceptions - # WARNING: If it is normal for a `symbol(exchange)` to differ by a factor of *89* to *111* for consecutive trading days, - # WARNING: the logic in the following line needs to be modified - _count = 0 - while True: - # NOTE: may appear unusual for many days in a row - change_series = BaostockNormalizeHS3005min.calc_change(df, last_close) - _mask = (change_series >= 89) & (change_series <= 111) - if not _mask.any(): - break - _tmp_cols = ["high", "close", "low", "open"] - df.loc[_mask, _tmp_cols] = df.loc[_mask, _tmp_cols] / 100 - _count += 1 - if _count >= 10: - _symbol = df.loc[df[symbol_field_name].first_valid_index()]["symbol"] - logger.warning( - f"{_symbol} `change` is abnormal for {_count} consecutive days, please check the specific data file carefully" - ) - df["change"] = BaostockNormalizeHS3005min.calc_change(df, last_close) columns += ["change"] diff --git a/scripts/data_collector/yahoo/README.md b/scripts/data_collector/yahoo/README.md index bd852d0523..e3f8be9049 100644 --- a/scripts/data_collector/yahoo/README.md +++ b/scripts/data_collector/yahoo/README.md @@ -121,7 +121,6 @@ pip install -r requirements.txt qlib_data_1d can be obtained like this: $ python scripts/get_data.py qlib_data --target_dir --interval 1d - $ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir --trading_date 2021-06-01 or: download 1d data from YahooFinance @@ -180,9 +179,8 @@ pip install -r requirements.txt * Manual update of data ``` - python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir --trading_date --end_date + python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir --end_date ``` - * `trading_date`: start of trading day * `end_date`: end of trading day(not included) * `check_data_length`: check the number of rows per *symbol*, by default `None` > if `len(symbol_df) < check_data_length`, it will be re-fetched, with the number of re-fetches coming from the `max_collector_count` parameter @@ -191,10 +189,10 @@ pip install -r requirements.txt * `source_dir`: The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source" * `normalize_dir`: Directory for normalize data, default "Path(__file__).parent/normalize" * `qlib_data_1d_dir`: the qlib data to be updated for yahoo, usually from: [download qlib data](https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data) - * `trading_date`: trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")`` * `end_date`: end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end) * `region`: region, value from ["CN", "US"], default "CN" - + * `interval`: interval, default "1d"(Currently only supports 1d data) + * `exists_skip`: exists skip, by default False ## Using qlib data diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index fa59dcfeb0..869eb63fc2 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -930,7 +930,6 @@ def download_today_data( def update_data_to_bin( self, qlib_data_1d_dir: str, - trading_date: str = None, end_date: str = None, check_data_length: int = None, delay: float = 1, @@ -943,8 +942,6 @@ def update_data_to_bin( qlib_data_1d_dir: str the qlib data to be updated for yahoo, usually from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data - trading_date: str - trading days to be updated, by default ``datetime.datetime.now().strftime("%Y-%m-%d")`` end_date: str end datetime, default ``pd.Timestamp(trading_date + pd.Timedelta(days=1))``; open interval(excluding end) check_data_length: int @@ -965,24 +962,22 @@ def update_data_to_bin( if self.interval.lower() != "1d": logger.warning(f"currently supports 1d data updates: --interval 1d") - # start/end date - if trading_date is None: - trading_date = datetime.datetime.now().strftime("%Y-%m-%d") - logger.warning(f"trading_date is None, use the current date: {trading_date}") - - if end_date is None: - end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") - # download qlib 1d data qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve()) if not exists_qlib_data(qlib_data_1d_dir): GetData().qlib_data( target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region, exists_skip=exists_skip ) + + # start/end date + calendar_df = pd.read_csv(Path(qlib_data_1d_dir).joinpath("calendars/day.txt")) + trading_date = (pd.Timestamp(calendar_df.iloc[-1, 0]) - pd.Timedelta(days=2)).strftime("%Y-%m-%d") + + if end_date is None: + end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") # download data from yahoo # NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1 - trading_date = (pd.Timestamp(trading_date) - pd.Timedelta(days=2)).strftime("%Y-%m-%d") self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length) # NOTE: a larger max_workers setting here would be faster self.max_workers = ( From 6633605c8f455287acff3496b3ddf3b84710f6a1 Mon Sep 17 00:00:00 2001 From: Linlang Date: Fri, 27 Oct 2023 14:09:25 +0800 Subject: [PATCH 16/25] optimize code logic --- scripts/data_collector/yahoo/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/data_collector/yahoo/README.md b/scripts/data_collector/yahoo/README.md index e3f8be9049..b49baaf66d 100644 --- a/scripts/data_collector/yahoo/README.md +++ b/scripts/data_collector/yahoo/README.md @@ -121,6 +121,7 @@ pip install -r requirements.txt qlib_data_1d can be obtained like this: $ python scripts/get_data.py qlib_data --target_dir --interval 1d + $ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir --trading_date or: download 1d data from YahooFinance From 0fdcb9761921d72febe9321f7156120662eba07e Mon Sep 17 00:00:00 2001 From: Linlang Date: Fri, 27 Oct 2023 14:12:04 +0800 Subject: [PATCH 17/25] fix error code --- scripts/data_collector/yahoo/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/data_collector/yahoo/README.md b/scripts/data_collector/yahoo/README.md index b49baaf66d..17d94c96d7 100644 --- a/scripts/data_collector/yahoo/README.md +++ b/scripts/data_collector/yahoo/README.md @@ -121,7 +121,7 @@ pip install -r requirements.txt qlib_data_1d can be obtained like this: $ python scripts/get_data.py qlib_data --target_dir --interval 1d - $ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir --trading_date + $ python scripts/data_collector/yahoo/collector.py update_data_to_bin --qlib_data_1d_dir --end_date or: download 1d data from YahooFinance From fe8c4bc736b25831dd7c859ce50b568bc176af7d Mon Sep 17 00:00:00 2001 From: Linlang Date: Fri, 27 Oct 2023 15:48:42 +0800 Subject: [PATCH 18/25] drop warning during code runs --- scripts/data_collector/utils.py | 2 +- scripts/data_collector/yahoo/collector.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 31170de7d0..1bcace33c7 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -686,7 +686,7 @@ def _calc_factor(df_1d: pd.DataFrame): df_1d["paused"] = np.nan return df_1d - df = df.groupby([df[_date_field_name].dt.date]).apply(_calc_factor) + df = df.groupby([df[_date_field_name].dt.date], group_keys=False).apply(_calc_factor) if consistent_1d: # the date sequence is consistent with 1d df.set_index(_date_field_name, inplace=True) diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 869eb63fc2..1d45d93afa 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -971,7 +971,7 @@ def update_data_to_bin( # start/end date calendar_df = pd.read_csv(Path(qlib_data_1d_dir).joinpath("calendars/day.txt")) - trading_date = (pd.Timestamp(calendar_df.iloc[-1, 0]) - pd.Timedelta(days=2)).strftime("%Y-%m-%d") + trading_date = (pd.Timestamp(calendar_df.iloc[-1, 0]) - pd.Timedelta(days=1)).strftime("%Y-%m-%d") if end_date is None: end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") From efd43c10fb9698cdb8a9e7db9fd0f37d63a9e309 Mon Sep 17 00:00:00 2001 From: Linlang Date: Fri, 27 Oct 2023 15:50:29 +0800 Subject: [PATCH 19/25] optimize code --- scripts/data_collector/baostock_5min/collector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/data_collector/baostock_5min/collector.py b/scripts/data_collector/baostock_5min/collector.py index e188ec024f..337a414aad 100644 --- a/scripts/data_collector/baostock_5min/collector.py +++ b/scripts/data_collector/baostock_5min/collector.py @@ -60,7 +60,6 @@ def __init__( using for debug, by default None """ bs.login() - interval = "5min" super(BaostockCollectorHS3005min, self).__init__( save_dir=save_dir, start=start, From 3b70cf51ad765cad8ffc6591e747a1a9523cd467 Mon Sep 17 00:00:00 2001 From: Linlang Date: Fri, 27 Oct 2023 15:51:39 +0800 Subject: [PATCH 20/25] format with black --- scripts/data_collector/yahoo/collector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 1d45d93afa..96953d57f4 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -968,7 +968,7 @@ def update_data_to_bin( GetData().qlib_data( target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region, exists_skip=exists_skip ) - + # start/end date calendar_df = pd.read_csv(Path(qlib_data_1d_dir).joinpath("calendars/day.txt")) trading_date = (pd.Timestamp(calendar_df.iloc[-1, 0]) - pd.Timedelta(days=1)).strftime("%Y-%m-%d") From 3af3cae81aaa140489b0f487fec12e833593141b Mon Sep 17 00:00:00 2001 From: Linlang Date: Wed, 1 Nov 2023 15:38:21 +0800 Subject: [PATCH 21/25] fix bug --- scripts/data_collector/yahoo/collector.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 96953d57f4..11ba886f7f 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -390,20 +390,26 @@ def normalize_yahoo( return df symbol = df.loc[df[symbol_field_name].first_valid_index(), symbol_field_name] columns = copy.deepcopy(YahooNormalize.COLUMNS) - df = df.copy() df.set_index(date_field_name, inplace=True) df.index = pd.to_datetime(df.index) df = df[~df.index.duplicated(keep="first")] + df_tmp = df.copy() if calendar_list is not None: - df = df.reindex( + df_tmp = df_tmp.reindex( pd.DataFrame(index=calendar_list) .loc[ - pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + pd.Timestamp(df_tmp.index.min()).date() : pd.Timestamp(df_tmp.index.max()).date() + pd.Timedelta(hours=23, minutes=59) ] .index ) - df.sort_index(inplace=True) + df_tmp.index = pd.to_datetime(df_tmp.index) + df_tmp.sort_index(inplace=True) + df_tmp.index = df_tmp.index.tz_localize(None) + df.index = df.index.tz_localize(None) + df_tmp['symbol'] = df.iloc[0]['symbol'] + df_tmp = df_tmp.drop(columns=['open', 'high', 'low', 'close', 'volume']) + df = df_tmp.merge(df[["open", "high", "low", "close", "volume"]], left_index=True, right_index=True, how='left') df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), list(set(df.columns) - {symbol_field_name})] = np.nan change_series = YahooNormalize.calc_change(df, last_close) @@ -536,6 +542,9 @@ def normalize(self, df: pd.DataFrame) -> pd.DataFrame: df = super(YahooNormalize1dExtend, self).normalize(df) df.set_index(self._date_field_name, inplace=True) symbol_name = df[self._symbol_field_name].iloc[0] + old_symbol_list = self.old_qlib_data.index.get_level_values('instrument').unique().to_list() + if str(symbol_name).upper() not in old_symbol_list: + return df.reset_index() old_df = self.old_qlib_data.loc[str(symbol_name).upper()] latest_date = old_df.index[-1] df = df.loc[latest_date:] From 34e9553eb40197c805cba39aaca743a2159f7e92 Mon Sep 17 00:00:00 2001 From: Linlang Date: Wed, 1 Nov 2023 15:38:55 +0800 Subject: [PATCH 22/25] format with black --- scripts/data_collector/yahoo/collector.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 11ba886f7f..5717ac95de 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -407,9 +407,9 @@ def normalize_yahoo( df_tmp.sort_index(inplace=True) df_tmp.index = df_tmp.index.tz_localize(None) df.index = df.index.tz_localize(None) - df_tmp['symbol'] = df.iloc[0]['symbol'] - df_tmp = df_tmp.drop(columns=['open', 'high', 'low', 'close', 'volume']) - df = df_tmp.merge(df[["open", "high", "low", "close", "volume"]], left_index=True, right_index=True, how='left') + df_tmp["symbol"] = df.iloc[0]["symbol"] + df_tmp = df_tmp.drop(columns=["open", "high", "low", "close", "volume"]) + df = df_tmp.merge(df[["open", "high", "low", "close", "volume"]], left_index=True, right_index=True, how="left") df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), list(set(df.columns) - {symbol_field_name})] = np.nan change_series = YahooNormalize.calc_change(df, last_close) @@ -542,7 +542,7 @@ def normalize(self, df: pd.DataFrame) -> pd.DataFrame: df = super(YahooNormalize1dExtend, self).normalize(df) df.set_index(self._date_field_name, inplace=True) symbol_name = df[self._symbol_field_name].iloc[0] - old_symbol_list = self.old_qlib_data.index.get_level_values('instrument').unique().to_list() + old_symbol_list = self.old_qlib_data.index.get_level_values("instrument").unique().to_list() if str(symbol_name).upper() not in old_symbol_list: return df.reset_index() old_df = self.old_qlib_data.loc[str(symbol_name).upper()] From 0dba040503d25e7b0915a778a7db78910cbe6aba Mon Sep 17 00:00:00 2001 From: Linlang Date: Fri, 10 Nov 2023 15:10:18 +0800 Subject: [PATCH 23/25] optimize code --- scripts/data_collector/yahoo/collector.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index 5717ac95de..b9e25e6d80 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -390,26 +390,21 @@ def normalize_yahoo( return df symbol = df.loc[df[symbol_field_name].first_valid_index(), symbol_field_name] columns = copy.deepcopy(YahooNormalize.COLUMNS) + df = df.copy() df.set_index(date_field_name, inplace=True) df.index = pd.to_datetime(df.index) df = df[~df.index.duplicated(keep="first")] - df_tmp = df.copy() + calendar_list = calendar_list.tz_localize("Asia/Shanghai") if calendar_list is not None: - df_tmp = df_tmp.reindex( + df = df.reindex( pd.DataFrame(index=calendar_list) .loc[ - pd.Timestamp(df_tmp.index.min()).date() : pd.Timestamp(df_tmp.index.max()).date() + pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + pd.Timedelta(hours=23, minutes=59) ] .index ) - df_tmp.index = pd.to_datetime(df_tmp.index) - df_tmp.sort_index(inplace=True) - df_tmp.index = df_tmp.index.tz_localize(None) - df.index = df.index.tz_localize(None) - df_tmp["symbol"] = df.iloc[0]["symbol"] - df_tmp = df_tmp.drop(columns=["open", "high", "low", "close", "volume"]) - df = df_tmp.merge(df[["open", "high", "low", "close", "volume"]], left_index=True, right_index=True, how="left") + df.sort_index(inplace=True) df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), list(set(df.columns) - {symbol_field_name})] = np.nan change_series = YahooNormalize.calc_change(df, last_close) From 6a95c866fb76f9c4247b53d1349c982170191eaf Mon Sep 17 00:00:00 2001 From: Linlang Date: Fri, 10 Nov 2023 17:04:22 +0800 Subject: [PATCH 24/25] optimize code --- scripts/data_collector/yahoo/collector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index b9e25e6d80..25e2963883 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -393,8 +393,8 @@ def normalize_yahoo( df = df.copy() df.set_index(date_field_name, inplace=True) df.index = pd.to_datetime(df.index) + df.index = df.index.tz_localize(None) df = df[~df.index.duplicated(keep="first")] - calendar_list = calendar_list.tz_localize("Asia/Shanghai") if calendar_list is not None: df = df.reindex( pd.DataFrame(index=calendar_list) From e95cecb7cea16b86e9ab82242e8f8b8e993e0fa3 Mon Sep 17 00:00:00 2001 From: Linlang Date: Tue, 21 Nov 2023 19:53:55 +0800 Subject: [PATCH 25/25] add comments --- scripts/data_collector/utils.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index 1bcace33c7..596eae60ef 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -647,6 +647,21 @@ def calc_adjusted_price( consistent_1d: bool = True, calc_paused: bool = True, ) -> pd.DataFrame: + """calc adjusted price + This method does 4 things. + 1. Adds the `paused` field. + - The added paused field comes from the paused field of the 1d data. + 2. Aligns the time of the 1d data. + 3. The data is reweighted. + - The reweighting method: + - volume / factor + - open * factor + - high * factor + - low * factor + - close * factor + 4. Called `calc_paused_num` method to add the `paused_num` field. + - The `paused_num` is the number of consecutive days of trading suspension. + """ # TODO: using daily data factor if df.empty: return df @@ -714,6 +729,10 @@ def _calc_factor(df_1d: pd.DataFrame): def calc_paused_num(df: pd.DataFrame, _date_field_name, _symbol_field_name): + """calc paused num + This method adds the paused_num field + - The `paused_num` is the number of consecutive days of trading suspension. + """ _symbol = df.iloc[0][_symbol_field_name] df = df.copy() df["_tmp_date"] = df[_date_field_name].apply(lambda x: pd.Timestamp(x).date())