-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
97 lines (76 loc) · 3.14 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import random
import pandas as pd
import tensortrade.stochastic as sp
import tensortrade.env.default as default
from tensortrade.feed.core import Stream, DataFeed
from tensortrade.data.cdd import CryptoDataDownload
from tensortrade.oms.wallets import Portfolio, Wallet
from tensortrade.oms.exchanges import Exchange
from tensortrade.oms.services.execution.simulated import execute_order
from tensortrade.oms.instruments import USD, BTC, ETH, LTC
from tensortrade.agents import DQNAgent
from tensortrade.env.default.renderers import PlotlyTradingChart, FileLogger
def run():
cdd = CryptoDataDownload()
bitfinex_btc = cdd.fetch("Bitfinex", "USD", "BTC", "1h")
bitfinex = Exchange("bitfinex", service=execute_order)(
Stream.source(list(bitfinex_btc['close']), dtype="float").rename("USD-BTC"),
)
portfolio = Portfolio(USD, [
Wallet(bitfinex, 10000 * USD),
Wallet(bitfinex, 10 * BTC)
])
features = []
data = bitfinex_btc
for c in data.columns[1:]:
s = Stream.source(list(data[c]), dtype="float").rename(data[c].name)
features += [s]
cp = Stream.select(features, lambda s: s.name == "close")
features = [
cp.log().diff().rename("lr"),
rsi(cp, period=20).rename("rsi"),
macd(cp, fast=10, slow=50, signal=5).rename("macd")
]
feed = DataFeed(features)
feed.compile()
renderer_feed = DataFeed([
Stream.source(list(data["date"])).rename("date"),
Stream.source(list(data["open"]), dtype="float").rename("open"),
Stream.source(list(data["high"]), dtype="float").rename("high"),
Stream.source(list(data["low"]), dtype="float").rename("low"),
Stream.source(list(data["close"]), dtype="float").rename("close"),
Stream.source(list(data["volume"]), dtype="float").rename("volume")
])
chart_renderer = PlotlyTradingChart(
display=True, # show the chart on screen (default)
height=800, # affects both displayed and saved file height. None for 100% height.
save_format="html", # save the chart to an HTML file
auto_open_html=True, # open the saved HTML chart in a new browser tab
)
env = default.create(
portfolio=portfolio,
action_scheme="managed-risk",
reward_scheme="risk-adjusted",
feed=feed,
renderer_feed=renderer_feed,
renderer='file-log',
window_size=20
)
agent = DQNAgent(env)
agent.train(n_steps=200, n_episodes=10, save_path="agents/", render_interval=100)
portfolio.performance.net_worth.plot()
return
def rsi(price: Stream[float], period: float) -> Stream[float]:
r = price.diff()
upside = r.clamp_min(0).abs()
downside = r.clamp_max(0).abs()
rs = upside.ewm(alpha=1 / period).mean() / downside.ewm(alpha=1 / period).mean()
return 100 * (1 - (1 + rs) ** -1)
def macd(price: Stream[float], fast: float, slow: float, signal: float) -> Stream[float]:
fm = price.ewm(span=fast, adjust=False).mean()
sm = price.ewm(span=slow, adjust=False).mean()
md = fm - sm
signal = md - md.ewm(span=signal, adjust=False).mean()
return signal
if __name__ == "__main__":
run()