Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fief/alembic.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ hooks = codemod,format,lint

codemod.type = exec
codemod.executable = python
codemod.options = -m libcst.tool codemod -x fief.alembic.table_prefix_codemod.ConvertTablePrefixStrings REVISION_SCRIPT_FILENAME
codemod.options = -m libcst.tool codemod --no-format -x fief.alembic.table_prefix_codemod.ConvertTablePrefixStrings REVISION_SCRIPT_FILENAME

format.type = exec
format.executable = ruff
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Add unique constraint on OAuthAccount provider_id and account_id

Revision ID: a736fe95ec4f
Revises: 6d9fa141730c
Create Date: 2024-08-29 09:01:04.397106

"""

import sqlalchemy as sa
from alembic import op

import fief

# revision identifiers, used by Alembic.
revision = "a736fe95ec4f"
down_revision = "6d9fa141730c"
branch_labels = None
depends_on = None


def upgrade():
table_prefix = op.get_context().opts["table_prefix"]
# ### commands auto generated by Alembic - please adjust! ###
op.execute(
f"""
DELETE FROM {table_prefix}oauth_accounts
WHERE user_id IS NULL
"""
)

connection = op.get_bind()
if connection.dialect.name == "sqlite":
with op.batch_alter_table(f"{table_prefix}oauth_accounts") as batch_op:
batch_op.create_unique_constraint(
op.f(f"{table_prefix}oauth_accounts_oauth_provider_id_account_id_key"),
["oauth_provider_id", "account_id"],
)
else:
op.create_unique_constraint(
op.f(f"{table_prefix}oauth_accounts_oauth_provider_id_account_id_key"),
f"{table_prefix}oauth_accounts",
["oauth_provider_id", "account_id"],
)
# ### end Alembic commands ###


def downgrade():
table_prefix = op.get_context().opts["table_prefix"]
# ### commands auto generated by Alembic - please adjust! ###

connection = op.get_bind()
if connection.dialect.name == "sqlite":
with op.batch_alter_table(f"{table_prefix}oauth_accounts") as batch_op:
batch_op.drop_constraint(
op.f(f"{table_prefix}oauth_accounts_oauth_provider_id_account_id_key"),
type_="unique",
)
else:
op.drop_constraint(
op.f(f"{table_prefix}oauth_accounts_oauth_provider_id_account_id_key"),
f"{table_prefix}oauth_accounts",
type_="unique",
)
# ### end Alembic commands ###
46 changes: 31 additions & 15 deletions fief/apps/auth/routers/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ async def callback(
)

# Existing account
if oauth_account is not None and oauth_account.user is not None:
if oauth_account is not None:
user = oauth_account.user
if not user.is_active:
if user and not user.is_active:
raise OAuthException(
OAuthError.get_inactive_user(_("Your account is inactive.")),
oauth_providers=oauth_providers,
Expand All @@ -149,23 +149,39 @@ async def callback(
oauth_account.expires_at = expires_at
await oauth_account_repository.update(oauth_account)

# Redirect to consent or profile
if login_session is not None:
response = RedirectResponse(
tenant.url_path_for(request, "auth:consent"),
status_code=status.HTTP_302_FOUND,
if user:
# Redirect to consent or profile
if login_session is not None:
response = RedirectResponse(
tenant.url_path_for(request, "auth:consent"),
status_code=status.HTTP_302_FOUND,
)
else:
response = RedirectResponse(
tenant.url_path_for(request, "auth.dashboard:profile"),
status_code=status.HTTP_302_FOUND,
)
response = await authentication_flow.rotate_session_token(
response, user.id, session_token=session_token
)
else:
response = RedirectResponse(
tenant.url_path_for(request, "auth.dashboard:profile"),
status_code=status.HTTP_302_FOUND,
response = await authentication_flow.set_login_hint(
response, str(oauth_provider.id)
)
response = await authentication_flow.rotate_session_token(
response, user.id, session_token=session_token
return response

# Redirect to register
response = RedirectResponse(
tenant.url_path_for(request, "register:register"),
status_code=status.HTTP_302_FOUND,
)
response = await authentication_flow.set_login_hint(
response, str(oauth_provider.id)

await registration_flow.create_registration_session(
response,
RegistrationSessionFlow.OAUTH,
tenant=tenant,
oauth_account=oauth_account,
)

return response

# New account to create
Expand Down
11 changes: 11 additions & 0 deletions fief/models/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

from sqlalchemy import MetaData
from sqlalchemy.orm import DeclarativeBase

from fief.settings import settings
Expand All @@ -16,6 +17,16 @@ def get_prefixed_tablename(name: str) -> str:


class Base(DeclarativeBase):
metadata = MetaData(
naming_convention={
"ix": "ix_%(column_0_label)s",
"uq": "%(table_name)s_%(column_0_N_name)s_key",
"ck": "ck_%(table_name)s_`%(constraint_name)s`",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
)

def __init_subclass__(cls) -> None:
cls.__tablename__ = get_prefixed_tablename(cls.__tablename__)
super().__init_subclass__()
5 changes: 4 additions & 1 deletion fief/models/oauth_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

class OAuthAccount(UUIDModel, CreatedUpdatedAt, Base):
__tablename__ = "oauth_accounts"
__table_args__ = (UniqueConstraint("oauth_provider_id", "user_id"),)
__table_args__ = (
UniqueConstraint("oauth_provider_id", "user_id"),
UniqueConstraint("oauth_provider_id", "account_id"),
)

access_token: Mapped[str] = mapped_column(
StringEncryptedType(Text, settings.encryption_key, FernetEngine), nullable=False
Expand Down
2 changes: 1 addition & 1 deletion fief/repositories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ async def all(self) -> list[M]:

async def get_one_or_none(self, statement: Select) -> M | None:
result = await self._execute_query(statement)
return result.scalar()
return result.scalar_one_or_none()

async def create(self, object: M) -> M:
self.session.add(object)
Expand Down
73 changes: 73 additions & 0 deletions tests/test_apps_auth_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,79 @@ async def test_existing_oauth_account(
assert oauth_account.expires_at is not None
assert updated_oauth_account.expires_at > oauth_account.expires_at

async def test_dangling_account(
self,
mocker: MockerFixture,
test_client_auth: httpx.AsyncClient,
test_data: TestData,
main_session: AsyncSession,
):
login_session = test_data["login_sessions"]["default"]
client = login_session.client
tenant = client.tenant
path_prefix = tenant.slug if not tenant.default else ""

oauth_session = test_data["oauth_sessions"]["default_google"]
oauth_account = test_data["oauth_accounts"]["new_user_google"]

cookies = {}
cookies[settings.login_session_cookie_name] = login_session.token

oauth_provider_service_mock = MagicMock(spec=BaseOAuth2)
oauth_provider_service_mock.get_access_token.side_effect = AsyncMock(
return_value={
"access_token": "ACCESS_TOKEN",
"expires_in": 3600,
"expires_at": int(datetime.now(UTC).timestamp() + 3600),
"refresh_token": "REFRESH_TOKEN",
}
)
oauth_provider_service_mock.get_id_email.side_effect = AsyncMock(
return_value=(oauth_account.account_id, oauth_account.account_email)
)
mocker.patch(
"fief.apps.auth.routers.oauth.get_oauth_provider_service"
).return_value = oauth_provider_service_mock

response = await test_client_auth.get(
"/oauth/callback",
params={
"code": "CODE",
"redirect_uri": oauth_session.redirect_uri,
"state": oauth_session.token,
},
cookies=cookies,
)

assert response.status_code == status.HTTP_302_FOUND

redirect_uri = response.headers["Location"]
assert redirect_uri.endswith(f"{path_prefix}/register")

oauth_account_repository = OAuthAccountRepository(main_session)
updated_oauth_account = (
await oauth_account_repository.get_by_provider_and_account_id(
oauth_session.oauth_provider_id, oauth_account.account_id
)
)
assert updated_oauth_account is not None
assert updated_oauth_account.id == oauth_account.id
assert updated_oauth_account.access_token == "ACCESS_TOKEN"
assert updated_oauth_account.refresh_token == "REFRESH_TOKEN"
assert updated_oauth_account.user is None

registration_session_cookie = response.cookies[
settings.registration_session_cookie_name
]
registration_session_repository = RegistrationSessionRepository(main_session)
registration_session = await registration_session_repository.get_by_token(
registration_session_cookie
)
assert registration_session is not None
assert registration_session.flow == RegistrationSessionFlow.OAUTH
assert registration_session.oauth_account_id == updated_oauth_account.id
assert registration_session.email == updated_oauth_account.account_email

async def test_new_account(
self,
mocker: MockerFixture,
Expand Down