9
9
10
10
import unittest
11
11
12
- from google .protobuf import field_mask_pb2
12
+ from google .protobuf import descriptor
13
13
from google .protobuf .internal import field_mask
14
14
from google .protobuf .internal import test_util
15
- from google .protobuf import descriptor
15
+
16
+ from google .protobuf import field_mask_pb2
16
17
from google .protobuf import map_unittest_pb2
18
+ from google .protobuf import unittest_no_field_presence_pb2
17
19
from google .protobuf import unittest_pb2
18
20
19
21
@@ -106,22 +108,16 @@ def testCanonicalFrom(self):
106
108
self .assertEqual ('bar,foo.b1,foo.b2' , out_mask .ToJsonString ())
107
109
108
110
# Test more deeply nested cases.
109
- mask .FromJsonString (
110
- 'foo.bar.baz1,foo.bar.baz2.quz,foo.bar.baz2' )
111
+ mask .FromJsonString ('foo.bar.baz1,foo.bar.baz2.quz,foo.bar.baz2' )
111
112
out_mask .CanonicalFormFromMask (mask )
112
- self .assertEqual ('foo.bar.baz1,foo.bar.baz2' ,
113
- out_mask .ToJsonString ())
114
- mask .FromJsonString (
115
- 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz' )
113
+ self .assertEqual ('foo.bar.baz1,foo.bar.baz2' , out_mask .ToJsonString ())
114
+ mask .FromJsonString ('foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz' )
116
115
out_mask .CanonicalFormFromMask (mask )
117
- self .assertEqual ('foo.bar.baz1,foo.bar.baz2' ,
118
- out_mask .ToJsonString ())
119
- mask .FromJsonString (
120
- 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo.bar' )
116
+ self .assertEqual ('foo.bar.baz1,foo.bar.baz2' , out_mask .ToJsonString ())
117
+ mask .FromJsonString ('foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo.bar' )
121
118
out_mask .CanonicalFormFromMask (mask )
122
119
self .assertEqual ('foo.bar' , out_mask .ToJsonString ())
123
- mask .FromJsonString (
124
- 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo' )
120
+ mask .FromJsonString ('foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo' )
125
121
out_mask .CanonicalFormFromMask (mask )
126
122
self .assertEqual ('foo' , out_mask .ToJsonString ())
127
123
@@ -291,6 +287,36 @@ def testMergeMessageWithoutMapFields(self):
291
287
self .assertTrue (dst .HasField ('foo_message' ))
292
288
self .assertFalse (dst .HasField ('foo_lazy_message' ))
293
289
290
+ def testMergeMessageWithoutMapFieldsOrFieldPresence (self ):
291
+ # Test merge one field.
292
+ src = unittest_no_field_presence_pb2 .TestAllTypes ()
293
+ test_util .SetAllFields (src )
294
+ for field in src .DESCRIPTOR .fields :
295
+ if field .containing_oneof :
296
+ continue
297
+ field_name = field .name
298
+ dst = unittest_no_field_presence_pb2 .TestAllTypes ()
299
+ # Only set one path to mask.
300
+ mask = field_mask_pb2 .FieldMask ()
301
+ mask .paths .append (field_name )
302
+ mask .MergeMessage (src , dst )
303
+ # The expected result message.
304
+ msg = unittest_no_field_presence_pb2 .TestAllTypes ()
305
+ if field .label == descriptor .FieldDescriptor .LABEL_REPEATED :
306
+ repeated_src = getattr (src , field_name )
307
+ repeated_msg = getattr (msg , field_name )
308
+ if field .cpp_type == descriptor .FieldDescriptor .CPPTYPE_MESSAGE :
309
+ for item in repeated_src :
310
+ repeated_msg .add ().CopyFrom (item )
311
+ else :
312
+ repeated_msg .extend (repeated_src )
313
+ elif field .cpp_type == descriptor .FieldDescriptor .CPPTYPE_MESSAGE :
314
+ getattr (msg , field_name ).CopyFrom (getattr (src , field_name ))
315
+ else :
316
+ setattr (msg , field_name , getattr (src , field_name ))
317
+ # Only field specified in mask is merged.
318
+ self .assertEqual (msg , dst )
319
+
294
320
def testMergeMessageWithMapField (self ):
295
321
empty_map = map_unittest_pb2 .TestRecursiveMapMessage ()
296
322
src_level_2 = map_unittest_pb2 .TestRecursiveMapMessage ()
@@ -314,6 +340,74 @@ def testMergeMessageWithMapField(self):
314
340
self .assertEqual (dst .a ['src level 1' ], src_level_2 )
315
341
self .assertEqual (dst .a ['dst level 1' ], empty_map )
316
342
343
+ def testMergeMessageWithUnsetFieldsWithFieldPresence (self ):
344
+ # Test merging each empty field one at a time.
345
+ src = unittest_pb2 .TestAllTypes ()
346
+ for field in src .DESCRIPTOR .fields :
347
+ if field .containing_oneof :
348
+ continue
349
+ field_name = field .name
350
+ dst = unittest_pb2 .TestAllTypes ()
351
+ # Only set one path to mask.
352
+ mask = field_mask_pb2 .FieldMask ()
353
+ mask .paths .append (field_name )
354
+ mask .MergeMessage (src , dst )
355
+ # Nothing should be merged.
356
+ self .assertEqual (unittest_pb2 .TestAllTypes (), dst )
357
+
358
+ # Test merge clears previously set fields when source is unset.
359
+ dst_template = unittest_pb2 .TestAllTypes ()
360
+ test_util .SetAllFields (dst_template )
361
+ for field in src .DESCRIPTOR .fields :
362
+ if field .containing_oneof :
363
+ continue
364
+ dst = unittest_pb2 .TestAllTypes ()
365
+ dst .CopyFrom (dst_template )
366
+ # Only set one path to mask.
367
+ mask = field_mask_pb2 .FieldMask ()
368
+ mask .paths .append (field .name )
369
+ mask .MergeMessage (
370
+ src , dst , replace_message_field = True , replace_repeated_field = True
371
+ )
372
+ msg = unittest_pb2 .TestAllTypes ()
373
+ msg .CopyFrom (dst_template )
374
+ msg .ClearField (field .name )
375
+ self .assertEqual (msg , dst )
376
+
377
+ def testMergeMessageWithUnsetFieldsWithoutFieldPresence (self ):
378
+ # Test merging each empty field one at a time.
379
+ src = unittest_no_field_presence_pb2 .TestAllTypes ()
380
+ for field in src .DESCRIPTOR .fields :
381
+ if field .containing_oneof :
382
+ continue
383
+ field_name = field .name
384
+ dst = unittest_no_field_presence_pb2 .TestAllTypes ()
385
+ # Only set one path to mask.
386
+ mask = field_mask_pb2 .FieldMask ()
387
+ mask .paths .append (field_name )
388
+ mask .MergeMessage (src , dst )
389
+ # Nothing should be merged.
390
+ self .assertEqual (unittest_no_field_presence_pb2 .TestAllTypes (), dst )
391
+
392
+ # Test merge clears previously set fields when source is unset.
393
+ dst_template = unittest_no_field_presence_pb2 .TestAllTypes ()
394
+ test_util .SetAllFields (dst_template )
395
+ for field in src .DESCRIPTOR .fields :
396
+ if field .containing_oneof :
397
+ continue
398
+ dst = unittest_no_field_presence_pb2 .TestAllTypes ()
399
+ dst .CopyFrom (dst_template )
400
+ # Only set one path to mask.
401
+ mask = field_mask_pb2 .FieldMask ()
402
+ mask .paths .append (field .name )
403
+ mask .MergeMessage (
404
+ src , dst , replace_message_field = True , replace_repeated_field = True
405
+ )
406
+ msg = unittest_no_field_presence_pb2 .TestAllTypes ()
407
+ msg .CopyFrom (dst_template )
408
+ msg .ClearField (field .name )
409
+ self .assertEqual (msg , dst )
410
+
317
411
def testMergeErrors (self ):
318
412
src = unittest_pb2 .TestAllTypes ()
319
413
dst = unittest_pb2 .TestAllTypes ()
@@ -322,25 +416,26 @@ def testMergeErrors(self):
322
416
mask .FromJsonString ('optionalInt32.field' )
323
417
with self .assertRaises (ValueError ) as e :
324
418
mask .MergeMessage (src , dst )
325
- self .assertEqual ('Error: Field optional_int32 in message '
326
- 'proto2_unittest.TestAllTypes is not a singular '
327
- 'message field and cannot have sub-fields.' ,
328
- str (e .exception ))
419
+ self .assertEqual (
420
+ 'Error: Field optional_int32 in message '
421
+ 'proto2_unittest.TestAllTypes is not a singular '
422
+ 'message field and cannot have sub-fields.' ,
423
+ str (e .exception ),
424
+ )
329
425
330
426
def testSnakeCaseToCamelCase (self ):
331
- self .assertEqual ('fooBar' ,
332
- field_mask ._SnakeCaseToCamelCase ('foo_bar' ))
333
- self .assertEqual ('FooBar' ,
334
- field_mask ._SnakeCaseToCamelCase ('_foo_bar' ))
335
- self .assertEqual ('foo3Bar' ,
336
- field_mask ._SnakeCaseToCamelCase ('foo3_bar' ))
427
+ self .assertEqual ('fooBar' , field_mask ._SnakeCaseToCamelCase ('foo_bar' ))
428
+ self .assertEqual ('FooBar' , field_mask ._SnakeCaseToCamelCase ('_foo_bar' ))
429
+ self .assertEqual ('foo3Bar' , field_mask ._SnakeCaseToCamelCase ('foo3_bar' ))
337
430
338
431
# No uppercase letter is allowed.
339
432
self .assertRaisesRegex (
340
433
ValueError ,
341
434
'Fail to print FieldMask to Json string: Path name Foo must '
342
435
'not contain uppercase letters.' ,
343
- field_mask ._SnakeCaseToCamelCase , 'Foo' )
436
+ field_mask ._SnakeCaseToCamelCase ,
437
+ 'Foo' ,
438
+ )
344
439
# Any character after a "_" must be a lowercase letter.
345
440
# 1. "_" cannot be followed by another "_".
346
441
# 2. "_" cannot be followed by a digit.
@@ -349,28 +444,34 @@ def testSnakeCaseToCamelCase(self):
349
444
ValueError ,
350
445
'Fail to print FieldMask to Json string: The character after a '
351
446
'"_" must be a lowercase letter in path name foo__bar.' ,
352
- field_mask ._SnakeCaseToCamelCase , 'foo__bar' )
447
+ field_mask ._SnakeCaseToCamelCase ,
448
+ 'foo__bar' ,
449
+ )
353
450
self .assertRaisesRegex (
354
451
ValueError ,
355
452
'Fail to print FieldMask to Json string: The character after a '
356
453
'"_" must be a lowercase letter in path name foo_3bar.' ,
357
- field_mask ._SnakeCaseToCamelCase , 'foo_3bar' )
454
+ field_mask ._SnakeCaseToCamelCase ,
455
+ 'foo_3bar' ,
456
+ )
358
457
self .assertRaisesRegex (
359
458
ValueError ,
360
459
'Fail to print FieldMask to Json string: Trailing "_" in path '
361
- 'name foo_bar_.' , field_mask ._SnakeCaseToCamelCase , 'foo_bar_' )
460
+ 'name foo_bar_.' ,
461
+ field_mask ._SnakeCaseToCamelCase ,
462
+ 'foo_bar_' ,
463
+ )
362
464
363
465
def testCamelCaseToSnakeCase (self ):
364
- self .assertEqual ('foo_bar' ,
365
- field_mask ._CamelCaseToSnakeCase ('fooBar' ))
366
- self .assertEqual ('_foo_bar' ,
367
- field_mask ._CamelCaseToSnakeCase ('FooBar' ))
368
- self .assertEqual ('foo3_bar' ,
369
- field_mask ._CamelCaseToSnakeCase ('foo3Bar' ))
466
+ self .assertEqual ('foo_bar' , field_mask ._CamelCaseToSnakeCase ('fooBar' ))
467
+ self .assertEqual ('_foo_bar' , field_mask ._CamelCaseToSnakeCase ('FooBar' ))
468
+ self .assertEqual ('foo3_bar' , field_mask ._CamelCaseToSnakeCase ('foo3Bar' ))
370
469
self .assertRaisesRegex (
371
470
ValueError ,
372
471
'Fail to parse FieldMask: Path name foo_bar must not contain "_"s.' ,
373
- field_mask ._CamelCaseToSnakeCase , 'foo_bar' )
472
+ field_mask ._CamelCaseToSnakeCase ,
473
+ 'foo_bar' ,
474
+ )
374
475
375
476
376
477
if __name__ == '__main__' :
0 commit comments