Skip to content

Commit 81560df

Browse files
committed
Added JSONString custom scalar
1 parent ca0d1a3 commit 81560df

File tree

4 files changed

+44
-9
lines changed

4 files changed

+44
-9
lines changed

graphene/contrib/django/converter.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from django.db import models
22

3-
from ...core.classtypes.objecttype import ObjectType
43
from ...core.types.definitions import List
5-
from ...core.types.field import Field
64
from ...core.types.scalars import ID, Boolean, Float, Int, String
5+
from ...core.types.custom_scalars import JSONString
76
from ...core.classtypes.enum import Enum
87
from .compat import RelatedObject, UUIDField, ArrayField, HStoreField, JSONField, RangeField
98
from .utils import get_related_model, import_single_dispatch
@@ -103,7 +102,7 @@ def convert_postgres_array_to_list(field):
103102
@convert_django_field.register(HStoreField)
104103
@convert_django_field.register(JSONField)
105104
def convert_posgres_field_to_string(field):
106-
return String(description=field.help_text)
105+
return JSONString(description=field.help_text)
107106

108107

109108
@convert_django_field.register(RangeField)

graphene/contrib/django/tests/test_converter.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from ..fields import (ConnectionOrListField,
99
DjangoModelField)
1010
from ..compat import MissingType, ArrayField, HStoreField, JSONField, RangeField
11+
from graphene.core.types.custom_scalars import JSONString
1112

1213
from .models import Article, Reporter
1314

@@ -168,13 +169,13 @@ def test_should_postgres_array_multiple_convert_list():
168169
@pytest.mark.skipif(HStoreField is MissingType,
169170
reason="HStoreField should exist")
170171
def test_should_postgres_hstore_convert_string():
171-
assert_conversion(HStoreField, graphene.String)
172+
assert_conversion(HStoreField, JSONString)
172173

173174

174175
@pytest.mark.skipif(JSONField is MissingType,
175176
reason="JSONField should exist")
176177
def test_should_postgres_json_convert_string():
177-
assert_conversion(JSONField, graphene.String)
178+
assert_conversion(JSONField, JSONString)
178179

179180

180181
@pytest.mark.skipif(RangeField is MissingType,

graphene/contrib/django/tests/test_query.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,14 @@ def resolve_reporter(self, *args, **kwargs):
6666

6767
@pytest.mark.skipif(RangeField is MissingType,
6868
reason="RangeField should exist")
69-
def test_should_query_ranges():
70-
from django.contrib.postgres.fields import IntegerRangeField
69+
def test_should_query_postgres_fields():
70+
from django.contrib.postgres.fields import IntegerRangeField, ArrayField, JSONField, HStoreField
7171

7272
class Event(models.Model):
73-
ages = IntegerRangeField(help_text='Range desc')
73+
ages = IntegerRangeField(help_text='The age ranges')
74+
data = JSONField(help_text='Data')
75+
store = HStoreField()
76+
tags = ArrayField(models.CharField(max_length=50))
7477

7578
class EventType(DjangoObjectType):
7679
class Meta:
@@ -80,19 +83,30 @@ class Query(graphene.ObjectType):
8083
event = graphene.Field(EventType)
8184

8285
def resolve_event(self, *args, **kwargs):
83-
return Event(ages=(0, 10))
86+
return Event(
87+
ages=(0, 10),
88+
data={'angry_babies': True},
89+
store={'h': 'store'},
90+
tags=['child', 'angry', 'babies']
91+
)
8492

8593
schema = graphene.Schema(query=Query)
8694
query = '''
8795
query myQuery {
8896
event {
8997
ages
98+
tags
99+
data
100+
store
90101
}
91102
}
92103
'''
93104
expected = {
94105
'event': {
95106
'ages': [0, 10],
107+
'tags': ['child', 'angry', 'babies'],
108+
'data': '{"angry_babies": true}',
109+
'store': '{"h": "store"}',
96110
},
97111
}
98112
result = schema.execute(query)

graphene/core/types/custom_scalars.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import json
2+
3+
from graphql.core.language import ast
4+
from ...core.classtypes.scalar import Scalar
5+
6+
7+
class JSONString(Scalar):
8+
'''JSON String'''
9+
10+
@staticmethod
11+
def serialize(dt):
12+
return json.dumps(dt)
13+
14+
@staticmethod
15+
def parse_literal(node):
16+
if isinstance(node, ast.StringValue):
17+
return json.dumps(node.value)
18+
19+
@staticmethod
20+
def parse_value(value):
21+
return json.dumps(value)

0 commit comments

Comments
 (0)