Skip to content

Commit 1c87ac3

Browse files
Bug fix for FieldMask.MergeFrom() with unset fields.
The Python field mask implementation was incorrectly setting fields in the `destination` even if the `source` proto had presence bits and the field was unset. This behavior differed from C++ where the [`HasField()`](http://google3/third_party/protobuf/util/field_mask_util.cc;l=460;rcl=647700575) is checked before setting. PiperOrigin-RevId: 746132583
1 parent 67cf7ea commit 1c87ac3

File tree

3 files changed

+155
-43
lines changed

3 files changed

+155
-43
lines changed

python/google/protobuf/internal/field_mask.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,10 @@ def _MergeMessage(
293293
destination.ClearField(_StrConvert(name))
294294
if source.HasField(name):
295295
getattr(destination, name).MergeFrom(getattr(source, name))
296-
else:
296+
elif not field.has_presence or source.HasField(name):
297297
setattr(destination, name, getattr(source, name))
298+
else:
299+
destination.ClearField(_StrConvert(name))
298300

299301

300302
def _AddFieldPaths(node, prefix, field_mask):

python/google/protobuf/internal/field_mask_test.py

+136-35
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99

1010
import unittest
1111

12-
from google.protobuf import field_mask_pb2
12+
from google.protobuf import descriptor
1313
from google.protobuf.internal import field_mask
1414
from google.protobuf.internal import test_util
15-
from google.protobuf import descriptor
15+
16+
from google.protobuf import field_mask_pb2
1617
from google.protobuf import map_unittest_pb2
18+
from google.protobuf import unittest_no_field_presence_pb2
1719
from google.protobuf import unittest_pb2
1820

1921

@@ -106,22 +108,16 @@ def testCanonicalFrom(self):
106108
self.assertEqual('bar,foo.b1,foo.b2', out_mask.ToJsonString())
107109

108110
# Test more deeply nested cases.
109-
mask.FromJsonString(
110-
'foo.bar.baz1,foo.bar.baz2.quz,foo.bar.baz2')
111+
mask.FromJsonString('foo.bar.baz1,foo.bar.baz2.quz,foo.bar.baz2')
111112
out_mask.CanonicalFormFromMask(mask)
112-
self.assertEqual('foo.bar.baz1,foo.bar.baz2',
113-
out_mask.ToJsonString())
114-
mask.FromJsonString(
115-
'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz')
113+
self.assertEqual('foo.bar.baz1,foo.bar.baz2', out_mask.ToJsonString())
114+
mask.FromJsonString('foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz')
116115
out_mask.CanonicalFormFromMask(mask)
117-
self.assertEqual('foo.bar.baz1,foo.bar.baz2',
118-
out_mask.ToJsonString())
119-
mask.FromJsonString(
120-
'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo.bar')
116+
self.assertEqual('foo.bar.baz1,foo.bar.baz2', out_mask.ToJsonString())
117+
mask.FromJsonString('foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo.bar')
121118
out_mask.CanonicalFormFromMask(mask)
122119
self.assertEqual('foo.bar', out_mask.ToJsonString())
123-
mask.FromJsonString(
124-
'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo')
120+
mask.FromJsonString('foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo')
125121
out_mask.CanonicalFormFromMask(mask)
126122
self.assertEqual('foo', out_mask.ToJsonString())
127123

@@ -291,6 +287,36 @@ def testMergeMessageWithoutMapFields(self):
291287
self.assertTrue(dst.HasField('foo_message'))
292288
self.assertFalse(dst.HasField('foo_lazy_message'))
293289

290+
def testMergeMessageWithoutMapFieldsOrFieldPresence(self):
291+
# Test merge one field.
292+
src = unittest_no_field_presence_pb2.TestAllTypes()
293+
test_util.SetAllFields(src)
294+
for field in src.DESCRIPTOR.fields:
295+
if field.containing_oneof:
296+
continue
297+
field_name = field.name
298+
dst = unittest_no_field_presence_pb2.TestAllTypes()
299+
# Only set one path to mask.
300+
mask = field_mask_pb2.FieldMask()
301+
mask.paths.append(field_name)
302+
mask.MergeMessage(src, dst)
303+
# The expected result message.
304+
msg = unittest_no_field_presence_pb2.TestAllTypes()
305+
if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
306+
repeated_src = getattr(src, field_name)
307+
repeated_msg = getattr(msg, field_name)
308+
if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
309+
for item in repeated_src:
310+
repeated_msg.add().CopyFrom(item)
311+
else:
312+
repeated_msg.extend(repeated_src)
313+
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
314+
getattr(msg, field_name).CopyFrom(getattr(src, field_name))
315+
else:
316+
setattr(msg, field_name, getattr(src, field_name))
317+
# Only field specified in mask is merged.
318+
self.assertEqual(msg, dst)
319+
294320
def testMergeMessageWithMapField(self):
295321
empty_map = map_unittest_pb2.TestRecursiveMapMessage()
296322
src_level_2 = map_unittest_pb2.TestRecursiveMapMessage()
@@ -314,6 +340,74 @@ def testMergeMessageWithMapField(self):
314340
self.assertEqual(dst.a['src level 1'], src_level_2)
315341
self.assertEqual(dst.a['dst level 1'], empty_map)
316342

343+
def testMergeMessageWithUnsetFieldsWithFieldPresence(self):
344+
# Test merging each empty field one at a time.
345+
src = unittest_pb2.TestAllTypes()
346+
for field in src.DESCRIPTOR.fields:
347+
if field.containing_oneof:
348+
continue
349+
field_name = field.name
350+
dst = unittest_pb2.TestAllTypes()
351+
# Only set one path to mask.
352+
mask = field_mask_pb2.FieldMask()
353+
mask.paths.append(field_name)
354+
mask.MergeMessage(src, dst)
355+
# Nothing should be merged.
356+
self.assertEqual(unittest_pb2.TestAllTypes(), dst)
357+
358+
# Test merge clears previously set fields when source is unset.
359+
dst_template = unittest_pb2.TestAllTypes()
360+
test_util.SetAllFields(dst_template)
361+
for field in src.DESCRIPTOR.fields:
362+
if field.containing_oneof:
363+
continue
364+
dst = unittest_pb2.TestAllTypes()
365+
dst.CopyFrom(dst_template)
366+
# Only set one path to mask.
367+
mask = field_mask_pb2.FieldMask()
368+
mask.paths.append(field.name)
369+
mask.MergeMessage(
370+
src, dst, replace_message_field=True, replace_repeated_field=True
371+
)
372+
msg = unittest_pb2.TestAllTypes()
373+
msg.CopyFrom(dst_template)
374+
msg.ClearField(field.name)
375+
self.assertEqual(msg, dst)
376+
377+
def testMergeMessageWithUnsetFieldsWithoutFieldPresence(self):
378+
# Test merging each empty field one at a time.
379+
src = unittest_no_field_presence_pb2.TestAllTypes()
380+
for field in src.DESCRIPTOR.fields:
381+
if field.containing_oneof:
382+
continue
383+
field_name = field.name
384+
dst = unittest_no_field_presence_pb2.TestAllTypes()
385+
# Only set one path to mask.
386+
mask = field_mask_pb2.FieldMask()
387+
mask.paths.append(field_name)
388+
mask.MergeMessage(src, dst)
389+
# Nothing should be merged.
390+
self.assertEqual(unittest_no_field_presence_pb2.TestAllTypes(), dst)
391+
392+
# Test merge clears previously set fields when source is unset.
393+
dst_template = unittest_no_field_presence_pb2.TestAllTypes()
394+
test_util.SetAllFields(dst_template)
395+
for field in src.DESCRIPTOR.fields:
396+
if field.containing_oneof:
397+
continue
398+
dst = unittest_no_field_presence_pb2.TestAllTypes()
399+
dst.CopyFrom(dst_template)
400+
# Only set one path to mask.
401+
mask = field_mask_pb2.FieldMask()
402+
mask.paths.append(field.name)
403+
mask.MergeMessage(
404+
src, dst, replace_message_field=True, replace_repeated_field=True
405+
)
406+
msg = unittest_no_field_presence_pb2.TestAllTypes()
407+
msg.CopyFrom(dst_template)
408+
msg.ClearField(field.name)
409+
self.assertEqual(msg, dst)
410+
317411
def testMergeErrors(self):
318412
src = unittest_pb2.TestAllTypes()
319413
dst = unittest_pb2.TestAllTypes()
@@ -322,25 +416,26 @@ def testMergeErrors(self):
322416
mask.FromJsonString('optionalInt32.field')
323417
with self.assertRaises(ValueError) as e:
324418
mask.MergeMessage(src, dst)
325-
self.assertEqual('Error: Field optional_int32 in message '
326-
'proto2_unittest.TestAllTypes is not a singular '
327-
'message field and cannot have sub-fields.',
328-
str(e.exception))
419+
self.assertEqual(
420+
'Error: Field optional_int32 in message '
421+
'proto2_unittest.TestAllTypes is not a singular '
422+
'message field and cannot have sub-fields.',
423+
str(e.exception),
424+
)
329425

330426
def testSnakeCaseToCamelCase(self):
331-
self.assertEqual('fooBar',
332-
field_mask._SnakeCaseToCamelCase('foo_bar'))
333-
self.assertEqual('FooBar',
334-
field_mask._SnakeCaseToCamelCase('_foo_bar'))
335-
self.assertEqual('foo3Bar',
336-
field_mask._SnakeCaseToCamelCase('foo3_bar'))
427+
self.assertEqual('fooBar', field_mask._SnakeCaseToCamelCase('foo_bar'))
428+
self.assertEqual('FooBar', field_mask._SnakeCaseToCamelCase('_foo_bar'))
429+
self.assertEqual('foo3Bar', field_mask._SnakeCaseToCamelCase('foo3_bar'))
337430

338431
# No uppercase letter is allowed.
339432
self.assertRaisesRegex(
340433
ValueError,
341434
'Fail to print FieldMask to Json string: Path name Foo must '
342435
'not contain uppercase letters.',
343-
field_mask._SnakeCaseToCamelCase, 'Foo')
436+
field_mask._SnakeCaseToCamelCase,
437+
'Foo',
438+
)
344439
# Any character after a "_" must be a lowercase letter.
345440
# 1. "_" cannot be followed by another "_".
346441
# 2. "_" cannot be followed by a digit.
@@ -349,28 +444,34 @@ def testSnakeCaseToCamelCase(self):
349444
ValueError,
350445
'Fail to print FieldMask to Json string: The character after a '
351446
'"_" must be a lowercase letter in path name foo__bar.',
352-
field_mask._SnakeCaseToCamelCase, 'foo__bar')
447+
field_mask._SnakeCaseToCamelCase,
448+
'foo__bar',
449+
)
353450
self.assertRaisesRegex(
354451
ValueError,
355452
'Fail to print FieldMask to Json string: The character after a '
356453
'"_" must be a lowercase letter in path name foo_3bar.',
357-
field_mask._SnakeCaseToCamelCase, 'foo_3bar')
454+
field_mask._SnakeCaseToCamelCase,
455+
'foo_3bar',
456+
)
358457
self.assertRaisesRegex(
359458
ValueError,
360459
'Fail to print FieldMask to Json string: Trailing "_" in path '
361-
'name foo_bar_.', field_mask._SnakeCaseToCamelCase, 'foo_bar_')
460+
'name foo_bar_.',
461+
field_mask._SnakeCaseToCamelCase,
462+
'foo_bar_',
463+
)
362464

