Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions docs/how-to-guides/ogm.md
Original file line number Diff line number Diff line change
Expand Up @@ -395,19 +395,57 @@ To check which constraints have been created, run:
print(db.get_constraints())
```

## Using enums

Memgraph's built-in [enum data type](https://memgraph.com/docs/fundamentals/data-types#enum) can be utilized on your GQLAlchemy OGM models. GQLAlchemy's enum implementation extends Python's [enum support](https://docs.python.org/3.11/library/enum.html).

First, create an enum.

```python
from enum import Enum

class SubscriptionType(Enum):
FREE = 1
BASIC = 2
EXTENDED = 3
```

Then, use the defined enum class in your model definition. Using the `Field` class, set the `enum` attribute to `True`. This will indicate that GQLAlchemy should treat the property value stored as a Memgraph enum. If the enum does not exist in the database, it will be created.

```python
class User(Node):
id: str = Field(index=True, db=db)
username: str
subscription: SubscriptionType = Field(enum=True, db=db)
```

Enum types may be defined for properties on Nodes and Relationships.

!!! info
If the `Field` class specification on the property isn't specified, or if `enum` is explicitly set to `False`, GQLAlchemy will use the `value` of the enum member when serializing to a Cypher query. A corresponding enum will not be created in the database.

This functionality allows for flexiblity when using the Python `Enum` class, and would, for instance, respect an overridden `__getattribute__` method to customize the value passed to Cypher.

## Full code example

The above mentioned examples can be merged into a working code example which you can run. Here is the code:

```python
from gqlalchemy import Memgraph, Node, Relationship, Field
from typing import Optional
from enum import Enum

db = Memgraph()

class SubscriptionType(Enum):
FREE = 1
BASIC = 2
EXTENDED = 3

class User(Node):
id: str = Field(index=True, db=db)
username: str = Field(exists=True, db=db)
subscription: SubscriptionType = Field(enum=True, db=db)

class Streamer(User):
id: str
Expand All @@ -423,8 +461,8 @@ class ChatsWith(Relationship, type="CHATS_WITH"):
class Speaks(Relationship, type="SPEAKS"):
since: Optional[str]

john = User(id="1", username="John").save(db)
jane = Streamer(id="2", username="janedoe", followers=111).save(db)
john = User(id="1", username="John", subscription=SubscriptionType(1)).save(db)
jane = Streamer(id="2", username="janedoe", subscription=SubscriptionType(3), followers=111).save(db)
language = Language(name="en").save(db)

