-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
77 lines (63 loc) · 3.03 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import datetime
from RL.env import StockTradingEnv
import RL.config as config
from RL.train_test import *
from erl.plot import backtest_stats, backtest_plot
env = StockTradingEnv
TRAIN_START_DATE = '2014-01-01'
TRAIN_END_DATE = '2020-07-30'
TEST_START_DATE = '2020-08-01'
TEST_END_DATE = '2021-10-01'
TECHNICAL_INDICATORS_LIST = ['macd',
'boll_ub',
'boll_lb',
'rsi_30',
'dx_30',
'close_30_sma',
'close_60_sma']
ERL_PARAMS = {"learning_rate": 3e-5,"batch_size": 2048,"gamma": 0.985,
"seed":312,"net_dimension":512, "target_step":5000, "eval_gap":60}
#demo for elegantrl
account_value_train = train(start_date = TRAIN_START_DATE,
end_date = TRAIN_END_DATE,
ticker_list = config.DOW_30_TICKER,
data_source = 'yahoofinance',
time_interval= '1D',
technical_indicator_list= TECHNICAL_INDICATORS_LIST,
drl_lib='elegantrl',
env=env,
model_name='ddpg',
cwd='./test_'+ 'ddpg',
erl_params=ERL_PARAMS,
break_step=1e5)
account_value_erl=test(start_date = TEST_START_DATE,
end_date = TEST_END_DATE,
ticker_list = config.DOW_30_TICKER,
data_source = 'yahoofinance',
time_interval= '1D',
technical_indicator_list= TECHNICAL_INDICATORS_LIST,
drl_lib='elegantrl',
env=env,
model_name='ddpg',
cwd='./test_ddpg',
net_dimension = 512)
####Plot
baseline_df = DataProcessor('yahoofinance').download_data(ticker_list = ["^DJI"],
start_date = TEST_START_DATE,
end_date = TEST_END_DATE,
time_interval = "1D")
stats = backtest_stats(baseline_df, value_col_name = 'close')
account_value_erl = pd.DataFrame({'date':baseline_df.date,'account_value':account_value_erl[0:len(account_value_erl)-1]})
print("==============Get Backtest Results===========")
now = datetime.datetime.now().strftime('%Y%m%d-%Hh%M')
perf_stats_all = backtest_stats(account_value=account_value_erl)
perf_stats_all = pd.DataFrame(perf_stats_all)
perf_stats_all.to_csv("./"+"/perf_stats_all_"+".csv.")
print("==============Compare to DJIA===========")
# S&P 500: ^GSPC
# Dow Jones Index: ^DJI
# NASDAQ 100: ^NDX
backtest_plot(account_value_erl,
baseline_ticker = '^DJI',
baseline_start = account_value_erl.loc[0,'date'],
baseline_end = account_value_erl.loc[len(account_value_erl)-1,'date'])