363465
def testCamelCaseToSnakeCase(self):
364-
self.assertEqual('foo_bar',
365-
field_mask._CamelCaseToSnakeCase('fooBar'))
366-
self.assertEqual('_foo_bar',
367-
field_mask._CamelCaseToSnakeCase('FooBar'))
368-
self.assertEqual('foo3_bar',
369-
field_mask._CamelCaseToSnakeCase('foo3Bar'))
466+
self.assertEqual('foo_bar', field_mask._CamelCaseToSnakeCase('fooBar'))
467+
self.assertEqual('_foo_bar', field_mask._CamelCaseToSnakeCase('FooBar'))
468+
self.assertEqual('foo3_bar', field_mask._CamelCaseToSnakeCase('foo3Bar'))
370469
self.assertRaisesRegex(
371470
ValueError,
372471
'Fail to parse FieldMask: Path name foo_bar must not contain "_"s.',
373-
field_mask._CamelCaseToSnakeCase, 'foo_bar')
472+
field_mask._CamelCaseToSnakeCase,
473+
'foo_bar',
474+
)
374475

375476

376477
if __name__ == '__main__':

python/google/protobuf/internal/test_util.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,12 @@ def SetAllNonLazyFields(message):
6767
message.optionalgroup.a = 117
6868
message.optional_nested_message.bb = 118
6969
message.optional_foreign_message.c = 119
70-
message.optional_import_message.d = 120
71-
message.optional_public_import_message.e = 126
70+
if hasattr(message, 'optional_import_message'):
71+
message.optional_import_message.d = 120
72+
if hasattr(message, 'optional_public_import_message'):
73+
message.optional_public_import_message.e = 126
74+
if hasattr(message, 'optional_proto2_message'):
75+
SetAllFields(message.optional_proto2_message)
7276

