Skip to content

Commit 8961be9

Browse files
authored
Map keys of custom types should serialize using MarshalText when available (#461)
* Map keys of custom types should serialize/deserialize using MarshalText/UnmarshalText when available - this brings marshaling/unmarshaling behavior in line with encoding/json - in general, any types that implement the interfaces from the encoding package (TextUnmarshaler, TextMarshaler, etc.) should use the provided method when available
1 parent 0f8241d commit 8961be9

File tree

2 files changed

+55
-36
lines changed

2 files changed

+55
-36
lines changed

reflect_map.go

Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,33 @@ func decoderOfMapKey(ctx *ctx, typ reflect2.Type) ValDecoder {
4949
return decoder
5050
}
5151
}
52+
53+
ptrType := reflect2.PtrTo(typ)
54+
if ptrType.Implements(unmarshalerType) {
55+
return &referenceDecoder{
56+
&unmarshalerDecoder{
57+
valType: ptrType,
58+
},
59+
}
60+
}
61+
if typ.Implements(unmarshalerType) {
62+
return &unmarshalerDecoder{
63+
valType: typ,
64+
}
65+
}
66+
if ptrType.Implements(textUnmarshalerType) {
67+
return &referenceDecoder{
68+
&textUnmarshalerDecoder{
69+
valType: ptrType,
70+
},
71+
}
72+
}
73+
if typ.Implements(textUnmarshalerType) {
74+
return &textUnmarshalerDecoder{
75+
valType: typ,
76+
}
77+
}
78+
5279
switch typ.Kind() {
5380
case reflect.String:
5481
return decoderOfType(ctx, reflect2.DefaultTypeOfKind(reflect.String))
@@ -63,31 +90,6 @@ func decoderOfMapKey(ctx *ctx, typ reflect2.Type) ValDecoder {
6390
typ = reflect2.DefaultTypeOfKind(typ.Kind())
6491
return &numericMapKeyDecoder{decoderOfType(ctx, typ)}
6592
default:
66-
ptrType := reflect2.PtrTo(typ)
67-
if ptrType.Implements(unmarshalerType) {
68-
return &referenceDecoder{
69-
&unmarshalerDecoder{
70-
valType: ptrType,
71-
},
72-
}
73-
}
74-
if typ.Implements(unmarshalerType) {
75-
return &unmarshalerDecoder{
76-
valType: typ,
77-
}
78-
}
79-
if ptrType.Implements(textUnmarshalerType) {
80-
return &referenceDecoder{
81-
&textUnmarshalerDecoder{
82-
valType: ptrType,
83-
},
84-
}
85-
}
86-
if typ.Implements(textUnmarshalerType) {
87-
return &textUnmarshalerDecoder{
88-
valType: typ,
89-
}
90-
}
9193
return &lazyErrorDecoder{err: fmt.Errorf("unsupported map key type: %v", typ)}
9294
}
9395
}
@@ -103,6 +105,19 @@ func encoderOfMapKey(ctx *ctx, typ reflect2.Type) ValEncoder {
103105
return encoder
104106
}
105107
}
108+
109+
if typ == textMarshalerType {
110+
return &directTextMarshalerEncoder{
111+
stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")),
112+
}
113+
}
114+
if typ.Implements(textMarshalerType) {
115+
return &textMarshalerEncoder{
116+
valType: typ,
117+
stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")),
118+
}
119+
}
120+
106121
switch typ.Kind() {
107122
case reflect.String:
108123
return encoderOfType(ctx, reflect2.DefaultTypeOfKind(reflect.String))
@@ -117,17 +132,6 @@ func encoderOfMapKey(ctx *ctx, typ reflect2.Type) ValEncoder {
117132
typ = reflect2.DefaultTypeOfKind(typ.Kind())
118133
return &numericMapKeyEncoder{encoderOfType(ctx, typ)}
119134
default:
120-
if typ == textMarshalerType {
121-
return &directTextMarshalerEncoder{
122-
stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")),
123-
}
124-
}
125-
if typ.Implements(textMarshalerType) {
126-
return &textMarshalerEncoder{
127-
valType: typ,
128-
stringEncoder: ctx.EncoderOf(reflect2.TypeOf("")),
129-
}
130-
}
131135
if typ.Kind() == reflect.Interface {
132136
return &dynamicMapKeyEncoder{ctx, typ}
133137
}

value_tests/map_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ func init() {
3131
map[string]*json.RawMessage{"hello": pRawMessage(json.RawMessage("[]"))},
3232
map[Date]bool{{}: true},
3333
map[Date2]bool{{}: true},
34+
map[customKey]string{customKey(1): "bar"},
3435
)
3536
unmarshalCases = append(unmarshalCases, unmarshalCase{
3637
ptr: (*map[string]string)(nil),
@@ -55,6 +56,9 @@ func init() {
5556
"2018-12-13": true,
5657
"2018-12-14": true
5758
}`,
59+
}, unmarshalCase{
60+
ptr: (*map[customKey]string)(nil),
61+
input: `{"foo": "bar"}`,
5862
})
5963
}
6064

@@ -115,3 +119,14 @@ func (d Date2) UnmarshalJSON(b []byte) error {
115119
func (d Date2) MarshalJSON() ([]byte, error) {
116120
return []byte(d.Time.Format("2006-01-02")), nil
117121
}
122+
123+
type customKey int32
124+
125+
func (c customKey) MarshalText() ([]byte, error) {
126+
return []byte("foo"), nil
127+
}
128+
129+
func (c *customKey) UnmarshalText(value []byte) error {
130+
*c = 1
131+
return nil
132+
}

0 commit comments

Comments
 (0)