Adilbai commited on
Commit
34aa90f
·
verified ·
1 Parent(s): 392748e

Upload 2 files

Browse files
Files changed (2) hide show
  1. dataprocessor.py +431 -0
  2. enviromentcreator.py +463 -0
dataprocessor.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yfinance as yf
2
+ import pandas as pd
3
+ import numpy as np
4
+ from typing import List, Dict, Optional, Tuple
5
+ import os
6
+ import logging
7
+ from datetime import datetime, timedelta
8
+ import pickle
9
+ from concurrent.futures import ThreadPoolExecutor, as_completed
10
+ import time
11
+ from sklearn.preprocessing import MinMaxScaler, StandardScaler
12
+ import warnings
13
+ warnings.filterwarnings('ignore')
14
+
15
+ # Set up logging
16
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class StockDataProcessor:
20
+ """
21
+ A comprehensive class for downloading, processing, and preprocessing stock data
22
+ from Yahoo Finance for reinforcement learning applications.
23
+ """
24
+
25
+ def __init__(self, data_dir: str = "stock_data", cache_dir: str = "cache"):
26
+ self.data_dir = data_dir
27
+ self.cache_dir = cache_dir
28
+ self.scalers = {}
29
+
30
+ # Create directories if they don't exist
31
+ os.makedirs(data_dir, exist_ok=True)
32
+ os.makedirs(cache_dir, exist_ok=True)
33
+
34
+ def get_sp500_tickers(self) -> List[str]:
35
+ """Get S&P 500 stock tickers"""
36
+ try:
37
+ # Download S&P 500 list from Wikipedia
38
+ url = 'https://en.wikipedia.org/wiki/List_of_S%26P_500_companies'
39
+ tables = pd.read_html(url)
40
+ sp500_table = tables[0]
41
+ tickers = sp500_table['Symbol'].tolist()
42
+ # Clean tickers (remove dots, etc.)
43
+ tickers = [ticker.replace('.', '-') for ticker in tickers]
44
+ logger.info(f"Retrieved {len(tickers)} S&P 500 tickers")
45
+ return tickers
46
+ except Exception as e:
47
+ logger.error(f"Error fetching S&P 500 tickers: {e}")
48
+ # Fallback to a smaller list of popular stocks
49
+ return ['AAPL', 'MSFT', 'GOOGL', 'AMZN', 'TSLA', 'META', 'NVDA', 'JPM', 'JNJ', 'V']
50
+
51
+ def download_stock_data(self,
52
+ ticker: str,
53
+ period: str = "10y",
54
+ interval: str = "1d") -> Optional[pd.DataFrame]:
55
+ """
56
+ Download stock data for a single ticker
57
+
58
+ Args:
59
+ ticker: Stock symbol
60
+ period: Time period (1d, 5d, 1mo, 3mo, 6mo, 1y, 2y, 5y, 10y, ytd, max)
61
+ interval: Data interval (1m, 2m, 5m, 15m, 30m, 60m, 90m, 1h, 1d, 5d, 1wk, 1mo, 3mo)
62
+ """
63
+ try:
64
+ stock = yf.Ticker(ticker)
65
+ data = stock.history(period=period, interval=interval)
66
+
67
+ if data.empty:
68
+ logger.warning(f"No data found for {ticker}")
69
+ return None
70
+
71
+ # Add ticker column
72
+ data['Ticker'] = ticker
73
+ data.reset_index(inplace=True)
74
+
75
+ logger.info(f"Downloaded {len(data)} records for {ticker}")
76
+ return data
77
+
78
+ except Exception as e:
79
+ logger.error(f"Error downloading data for {ticker}: {e}")
80
+ return None
81
+
82
+ def download_multiple_stocks(self,
83
+ tickers: List[str],
84
+ period: str = "10y",
85
+ interval: str = "1d",
86
+ max_workers: int = 10) -> pd.DataFrame:
87
+ """
88
+ Download stock data for multiple tickers using parallel processing
89
+ """
90
+ all_data = []
91
+
92
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
93
+ # Submit all download tasks
94
+ future_to_ticker = {
95
+ executor.submit(self.download_stock_data, ticker, period, interval): ticker
96
+ for ticker in tickers
97
+ }
98
+
99
+ # Collect results
100
+ for future in as_completed(future_to_ticker):
101
+ ticker = future_to_ticker[future]
102
+ try:
103
+ data = future.result()
104
+ if data is not None:
105
+ all_data.append(data)
106
+ except Exception as e:
107
+ logger.error(f"Error processing {ticker}: {e}")
108
+
109
+ # Rate limiting
110
+ time.sleep(0.1)
111
+
112
+ if all_data:
113
+ combined_data = pd.concat(all_data, ignore_index=True)
114
+ logger.info(f"Combined data shape: {combined_data.shape}")
115
+ return combined_data
116
+ else:
117
+ return pd.DataFrame()
118
+
119
+ def calculate_technical_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
120
+ """
121
+ Calculate technical indicators for each stock
122
+ """
123
+ logger.info("Calculating technical indicators...")
124
+
125
+ result_dfs = []
126
+
127
+ for ticker in df['Ticker'].unique():
128
+ ticker_data = df[df['Ticker'] == ticker].copy()
129
+ ticker_data = ticker_data.sort_values('Date')
130
+
131
+ # Moving averages
132
+ ticker_data['SMA_5'] = ticker_data['Close'].rolling(window=5).mean()
133
+ ticker_data['SMA_10'] = ticker_data['Close'].rolling(window=10).mean()
134
+ ticker_data['SMA_20'] = ticker_data['Close'].rolling(window=20).mean()
135
+ ticker_data['SMA_50'] = ticker_data['Close'].rolling(window=50).mean()
136
+
137
+ # Exponential moving averages
138
+ ticker_data['EMA_12'] = ticker_data['Close'].ewm(span=12).mean()
139
+ ticker_data['EMA_26'] = ticker_data['Close'].ewm(span=26).mean()
140
+
141
+ # MACD
142
+ ticker_data['MACD'] = ticker_data['EMA_12'] - ticker_data['EMA_26']
143
+ ticker_data['MACD_Signal'] = ticker_data['MACD'].ewm(span=9).mean()
144
+ ticker_data['MACD_Histogram'] = ticker_data['MACD'] - ticker_data['MACD_Signal']
145
+
146
+ # RSI
147
+ delta = ticker_data['Close'].diff()
148
+ gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
149
+ loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
150
+ rs = gain / loss
151
+ ticker_data['RSI'] = 100 - (100 / (1 + rs))
152
+
153
+ # Bollinger Bands
154
+ ticker_data['BB_Middle'] = ticker_data['Close'].rolling(window=20).mean()
155
+ bb_std = ticker_data['Close'].rolling(window=20).std()
156
+ ticker_data['BB_Upper'] = ticker_data['BB_Middle'] + (bb_std * 2)
157
+ ticker_data['BB_Lower'] = ticker_data['BB_Middle'] - (bb_std * 2)
158
+ ticker_data['BB_Width'] = ticker_data['BB_Upper'] - ticker_data['BB_Lower']
159
+ ticker_data['BB_Position'] = (ticker_data['Close'] - ticker_data['BB_Lower']) / ticker_data['BB_Width']
160
+
161
+ # Volatility
162
+ ticker_data['Volatility'] = ticker_data['Close'].rolling(window=20).std()
163
+
164
+ # Price change features
165
+ ticker_data['Price_Change'] = ticker_data['Close'].pct_change()
166
+ ticker_data['Price_Change_5d'] = ticker_data['Close'].pct_change(periods=5)
167
+ ticker_data['High_Low_Ratio'] = ticker_data['High'] / ticker_data['Low']
168
+ ticker_data['Open_Close_Ratio'] = ticker_data['Open'] / ticker_data['Close']
169
+
170
+ # Volume features
171
+ ticker_data['Volume_SMA'] = ticker_data['Volume'].rolling(window=20).mean()
172
+ ticker_data['Volume_Ratio'] = ticker_data['Volume'] / ticker_data['Volume_SMA']
173
+
174
+ result_dfs.append(ticker_data)
175
+
176
+ result = pd.concat(result_dfs, ignore_index=True)
177
+ logger.info(f"Technical indicators calculated. New shape: {result.shape}")
178
+ return result
179
+
180
+ def create_lagged_features(self, df: pd.DataFrame, lags: List[int] = [1, 2, 3, 5, 10]) -> pd.DataFrame:
181
+ """
182
+ Create lagged features for time series analysis
183
+ """
184
+ logger.info("Creating lagged features...")
185
+
186
+ result_dfs = []
187
+ feature_columns = ['Close', 'Volume', 'Price_Change', 'RSI', 'MACD', 'Volatility']
188
+
189
+ for ticker in df['Ticker'].unique():
190
+ ticker_data = df[df['Ticker'] == ticker].copy()
191
+ ticker_data = ticker_data.sort_values('Date')
192
+
193
+ for col in feature_columns:
194
+ if col in ticker_data.columns:
195
+ for lag in lags:
196
+ ticker_data[f'{col}_lag_{lag}'] = ticker_data[col].shift(lag)
197
+
198
+ result_dfs.append(ticker_data)
199
+
200
+ result = pd.concat(result_dfs, ignore_index=True)
201
+ logger.info(f"Lagged features created. New shape: {result.shape}")
202
+ return result
203
+
204
+ def create_future_returns(self, df: pd.DataFrame, horizons: List[int] = [1, 5, 10, 20]) -> pd.DataFrame:
205
+ """
206
+ Create future return targets for prediction
207
+ """
208
+ logger.info("Creating future return targets...")
209
+
210
+ result_dfs = []
211
+
212
+ for ticker in df['Ticker'].unique():
213
+ ticker_data = df[df['Ticker'] == ticker].copy()
214
+ ticker_data = ticker_data.sort_values('Date')
215
+
216
+ for horizon in horizons:
217
+ ticker_data[f'Future_Return_{horizon}d'] = ticker_data['Close'].shift(-horizon) / ticker_data['Close'] - 1
218
+
219
+ # Create binary classification targets
220
+ ticker_data[f'Future_Up_{horizon}d'] = (ticker_data[f'Future_Return_{horizon}d'] > 0).astype(int)
221
+
222
+ # Create categorical targets (strong down, down, up, strong up)
223
+ returns = ticker_data[f'Future_Return_{horizon}d']
224
+ ticker_data[f'Future_Category_{horizon}d'] = pd.cut(
225
+ returns,
226
+ bins=[-np.inf, -0.02, 0, 0.02, np.inf],
227
+ labels=[0, 1, 2, 3]
228
+ ).astype(float)
229
+
230
+ result_dfs.append(ticker_data)
231
+
232
+ result = pd.concat(result_dfs, ignore_index=True)
233
+ logger.info(f"Future return targets created. New shape: {result.shape}")
234
+ return result
235
+
236
+ def clean_and_normalize_data(self, df: pd.DataFrame) -> pd.DataFrame:
237
+ """
238
+ Clean and normalize the data for ML/RL
239
+ """
240
+ logger.info("Cleaning and normalizing data...")
241
+
242
+ # Remove rows with too many NaN values
243
+ df = df.dropna(thresh=len(df.columns) * 0.7)
244
+
245
+ # Forward fill remaining NaN values
246
+ numeric_columns = df.select_dtypes(include=[np.number]).columns
247
+ df[numeric_columns] = df[numeric_columns].fillna(method='ffill')
248
+
249
+ # Remove infinite values
250
+ df = df.replace([np.inf, -np.inf], np.nan)
251
+ df = df.dropna()
252
+
253
+ logger.info(f"Data cleaned. Final shape: {df.shape}")
254
+ return df
255
+
256
+ def create_rl_states_actions(self, df: pd.DataFrame) -> Dict:
257
+ """
258
+ Create state and action spaces suitable for reinforcement learning
259
+ """
260
+ logger.info("Creating RL state and action representations...")
261
+
262
+ # Define state features (technical indicators and market data)
263
+ state_features = [
264
+ 'Open', 'High', 'Low', 'Close', 'Volume',
265
+ 'SMA_5', 'SMA_10', 'SMA_20', 'SMA_50',
266
+ 'EMA_12', 'EMA_26', 'MACD', 'MACD_Signal', 'RSI',
267
+ 'BB_Position', 'BB_Width', 'Volatility',
268
+ 'Price_Change', 'High_Low_Ratio', 'Volume_Ratio'
269
+ ]
270
+
271
+ # Add lagged features to state
272
+ lag_features = [col for col in df.columns if '_lag_' in col]
273
+ state_features.extend(lag_features)
274
+
275
+ # Filter existing features
276
+ state_features = [feat for feat in state_features if feat in df.columns]
277
+
278
+ # Normalize state features
279
+ scaler = StandardScaler()
280
+ df_scaled = df.copy()
281
+ df_scaled[state_features] = scaler.fit_transform(df[state_features])
282
+
283
+ # Define action space (0: Hold, 1: Buy, 2: Sell)
284
+ # You can expand this based on your RL strategy
285
+
286
+ # Create sequences for each stock
287
+ rl_data = {}
288
+ sequence_length = 60 # Number of days to look back
289
+
290
+ for ticker in df_scaled['Ticker'].unique():
291
+ ticker_data = df_scaled[df_scaled['Ticker'] == ticker].sort_values('Date')
292
+
293
+ states = []
294
+ rewards = []
295
+ dates = []
296
+
297
+ for i in range(sequence_length, len(ticker_data)):
298
+ # State: sequence of technical indicators
299
+ state_sequence = ticker_data.iloc[i-sequence_length:i][state_features].values
300
+ states.append(state_sequence)
301
+
302
+ # Reward: next day return (can be modified based on your RL objective)
303
+ if 'Future_Return_1d' in ticker_data.columns:
304
+ reward = ticker_data.iloc[i]['Future_Return_1d']
305
+ else:
306
+ current_price = ticker_data.iloc[i]['Close']
307
+ if i < len(ticker_data) - 1:
308
+ next_price = ticker_data.iloc[i+1]['Close']
309
+ reward = (next_price - current_price) / current_price
310
+ else:
311
+ reward = 0
312
+
313
+ rewards.append(reward)
314
+ dates.append(ticker_data.iloc[i]['Date'])
315
+
316
+ rl_data[ticker] = {
317
+ 'states': np.array(states),
318
+ 'rewards': np.array(rewards),
319
+ 'dates': dates,
320
+ 'state_features': state_features
321
+ }
322
+
323
+ logger.info(f"RL data created for {len(rl_data)} stocks")
324
+ return rl_data, scaler
325
+
326
+ def save_processed_data(self, data: pd.DataFrame, rl_data: Dict, scaler, filename_prefix: str = "processed_stock_data"):
327
+ """
328
+ Save processed data to files
329
+ """
330
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
331
+
332
+ # Save CSV data
333
+ csv_filename = f"{self.data_dir}/{filename_prefix}_{timestamp}.csv"
334
+ data.to_csv(csv_filename, index=False)
335
+ logger.info(f"CSV data saved to {csv_filename}")
336
+
337
+ # Save RL data
338
+ rl_filename = f"{self.data_dir}/{filename_prefix}_rl_{timestamp}.pkl"
339
+ with open(rl_filename, 'wb') as f:
340
+ pickle.dump(rl_data, f)
341
+ logger.info(f"RL data saved to {rl_filename}")
342
+
343
+ # Save scaler
344
+ scaler_filename = f"{self.data_dir}/{filename_prefix}_scaler_{timestamp}.pkl"
345
+ with open(scaler_filename, 'wb') as f:
346
+ pickle.dump(scaler, f)
347
+ logger.info(f"Scaler saved to {scaler_filename}")
348
+
349
+ return csv_filename, rl_filename, scaler_filename
350
+
351
+ def process_stocks_pipeline(self,
352
+ tickers: Optional[List[str]] = None,
353
+ period: str = "10y",
354
+ interval: str = "1d",
355
+ use_sp500: bool = True) -> Tuple[pd.DataFrame, Dict, object]:
356
+ """
357
+ Complete pipeline for processing stock data
358
+ """
359
+ logger.info("Starting stock data processing pipeline...")
360
+
361
+ # Get tickers
362
+ if tickers is None:
363
+ if use_sp500:
364
+ tickers = self.get_sp500_tickers()
365
+ else:
366
+ tickers = ['AAPL', 'MSFT', 'GOOGL', 'AMZN', 'TSLA'] # Default list
367
+
368
+ # Download data
369
+ logger.info(f"Downloading data for {len(tickers)} tickers...")
370
+ raw_data = self.download_multiple_stocks(tickers, period, interval)
371
+
372
+ if raw_data.empty:
373
+ logger.error("No data downloaded. Exiting.")
374
+ return None, None, None
375
+
376
+ # Process data
377
+ data_with_indicators = self.calculate_technical_indicators(raw_data)
378
+ data_with_lags = self.create_lagged_features(data_with_indicators)
379
+ data_with_targets = self.create_future_returns(data_with_lags)
380
+ cleaned_data = self.clean_and_normalize_data(data_with_targets)
381
+
382
+ # Create RL data
383
+ rl_data, scaler = self.create_rl_states_actions(cleaned_data)
384
+
385
+ # Save data
386
+ self.save_processed_data(cleaned_data, rl_data, scaler)
387
+
388
+ logger.info("Pipeline completed successfully!")
389
+ return cleaned_data, rl_data, scaler
390
+
391
+ # Example usage and utility functions
392
+ def example_usage():
393
+ """
394
+ Example of how to use the StockDataProcessor
395
+ """
396
+ # Initialize processor
397
+ processor = StockDataProcessor()
398
+
399
+ # Option 1: Process S&P 500 stocks
400
+ print("Processing S&P 500 stocks...")
401
+ data, rl_data, scaler = processor.process_stocks_pipeline(use_sp500=True, period="5y")
402
+
403
+ # Option 2: Process specific stocks
404
+ # custom_tickers = ['AAPL', 'MSFT', 'GOOGL', 'AMZN', 'TSLA', 'META', 'NVDA']
405
+ # data, rl_data, scaler = processor.process_stocks_pipeline(tickers=custom_tickers, period="10y")
406
+
407
+ if data is not None:
408
+ print(f"Processed data shape: {data.shape}")
409
+ print(f"Features: {data.columns.tolist()}")
410
+ print(f"RL data available for {len(rl_data)} stocks")
411
+
412
+ # Example: Access RL data for a specific stock
413
+ if 'AAPL' in rl_data:
414
+ aapl_states = rl_data['AAPL']['states']
415
+ aapl_rewards = rl_data['AAPL']['rewards']
416
+ print(f"AAPL: {aapl_states.shape[0]} sequences, each with {aapl_states.shape[1]} timesteps and {aapl_states.shape[2]} features")
417
+
418
+ def load_processed_data(rl_filename: str, scaler_filename: str) -> Tuple[Dict, object]:
419
+ """
420
+ Load previously processed RL data
421
+ """
422
+ with open(rl_filename, 'rb') as f:
423
+ rl_data = pickle.load(f)
424
+
425
+ with open(scaler_filename, 'rb') as f:
426
+ scaler = pickle.load(f)
427
+
428
+ return rl_data, scaler
429
+
430
+ if __name__ == "__main__":
431
+ example_usage()
enviromentcreator.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import gymnasium as gym
4
+ from gymnasium import spaces
5
+ from typing import Dict, Tuple, List, Optional
6
+ import logging
7
+ from dataclasses import dataclass
8
+ from enum import Enum
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+ from collections import deque
12
+ import json
13
+
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class ActionType(Enum):
19
+ HOLD = 0
20
+ BUY = 1
21
+ SELL = 2
22
+
23
+ @dataclass
24
+ class TradingMetrics:
25
+ """Comprehensive trading metrics for evaluation"""
26
+ total_return: float = 0.0
27
+ sharpe_ratio: float = 0.0
28
+ max_drawdown: float = 0.0
29
+ win_rate: float = 0.0
30
+ total_trades: int = 0
31
+ profitable_trades: int = 0
32
+ average_trade_return: float = 0.0
33
+ volatility: float = 0.0
34
+ calmar_ratio: float = 0.0
35
+ sortino_ratio: float = 0.0
36
+
37
+ class EnhancedStockTradingEnvironment(gym.Env):
38
+ """
39
+ Enhanced stock trading environment with comprehensive metrics and logging
40
+ """
41
+
42
+ def __init__(self,
43
+ rl_data: Dict,
44
+ ticker: str,
45
+ initial_balance: float = 10000,
46
+ transaction_cost: float = 0.001, # 0.1% transaction cost
47
+ max_position_size: float = 1.0, # Maximum position size as fraction of portfolio
48
+ lookback_window: int = 60, # Number of days to look back
49
+ reward_type: str = "return", # "return", "sharpe", "sortino"
50
+ enable_logging: bool = True):
51
+
52
+ super().__init__()
53
+
54
+ self.rl_data = rl_data
55
+ self.ticker = ticker
56
+ self.initial_balance = initial_balance
57
+ self.transaction_cost = transaction_cost
58
+ self.max_position_size = max_position_size
59
+ self.lookback_window = lookback_window
60
+ self.reward_type = reward_type
61
+ self.enable_logging = enable_logging
62
+
63
+ # Get data for the specific ticker
64
+ self.stock_data = rl_data[ticker]
65
+ self.states = self.stock_data['states']
66
+ self.prices = self._extract_prices() # Extract actual prices
67
+ self.dates = self.stock_data['dates']
68
+
69
+ # Environment parameters
70
+ self.current_step = 0
71
+ self.max_steps = len(self.states) - 1
72
+
73
+ # Portfolio state
74
+ self.reset_portfolio()
75
+
76
+ # Trading history
77
+ self.trade_history = []
78
+ self.portfolio_history = []
79
+ self.action_history = []
80
+ self.reward_history = []
81
+
82
+ # Performance tracking
83
+ self.daily_returns = deque(maxlen=252) # 1 year of returns for Sharpe calculation
84
+ self.drawdown_history = []
85
+ self.peak_portfolio_value = initial_balance
86
+
87
+ # Action space: 0 = Hold, 1 = Buy, 2 = Sell, with continuous position sizing
88
+ self.action_space = spaces.Box(
89
+ low=np.array([0, 0]), # [action_type (0-2), position_size (0-1)]
90
+ high=np.array([2, 1]),
91
+ dtype=np.float32
92
+ )
93
+
94
+ # Observation space: market state + portfolio state + technical indicators
95
+ market_state_size = self.states.shape[1] * self.states.shape[2]
96
+ portfolio_state_size = 8 # Extended portfolio state
97
+
98
+ self.observation_space = spaces.Box(
99
+ low=-np.inf,
100
+ high=np.inf,
101
+ shape=(market_state_size + portfolio_state_size,),
102
+ dtype=np.float32
103
+ )
104
+
105
+ if self.enable_logging:
106
+ logger.info(f"Environment initialized for {ticker}")
107
+ logger.info(f"Data shape: {self.states.shape}")
108
+ logger.info(f"Price range: ${self.prices.min():.2f} - ${self.prices.max():.2f}")
109
+
110
+ def _extract_prices(self) -> np.ndarray:
111
+ """Extract actual prices from the state data"""
112
+ # Assuming the first feature in states is the close price
113
+ return self.states[:, -1, 3] # Close price is typically at index 3
114
+
115
+ def reset_portfolio(self):
116
+ """Reset portfolio to initial state"""
117
+ self.balance = self.initial_balance
118
+ self.shares_held = 0
119
+ self.net_worth = self.initial_balance
120
+ self.max_net_worth = self.initial_balance
121
+ self.position_value = 0
122
+ self.total_transaction_costs = 0
123
+
124
+ def reset(self, seed=None, options=None):
125
+ super().reset(seed=seed)
126
+
127
+ self.current_step = 0
128
+ self.reset_portfolio()
129
+
130
+ # Clear histories
131
+ self.trade_history.clear()
132
+ self.portfolio_history.clear()
133
+ self.action_history.clear()
134
+ self.reward_history.clear()
135
+ self.daily_returns.clear()
136
+ self.drawdown_history.clear()
137
+ self.peak_portfolio_value = self.initial_balance
138
+
139
+ return self._get_observation(), {}
140
+
141
+ def step(self, action):
142
+ # Parse action
143
+ action_type = int(np.clip(action[0], 0, 2))
144
+ position_size = np.clip(action[1], 0, 1)
145
+
146
+ # Execute action
147
+ reward = self._execute_action(action_type, position_size)
148
+
149
+ # Update portfolio metrics
150
+ self._update_portfolio_metrics()
151
+
152
+ # Store history
153
+ self._store_step_data(action_type, position_size, reward)
154
+
155
+ # Move to next step
156
+ self.current_step += 1
157
+
158
+ # Check if episode is done
159
+ done = self.current_step >= self.max_steps
160
+ truncated = False
161
+
162
+ # Calculate final metrics if done
163
+ info = {}
164
+ if done:
165
+ info = self._calculate_episode_metrics()
166
+
167
+ return self._get_observation(), reward, done, truncated, info
168
+
169
+ def _execute_action(self, action_type: int, position_size: float) -> float:
170
+ """Execute trading action and return reward"""
171
+ current_price = self.prices[self.current_step]
172
+ previous_net_worth = self.net_worth
173
+
174
+ if action_type == ActionType.BUY.value:
175
+ # Calculate how much to buy
176
+ max_affordable = self.balance / current_price
177
+ shares_to_buy = int(max_affordable * position_size)
178
+
179
+ if shares_to_buy > 0:
180
+ cost = shares_to_buy * current_price
181
+ transaction_cost = cost * self.transaction_cost
182
+
183
+ if self.balance >= cost + transaction_cost:
184
+ self.shares_held += shares_to_buy
185
+ self.balance -= (cost + transaction_cost)
186
+ self.total_transaction_costs += transaction_cost
187
+
188
+ self.trade_history.append({
189
+ 'step': self.current_step,
190
+ 'action': 'BUY',
191
+ 'shares': shares_to_buy,
192
+ 'price': current_price,
193
+ 'cost': cost,
194
+ 'transaction_cost': transaction_cost
195
+ })
196
+
197
+ elif action_type == ActionType.SELL.value:
198
+ # Calculate how much to sell
199
+ shares_to_sell = int(self.shares_held * position_size)
200
+
201
+ if shares_to_sell > 0:
202
+ revenue = shares_to_sell * current_price
203
+ transaction_cost = revenue * self.transaction_cost
204
+
205
+ self.shares_held -= shares_to_sell
206
+ self.balance += (revenue - transaction_cost)
207
+ self.total_transaction_costs += transaction_cost
208
+
209
+ self.trade_history.append({
210
+ 'step': self.current_step,
211
+ 'action': 'SELL',
212
+ 'shares': shares_to_sell,
213
+ 'price': current_price,
214
+ 'revenue': revenue,
215
+ 'transaction_cost': transaction_cost
216
+ })
217
+
218
+ # Calculate new net worth
219
+ self.position_value = self.shares_held * current_price
220
+ self.net_worth = self.balance + self.position_value
221
+
222
+ # Calculate reward based on selected method
223
+ reward = self._calculate_reward(previous_net_worth)
224
+
225
+ return reward
226
+
227
+ def _calculate_reward(self, previous_net_worth: float) -> float:
228
+ """Calculate reward based on the selected reward type"""
229
+ if self.reward_type == "return":
230
+ # Simple return-based reward
231
+ return (self.net_worth - previous_net_worth) / previous_net_worth
232
+
233
+ elif self.reward_type == "sharpe":
234
+ # Sharpe ratio-based reward
235
+ if len(self.daily_returns) > 1:
236
+ returns = np.array(self.daily_returns)
237
+ if np.std(returns) > 0:
238
+ sharpe = np.mean(returns) / np.std(returns) * np.sqrt(252)
239
+ return sharpe / 100 # Scale down
240
+ return 0
241
+
242
+ elif self.reward_type == "sortino":
243
+ # Sortino ratio-based reward
244
+ if len(self.daily_returns) > 1:
245
+ returns = np.array(self.daily_returns)
246
+ negative_returns = returns[returns < 0]
247
+ if len(negative_returns) > 0 and np.std(negative_returns) > 0:
248
+ sortino = np.mean(returns) / np.std(negative_returns) * np.sqrt(252)
249
+ return sortino / 100 # Scale down
250
+ return 0
251
+
252
+ else:
253
+ return (self.net_worth - previous_net_worth) / previous_net_worth
254
+
255
+ def _update_portfolio_metrics(self):
256
+ """Update portfolio performance metrics"""
257
+ # Calculate daily return
258
+ if len(self.portfolio_history) > 0:
259
+ daily_return = (self.net_worth - self.portfolio_history[-1]['net_worth']) / self.portfolio_history[-1]['net_worth']
260
+ self.daily_returns.append(daily_return)
261
+
262
+ # Update peak and drawdown
263
+ if self.net_worth > self.peak_portfolio_value:
264
+ self.peak_portfolio_value = self.net_worth
265
+
266
+ current_drawdown = (self.peak_portfolio_value - self.net_worth) / self.peak_portfolio_value
267
+ self.drawdown_history.append(current_drawdown)
268
+
269
+ def _store_step_data(self, action_type: int, position_size: float, reward: float):
270
+ """Store data for analysis"""
271
+ self.action_history.append({
272
+ 'step': self.current_step,
273
+ 'action_type': action_type,
274
+ 'position_size': position_size
275
+ })
276
+
277
+ self.portfolio_history.append({
278
+ 'step': self.current_step,
279
+ 'balance': self.balance,
280
+ 'shares_held': self.shares_held,
281
+ 'position_value': self.position_value,
282
+ 'net_worth': self.net_worth,
283
+ 'price': self.prices[self.current_step]
284
+ })
285
+
286
+ self.reward_history.append(reward)
287
+
288
+ def _calculate_episode_metrics(self) -> Dict:
289
+ """Calculate comprehensive episode metrics"""
290
+ if len(self.portfolio_history) == 0:
291
+ return {}
292
+
293
+ # Basic returns
294
+ total_return = (self.net_worth - self.initial_balance) / self.initial_balance
295
+
296
+ # Risk metrics
297
+ returns = np.array(self.daily_returns) if self.daily_returns else np.array([0])
298
+ max_drawdown = max(self.drawdown_history) if self.drawdown_history else 0
299
+ volatility = np.std(returns) * np.sqrt(252)
300
+
301
+ # Sharpe ratio
302
+ sharpe_ratio = np.mean(returns) / np.std(returns) * np.sqrt(252) if np.std(returns) > 0 else 0
303
+
304
+ # Sortino ratio
305
+ negative_returns = returns[returns < 0]
306
+ sortino_ratio = np.mean(returns) / np.std(negative_returns) * np.sqrt(252) if len(negative_returns) > 0 and np.std(negative_returns) > 0 else 0
307
+
308
+ # Calmar ratio
309
+ calmar_ratio = (np.mean(returns) * 252) / max_drawdown if max_drawdown > 0 else 0
310
+
311
+ # Trading metrics
312
+ total_trades = len(self.trade_history)
313
+ buy_trades = [t for t in self.trade_history if t['action'] == 'BUY']
314
+ sell_trades = [t for t in self.trade_history if t['action'] == 'SELL']
315
+
316
+ # Win rate calculation (simplified)
317
+ profitable_trades = len([r for r in self.reward_history if r > 0])
318
+ win_rate = profitable_trades / len(self.reward_history) if len(self.reward_history) > 0 else 0
319
+
320
+ metrics = {
321
+ 'total_return': total_return,
322
+ 'sharpe_ratio': sharpe_ratio,
323
+ 'sortino_ratio': sortino_ratio,
324
+ 'calmar_ratio': calmar_ratio,
325
+ 'max_drawdown': max_drawdown,
326
+ 'volatility': volatility,
327
+ 'win_rate': win_rate,
328
+ 'total_trades': total_trades,
329
+ 'buy_trades': len(buy_trades),
330
+ 'sell_trades': len(sell_trades),
331
+ 'final_balance': self.balance,
332
+ 'final_shares': self.shares_held,
333
+ 'final_net_worth': self.net_worth,
334
+ 'total_transaction_costs': self.total_transaction_costs,
335
+ 'average_reward': np.mean(self.reward_history) if self.reward_history else 0
336
+ }
337
+
338
+ if self.enable_logging:
339
+ logger.info(f"Episode completed for {self.ticker}")
340
+ logger.info(f"Total Return: {total_return:.2%}")
341
+ logger.info(f"Sharpe Ratio: {sharpe_ratio:.2f}")
342
+ logger.info(f"Max Drawdown: {max_drawdown:.2%}")
343
+ logger.info(f"Win Rate: {win_rate:.2%}")
344
+
345
+ return metrics
346
+
347
+ def _get_observation(self):
348
+ """Get current observation"""
349
+ if self.current_step >= len(self.states):
350
+ # Return last available state if we're at the end
351
+ market_state = self.states[-1].flatten()
352
+ else:
353
+ market_state = self.states[self.current_step].flatten()
354
+
355
+ # Portfolio state (normalized)
356
+ current_price = self.prices[min(self.current_step, len(self.prices)-1)]
357
+
358
+ portfolio_state = np.array([
359
+ self.balance / self.initial_balance, # Normalized balance
360
+ self.shares_held * current_price / self.initial_balance, # Normalized position value
361
+ self.net_worth / self.initial_balance, # Normalized net worth
362
+ (self.net_worth - self.initial_balance) / self.initial_balance, # Return
363
+ len(self.trade_history) / 100, # Number of trades (normalized)
364
+ self.total_transaction_costs / self.initial_balance, # Transaction costs
365
+ max(self.drawdown_history) if self.drawdown_history else 0, # Current max drawdown
366
+ np.std(self.daily_returns) if len(self.daily_returns) > 1 else 0 # Volatility
367
+ ])
368
+
369
+ return np.concatenate([market_state, portfolio_state]).astype(np.float32)
370
+
371
+ def render(self, mode='human'):
372
+ """Render environment state"""
373
+ current_price = self.prices[min(self.current_step, len(self.prices)-1)]
374
+
375
+ print(f"\n=== {self.ticker} Trading Environment ===")
376
+ print(f"Step: {self.current_step}/{self.max_steps}")
377
+ print(f"Current Price: ${current_price:.2f}")
378
+ print(f"Balance: ${self.balance:.2f}")
379
+ print(f"Shares Held: {self.shares_held}")
380
+ print(f"Position Value: ${self.position_value:.2f}")
381
+ print(f"Net Worth: ${self.net_worth:.2f}")
382
+ print(f"Total Return: {((self.net_worth - self.initial_balance) / self.initial_balance):.2%}")
383
+ print(f"Total Trades: {len(self.trade_history)}")
384
+ print(f"Transaction Costs: ${self.total_transaction_costs:.2f}")
385
+
386
+ if self.drawdown_history:
387
+ print(f"Max Drawdown: {max(self.drawdown_history):.2%}")
388
+
389
+ print("=" * 40)
390
+
391
+ def plot_performance(self, save_path: Optional[str] = None):
392
+ """Plot comprehensive performance metrics"""
393
+ if len(self.portfolio_history) == 0:
394
+ print("No data to plot")
395
+ return
396
+
397
+ fig, axes = plt.subplots(2, 2, figsize=(15, 10))
398
+ fig.suptitle(f'{self.ticker} Trading Performance', fontsize=16)
399
+
400
+ # Portfolio value over time
401
+ steps = [p['step'] for p in self.portfolio_history]
402
+ net_worths = [p['net_worth'] for p in self.portfolio_history]
403
+ prices = [p['price'] for p in self.portfolio_history]
404
+
405
+ axes[0, 0].plot(steps, net_worths, label='Portfolio Value', linewidth=2)
406
+ axes[0, 0].axhline(y=self.initial_balance, color='r', linestyle='--', label='Initial Balance')
407
+ axes[0, 0].set_title('Portfolio Value Over Time')
408
+ axes[0, 0].set_xlabel('Time Steps')
409
+ axes[0, 0].set_ylabel('Portfolio Value ($)')
410
+ axes[0, 0].legend()
411
+ axes[0, 0].grid(True)
412
+
413
+ # Stock price over time
414
+ axes[0, 1].plot(steps, prices, label='Stock Price', color='orange', linewidth=2)
415
+ axes[0, 1].set_title('Stock Price Over Time')
416
+ axes[0, 1].set_xlabel('Time Steps')
417
+ axes[0, 1].set_ylabel('Price ($)')
418
+ axes[0, 1].legend()
419
+ axes[0, 1].grid(True)
420
+
421
+ # Drawdown
422
+ if self.drawdown_history:
423
+ axes[1, 0].fill_between(range(len(self.drawdown_history)),
424
+ self.drawdown_history, 0,
425
+ alpha=0.3, color='red')
426
+ axes[1, 0].plot(self.drawdown_history, color='red', linewidth=2)
427
+ axes[1, 0].set_title('Drawdown Over Time')
428
+ axes[1, 0].set_xlabel('Time Steps')
429
+ axes[1, 0].set_ylabel('Drawdown')
430
+ axes[1, 0].grid(True)
431
+
432
+ # Action distribution
433
+ actions = [a['action_type'] for a in self.action_history]
434
+ action_counts = [actions.count(i) for i in range(3)]
435
+ action_labels = ['Hold', 'Buy', 'Sell']
436
+
437
+ axes[1, 1].pie(action_counts, labels=action_labels, autopct='%1.1f%%')
438
+ axes[1, 1].set_title('Action Distribution')
439
+
440
+ plt.tight_layout()
441
+
442
+ if save_path:
443
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
444
+ logger.info(f"Performance plot saved to {save_path}")
445
+
446
+ plt.show()
447
+
448
+ def get_metrics_summary(self) -> TradingMetrics:
449
+ """Get trading metrics as a structured object"""
450
+ metrics_dict = self._calculate_episode_metrics()
451
+
452
+ return TradingMetrics(
453
+ total_return=metrics_dict.get('total_return', 0),
454
+ sharpe_ratio=metrics_dict.get('sharpe_ratio', 0),
455
+ max_drawdown=metrics_dict.get('max_drawdown', 0),
456
+ win_rate=metrics_dict.get('win_rate', 0),
457
+ total_trades=metrics_dict.get('total_trades', 0),
458
+ profitable_trades=int(metrics_dict.get('win_rate', 0) * metrics_dict.get('total_trades', 0)),
459
+ average_trade_return=metrics_dict.get('average_reward', 0),
460
+ volatility=metrics_dict.get('volatility', 0),
461
+ calmar_ratio=metrics_dict.get('calmar_ratio', 0),
462
+ sortino_ratio=metrics_dict.get('sortino_ratio', 0)
463
+ )