From 4182f9483192ff7fda9d47dee2046a5ca9be05c9 Mon Sep 17 00:00:00 2001 From: "Huang, Huan (PG/T - Comp Sci & Elec Eng)" <ch01402@surrey.ac.uk> Date: Thu, 2 May 2024 19:36:23 +0000 Subject: [PATCH] OAuth2 with JWT (JSON Web Token) can be a powerful and relatively straightforward way to handle authentication in your FastAPI application. Here's how you can implement it effectively: Understanding OAuth2 and JWT: OAuth2: A standard protocol for authorization, allowing third-party applications to access resources securely. JWT: A compact, self-contained token format that encodes user information and claims. JWTs can be signed with a secret key or a public/private key pair. auth.py is for later database. #Auth# in main is for current my_ride. --- main.py | 355 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 355 insertions(+) create mode 100644 main.py diff --git a/main.py b/main.py new file mode 100644 index 0000000..9200138 --- /dev/null +++ b/main.py @@ -0,0 +1,355 @@ +from fastapi import FastAPI, HTTPException, Depends, status +from fastapi.middleware.cors import CORSMiddleware +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from pydantic import BaseModel +from jose import JWTError, jwt +from passlib.context import CryptContext +from pathlib import Path +import sqlite3 +from typing import Optional, List +from datetime import datetime, timedelta + +app = FastAPI() + +# Allow requests from all origins +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE"], + allow_headers=["*"], +) + +# SQLite connection +parent_directory = Path(__file__).resolve().parent.parent +db_file_path = parent_directory / "my_ride.db" +# print(db_file_path) +conn = sqlite3.connect(db_file_path) + +cursor = conn.cursor() + +# Create the Users table if it doesn't exist +cursor.execute(''' + CREATE TABLE IF NOT EXISTS Users ( + user_id INTEGER PRIMARY KEY, + username TEXT, + password TEXT, + email TEXT, + phone_number TEXT, + credit_card_info TEXT, + registration_date DATETIME, + last_login DATETIME + ) +''') +conn.commit() + + +class User(BaseModel): + username: str + password: str + email: str + phone_number: str + credit_card_info: Optional[str] = None + registration_date: str + last_login: Optional[str] = None + + +class UserResponse(BaseModel): + user_id: int + username: str + email: str + phone_number: str + registration_date: str + + +# Modify the read_users() function to return a list of UserResponse objects +@app.get("/users/", response_model=List[UserResponse]) +async def read_users(): + cursor.execute('SELECT user_id, username, email, phone_number, registration_date FROM Users') + print("got record") + users = cursor.fetchall() + user_objects = [] + for user in users: + user_obj = UserResponse( + user_id=user[0], + username=user[1], + email=user[2], + phone_number=user[3], + registration_date=user[4] + ) + user_objects.append(user_obj) + return user_objects + + +# Routes +@app.post("/users/", response_model=User) +async def create_user(user: User): + print('hit') + # Use a tuple to conditionally include the optional fields in the SQL query + if user.credit_card_info is None and user.last_login is None: + cursor.execute(''' + INSERT INTO Users + (username, password, email, phone_number, registration_date) + VALUES (?, ?, ?, ?, ?) + ''', ( + user.username, user.password, user.email, user.phone_number, + user.registration_date + )) + elif user.credit_card_info is None: + cursor.execute(''' + INSERT INTO Users + (username, password, email, phone_number, registration_date, last_login) + VALUES (?, ?, ?, ?, ?, ?) + ''', ( + user.username, user.password, user.email, user.phone_number, + user.registration_date, user.last_login + )) + elif user.last_login is None: + cursor.execute(''' + INSERT INTO Users + (username, password, email, phone_number, credit_card_info, registration_date) + VALUES (?, ?, ?, ?, ?, ?) + ''', ( + user.username, user.password, user.email, user.phone_number, + user.credit_card_info, user.registration_date + )) + else: + cursor.execute(''' + INSERT INTO Users + (username, password, email, phone_number, credit_card_info, registration_date, last_login) + VALUES (?, ?, ?, ?, ?, ?, ?) + ''', ( + user.username, user.password, user.email, user.phone_number, + user.credit_card_info, user.registration_date, user.last_login + )) + conn.commit() + return user + + + +@app.get("/users/{user_id}", response_model=User) +async def read_user(user_id: int): + cursor.execute('SELECT * FROM Users WHERE user_id = ?', (user_id,)) + user = cursor.fetchone() + if user is None: + raise HTTPException(status_code=404, detail="User not found") + return user + + +@app.put("/users/{user_id}", response_model=User) +async def update_user(user_id: int, user: User): + cursor.execute(''' + UPDATE Users + SET username = ?, password = ?, email = ?, phone_number = ?, + credit_card_info = ?, registration_date = ?, last_login = ? + WHERE user_id = ? + ''', ( + user.username, user.password, user.email, user.phone_number, + user.credit_card_info, user.registration_date, user.last_login, user_id + )) + conn.commit() + return user + + +@app.delete("/users/{user_id}") +async def delete_user(user_id: int): + cursor.execute('DELETE FROM Users WHERE user_id = ?', (user_id,)) + conn.commit() + return {"message": "User deleted successfully"} + +#####################################################Auth######################################### +#pip install python-jose +#pip install passlib[bcrypt] + +# Configurations +SECRET_KEY = "your_secret_key" +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 + +# Initialize the FastAPI application +app = FastAPI() + +# Allow requests from all origins +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE"], + allow_headers=["*"], +) + +# SQLite connection setup +db_file_path = Path(__file__).resolve().parent.parent / "my_ride.db" + +def get_db_connection(): + return sqlite3.connect(db_file_path) + +# Password hashing context +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +# OAuth2PasswordBearer for OAuth2 flow +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + +# Helper Functions +def verify_password(plain_password: str, hashed_password: str) -> bool: + return pwd_context.verify(plain_password, hashed_password) + +def get_password_hash(password: str) -> str: + return pwd_context.hash(password) + +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + to_encode.update({"exp": expire}) + return jwt.encode(to_encode, SECRET_KEY, ALGORITHM) + +def get_user_by_username(username: str): + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute("SELECT * FROM Users WHERE username = ?", (username,)) + user = cursor.fetchone() + if not user: + return None + return { + "user_id": user[0], + "username": user[1], + "password": user[2], + "email": user[3], + "phone_number": user[4], + "credit_card_info": user[5], + "registration_date": user[6], + "last_login": user[7] + } + +def authenticate_user(username: str, password: str): + user = get_user_by_username(username) + if not user or not verify_password(password, user["password"]): + return None + return user + +def get_current_user(token: str = Depends(oauth2_scheme)): + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username = payload.get("sub") + if not username: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") + user = get_user_by_username(username) + if not user: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found") + except JWTError: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") + return user + +# Models +class User(BaseModel): + username: str + password: str + email: str + phone_number: str + credit_card_info: Optional[str] = None + registration_date: str + last_login: Optional[str] = None + +class UserResponse(BaseModel): + user_id: int + username: str + email: str + phone_number: str + registration_date: str + +# Token endpoint +@app.post("/token") +async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()): + user = authenticate_user(form_data.username, form_data.password) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + access_token = create_access_token(data={"sub": user["username"]}, expires_delta=access_token_expires) + return {"access_token": access_token, "token_type": "bearer"} + +# Routes +@app.post("/users/", response_model=User) +async def create_user(user: User): + conn = get_db_connection() + cursor = conn.cursor() + + hashed_password = get_password_hash(user.password) + cursor.execute(''' + INSERT INTO Users + (username, password, email, phone_number, credit_card_info, registration_date, last_login) + VALUES (?, ?, ?, ?, ?, ?, ?) + ''', ( + user.username, hashed_password, user.email, user.phone_number, + user.credit_card_info, user.registration_date, user.last_login + )) + conn.commit() + return user + +@app.get("/users/{user_id}", response_model=User) +async def read_user(user_id: int): + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute('SELECT * FROM Users WHERE user_id = ?', (user_id,)) + user = cursor.fetchone() + if user is None: + raise HTTPException(status_code=404, detail="User not found") + return { + "user_id": user[0], + "username": user[1], + "password": user[2], + "email": user[3], + "phone_number": user[4], + "credit_card_info": user[5], + "registration_date": user[6], + "last_login": user[7] + } + +@app.put("/users/{user_id}", response_model=User) +async def update_user(user_id: int, user: User): + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute(''' + UPDATE Users + SET username = ?, password = ?, email = ?, phone_number = ?, + credit_card_info = ?, registration_date = ?, last_login = ? + WHERE user_id = ? + ''', ( + user.username, get_password_hash(user.password), user.email, user.phone_number, + user.credit_card_info, user.registration_date, user.last_login, user_id + )) + conn.commit() + return user + +@app.delete("/users/{user_id}") +async def delete_user(user_id: int): + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute('DELETE FROM Users WHERE user_id = ?', (user_id,)) + conn.commit() + return {"message": "User deleted successfully"} + +@app.get("/users/", response_model=List[UserResponse]) +async def read_users(): + conn = get_db_connection() + cursor = conn.cursor() + cursor.execute('SELECT user_id, username, email, phone_number, registration_date FROM Users') + users = cursor.fetchall() + + user_objects = [] + for user in users: + user_obj = UserResponse( + user_id=user[0], + username=user[1], + email=user[2], + phone_number=user[3], + registration_date=user[4] + ) + user_objects.append(user_obj) + return user_objects \ No newline at end of file -- GitLab