Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
7b2a049
feat: added google auth button
VictoriaTskhondia May 2, 2025
fd4386f
fix: fixed linting errors
VictoriaTskhondia May 2, 2025
e3a855b
feat: add auth0 configs
VictoriaTskhondia May 3, 2025
8200fd3
feat: add middleware for sessions
VictoriaTskhondia May 3, 2025
e419fe7
feat: update main.py
VictoriaTskhondia May 3, 2025
7bbde57
feat: add auth0 api endpoints
VictoriaTskhondia May 3, 2025
b094f3d
feat: add missing library
VictoriaTskhondia May 3, 2025
6f596a0
feat: minor change
VictoriaTskhondia May 3, 2025
2f7f91f
fix: apply ruff formating
VictoriaTskhondia May 3, 2025
0a8f3de
fix: apply ruff formating
VictoriaTskhondia May 3, 2025
aa7c339
feat: integrate backend
VictoriaTskhondia May 11, 2025
2acc74d
fix: remove redundant route
VictoriaTskhondia May 11, 2025
bb6626f
fix: remove redundant service file
VictoriaTskhondia May 11, 2025
bf211a5
feat: add GitHub action for backend testing (#77)
ryan331913 May 4, 2025
6b457c6
fix: update migration file
VictoriaTskhondia May 13, 2025
c9e2ffc
fix: add missing package
VictoriaTskhondia May 13, 2025
0513047
fix: add missing authlib package to pyproject.toml
VictoriaTskhondia May 13, 2025
78d0be7
feat: connect backend to frontend
VictoriaTskhondia May 16, 2025
cd97756
fix: fix lint issues
VictoriaTskhondia May 16, 2025
56efd12
fix: fix lint issues
VictoriaTskhondia May 16, 2025
c7a1354
fix: fix ruff issues
VictoriaTskhondia May 16, 2025
efab959
fix: fix lint errors in services.py
VictoriaTskhondia May 16, 2025
c9299ab
fix: fix lint errors in api.py
VictoriaTskhondia May 16, 2025
facde11
feat: replace hardcoded urls
VictoriaTskhondia May 18, 2025
48b0894
feat: add frontend_url and allowed_redirect_origins env variables
VictoriaTskhondia May 18, 2025
84bc11c
Merge branch 'main' into feat/issue-32-enable-google-with-auth0
VictoriaTskhondia May 18, 2025
8a20de7
fix: fix frontend lint errors
VictoriaTskhondia May 18, 2025
4ddf3f5
feat: minor changes
VictoriaTskhondia May 18, 2025
854f9e1
feat: use dependecy injection for db session in auth0_callback
VictoriaTskhondia May 19, 2025
dcff54f
fix: fix lint errors
VictoriaTskhondia May 19, 2025
9ebfb83
fix: organize imports
VictoriaTskhondia May 19, 2025
9ee051d
fix: fix lint errors in auth0_api
VictoriaTskhondia May 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions backend/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ [email protected]
FIRST_SUPERUSER_PASSWORD=changethis
USERS_OPEN_REGISTRATION=True

# Frontend
FRONTEND_URL="http://localhost:5173"

# Postgres
POSTGRES_SERVER=localhost
Expand All @@ -24,3 +26,15 @@ AI_API_KEY="dummy_api_key"

COLLECTION_GENERATION_PROMPT="I want to generate flashcards on a specific topic for efficient studying. Please create a set of flashcards covering key concepts, definitions, important details, and examples, with a focus on progressively building understanding of the topic. The flashcards should aim to provide a helpful learning experience by using structured explanations, real-world examples and formatting. Each flashcard should follow this format: Front (Question/Prompt): A clear and concise question or term to test recall, starting with introductory concepts and moving toward more complex details. Back (Answer): If the front is a concept or topic, provide a detailed explanation, broken down into clear paragraphs with easy-to-understand language. If possible, include a real-world example, analogy or illustrative diagrams to make the concept more memorable and relatable. If the front is a vocabulary word (for language learning), provide a direct translation in the target language. Optional Hint: A short clue to aid recall, especially for more complex concepts. Important: Use valid Markdown format for the back of the flashcard."
CARD_GENERATION_PROMPT="I want to generate a flashcard on a specific topic. The contents of the flashcard should provide helpful information that aim to help the learner retain the concepts given. The flashcard must follow this format: Front (Question/Prompt): A clear and concise question or term to test recall. Back (Answer): If the front is a concept or topic, provide a detailed explanation, broken down into clear paragraphs with easy-to-understand language. If possible, include a real-world example, analogy or illustrative diagrams to make the concept more memorable and relatable. If the front is a vocabulary word (for language learning), provide a direct translation in the target language. Important: Use valid Markdown format for the back of the flashcard."

# Auth0 Configuration
AUTH0_DOMAIN=auth0-domain
AUTH0_CLIENT_ID=auth0-client-id
AUTH0_CLIENT_SECRET=auth0-client-secret
AUTH0_CALLBACK_URL=auth0-callback-url

ALLOWED_REDIRECT_ORIGINS=http://localhost:5173

# Session Configuration
SECRET_KEY=secret-key

3 changes: 3 additions & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ dependencies = [
"fastapi-pagination>=0.12.34",
"bcrypt==4.0.1",
"google-genai>=1.5.0",
"starlette (>=0.46.2,<0.47.0)",
"itsdangerous (>=2.2.0,<3.0.0)",
"authlib (>=1.5.2,<2.0.0)",
]

[tool.uv]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Add auth0_id field to Users table

Revision ID: 1425c896d3ef
Revises: cb16ae472c1e
Create Date: 2025-05-10 00:12:00.358973

"""
from alembic import op
import sqlalchemy as sa
import sqlmodel.sql.sqltypes


# revision identifiers, used by Alembic.
revision = '1425c896d3ef'
down_revision = 'cb16ae472c1e'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('user', sa.Column('auth0_id', sa.String(), nullable=True))

op.alter_column('user', 'hashed_password',
existing_type=sa.VARCHAR(),
nullable=True)
op.create_index(op.f('ix_user_auth0_id'), 'user', ['auth0_id'], unique=False)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_user_auth0_id'), table_name='user')
op.drop_column('user', 'auth0_id')
op.alter_column('user', 'hashed_password',
existing_type=sa.VARCHAR(),
nullable=False)
# ### end Alembic commands ###
19 changes: 12 additions & 7 deletions backend/src/auth/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import timedelta
from typing import Annotated

from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.security import OAuth2PasswordRequestForm

from src.auth.schemas import Token
Expand All @@ -15,7 +15,9 @@

@router.post("/tokens")
def login_access_token(
session: SessionDep, form_data: Annotated[OAuth2PasswordRequestForm, Depends()]
request: Request,
session: SessionDep,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
) -> Token:
user = services.authenticate(
session=session, email=form_data.username, password=form_data.password
Expand All @@ -24,9 +26,12 @@ def login_access_token(
raise HTTPException(status_code=400, detail="Incorrect email or password")
elif not user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")

access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
return Token(
access_token=services.create_access_token(
user.id, expires_delta=access_token_expires
)
)

# Create token
token = services.create_access_token(user.id, expires_delta=access_token_expires)

request.session["user_id"] = str(user.id)

return Token(access_token=token)
68 changes: 68 additions & 0 deletions backend/src/auth/auth0_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from authlib.integrations.starlette_client import OAuth
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import RedirectResponse
from sqlmodel import Session

from src.auth.services import get_or_create_user_by_email # Fix the import path
from src.core.config import settings
from src.core.db import get_db

router = APIRouter(tags=["auth"])

oauth = OAuth()
oauth.register(
name="auth0",
client_id=settings.AUTH0_CLIENT_ID,
client_secret=settings.AUTH0_CLIENT_SECRET,
server_metadata_url=f"https://{settings.AUTH0_DOMAIN}/.well-known/openid-configuration",
client_kwargs={"scope": "openid profile email"},
)


@router.get("/login", name="auth0_login")
async def login(request: Request, redirect_to: str = "/collections"):
request.session["redirect_to"] = redirect_to
redirect_uri = request.url_for("auth0_callback")
return await oauth.auth0.authorize_redirect(
request, redirect_uri, prompt="select_account", connection="google-oauth2"
)


@router.get("/callback", name="auth0_callback")
async def auth0_callback(request: Request, db: Session = Depends(get_db)):
# Exchange code for token
token = await oauth.auth0.authorize_access_token(request)

# Extract user info
user_info = token.get("userinfo")
if not user_info:
user_info = await oauth.auth0.userinfo(token=token)

if not user_info or "email" not in user_info:
raise HTTPException(status_code=400, detail="Invalid user info from Auth0")

# Create or get user in local DB
db_user = get_or_create_user_by_email(
session=db,
email=user_info["email"],
defaults={
"auth0_id": user_info["sub"],
"full_name": user_info.get("name"),
"is_active": True,
},
)

# Store user in session
request.session["user_id"] = str(db_user.id)

# Determine redirect target
redirect_to = request.session.pop("redirect_to", "/collections")
redirect_url = f"{settings.FRONTEND_URL}{redirect_to}"

return RedirectResponse(url=redirect_url)


@router.get("/logout", name="auth0_logout")
async def logout(request: Request):
request.session.clear()
return {"detail": "Logged out"}
93 changes: 80 additions & 13 deletions backend/src/auth/services.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import uuid
from datetime import datetime, timedelta, timezone
from typing import Annotated, Any

import jwt
from fastapi import Depends, HTTPException, status
from fastapi import Depends, HTTPException, Request
from fastapi.security import OAuth2PasswordBearer
from jwt.exceptions import InvalidTokenError
from passlib.context import CryptContext
from pydantic import ValidationError
from sqlmodel import Session
from sqlmodel import Session, select

from src.auth.schemas import TokenPayload
from src.core.config import settings
from src.core.db import get_db
from src.users.models import User
from src.users.schemas import UserPublic

ALGORITHM = "HS256"

Expand All @@ -23,21 +25,64 @@
TokenDep = Annotated[str, Depends(reusable_oauth2)]


def get_current_user(session: SessionDep, token: TokenDep) -> User:
def get_user_from_session(request: Request, session: SessionDep) -> User:
user_id = request.session.get("user_id")
if not user_id:
raise HTTPException(status_code=401, detail="Not authenticated (no session)")

from src.users.services import get_user_by_id

try:
user_uuid = uuid.UUID(user_id)
user = get_user_by_id(session=session, user_id=user_uuid)
if not user or not user.is_active:
raise HTTPException(status_code=401, detail="Invalid session user")
return UserPublic.model_validate(user)
except (ValueError, TypeError):
raise HTTPException(status_code=401, detail="Invalid user ID in session")


def get_user_from_token(
session: SessionDep,
token: Annotated[str, Depends(reusable_oauth2)],
) -> User:
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
token_data = TokenPayload(**payload)

user = session.get(User, token_data.sub)

if not user or not user.is_active:
raise HTTPException(status_code=401, detail="Invalid user")

return user

except (InvalidTokenError, ValidationError):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Could not validate credentials",
)
user = session.get(User, token_data.sub)
if not user:
raise HTTPException(status_code=404, detail="User not found")
if not user.is_active:
raise HTTPException(status_code=400, detail="Inactive user")
return user
raise HTTPException(status_code=403, detail="Invalid token")


def get_current_user(
request: Request,
session: SessionDep,
token: Annotated[
str | None,
Depends(
OAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/tokens", auto_error=False
)
),
] = None,
) -> User:
# Check for token-based authentication first
if token:
return get_user_from_token(session, token)

# Check for session-based authentication
if request.session.get("user_id"):
return get_user_from_session(request, session)

# No valid authentication method found
raise HTTPException(status_code=401, detail="Not authenticated")


CurrentUser = Annotated[User, Depends(get_current_user)]
Expand All @@ -53,8 +98,15 @@ def authenticate(*, session: Session, email: str, password: str) -> User | None:
db_user = get_user_by_email(session=session, email=email)
if not db_user:
return None

# Auth0 users may not have a password
if not db_user.hashed_password:
# Return None for users without a password when using password authentication
return None

if not verify_password(password, db_user.hashed_password):
return None

return db_user


Expand All @@ -67,3 +119,18 @@ def create_access_token(subject: str | Any, expires_delta: timedelta) -> str:

def get_password_hash(password: str) -> str:
return pwd_context.hash(password)


def get_or_create_user_by_email(
session: Session,
email: str,
defaults: dict | None = None,
) -> User:
user = session.exec(select(User).where(User.email == email)).first()
if user:
return user
user = User(email=email, **defaults)
session.add(user)
session.commit()
session.refresh(user)
return user
12 changes: 12 additions & 0 deletions backend/src/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,12 @@ class Settings(BaseSettings):
SECRET_KEY: str = secrets.token_urlsafe(32)
# 60 minutes * 24 hours * 8 days = 8 days
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8
AUTH_EXPIRE_MINUTES: int = 60 * 24 * 8
DOMAIN: str = "localhost"
ENVIRONMENT: Literal["local", "staging", "production"] = "local"

FRONTEND_URL: str

BACKEND_CORS_ORIGINS: Annotated[
list[AnyUrl] | str, BeforeValidator(parse_cors)
] = []
Expand All @@ -46,6 +49,14 @@ class Settings(BaseSettings):
POSTGRES_PASSWORD: str
POSTGRES_DB: str = ""

AUTH0_CLIENT_ID: str
AUTH0_CLIENT_SECRET: str
AUTH0_DOMAIN: str

ALLOWED_REDIRECT_ORIGINS: Annotated[
list[str] | str, BeforeValidator(parse_cors)
] = []

@computed_field # type: ignore[misc]
@property
def SQLALCHEMY_DATABASE_URI(self) -> PostgresDsn:
Expand Down Expand Up @@ -91,6 +102,7 @@ def _enforce_non_default_secrets(self) -> Self:
self._check_default_secret(
"FIRST_SUPERUSER_PASSWORD", self.FIRST_SUPERUSER_PASSWORD
)
self._check_default_secret("AUTH0_CLIENT_SECRET", self.AUTH0_CLIENT_SECRET)

return self

Expand Down
10 changes: 10 additions & 0 deletions backend/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from fastapi.routing import APIRoute
from fastapi_pagination import add_pagination
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware

from src.core.config import settings
from src.routers import api_router
Expand Down Expand Up @@ -29,5 +30,14 @@ def custom_generate_unique_id(route: APIRoute) -> str:
allow_headers=["*"],
)

# Setup session middleware
app.add_middleware(
SessionMiddleware,
secret_key=settings.SECRET_KEY,
same_site="lax", # adjust for production
https_only=False,
max_age=settings.AUTH_EXPIRE_MINUTES * 60,
)

app.include_router(api_router, prefix=settings.API_V1_STR)
add_pagination(app)
2 changes: 2 additions & 0 deletions backend/src/routers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi import APIRouter

from src.auth.api import router as auth_router
from src.auth.auth0_api import router as auth0_router
from src.flashcards.api import router as flashcards_router
from src.stats.api import router as stats_router
from src.users.api import router as user_router
Expand All @@ -11,3 +12,4 @@
api_router.include_router(user_router, prefix="/users", tags=["users"])
api_router.include_router(flashcards_router, tags=["flashcards"])
api_router.include_router(stats_router, tags=["stats"])
api_router.include_router(auth0_router, prefix="/auth0", tags=["auth0"])
3 changes: 2 additions & 1 deletion backend/src/users/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

class User(UserBase, table=True):
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
hashed_password: str
auth0_id: str | None = Field(default=None, index=True)
hashed_password: str | None = Field(default=None)
collections: list["Collection"] = Relationship(
back_populates="user",
cascade_delete=True,
Expand Down
Loading