7377
message.optional_nested_enum = unittest_pb2.TestAllTypes.BAZ
7478
message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ
@@ -77,7 +81,8 @@ def SetAllNonLazyFields(message):
7781

7882
message.optional_string_piece = u'124'
7983
message.optional_cord = u'125'
80-
message.optional_bytes_cord = b'optional bytes cord'
84+
if hasattr(message, 'optional_bytes_cord'):
85+
message.optional_bytes_cord = b'optional bytes cord'
8186

8287
#
8388
# Repeated fields.
@@ -103,7 +108,8 @@ def SetAllNonLazyFields(message):
103108
message.repeatedgroup.add().a = 217
104109
message.repeated_nested_message.add().bb = 218
105110
message.repeated_foreign_message.add().c = 219
106-
message.repeated_import_message.add().d = 220
111+
if hasattr(message, 'repeated_import_message'):
112+
message.repeated_import_message.add().d = 220
107113
message.repeated_lazy_message.add().bb = 227
108114

109115
message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR)
@@ -150,7 +156,8 @@ def SetAllNonLazyFields(message):
150156
message.repeatedgroup.add().a = 317
151157
message.repeated_nested_message.add().bb = 318
152158
message.repeated_foreign_message.add().c = 319
153-
message.repeated_import_message.add().d = 320
159+
if hasattr(message, 'repeated_import_message'):
160+
message.repeated_import_message.add().d = 320
154161
message.repeated_lazy_message.add().bb = 327
155162

156163
message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR)
@@ -193,13 +200,15 @@ def SetAllNonLazyFields(message):
193200
message.oneof_uint32 = 601
194201
message.oneof_nested_message.bb = 602
195202
message.oneof_string = '603'
196-
message.oneof_bytes = b'604'
203+
if hasattr(message, 'oneof_bytes'):
204+
message.oneof_bytes = b'604'
197205

198206

199207
def SetAllFields(message):
200208
SetAllNonLazyFields(message)
201209
message.optional_lazy_message.bb = 127
202-
message.optional_unverified_lazy_message.bb = 128
210+
if hasattr(message, 'optional_unverified_lazy_message'):
211+
message.optional_unverified_lazy_message.bb = 128
203212

204213

205214
def SetAllExtensions(message):

0 commit comments

Comments
 (0)