app.py
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import json
from fastapi import FastAPI, Request, HTTPException, Security, Depends
from fastapi.security import APIKeyHeader
from fastapi.responses import JSONResponse
import logging
import time
import os
from typing import Optional
配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
全局变量
model = None
tokenizer = None
device = "cpu"
安全配置
TEST_MODE: bool = os.getenv("TEST_MODE", "false").lower() == "true"
API_KEYS = os.getenv("API_KEYS", "your-secret-key-1,your-secret-key-2").split(",")
初始化API密钥头认证
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
def load_model():
"""加载模型 - 使用Qwen3-4B模型"""
global model, tokenizer, device
if model is not None:
return True
try:
# 使用Qwen3-4B模型
model_name = "Qwen/Qwen3-4B"
logger.info(f"正在加载模型: {model_name}")
# 检查是否有GPU可用
if torch.cuda.is_available():
device = "cuda"
logger.info("检测到GPU可用,将使用GPU加速")
# 配置量化设置以节省GPU内存
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16
)
else:
logger.info("未检测到GPU,将使用CPU")
bnb_config = None
# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
# 确保tokenizer有pad_token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# 模型加载配置
model_kwargs = {
"torch_dtype": torch.float16 if device == "cuda" else torch.float32,
"trust_remote_code": True,
"device_map": "auto" if device == "cuda" else None,
}
# 如果使用GPU,添加量化配置
if device == "cuda" and bnb_config:
model_kwargs["quantization_config"] = bnb_config
model = AutoModelForCausalLM.from_pretrained(
model_name,
**model_kwargs
)
# 如果使用CPU,手动移动模型
if device == "cpu":
model = model.to(device)
logger.info(f"{model_name} 模型加载成功!")
return True
except Exception as e:
logger.error(f"Qwen3-4B模型加载失败: {e}")
# 如果4B模型失败,尝试使用更小的1.5B模型
logger.info("尝试加载Qwen2.5-1.5B模型...")
try:
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model_kwargs = {
"torch_dtype": torch.float16 if device == "cuda" else torch.float32,
"trust_remote_code": True,
"device_map": "auto" if device == "cuda" else None,
}
model = AutoModelForCausalLM.from_pretrained(
model_name,
**model_kwargs
)
if device == "cpu":
model = model.to(device)
logger.info(f"备用模型 {model_name} 加载成功!")
return True
except Exception as e2:
logger.error(f"备用模型也加载失败: {e2}")
return False
def verify_api_key(
request_key_header: Optional[str] = Security(api_key_header) if not TEST_MODE else None,
) -> str:
"""API密钥验证依赖函数"""
logger.info(f"当前安全模式: {'测试模式' if TEST_MODE else '生产模式'}")
if TEST_MODE:
logger.info("测试模式下跳过API密钥验证")
return "test_mode_bypass"
if request_key_header is None:
logger.warning("请求头中缺少API密钥")
raise HTTPException(
status_code=401,
detail="缺少API密钥,请在请求头中添加 X-API-Key"
)
if request_key_header not in API_KEYS:
logger.warning(f"无效的API密钥尝试: {request_key_header}")
raise HTTPException(
status_code=401,
detail="无效的API密钥"
)
logger.info("API密钥验证通过")
return request_key_header
def generate_response(message, max_tokens=512, temperature=0.7):
"""生成模型响应"""
if not load_model():
return "模型加载失败,请稍后重试"
try:
# 构建对话格式
messages = [
{"role": "user", "content": message}
]
# 使用Qwen3的对话模板
formatted_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# 编码输入
inputs = tokenizer(
formatted_prompt,
return_tensors="pt",
truncation=True,
max_length=2048
).to(device)
# 生成回复
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1,
eos_token_id=tokenizer.eos_token_id
)
# 解码回复 - 只解码生成的部分
response = tokenizer.decode(
outputs[0][inputs.input_ids.shape[-1]:],
skip_special_tokens=True
)
return response.strip()
except Exception as e:
logger.error(f"生成回复时出错: {str(e)}")
return f"生成回复时出错: {str(e)}"
创建FastAPI应用
app = FastAPI(title="Qwen3-4B大模型API服务", description="基于Qwen3-4B大模型的API服务")
API健康检查端点
@app.get("/")
async def root():
return {
"message": "Qwen3-4B大模型API服务运行中",
"timestamp": int(time.time()),
"model": "Qwen3-4B",
"device": device
}
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"model_loaded": model is not None,
"device": device,
"gpu_available": torch.cuda.is_available()
}
受保护的聊天API端点
@app.post("/api/chat")
async def chat_api(
request: Request,
api_key: str = Depends(verify_api_key)
):
"""OpenAI兼容的聊天API端点"""
try:
data = await request.json()
messages = data.get("messages", [])
model_name = data.get("model", "Qwen3-4B")
max_tokens = data.get("max_tokens", 512)
temperature = data.get("temperature", 0.7)
# 提取用户消息
user_message = ""
for msg in messages:
if msg["role"] == "user":
user_message = msg["content"]
break
if not user_message:
return JSONResponse({
"error": "未找到用户消息",
"choices": []
}, status_code=400)
response_text = generate_response(user_message, max_tokens, temperature)
if not response_text:
response_text = "抱歉,我无法生成合适的回复。"
return JSONResponse({
"id": "chatcmpl-" + str(int(time.time())),
"object": "chat.completion",
"created": int(time.time()),
"model": model_name,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": len(tokenizer.encode(user_message)) if tokenizer else 0,
"completion_tokens": len(tokenizer.encode(response_text)) if tokenizer else 0,
"total_tokens": (len(tokenizer.encode(user_message)) + len(tokenizer.encode(response_text))) if tokenizer else 0
}
})
except Exception as e:
logger.error(f"API调用错误: {str(e)}")
return JSONResponse({
"error": f"API调用错误: {str(e)}",
"choices": []
}, status_code=500)
创建Gradio界面
with gr.Blocks(title="Qwen3-4B大模型API服务", theme=gr.themes.Soft()) as demo:
gr.Markdown(f"""
# Qwen3-4B大模型API服务
*基于高性能Qwen3-4B大语言模型*
## 安全状态: {' 测试模式(认证已禁用)' if TEST_MODE else ' 生产模式(认证已启用)'}
## 运行设备: {device.upper()}
## GPU可用性: {'可用' if torch.cuda.is_available() else ' 不可用'}
## API端点信息
- **聊天端点**: `/api/chat` (需要API密钥认证)
- **健康检查**: `/health` (公开)
- **模型名称**: `Qwen3-4B`
## 当前使用模型
- 主模型: Qwen3-4B
- 备用模型: Qwen2.5-1.5B-Instruct
""")
with gr.Row():
with gr.Column(scale=2):
message_input = gr.Textbox(
label="输入消息",
placeholder="请输入您的问题...",
lines=3
)
with gr.Row():
submit_button = gr.Button("发送", variant="primary")
clear_button = gr.Button("清除")
with gr.Accordion("高级设置", open=False):
max_tokens = gr.Slider(
minimum=64, maximum=1024, value=512,
label="最大生成长度"
)
temperature = gr.Slider(
minimum=0.1, maximum=1.0, value=0.7,
label="温度参数 (越高越有创意)"
)
with gr.Column(scale=3):
output_area = gr.Textbox(
label="模型响应",
lines=10,
interactive=False
)
def respond(message, max_tokens, temperature):
if not message.strip():
return ""
response = generate_response(message, max_tokens, temperature)
return response
submit_button.click(
respond,
inputs=[message_input, max_tokens, temperature],
outputs=output_area
)
message_input.submit(
respond,
inputs=[message_input, max_tokens, temperature],
outputs=output_area
)
clear_button.click(lambda: "", inputs=[], outputs=output_area)
将Gradio应用挂载到FastAPI
app = gr.mount_gradio_app(app, demo, path="/")
预加载模型
try:
load_model()
except Exception as e:
logger.error(f"预加载模型失败: {e}")
if name == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)