Как назначить поток FastAPI Uvicorn идентификатору сеанса?

Я интегрирую модель искусственного интеллекта в веб-приложение (эта модель должна иметь некоторый контекст, чтобы поддерживать плавный диалог с пользователем) при локальном развертывании. Проблема во внутренней структуре потока.

Я знаю, как работает поток пула. И проблема в том, что при выполнении нескольких POST-запросов (для общения с ботом) существует вероятность того, что на этот запрос ответит другой поток, который не использовался. Тогда контекст будет сохранен в памяти этого потока, а не в том, который мы использовали ранее.

Основная проблема: контекст сохраняется в разных потоках с разной памятью.

Во-первых, хочу отметить, что ни одно решение не должно заключаться ни в реализации файлов cookie, ни в создании файла для сохранения контекста. Идея состоит в том, чтобы назначить поток для каждого session_token.

Я пробовал следующее:

from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from dotenv import load_dotenv
from openai import AzureOpenAI
import os
from typing import Dict, List
from uuid import UUID, uuid4

load_dotenv()
router = APIRouter()

class Message(BaseModel):
    """Class representing the message of a conversation"""
    role: str  # either 'user' or 'assistant'
    content: str

class Prompt(BaseModel):
    """Class for sending or receiving messages"""
    session_id: UUID
    prompt: str

class NewSessionResponse(BaseModel):
    session_id: UUID
    
class ResetRequest(BaseModel):
    session_id: UUID

# Dictionary to store conversations
conversations: Dict[UUID, List[Message]] = {}

@router.post("/ai/chat")
async def chat(prompt: Prompt):
    """In charge of executing and obtaining the connection with the model"""
    api_key = os.getenv("key")
    api_url = os.getenv("endpoint url")
    
    if not api_key or not api_url:
        raise HTTPException(status_code=500, detail='API key or API endpoint not found! Try again')
    
    client = AzureOpenAI("here goes some parameters")
    
    # Retrieve the conversation history for the session
    session_id = prompt.session_id
    if session_id not in conversations:
        conversations[session_id] = []
    
    # Add the user's prompt to the conversation history
    conversations[session_id].append(Message(role = "user", content=prompt.prompt))
    
    # Create the context for the API request
    context = [{'role': msg.role, 'content': msg.content} for msg in conversations[session_id]]
    
    # Request completion from the model
    response = client.chat.completions.create(
        model = "gpt-35-turbo-4k-0613",
        messages=context
    )
    
    # Extract the model's response
    model_response = response.choices[0].message.content
    
    # Add the assistant's response to the conversation history
    conversations[session_id].append(Message(role='assistant', content=model_response))
    print(conversations)
    return {"response": model_response}

@router.post("/ai/new_session", response_model=NewSessionResponse)
async def newSession():
    """Creates a new conversation session"""
    session_id = uuid4()
    conversations[session_id] = []
    return NewSessionResponse(session_id=session_id)

Я также пытаюсь реализовать многопоточность при управлении запросами, но это не сработало. Я предполагаю, что это происходит потому, что нити Uvicorn и темы, созданные здесь, не одно и то же.

Это тоже может помочь
Chris 07.08.2024 13:15
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
0
1
53
1
Перейти к ответу Данный вопрос помечен как решенный

Ответы 1

Ответ принят как подходящий

Я реализовал базу данных Redis для сохранения сообщений в зависимости от идентификатора сеанса, который получает конечная точка.

Обратите внимание, что мне пришлось изменить файл docker-compose.yaml для подключения к БД.

Вот у вас есть код:

from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from dotenv import load_dotenv
from openai import AzureOpenAI
import os
from typing import List
from uuid import UUID, uuid4
import json
from app.core.redis_config import redis_db

load_dotenv()
router = APIRouter()

class Message(BaseModel):
    role: str
    content: str

class Prompt(BaseModel):
    session_id: UUID
    prompt: str
    ai_model: str

class NewSessionResponse(BaseModel):
    session_id: UUID

class ResetRequest(BaseModel):
    session_id: UUID

def get_conversation(session_id: UUID) -> List[Message]:
    """In charge of loading the conversation from the ddbb"""
    data = redis_db.get(str(session_id))
    if data:
        return [Message(**msg) for msg in json.loads(data)]
    return []

def save_conversation(session_id: UUID, messages: List[Message]):
    """In charge of inserting the conversation into the ddbb"""
    redis_db.set(str(session_id), json.dumps([msg.dict() for msg in messages])) # TODO: works for the moment with this although is deprecated

@router.post("/ai/chat")
async def chat(prompt: Prompt):
    api_key = os.getenv("OPENAI_API_KEY", "")
    api_url = os.getenv("OPENAI_API_BASE", "")
    
    ai_model = prompt.ai_model
    
    if not api_key or not api_url:
        raise HTTPException(status_code=500, detail='API key or API endpoint not found! Try again')
    
    client = AzureOpenAI("some parameters")
    
    session_id = prompt.session_id
    conversation = get_conversation(session_id)
    
    conversation.append(Message(role = "user", content=prompt.prompt))
    
    context = [{'role': msg.role, 'content': msg.content} for msg in conversation]
    
    response = client.chat.completions.create(
        model=ai_model,
        messages=context
    )
    
    model_response = response.choices[0].message.content
    
    conversation.append(Message(role='assistant', content=model_response))
    save_conversation(session_id, conversation)
    
    return {"response": model_response}

@router.post("/ai/reset")
async def resetConversation(reset_request: ResetRequest):
    session_id = reset_request.session_id
    if redis_db.exists(str(session_id)):
        redis_db.delete(str(session_id))
    else:
        raise HTTPException(status_code=400, detail='Invalid session ID')
    return {"response": True}

@router.post("/ai/new_session", response_model=NewSessionResponse)
async def newSession():
    session_id = uuid4()
    save_conversation(session_id, [])
    return NewSessionResponse(session_id=session_id)

Другие вопросы по теме