ChatsWith(
Expand All @@ -449,7 +487,7 @@ try:
streamer = Streamer(id="3").load(db=db)
except:
print("Creating new Streamer node in the database.")
streamer = Streamer(id="3", username="anne", followers=222).save(db=db)
streamer = Streamer(id="3", username="anne", subscription=SubscriptionType(2), followers=222).save(db=db)

try:
speaks = Speaks(_start_node_id=streamer._id, _end_node_id=language._id).load(db)
Expand Down
1 change: 1 addition & 0 deletions gqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
MemgraphConstraintExists,
MemgraphConstraintUnique,
MemgraphIndex,
MemgraphEnum,
MemgraphKafkaStream,
MemgraphPulsarStream,
MemgraphTrigger,
Expand Down
91 changes: 89 additions & 2 deletions gqlalchemy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, date, time, timedelta
from enum import Enum
import json
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from enum import Enum, EnumMeta

from pydantic.v1 import BaseModel, Extra, Field, PrivateAttr # noqa F401

Expand Down Expand Up @@ -59,6 +59,36 @@ def _format_timedelta(duration: timedelta) -> str:
return f"P{days}DT{hours}H{minutes}M{remainder_sec}S"


class GraphEnum(ABC):
def __init__(self, enum):

if not isinstance(enum, (Enum, EnumMeta)):
raise TypeError()

self.enum = enum if isinstance(enum, Enum) else None
self.cls = enum.__class__ if isinstance(enum, Enum) else enum

@property
def name(self):
return self.cls.__name__

@property
def members(self):
return self.cls.__members__

@abstractmethod
def _to_cypher(self):
pass


class MemgraphEnum(GraphEnum):
def _to_cypher(self):
return f"{{ {', '.join(self.cls._member_names_)} }}"

def __repr__(self):
return f"<enum '{self.name}'>" if self.enum is None else f"{self.name}::{self.enum.name}"


class TriggerEventType:
"""An enum representing types of trigger events."""

Expand Down Expand Up @@ -308,6 +338,17 @@ class GraphObject(BaseModel):
class Config:
extra = Extra.allow

def __init__(self, **data):
for field in self.__class__.__fields__:
attrs = self.__class__.__fields__[field].field_info.extra
cls = self.__fields__[field].type_
if issubclass(cls, Enum) and not attrs.get("enum", False):
value = data.get(field)
if isinstance(value, dict):
member = value.get("__value").split("::")[1]
data[field] = cls[member].value
super().__init__(**data)

def __init_subclass__(cls, type=None, label=None, labels=None, index=None, db=None):
"""Stores the subclass by type if type is specified, or by class name
when instantiating a subclass.
Expand Down Expand Up @@ -372,6 +413,8 @@ def escape_value(
return repr(value)
elif value_type == float:
return repr(value)
elif isinstance(value, Enum):
return repr(MemgraphEnum(value))
elif isinstance(value, str):
return json.dumps(value)
elif isinstance(value, list):
Expand Down Expand Up @@ -446,7 +489,11 @@ def _get_cypher_set_properties(self, variable_name: str) -> str:
cypher_set_properties = []
for field in self.__fields__:
attributes = self.__fields__[field].field_info.extra
value = getattr(self, field)
cls = self.__fields__[field].type_
if issubclass(cls, Enum) and not attributes.get("enum", False):
value = getattr(self, field).value
else:
value = getattr(self, field)
if value is not None and not attributes.get("on_disk", False):
cypher_set_properties.append(f" SET {variable_name}.{field} = {self.escape_value(value)}")

Expand Down Expand Up @@ -512,6 +559,9 @@ def get_base_labels() -> Set[str]:
cls.labels = get_base_labels().union({cls.label}, kwargs.get("labels", set()))

db = kwargs.get("db")

cls.enums = None

if cls.index is True:
if db is None:
raise GQLAlchemyDatabaseMissingInNodeClassError(cls=cls)
Expand All @@ -522,12 +572,25 @@ def get_base_labels() -> Set[str]:
for field in cls.__fields__:
attrs = cls.__fields__[field].field_info.extra
field_type = cls.__fields__[field].type_.__name__
field_cls = cls.__fields__[field].type_
label = attrs.get("label", cls.label)
skip_constraints = False

if db is None:
db = attrs.get("db")

if issubclass(field_cls, Enum) and attrs.get("enum", False):
if db is None:
raise GQLAlchemyDatabaseMissingInNodeClassError(cls=cls)
if cls.enums is None:
cls.enums = db.get_enums()
enum_names = [x.name for x in cls.enums]
if field_cls.__name__ in enum_names:
existing = cls.enums[enum_names.index(field_cls.__name__)]
db.sync_enum(existing, MemgraphEnum(field_cls))
else:
db.create_enum(MemgraphEnum(field_cls))

for constraint in FieldAttrsConstants.list():
if constraint in attrs and db is None:
base = field_in_superclass(field, constraint)
Expand Down Expand Up @@ -663,6 +726,30 @@ def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901
if name != "Relationship":
cls.type = kwargs.get("type", name)

db = kwargs.get("db")

cls.enums = None

for field in cls.__fields__:
attrs = cls.__fields__[field].field_info.extra
field_type = cls.__fields__[field].type_.__name__
field_cls = cls.__fields__[field].type_

if db is None:
db = attrs.get("db")

if issubclass(field_cls, Enum) and attrs.get("enum", False):
if db is None:
raise GQLAlchemyDatabaseMissingInNodeClassError(cls=cls)
if cls.enums is None:
cls.enums = db.get_enums()
enum_names = [x.name for x in cls.enums]
if field_type in enum_names:
existing = cls.enums[enum_names.index(field_type)]
db.sync_enum(existing, MemgraphEnum(field_cls))
else:
db.create_enum(MemgraphEnum(field_cls))

return cls


Expand Down
31 changes: 25 additions & 6 deletions gqlalchemy/vendors/database_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@

from gqlalchemy.connection import Connection
from gqlalchemy.exceptions import GQLAlchemyError
from gqlalchemy.models import (
Constraint,
Index,
Node,
Relationship,
)
from gqlalchemy.models import Constraint, Index, GraphEnum, Node, Relationship


class DatabaseClient(ABC):
Expand Down Expand Up @@ -128,6 +123,30 @@ def ensure_constraints(
for missing_constraint in new_constraints.difference(old_constraints):
self.create_constraint(missing_constraint)

@abstractmethod
def create_enum(self, enum: GraphEnum) -> None:
pass

@abstractmethod
def get_enums(self) -> List[GraphEnum]:
"""Returns a list of all enums defined in the database."""
pass

@abstractmethod
def sync_enum(self, existing: GraphEnum, new: GraphEnum) -> None:
"""Ensures that database enum matches input enum."""
pass

@abstractmethod
def drop_enum(self, enum: GraphEnum) -> None:
"""Drops a single enum in the database."""
pass

@abstractmethod
def drop_enums(self) -> None:
"""Drops all enums in the database"""
pass

def drop_database(self):
"""Drops database by removing all nodes and edges."""
self.execute("MATCH (n) DETACH DELETE n;")
Expand Down
27 changes: 27 additions & 0 deletions gqlalchemy/vendors/memgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
MemgraphIndex,
MemgraphStream,
MemgraphTrigger,
MemgraphEnum,
Node,
Relationship,
)
Expand Down Expand Up @@ -167,6 +168,32 @@ def get_constraints(
)
return constraints

def create_enum(self, graph_enum: MemgraphEnum) -> None:
query = f"CREATE ENUM {graph_enum.name} VALUES {graph_enum._to_cypher()};"
self.execute(query)

def get_enums(self) -> List[MemgraphEnum]:
"""Returns a list of all enums defined in the database."""
enums: List[MemgraphEnum] = []
for result in self.execute_and_fetch("SHOW ENUMS;"):
enums.append(MemgraphEnum(Enum(result["Enum Name"], result["Enum Values"])))
return enums

def sync_enum(self, existing: MemgraphEnum, new: MemgraphEnum) -> None:
"""Ensures that database enum matches input enum."""
for value in new.members:
if value not in existing.members:
query = f"ALTER ENUM {existing.name} ADD VALUE {value};"
self.execute(query)
Comment on lines +182 to +187
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To modify an existing enum by adding a new value use ALTER:

ALTER ENUM Status ADD VALUE Excellent;

To update an existing value in an enum do the following:

ALTER ENUM Status UPDATE VALUE Bad TO Poor;

Maybe instead of having sync_enum procedure, it would make sense to have two procedures matching the above queries?

Full example:

Create enum:

CREATE ENUM Status VALUES { Good, Okay, Bad };
SHOW ENUMS;

Result:

{"Enum Name":"Status","Enum Values":["Good","Okay","Bad"]}

Add new value to existing enum:

ALTER ENUM Status ADD VALUE Excellent;
SHOW ENUMS;

Result

{"Enum Name":"Status","Enum Values":["Good","Okay","Bad","Excellent"]}

Update existing value in an existing enum:

ALTER ENUM Status UPDATE VALUE Bad TO Poor;
SHOW ENUMS;

Result:

{"Enum Name":"Status","Enum Values":["Good","Okay","Poor","Excellent"]}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also listed some errors that can occur when running ALTER query - memgraph/memgraph#2968.


def drop_enum(self, graph_enum: MemgraphEnum):
raise GQLAlchemyError(f"DROP ENUM not yet implemented. Enum {graph_enum.name} is persisted in the database.")

def drop_enums(self, graph_enums: List[MemgraphEnum]):
raise GQLAlchemyError(
f"DROP ENUM not yet implemented. Enums {', '.join(graph_enums)} are persisted in the database."
)

def get_exists_constraints(
self,
) -> List[MemgraphConstraintExists]:
Expand Down
18 changes: 18 additions & 0 deletions gqlalchemy/vendors/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Neo4jConstraintExists,
Neo4jConstraintUnique,
Neo4jIndex,
GraphEnum,
Node,
Relationship,
)
Expand Down Expand Up @@ -99,6 +100,23 @@ def ensure_indexes(self, indexes: List[Neo4jIndex]) -> None:
for missing_index in new_indexes.difference(old_indexes):
self.create_index(missing_index)

def create_enum(self, graph_enum: GraphEnum) -> None:
raise GQLAlchemyError(f"CREATE ENUM not yet implemented in Neo4j.")

def get_enums(self) -> List[GraphEnum]:
"""Returns a list of all enums defined in the database."""
raise GQLAlchemyError(f"SHOW ENUMS not yet implemented in Neo4j.")

def sync_enum(self, existing: GraphEnum, new: GraphEnum) -> None:
"""Ensures that database enum matches input enum."""
raise GQLAlchemyError(f"ALTER ENUM not yet implemented in Neo4j.")

def drop_enum(self, graph_enum: GraphEnum):
raise GQLAlchemyError(f"DROP ENUM not yet implemented in Neo4j.")

def drop_enums(self, graph_enums: List[GraphEnum]):
raise GQLAlchemyError(f"DROP ENUM not yet implemented in Neo4j.")

def get_constraints(
self,
) -> List[Union[Neo4jConstraintExists, Neo4jConstraintUnique]]:
Expand Down
16 changes: 16 additions & 0 deletions tests/ogm/test_custom_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@

from pydantic.v1 import Field

from enum import Enum

from gqlalchemy import (
MemgraphConstraintExists,
MemgraphConstraintUnique,
MemgraphIndex,
MemgraphEnum,
Neo4jConstraintUnique,
Neo4jIndex,
Node,
Expand Down Expand Up @@ -56,6 +59,19 @@ def test_create_index(memgraph):
assert actual_constraints == [memgraph_index]


def test_create_graph_enum(memgraph):
enum1 = Enum("MgEnum", (("MEMBER1", "value1"), ("MEMBER2", "value2"), ("MEMBER3", "value3")))

class Node3(Node):
type: enum1

memgraph_enum = MemgraphEnum(enum1)

actual_enums = memgraph.get_enums()

assert actual_enums == [memgraph_enum]


def test_create_constraint_unique_neo4j(neo4j):
class Node2(Node):
id: int = Field(db=neo4j)
Expand Down