Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate backtest logic from NT #1263

Merged
merged 16 commits into from
Sep 19, 2022
8 changes: 5 additions & 3 deletions qlib/backtest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def get_exchange(
def create_account_instance(
start_time: Union[pd.Timestamp, str],
end_time: Union[pd.Timestamp, str],
benchmark: str,
benchmark: Optional[str],
account: Union[float, int, dict],
pos_type: str = "Position",
) -> Account:
Expand Down Expand Up @@ -163,7 +163,9 @@ def create_account_instance(
init_cash=init_cash,
position_dict=position_dict,
pos_type=pos_type,
benchmark_config={
benchmark_config={}
if benchmark is None
else {
"benchmark": benchmark,
"start_time": start_time,
"end_time": end_time,
Expand All @@ -176,7 +178,7 @@ def get_strategy_executor(
end_time: Union[pd.Timestamp, str],
strategy: Union[str, dict, object, Path],
executor: Union[str, dict, object, Path],
benchmark: str = "SH000300",
benchmark: Optional[str] = "SH000300",
account: Union[float, int, dict] = 1e9,
exchange_kwargs: dict = {},
pos_type: str = "Position",
Expand Down
231 changes: 231 additions & 0 deletions qlib/rl/contrib/backtest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import copy
import pickle
import sys
from pathlib import Path
from typing import Optional, Tuple, Union

import numpy as np
import pandas as pd
import torch
from joblib import Parallel, delayed

from qlib.backtest import collect_data_loop, get_strategy_executor
from qlib.backtest.decision import TradeRangeByTime
from qlib.backtest.executor import BaseExecutor, NestedExecutor, SimulatorExecutor
from qlib.backtest.high_performance_ds import BaseOrderIndicator
from qlib.rl.contrib.naive_config_parser import get_backtest_config_fromfile
from qlib.rl.contrib.utils import read_order_file
from qlib.rl.data.integration import init_qlib
from qlib.rl.utils.env_wrapper import CollectDataEnvWrapper


def _get_multi_level_executor_config(
strategy_config: dict,
cash_limit: float = None,
generate_report: bool = False,
) -> dict:
executor_config = {
"class": "SimulatorExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": "1min",
"verbose": False,
"trade_type": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL,
"generate_report": generate_report,
"track_data": True,
},
}

freqs = list(strategy_config.keys())
freqs.sort(key=lambda x: pd.Timedelta(x))
for freq in freqs:
executor_config = {
"class": "NestedExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": freq,
"inner_strategy": strategy_config[freq],
"inner_executor": executor_config,
"track_data": True,
},
}

return executor_config


def _set_env_for_all_strategy(executor: BaseExecutor) -> None:
if isinstance(executor, NestedExecutor):
if hasattr(executor.inner_strategy, "set_env"):
env = CollectDataEnvWrapper()
env.reset()
executor.inner_strategy.set_env(env)
_set_env_for_all_strategy(executor.inner_executor)


def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]:
record_list = []
for time, value_dict in indicator.items():
if isinstance(value_dict, BaseOrderIndicator):
# HACK: for qlib v0.8
value_dict = value_dict.to_series()
try:
value_dict = {k: v for k, v in value_dict.items()}
if value_dict["ffr"].empty:
continue
except Exception:
value_dict = {k: v for k, v in value_dict.items() if k != "pa"}
value_dict = pd.DataFrame(value_dict)
value_dict["datetime"] = time
record_list.append(value_dict)

if not record_list:
return None

records: pd.DataFrame = pd.concat(record_list, 0).reset_index().rename(columns={"index": "instrument"})
records = records.set_index(["instrument", "datetime"])
return records


def _generate_report(decisions: list, report_dict: dict) -> dict:
report = {}
decision_details = pd.concat([d.details for d in decisions if hasattr(d, "details")])
for key in ["1minute", "5minute", "30minute", "1day"]:
if key not in report_dict["indicator"]:
continue
report[key] = report_dict["indicator"][key]
report[key + "_obj"] = _convert_indicator_to_dataframe(
report_dict["indicator"][key + "_obj"].order_indicator_his
)
cur_details = decision_details[decision_details.freq == key.rstrip("ute")].set_index(["instrument", "datetime"])
if len(cur_details) > 0:
cur_details.pop("freq")
report[key + "_obj"] = report[key + "_obj"].join(cur_details, how="outer")
if "1minute" in report_dict["report"]:
report["simulator"] = report_dict["report"]["1minute"][0]
return report


def single(
backtest_config: dict,
orders: pd.DataFrame,
split: str = "stock",
cash_limit: float = None,
generate_report: bool = False,
) -> Union[Tuple[pd.DataFrame, dict], pd.DataFrame]:
if split == "stock":
stock_id = orders.iloc[0].instrument
init_qlib(backtest_config["qlib"], part=stock_id)
else:
day = orders.iloc[0].datetime
init_qlib(backtest_config["qlib"], part=day)

trade_start_time = orders["datetime"].min()
trade_end_time = orders["datetime"].max()
stocks = orders.instrument.unique().tolist()

top_strategy_config = {
"class": "FileOrderStrategy",
"module_path": "qlib.contrib.strategy.rule_strategy",
"kwargs": {
"file": orders,
"trade_range": TradeRangeByTime(
pd.Timestamp(backtest_config["start_time"]).time(),
pd.Timestamp(backtest_config["end_time"]).time(),
),
},
}

top_executor_config = _get_multi_level_executor_config(
strategy_config=backtest_config["strategies"],
cash_limit=cash_limit,
generate_report=generate_report,
)

tmp_backtest_config = copy.deepcopy(backtest_config["exchange"])
tmp_backtest_config.update(
{
"codes": stocks,
"freq": "1min",
}
)

strategy, executor = get_strategy_executor(
start_time=pd.Timestamp(trade_start_time),
end_time=pd.Timestamp(trade_end_time) + pd.DateOffset(1),
strategy=top_strategy_config,
executor=top_executor_config,
benchmark=None,
account=cash_limit if cash_limit is not None else int(1e12),
exchange_kwargs=tmp_backtest_config,
pos_type="Position" if cash_limit is not None else "InfPosition",
)
_set_env_for_all_strategy(executor=executor)

report_dict: dict = {}
decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict))

