From eb006c9aac536615f3415dc3b4c07d36ab85e890 Mon Sep 17 00:00:00 2001 From: astr0n0mer <42691857+astr0n0mer@users.noreply.github.com> Date: Wed, 16 Oct 2024 23:44:05 +0530 Subject: [PATCH 1/2] feat: allows adding provider to user --- firebase_admin/_user_mgt.py | 9 ++++++++- integration/test_auth.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index aa0dfb0a4..123793e54 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -688,7 +688,8 @@ def create_user(self, uid=None, display_name=None, email=None, phone_number=None def update_user(self, uid, display_name=None, email=None, phone_number=None, photo_url=None, password=None, disabled=None, email_verified=None, - valid_since=None, custom_claims=None, providers_to_delete=None): + valid_since=None, custom_claims=None, providers_to_delete=None, + provider_to_add=None): """Updates an existing user account with the specified properties""" payload = { 'localId': _auth_utils.validate_uid(uid, required=True), @@ -727,6 +728,12 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, custom_claims, dict) else custom_claims payload['customAttributes'] = _auth_utils.validate_custom_claims(json_claims) + if provider_to_add: + payload['linkProviderUserInfo'] = { + 'rawId': uid, + 'providerId': _auth_utils.validate_provider_id(provider_to_add) + } + if remove_provider: payload['deleteProvider'] = list(set(remove_provider)) diff --git a/integration/test_auth.py b/integration/test_auth.py index e1d01a254..94e2068c0 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -16,6 +16,7 @@ import base64 import datetime import random +import re import string import time from typing import List @@ -30,6 +31,7 @@ import firebase_admin from firebase_admin import auth from firebase_admin import credentials +from firebase_admin import exceptions _verify_token_url = 'https://www.googleapis.com/identitytoolkit/v3/relyingparty/verifyCustomToken' @@ -496,6 +498,39 @@ def test_disable_user(new_user_with_params): assert user.disabled is True assert len(user.provider_data) == 1 +def test_add_valid_provider(new_user_with_provider): + new_provider_id = "microsoft.com" + existing_provider_ids = [provider.provider_id for provider in new_user_with_provider.provider_data] + assert new_provider_id not in existing_provider_ids + user = auth.update_user(new_user_with_provider.uid, provider_to_add=new_provider_id) + assert user.uid == new_user_with_provider.uid + new_provider_ids = [provider.provider_id for provider in user.provider_data] + assert sorted(new_provider_ids) == sorted(existing_provider_ids + [new_provider_id]) + +def test_add_empty_provider(new_user_with_provider): + new_provider_id = "" + existing_provider_ids = [provider.provider_id for provider in new_user_with_provider.provider_data] + user = auth.update_user(new_user_with_provider.uid, provider_to_add=new_provider_id) + assert user.uid == new_user_with_provider.uid + new_provider_ids = [provider.provider_id for provider in user.provider_data] + assert sorted(new_provider_ids) == sorted(existing_provider_ids) + +def test_add_invalid_provider(new_user_with_provider): + new_provider_id = "xyz" + existing_provider_ids = [provider.provider_id for provider in new_user_with_provider.provider_data] + assert new_provider_id not in existing_provider_ids + with pytest.raises(exceptions.InvalidArgumentError, match=re.escape( + f"Error while calling Auth service (INVALID_PROVIDER_ID ). provider {new_provider_id} is not supported for linking." + )): + auth.update_user(new_user_with_provider.uid, provider_to_add=new_provider_id) + +def test_add_duplicate_provider(new_user_with_provider): + new_provider_id = "google.com" + with pytest.raises(exceptions.InvalidArgumentError, match=re.escape( + f"Error while calling Auth service (PROVIDER_ALREADY_LINKED)." + )): + auth.update_user(new_user_with_provider.uid, provider_to_add=new_provider_id) + def test_remove_provider(new_user_with_provider): provider_ids = [provider.provider_id for provider in new_user_with_provider.provider_data] assert 'google.com' in provider_ids From 8568d0920c5c3e2bdc5a012ee2c108ac98cb0078 Mon Sep 17 00:00:00 2001 From: astr0n0mer <42691857+astr0n0mer@users.noreply.github.com> Date: Thu, 17 Oct 2024 14:00:48 +0530 Subject: [PATCH 2/2] fix: changes provider_to_add type from str to UserProvider --- firebase_admin/_user_mgt.py | 9 +++------ integration/test_auth.py | 29 +++++++++++------------------ 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/firebase_admin/_user_mgt.py b/firebase_admin/_user_mgt.py index 123793e54..22802979d 100644 --- a/firebase_admin/_user_mgt.py +++ b/firebase_admin/_user_mgt.py @@ -25,7 +25,7 @@ from firebase_admin import _rfc3339 from firebase_admin import _user_identifier from firebase_admin import _user_import -from firebase_admin._user_import import ErrorInfo +from firebase_admin._user_import import ErrorInfo, UserProvider MAX_LIST_USERS_RESULTS = 1000 @@ -689,7 +689,7 @@ def create_user(self, uid=None, display_name=None, email=None, phone_number=None def update_user(self, uid, display_name=None, email=None, phone_number=None, photo_url=None, password=None, disabled=None, email_verified=None, valid_since=None, custom_claims=None, providers_to_delete=None, - provider_to_add=None): + provider_to_add: UserProvider | None=None): """Updates an existing user account with the specified properties""" payload = { 'localId': _auth_utils.validate_uid(uid, required=True), @@ -729,10 +729,7 @@ def update_user(self, uid, display_name=None, email=None, phone_number=None, payload['customAttributes'] = _auth_utils.validate_custom_claims(json_claims) if provider_to_add: - payload['linkProviderUserInfo'] = { - 'rawId': uid, - 'providerId': _auth_utils.validate_provider_id(provider_to_add) - } + payload['linkProviderUserInfo'] = provider_to_add.to_dict() if remove_provider: payload['deleteProvider'] = list(set(remove_provider)) diff --git a/integration/test_auth.py b/integration/test_auth.py index 94e2068c0..dfcab5f8c 100644 --- a/integration/test_auth.py +++ b/integration/test_auth.py @@ -499,37 +499,30 @@ def test_disable_user(new_user_with_params): assert len(user.provider_data) == 1 def test_add_valid_provider(new_user_with_provider): - new_provider_id = "microsoft.com" + new_provider = auth.UserProvider(uid=new_user_with_provider.uid, provider_id='microsoft.com') existing_provider_ids = [provider.provider_id for provider in new_user_with_provider.provider_data] - assert new_provider_id not in existing_provider_ids - user = auth.update_user(new_user_with_provider.uid, provider_to_add=new_provider_id) + assert new_provider.provider_id not in existing_provider_ids + user = auth.update_user(new_user_with_provider.uid, provider_to_add=new_provider) assert user.uid == new_user_with_provider.uid new_provider_ids = [provider.provider_id for provider in user.provider_data] - assert sorted(new_provider_ids) == sorted(existing_provider_ids + [new_provider_id]) - -def test_add_empty_provider(new_user_with_provider): - new_provider_id = "" - existing_provider_ids = [provider.provider_id for provider in new_user_with_provider.provider_data] - user = auth.update_user(new_user_with_provider.uid, provider_to_add=new_provider_id) - assert user.uid == new_user_with_provider.uid - new_provider_ids = [provider.provider_id for provider in user.provider_data] - assert sorted(new_provider_ids) == sorted(existing_provider_ids) + assert sorted(new_provider_ids) == sorted(existing_provider_ids + [new_provider.provider_id]) def test_add_invalid_provider(new_user_with_provider): - new_provider_id = "xyz" + new_provider = auth.UserProvider(uid=new_user_with_provider.uid, provider_id='xyz.com') existing_provider_ids = [provider.provider_id for provider in new_user_with_provider.provider_data] - assert new_provider_id not in existing_provider_ids + assert new_provider.provider_id not in existing_provider_ids with pytest.raises(exceptions.InvalidArgumentError, match=re.escape( - f"Error while calling Auth service (INVALID_PROVIDER_ID ). provider {new_provider_id} is not supported for linking." + f"Error while calling Auth service (INVALID_PROVIDER_ID ). provider {new_provider.provider_id} is not supported for linking." )): - auth.update_user(new_user_with_provider.uid, provider_to_add=new_provider_id) + auth.update_user(new_user_with_provider.uid, provider_to_add=new_provider) def test_add_duplicate_provider(new_user_with_provider): - new_provider_id = "google.com" + google_uid, google_email = _random_id() + duplicate_provider = auth.UserProvider(uid=google_uid, provider_id='google.com', email=google_email) with pytest.raises(exceptions.InvalidArgumentError, match=re.escape( f"Error while calling Auth service (PROVIDER_ALREADY_LINKED)." )): - auth.update_user(new_user_with_provider.uid, provider_to_add=new_provider_id) + auth.update_user(new_user_with_provider.uid, provider_to_add=duplicate_provider) def test_remove_provider(new_user_with_provider): provider_ids = [provider.provider_id for provider in new_user_with_provider.provider_data]