Как выполнить сериализацию пидантической модели с полиморфизмом?

Я попытался сериализовать модель pydantic с атрибутом, который может принадлежать к классу нескольких подклассов базового класса. Однако при простой реализации подклассы сериализуются в базовый класс.

Прочитав этот выпуск, я написал следующий код, но безуспешно:

from typing import Dict, Literal, Union

from pydantic import BaseModel, Field, RootModel


class NodeBase(BaseModel):
    id: str


class StartNode(NodeBase):
    type: Literal["start"] = "start"


class EndNode(NodeBase):
    type: Literal["end"] = "end"


class LLMNode(NodeBase):
    type: Literal["llm"] = "llm"
    name: str = Field(default_factory=lambda: id)
    purpose: str
    prompt: str
    model: Literal[
        "gpt-4o", "gpt4-turbo", "gpt-4", "gpt-3.5-turbo", "azure-gpt-3.5-turbo"
    ]


class NodeModel(RootModel):
    root: Union[StartNode, EndNode, LLMNode]


class Graph(BaseModel):
    nodes: Dict[str, NodeModel] = Field(default_factory=dict)

    def add_node(self, node: Union[StartNode, EndNode, LLMNode]) -> None:
        self.nodes[node.id] = NodeModel(root=node)


start_node = StartNode(id = "start", type = "start")
llm_node = LLMNode(id = "llm", type = "llm", purpose = "test", prompt = "test", model = "gpt-4o")
end_node = EndNode(id = "end", type = "end")

# ========= Node tests =========
start_node_dict = start_node.model_dump()
llm_node_dict = llm_node.model_dump()
end_node_dict = end_node.model_dump()
# Is it possible to use model_validate with the base class?
start_node_from_dict = NodeBase.model_validate(start_node_dict)
llm_node_from_dict = NodeBase.model_validate(llm_node_dict)
end_node_from_dict = NodeBase.model_validate(end_node_dict)


assert start_node == start_node_from_dict
assert llm_node == llm_node_from_dict
assert end_node == end_node_from_dict


# ========= Graph tests =========
g = Graph()
g.add_node(start_node)
g.add_node(llm_node)
g.add_node(end_node)

g_dict = g.model_dump()
g_from_dict = Graph.model_validate(g_dict)
assert g == g_from_dict

выдать следующие ошибки:

UserWarning: Pydantic serializer warnings:
  Expected `str` but got `builtin_function_or_method` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_python(
Traceback (most recent call last):
  File "file.py", line 53, in <module>
    assert start_node == start_node_from_dict
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

Я хотел бы иметь возможность выгружать график независимо от того, являются ли добавленные узлы подклассами Node, такими как StartNode или LLMNode, и иметь возможность десериализовать граф обратно, где все узлы имеют правильные типы. Кроме того, было бы здорово, если бы я мог также десериализовать подкласс Node, не зная, какой тип напрямую с помощью NodeBase.model_validate(subclass_of_NodeBase_that_I_dont_know_the_type_of)

Почему бы не использовать валидатор модели, который проверяет свойство type входного словаря и строит модель в соответствии с ним?

בנימין כהן 09.06.2024 02:12
Почему в Python есть оператор "pass"?
Почему в Python есть оператор "pass"?
Оператор pass в Python - это простая концепция, которую могут быстро освоить даже новички без опыта программирования.
Некоторые методы, о которых вы не знали, что они существуют в Python
Некоторые методы, о которых вы не знали, что они существуют в Python
Python - самый известный и самый простой в изучении язык в наши дни. Имея широкий спектр применения в области машинного обучения, Data Science,...
Основы Python Часть I
Основы Python Часть I
Вы когда-нибудь задумывались, почему в программах на Python вы видите приведенный ниже код?
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
LeetCode - 1579. Удаление максимального числа ребер для сохранения полной проходимости графа
Алиса и Боб имеют неориентированный граф из n узлов и трех типов ребер:
Оптимизация кода с помощью тернарного оператора Python
Оптимизация кода с помощью тернарного оператора Python
И последнее, что мы хотели бы показать вам, прежде чем двигаться дальше, это
Советы по эффективной веб-разработке с помощью Python
Советы по эффективной веб-разработке с помощью Python
Как веб-разработчик, Python может стать мощным инструментом для создания эффективных и масштабируемых веб-приложений.
1
1
109
2
Перейти к ответу Данный вопрос помечен как решенный

Ответы 2

Я бы предложил использовать что-то вроде этого. Это всего лишь пример того, как вам может помочь сопоставление по типу и пользовательской функции:

from enum import Enum
from typing import Union, Type

from pydantic import BaseModel


class NodeType(str, Enum):
    START = "start"
    END = "end"
    LLM = "llm"


class NodeBase(BaseModel):
    type: NodeType


class StartNode(NodeBase):
    field: int

class EndNode(NodeBase):
    field2: int


class NodeModel(BaseModel):
    root: Union[StartNode, EndNode]

    @classmethod
    def from_dict(cls, data: dict) -> "NodeModel":
        type_map: dict[NodeType, Type[NodeBase]] = {
            NodeType.START: StartNode,
            NodeType.END: EndNode,
        }

        node_type = NodeType(data["type"])
        type_class = type_map[node_type]

        node = type_class.model_validate(obj=data)

        return cls(root=node)


NodeModel.from_dict(data = {"field": 1, "type": "start"})  # root=StartNode(type=<NodeType.START: 'start'>, field=1)
NodeModel.from_dict(data = {"field2": 2, "type": "end"})  # root=EndNode(type=<NodeType.END: 'end'>, field2=2)

Не забудьте добавить дополнительную проверку! Надеюсь, это поможет!

Ответ принят как подходящий

Спасибо за ваши отзывы, вот решение, которое я использовал для своей проблемы.

# Create the NodeTypes union from the node types list
NodeTypes = Union[tuple(node_types)] # shouldn't contain NodeBase


class NodeModel(RootModel):
    root: NodeTypes

    @model_validator(mode = "after")
    @classmethod
    def get_root(cls, obj):
        if hasattr(obj, "root"):
            return obj.root
        return obj
    ```

And have a different way of adding nodes to the graph
```python
def add_node(self: Self, node: NodeBase) -> None:
        """Add a node to the graph.
        :param node: An instance of the Node class
        """
        self.nodes[node.id] = node
    ```

Другие вопросы по теме