We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 56453cd commit a4ae812Copy full SHA for a4ae812
official/legacy/image_classification/resnet/common.py
@@ -106,10 +106,14 @@ def get_config(self):
106
}
107
108
109
-def get_optimizer(learning_rate=0.1):
+def get_optimizer(learning_rate=0.1, use_legacy_optimizer=True):
110
"""Returns optimizer to use."""
111
# The learning_rate is overwritten at the beginning of each step by callback.
112
- return tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
+ if use_legacy_optimizer:
113
+ return tf.keras.optimizers.legacy.SGD(
114
+ learning_rate=learning_rate, momentum=0.9)
115
+ else:
116
+ return tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
117
118
119
def get_callbacks(pruning_method=None,
0 commit comments