Ai-Teacher/server.py
2025-03-18 20:54:51 +08:00

447 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import logging
import requests
import base64
from quart import Quart, request, jsonify, send_from_directory
from quart_cors import cors
import openai
import asyncio
import aiohttp
import PyPDF2
import time
import re
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
if not os.path.exists('./logs'):
os.makedirs('./logs')
logging.basicConfig(filename='./logs/server.log', level=logging.INFO)
logger = logging.getLogger(__name__)
console = logging.StreamHandler()
console.setLevel(logging.INFO)
logger.addHandler(console)
# 获取API密钥
#openai_api_key = os.getenv("OPENAI_API_KEY")
#openai_api_key = "sk-95ab48a1e0754ad39c13e2987f73fe37"
#openai_base_url = "https://api.deepseek.com"
openai_api_key = "sk-iVgiSZeNbLbTtp0lCvpIz2P0TpBGFLrcWdp5vDFtUFGfXCOs"
openai_base_url = "https://api.chatanywhere.tech"
llm_model = "gpt-4o-mini-2024-07-18"
# TTS API地址
TTS_BASE_URL = "http://server.feng-arch.cn:52861"
if not openai_api_key:
logger.warning("OpenAI API key not found. AI explanation will use fallback mode.")
# 加载设置
# 尽量把密钥,服务器地址,端口之类的设置全部放到setting.json中
try:
with open('setting.json', 'r') as f:
settings = json.load(f)
port = settings.get('websocket_port', 6006)
TTS_BASE_URL = settings.get('TTS_BASE_URL', TTS_BASE_URL)
except Exception as e:
logger.error(f"Error loading settings: {e}")
port = 6006
app = Quart(__name__, static_url_path='')
cors(app)
# 存储当前加载的PDF路径
current_pdf_path = None
pdfpages = None
chat_history = []
def extract_page_text(pdf_path, page_num):
"""提取PDF文档指定页面的文本内容"""
try:
with open(pdf_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
# 检查页码是否有效
if page_num < 1 or page_num > len(reader.pages):
return {
"success": False,
"error": f"无效的页码: {page_num}PDF共有 {len(reader.pages)}"
}
# 提取指定页面的文本
page = reader.pages[page_num - 1] # 页码从1开始但索引从0开始
page_text = page.extract_text()
return {
"success": True,
"page_text": page_text,
"page_count": len(reader.pages)
}
except Exception as e:
logger.error(f"Error extracting PDF text: {e}")
return {
"success": False,
"error": str(e)
}
def generate_explanation(page_num,page_text):
"""为单个页面生成讲解内容"""
if not openai_api_key:
return "这是一个示例讲解。请设置OpenAI API密钥以获取真实的AI讲解内容。"
try:
start_time = time.time()
client = openai.OpenAI(api_key=openai_api_key, base_url=openai_base_url)
response = client.chat.completions.create(
model= llm_model,
messages=[
{"role": "system", "content": f"你是一位幽默的教师正在为学生讲解PDF文档内容。请提供清晰、简洁的解释重点突出关键概念。这是你的讲解历史\n{chat_history}, 你需要与历史保持连贯。"},
{"role": "user", "content": f"请讲解第{page_num}页(总页数{pdfpages}ppt的内容{page_text}首先判断是否要详细或者简略比如标题页只需要简略示例稍微展开记住ppt不宜讲得太长不超过100字。你的输出应符合老师的风格句子间连贯幽默风趣。"}
]
)
logger.info(f"生成讲解耗时: {time.time()-start_time}")
chat_history.append({"page": page_num, "content": response.choices[0].message.content.strip()})
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"Error generating explanation: {e}")
return f"生成讲解时出错: {str(e)}"
def split_text_to_sentences(text):
"""将文本分割为句子"""
# 使用正则表达式分割句子
import re
# 匹配中文和英文的句子结束标志
sentence_endings = r'(?<=[。!?.!?])\s*'
sentences = re.split(sentence_endings, text)
# 过滤空句子
sentences = [s.strip() for s in sentences if s.strip()]
return sentences
async def fetch_audio_data(session, pair_data, voice, speed):
"""为每组句子生成音频"""
url = f"{TTS_BASE_URL}/tts"
payload = {
"text": pair_data["text"],
"voice": voice,
"speed": speed,
"return_type": "base64"
}
for i in range(3):
try:
async with session.post(url, json=payload) as response:
if response.status != 200:
logger.error(f"TTS API error: {response.status} - {await response.text()}")
return None
data = await response.json()
audio_base64 = data.get("audio_base64")
if not audio_base64:
logger.error(f"No audio data returned for pair")
return None
return {
"audio_base64": audio_base64,
"sentences": pair_data["sentences"],
"indices": pair_data["indices"]
}
except Exception as e:
logger.error(f"Error fetching audio: {e}")
if i == 2:
return None
await asyncio.sleep(0.5)
continue
async def text_to_speech(text, voice="af_heart", speed=1.5):
"""异步将文本转换为语音,返回每两句话的音频数据和时间戳"""
try:
start_time = time.time()
# 分割文本为句子
sentences = split_text_to_sentences(text)
if not sentences:
return {
"success": False,
"error": "无法分割文本为句子"
}
# 将句子按2句一组进行分组
sentence_pairs = []
i = 0
while i < len(sentences):
if i + 1 < len(sentences) and len(sentences[i]) + len(sentences[i+1]) < 60:
sentence_pairs.append({
"text": sentences[i] + " " + sentences[i+1],
"sentences": [sentences[i], sentences[i+1]],
"indices": [i, i+1]
})
i += 2
else:
sentence_pairs.append({
"text": sentences[i],
"sentences": [sentences[i]],
"indices": [i]
})
i += 1
# 将句子中的非文本emoji替换为空字符串
for pair in sentence_pairs:
pair["text"] = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s]', '', pair["text"])
# 创建异步HTTP会话
async with aiohttp.ClientSession() as session:
tasks = [
fetch_audio_data(session, pair_data, voice, speed)
for pair_data in sentence_pairs
]
# 异步执行所有任务
audio_segments = await asyncio.gather(*tasks)
# 过滤掉未成功的音频段
audio_segments = [seg for seg in audio_segments if seg]
logger.info(f"生成语音耗时: {time.time()-start_time}")
return {
"success": True,
"audio_segments": audio_segments,
"sentences": sentences
}
except Exception as e:
logger.error(f"Error in text_to_speech_async: {e}")
return {
"success": False,
"error": str(e)
}
@app.route('/')
def index():
return send_from_directory('', 'index.html')
@app.route('/<path:path>')
def serve_static(path):
return send_from_directory('', path)
@app.route('/api/explain', methods=['POST'])
async def explain():
data = await request.json
text = data.get('text', '')
page_num = data.get('page', None)
# 如果提供了页码但没有提供文本尝试从PDF中提取
if page_num and not text and current_pdf_path:
result = extract_page_text(current_pdf_path, page_num)
if result["success"]:
text = result["page_text"]
else:
return jsonify({
'success': False,
'explanation': f"无法提取页面文本: {result['error']}"
})
explanation = generate_explanation(page_num,text)
return jsonify({
'success': True,
'explanation': explanation
})
@app.route('/api/tts', methods=['POST'])
async def tts():
data = await request.json
text = data.get('text', '')
voice = data.get('voice', 'af_heart')
speed = data.get('speed', 1.0)
if not text:
return jsonify({
'success': False,
'error': '文本不能为空'
})
# 将文本转换为语音
result = asyncio.run(text_to_speech(text, voice, speed))
if result["success"]:
return jsonify({
'success': True,
'audio_segments': result["audio_segments"],
'sentences': result.get("sentences", [])
})
else:
return jsonify({
'success': False,
'error': result["error"]
})
cache_explanation = {"is_caching_flag":[]}
# 这里使用异步执行的方式, 用于提前加载缓存的讲解
async def generate_cache_explanation(page_num,voice,speed):
global cache_explanation
global pdfpages
global current_pdf_path
if page_num not in cache_explanation and page_num > 0 and page_num <= pdfpages and page_num not in cache_explanation["is_caching_flag"]:
cache_explanation["is_caching_flag"].append(page_num)
text = extract_page_text(current_pdf_path, page_num)["page_text"]
result = []
result.append(generate_explanation(page_num, text))
result.append(await text_to_speech(result[0], voice, speed))
cache_explanation[page_num] = result
logger.info(f"已缓存讲解: {page_num}")
if page_num+1 not in cache_explanation and page_num+1 > 0 and page_num+1 <= pdfpages and page_num+1 not in cache_explanation["is_caching_flag"]:
cache_explanation["is_caching_flag"].append(page_num+1)
text = extract_page_text(current_pdf_path, page_num+1)["page_text"]
result = []
result.append(generate_explanation(page_num+1, text))
result.append(await text_to_speech(result[0], voice, speed))
cache_explanation[page_num+1] = result
logger.info(f"已缓存讲解: {page_num+1}")
if page_num-1 not in cache_explanation and page_num-1 > 0 and page_num-1 <= pdfpages and page_num-1 not in cache_explanation["is_caching_flag"]:
cache_explanation["is_caching_flag"].append(page_num-1)
text = extract_page_text(current_pdf_path, page_num-1)["page_text"]
result = []
result.append(generate_explanation(page_num-1, text))
result.append(await text_to_speech(result[0], voice, speed))
cache_explanation[page_num-1] = result
logger.info(f"已缓存讲解: {page_num-1}")
@app.route('/api/explain_with_audio', methods=['POST'])
async def explain_with_audio():
global cache_explanation
data = await request.json
text = data.get('text', '')
page_num = data.get('page', None)
voice = data.get('voice', 'af_heart')
speed = data.get('speed', 1.0)
# 这里多线程执行, 用于提前加载缓存的讲解
asyncio.create_task(generate_cache_explanation(page_num,voice,speed))
# 如果已经有缓存的讲解,直接返回
if page_num in cache_explanation:
explanation = cache_explanation[page_num][0]
audio_segments = cache_explanation[page_num][1]["audio_segments"]
logger.info(f"已找到缓存讲解: {page_num}")
return jsonify({
'success': True,
'explanation': explanation,
'audio_segments': audio_segments,
'sentences': cache_explanation[page_num][1].get("sentences", [])
})
logger.info(f"未找到缓存讲解: {page_num}")
# 如果提供了页码但没有提供文本尝试从PDF中提取
if page_num and not text and current_pdf_path:
result = extract_page_text(current_pdf_path, page_num)
if result["success"]:
text = result["page_text"]
else:
return jsonify({
'success': False,
'explanation': f"无法提取页面文本: {result['error']}",
'error': result["error"]
})
# 生成讲解
explanation = generate_explanation(page_num,text)
# 将讲解转换为语音
tts_result = await text_to_speech(explanation, voice, speed)
if tts_result["success"]:
return jsonify({
'success': True,
'explanation': explanation,
'audio_segments': tts_result["audio_segments"],
'sentences': tts_result.get("sentences", [])
})
else:
return jsonify({
'success': True,
'explanation': explanation,
'audio_segments': None,
'tts_error': tts_result["error"]
})
@app.route('/api/load_pdf', methods=['POST'])
async def load_pdf():
global current_pdf_path
global cache_explanation
global pdfpages
# 清空cache
# cache_explanation = {"is_caching_flag":[]}
chat_history = []
data = await request.json
logger.info(f"加载PDF: {data}")
pdf_path = data.get('path', './public/pdf/test.pdf')
if pdf_path != current_pdf_path:
cache_explanation = {"is_caching_flag":[]}
try:
# 检查PDF是否存在
if not os.path.exists(pdf_path):
return jsonify({
'success': False,
'message': f'PDF文件不存在: {pdf_path}'
})
# 尝试打开PDF以验证其有效性
with open(pdf_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
page_count = len(reader.pages)
# 更新当前PDF路径
current_pdf_path = pdf_path
pdfpages = page_count
# 使用默认的声音和速度预加载讲解
voice = 'af_heart'
speed = 1.0
start_time = time.time()
asyncio.create_task(generate_cache_explanation(0,voice,speed))
logger.info(f"预加载讲解耗时: {time.time()-start_time}")
return jsonify({
'success': True,
'message': '已成功加载PDF',
'page_count': page_count
})
except Exception as e:
logger.error(f"Error loading PDF: {e}")
return jsonify({
'success': False,
'message': f'加载PDF失败: {str(e)}'
})
@app.route('/api/voices', methods=['GET'])
def get_voices():
"""获取可用的TTS声音列表"""
voices = [
# {"id": "zf_xiaoxiao", "name": "小小", "gender": "female", "lang": "zh"},
# {"id": "zf_xiaoni", "name": "小妮", "gender": "female", "lang": "zh"},
# {"id": "zf_xiaoyi", "name": "小怡", "gender": "female", "lang": "zh"},
# {"id": "zf_xiaobei", "name": "小贝", "gender": "female", "lang": "zh"},
# {"id": "zm_yunxi", "name": "云熙", "gender": "male", "lang": "zh"},
# {"id": "zm_yunyang", "name": "云扬", "gender": "male", "lang": "zh"},
# {"id": "zm_yunxia", "name": "云夏", "gender": "male", "lang": "zh"},
# {"id": "zm_yunjian", "name": "云健", "gender": "male", "lang": "zh"},
{"id": "af_heart", "name": "Heart", "gender": "female", "lang": "en"},
# {"id": "af_bella", "name": "Bella", "gender": "female", "lang": "en"},
# {"id": "am_michael", "name": "Michael", "gender": "male", "lang": "en"},
# {"id": "am_puck", "name": "Puck", "gender": "male", "lang": "en"}
]
return jsonify({
'success': True,
'voices': voices
})
if __name__ == '__main__':
# 设置默认PDF路径
default_pdf_path = './public/pdf/test.pdf'
if os.path.exists(default_pdf_path):
current_pdf_path = default_pdf_path
logger.info(f"默认PDF已设置: {default_pdf_path}")
app.run(host='0.0.0.0', port=port, debug=True)