records = _convert_indicator_to_dataframe(report_dict["indicator"]["1day_obj"].order_indicator_his)
assert records is None or not np.isnan(records["ffr"]).any()

if generate_report:
report = _generate_report(decisions, report_dict)
if split == "stock":
stock_id = orders.iloc[0].instrument
report = {stock_id: report}
else:
day = orders.iloc[0].datetime
report = {day: report}
return records, report
else:
return records


def backtest(backtest_config: dict) -> pd.DataFrame:
order_df = read_order_file(backtest_config["order_file"])

cash_limit = backtest_config["exchange"].pop("cash_limit")
generate_report = backtest_config["exchange"].pop("generate_report")

stock_pool = order_df["instrument"].unique().tolist()
stock_pool.sort()

mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"}
torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199
res = Parallel(**mp_config)(
delayed(single)(
backtest_config=backtest_config,
orders=order_df[order_df["instrument"] == stock].copy(),
split="stock",
cash_limit=cash_limit,
generate_report=generate_report,
)
for stock in stock_pool
)

output_path = Path(backtest_config["output_dir"])
if generate_report:
with (output_path / "report.pkl").open("wb") as f:
report = {}
for r in res:
report.update(r[1])
pickle.dump(report, f)
res = pd.concat([r[0] for r in res], 0)
else:
res = pd.concat(res)

res.to_csv(output_path / "summary.csv")
return res


if __name__ == "__main__":
import warnings

warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

path = sys.argv[1]
backtest(get_backtest_config_fromfile(path))
103 changes: 103 additions & 0 deletions qlib/rl/contrib/naive_config_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import platform
import shutil
import sys
import tempfile
from importlib import import_module

import yaml


def merge_a_into_b(a: dict, b: dict) -> dict:
b = b.copy()
for k, v in a.items():
if isinstance(v, dict) and k in b:
v.pop("_delete_", False) # TODO: make this more elegant
b[k] = merge_a_into_b(v, b[k])
else:
b[k] = v
return b


def check_file_exist(filename: str, msg_tmpl: str = 'file "{}" does not exist') -> None:
if not os.path.isfile(filename):
raise FileNotFoundError(msg_tmpl.format(filename))


def parse_backtest_config(path: str) -> dict:
abs_path = os.path.abspath(path)
check_file_exist(abs_path)

file_ext_name = os.path.splitext(abs_path)[1]
if file_ext_name not in (".py", ".json", ".yaml", ".yml"):
raise IOError("Only py/yml/yaml/json type are supported now!")

with tempfile.TemporaryDirectory() as tmp_config_dir:
with tempfile.NamedTemporaryFile(dir=tmp_config_dir, suffix=file_ext_name) as tmp_config_file:
if platform.system() == "Windows":
tmp_config_file.close()

tmp_config_name = os.path.basename(tmp_config_file.name)
shutil.copyfile(abs_path, tmp_config_file.name)

if abs_path.endswith(".py"):
tmp_module_name = os.path.splitext(tmp_config_name)[0]
sys.path.insert(0, tmp_config_dir)
module = import_module(tmp_module_name)
sys.path.pop(0)

config = {k: v for k, v in module.__dict__.items() if not k.startswith("__")}

del sys.modules[tmp_module_name]
else:
config = yaml.safe_load(open(tmp_config_file.name))

if "_base_" in config:
base_file_name = config.pop("_base_")
if not isinstance(base_file_name, list):
base_file_name = [base_file_name]

for f in base_file_name:
base_config = parse_backtest_config(os.path.join(os.path.dirname(abs_path), f))
config = merge_a_into_b(a=config, b=base_config)

return config


def _convert_all_list_to_tuple(config: dict) -> dict:
for k, v in config.items():
if isinstance(v, list):
config[k] = tuple(v)
elif isinstance(v, dict):
config[k] = _convert_all_list_to_tuple(v)
return config


def get_backtest_config_fromfile(path: str) -> dict:
backtest_config = parse_backtest_config(path)

exchange_config_default = {
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5.0,
"trade_unit": 100.0,
"cash_limit": None,
"generate_report": False,
}
backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default)
backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"])

backtest_config_default = {
"debug_single_stock": None,
"debug_single_day": None,
"concurrency": -1,
"multiplier": 1.0,
"output_dir": "outputs/",
# "runtime": {},
}
backtest_config = merge_a_into_b(a=backtest_config, b=backtest_config_default)

return backtest_config
29 changes: 29 additions & 0 deletions qlib/rl/contrib/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from __future__ import annotations

from pathlib import Path

import pandas as pd


def read_order_file(order_file: Path | pd.DataFrame) -> pd.DataFrame:
if isinstance(order_file, pd.DataFrame):
return order_file

order_file = Path(order_file)

if order_file.suffix == ".pkl":
order_df = pd.read_pickle(order_file).reset_index()
elif order_file.suffix == ".csv":
order_df = pd.read_csv(order_file)
else:
raise TypeError(f"Unsupported order file type: {order_file}")

if "date" in order_df.columns:
# legacy dataframe columns
order_df = order_df.rename(columns={"date": "datetime", "order_type": "direction"})
order_df["datetime"] = order_df["datetime"].astype(str)

return order_df
Loading