运行代码,要求输入股票代码和名称,其他参数可省略
import akshare as ak
import matplotlib.pyplot as plt
import pandas as pd
import mplfinance as mpf
import matplotlib.dates as mdates
import numpy as np
import os
from datetime import datetime, timedelta
plt.rcParams["font.family"] = ["SimHei", "Microsoft YaHei"]
plt.rcParams["axes.unicode_minus"] = False def get_stock_data(stock_code, start_date=None, end_date=None, adjust="qfq"):"""使用 AkShare 获取股票数据参数:stock_code (str): 股票代码,如 'sh000001' 或 '000001.SZ'start_date (str): 开始日期,格式 'YYYYMMDD',默认为 3 个月前end_date (str): 结束日期,格式 'YYYYMMDD',默认为今天adjust (str): 复权类型,'qfq' 为前复权,'hfq' 为后复权,None 为不复权返回:DataFrame: 包含股票数据的 DataFrame"""if end_date is None:end_date = datetime.now().strftime('%Y%m%d')if start_date is None:start_date = (datetime.now() - timedelta(days=90)).strftime('%Y%m%d')if not stock_code.startswith(('sh', 'sz')):market = 'sh' if stock_code.startswith(('6', '9')) else 'sz'stock_code = f'{market}{stock_code}'try:stock_data = ak.stock_zh_a_hist_tx(symbol=stock_code, start_date=start_date, end_date=end_date, adjust=adjust)volume_columns = ['成交量', '成交额', 'volume','amount']volume_col = next((col for col in volume_columns if col in stock_data.columns), None)if volume_col is None:print("警告: 数据中未找到成交量列,图表可能不完整")stock_data['volume'] = 0 else:stock_data = stock_data.rename(columns={volume_col: 'volume'})stock_data = stock_data.rename(columns={'日期': 'date', '开盘': 'open', '收盘': 'close', '最高': 'high', '最低': 'low'})stock_data['date'] = pd.to_datetime(stock_data['date'])stock_data = stock_data.sort_values('date')required_columns = ['date', 'open', 'high', 'low', 'close', 'volume']missing_columns = [col for col in required_columns if col not in stock_data.columns]if missing_columns:print(f"错误: 数据缺少必要的列: {', '.join(missing_columns)}")return Nonereturn stock_dataexcept Exception as e:print(f"获取股票数据时出错: {e}")return Nonedef plot_stock_daily_movement(stock_data, stock_name="股票", save_path=None, ma_periods=None):"""绘制股票日变动情况,包括 K 线图和成交量及均线参数:stock_data (DataFrame): 包含股票数据的 DataFrame,必须包含 'date', 'open', 'high', 'low', 'close', 'volume' 列stock_name (str): 股票名称,用于图表标题save_path (str): 图表保存路径,默认为 None 不保存ma_periods (list): 均线周期列表,默认为 [5, 10, 20]"""if stock_data is None or stock_data.empty:print("没有数据可绘制图表")returnif ma_periods is None:ma_periods = [5, 10, 20]plot_data = stock_data.copy()plot_data = plot_data.set_index('date')for period in ma_periods:column_name = f'MA{period}'plot_data[column_name] = plot_data['close'].rolling(window=period).mean()mc = mpf.make_marketcolors(up='r', down='g', inherit=True)ma_colors = ['blue', 'purple', 'orange', 'green', 'red', 'brown', 'gray']ma_colors = ma_colors[:len(ma_periods)] s = mpf.make_mpf_style(marketcolors=mc, gridstyle='--', y_on_right=False,rc={'font.family': ['SimHei', 'WenQuanYi Micro Hei', 'Heiti TC', 'Microsoft YaHei'],'lines.linewidth': 1.5 })fig, axes = mpf.plot(plot_data,type='candle',style=s,title=f'{stock_name} 日变动情况',ylabel='价格 (元)',volume=True,ylabel_lower='成交量 (手)',mav=tuple(ma_periods), returnfig=True,figsize=(14, 10),update_width_config=dict(candle_linewidth=1.2, candle_width=0.6, volume_width=0.8, ))fig.suptitle(f'{stock_name} 日变动情况', fontsize=16, y=0.98)ax = axes[0] for i, period in enumerate(ma_periods):ax.plot([], [], color=ma_colors[i], label=f'MA{period}', linewidth=1.5)ax.legend(loc='upper left')if save_path:save_dir = os.path.dirname(save_path)if not os.path.exists(save_dir):os.makedirs(save_dir)plt.savefig(save_path, dpi=300, bbox_inches='tight')print(f"图表已保存至: {save_path}")plt.show()
if __name__ == "__main__":stock_code = input("请输入股票代码(例如 000001 或 sh000001): ").strip()stock_name = input("请输入股票名称(例如 平安银行): ").strip()print("正在获取股票数据...")stock_data = get_stock_data(stock_code)if stock_data is not None and not stock_data.empty:print(f"成功获取 {len(stock_data)} 天的股票数据")print(f"数据列名: {', '.join(stock_data.columns)}")ma_input = input("请输入均线周期(用逗号分隔,默认 5,10,20): ").strip()if ma_input:try:ma_periods = [int(p.strip()) for p in ma_input.split(',')]except ValueError:print("无效的均线周期,使用默认值")ma_periods = [5, 10, 20]else:ma_periods = [5, 10, 20]plot_stock_daily_movement(stock_data, stock_name, ma_periods=ma_periods)else:print("未能获取股票数据,请检查股票代码是否正确")
