Я попытался сериализовать модель pydantic с атрибутом, который может принадлежать к классу нескольких подклассов базового класса. Однако при простой реализации подклассы сериализуются в базовый класс.
Прочитав этот выпуск, я написал следующий код, но безуспешно:
from typing import Dict, Literal, Union
from pydantic import BaseModel, Field, RootModel
class NodeBase(BaseModel):
id: str
class StartNode(NodeBase):
type: Literal["start"] = "start"
class EndNode(NodeBase):
type: Literal["end"] = "end"
class LLMNode(NodeBase):
type: Literal["llm"] = "llm"
name: str = Field(default_factory=lambda: id)
purpose: str
prompt: str
model: Literal[
"gpt-4o", "gpt4-turbo", "gpt-4", "gpt-3.5-turbo", "azure-gpt-3.5-turbo"
]
class NodeModel(RootModel):
root: Union[StartNode, EndNode, LLMNode]
class Graph(BaseModel):
nodes: Dict[str, NodeModel] = Field(default_factory=dict)
def add_node(self, node: Union[StartNode, EndNode, LLMNode]) -> None:
self.nodes[node.id] = NodeModel(root=node)
start_node = StartNode(id = "start", type = "start")
llm_node = LLMNode(id = "llm", type = "llm", purpose = "test", prompt = "test", model = "gpt-4o")
end_node = EndNode(id = "end", type = "end")
# ========= Node tests =========
start_node_dict = start_node.model_dump()
llm_node_dict = llm_node.model_dump()
end_node_dict = end_node.model_dump()
# Is it possible to use model_validate with the base class?
start_node_from_dict = NodeBase.model_validate(start_node_dict)
llm_node_from_dict = NodeBase.model_validate(llm_node_dict)
end_node_from_dict = NodeBase.model_validate(end_node_dict)
assert start_node == start_node_from_dict
assert llm_node == llm_node_from_dict
assert end_node == end_node_from_dict
# ========= Graph tests =========
g = Graph()
g.add_node(start_node)
g.add_node(llm_node)
g.add_node(end_node)
g_dict = g.model_dump()
g_from_dict = Graph.model_validate(g_dict)
assert g == g_from_dict
выдать следующие ошибки:
UserWarning: Pydantic serializer warnings:
Expected `str` but got `builtin_function_or_method` - serialized value may not be as expected
return self.__pydantic_serializer__.to_python(
Traceback (most recent call last):
File "file.py", line 53, in <module>
assert start_node == start_node_from_dict
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
Я хотел бы иметь возможность выгружать график независимо от того, являются ли добавленные узлы подклассами Node, такими как StartNode или LLMNode, и иметь возможность десериализовать граф обратно, где все узлы имеют правильные типы. Кроме того, было бы здорово, если бы я мог также десериализовать подкласс Node, не зная, какой тип напрямую с помощью NodeBase.model_validate(subclass_of_NodeBase_that_I_dont_know_the_type_of)






Я бы предложил использовать что-то вроде этого. Это всего лишь пример того, как вам может помочь сопоставление по типу и пользовательской функции:
from enum import Enum
from typing import Union, Type
from pydantic import BaseModel
class NodeType(str, Enum):
START = "start"
END = "end"
LLM = "llm"
class NodeBase(BaseModel):
type: NodeType
class StartNode(NodeBase):
field: int
class EndNode(NodeBase):
field2: int
class NodeModel(BaseModel):
root: Union[StartNode, EndNode]
@classmethod
def from_dict(cls, data: dict) -> "NodeModel":
type_map: dict[NodeType, Type[NodeBase]] = {
NodeType.START: StartNode,
NodeType.END: EndNode,
}
node_type = NodeType(data["type"])
type_class = type_map[node_type]
node = type_class.model_validate(obj=data)
return cls(root=node)
NodeModel.from_dict(data = {"field": 1, "type": "start"}) # root=StartNode(type=<NodeType.START: 'start'>, field=1)
NodeModel.from_dict(data = {"field2": 2, "type": "end"}) # root=EndNode(type=<NodeType.END: 'end'>, field2=2)
Не забудьте добавить дополнительную проверку! Надеюсь, это поможет!
Спасибо за ваши отзывы, вот решение, которое я использовал для своей проблемы.
# Create the NodeTypes union from the node types list
NodeTypes = Union[tuple(node_types)] # shouldn't contain NodeBase
class NodeModel(RootModel):
root: NodeTypes
@model_validator(mode = "after")
@classmethod
def get_root(cls, obj):
if hasattr(obj, "root"):
return obj.root
return obj
```
And have a different way of adding nodes to the graph
```python
def add_node(self: Self, node: NodeBase) -> None:
"""Add a node to the graph.
:param node: An instance of the Node class
"""
self.nodes[node.id] = node
```
Почему бы не использовать валидатор модели, который проверяет свойство
typeвходного словаря и строит модель в соответствии с ним?