소스 검색

a friendly little helper (#4021)

* a friendly little helper

* addressing comments

* update comment

---------

Co-authored-by: simon <simon@reflex.dev>
Simon Young 7 달 전
부모
커밋
e96b4bf42e
2개의 변경된 파일20개의 추가작업 그리고 3개의 파일을 삭제
  1. 2 3
      reflex/model.py
  2. 18 0
      reflex/utils/compat.py

+ 2 - 3
reflex/model.py

@@ -22,7 +22,7 @@ from reflex import constants
 from reflex.base import Base
 from reflex.base import Base
 from reflex.config import get_config
 from reflex.config import get_config
 from reflex.utils import console
 from reflex.utils import console
-from reflex.utils.compat import sqlmodel
+from reflex.utils.compat import sqlmodel, sqlmodel_field_has_primary_key
 
 
 
 
 def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
 def get_engine(url: str | None = None) -> sqlalchemy.engine.Engine:
@@ -166,8 +166,7 @@ class Model(Base, sqlmodel.SQLModel):  # pyright: ignore [reportGeneralTypeIssue
         non_default_primary_key_fields = [
         non_default_primary_key_fields = [
             field_name
             field_name
             for field_name, field in cls.__fields__.items()
             for field_name, field in cls.__fields__.items()
-            if field_name != "id"
-            and getattr(field.field_info, "primary_key", None) is True
+            if field_name != "id" and sqlmodel_field_has_primary_key(field)
         ]
         ]
         if non_default_primary_key_fields:
         if non_default_primary_key_fields:
             cls.__fields__.pop("id", None)
             cls.__fields__.pop("id", None)

+ 18 - 0
reflex/utils/compat.py

@@ -69,3 +69,21 @@ def pydantic_v1_patch():
 
 
 with pydantic_v1_patch():
 with pydantic_v1_patch():
     import sqlmodel as sqlmodel
     import sqlmodel as sqlmodel
+
+
+def sqlmodel_field_has_primary_key(field) -> bool:
+    """Determines if a field is a priamary.
+
+    Args:
+        field: a rx.model field
+
+    Returns:
+        If field is a primary key (Bool)
+    """
+    if getattr(field.field_info, "primary_key", None) is True:
+        return True
+    if getattr(field.field_info, "sa_column", None) is None:
+        return False
+    if getattr(field.field_info.sa_column, "primary_key", None) is True:
+        return True
+    return False