diff --git a/booking_service/main.py b/booking_service/main.py index 75700e5f06d6d39483caaf6cf0b1360198cf615e..1bfe1cc51dbbc797f27b27e93750720e14892c71 100644 --- a/booking_service/main.py +++ b/booking_service/main.py @@ -10,6 +10,7 @@ from pathlib import Path from typing import List import mysql.connector from dotenv import load_dotenv +from decimal import Decimal import os @@ -107,7 +108,7 @@ async def get_booking(booking_id: int): @app.post("/bookings/") async def book_bike(booking: Booking): - # Check for existing active booking for the user_id or bike_id + # Check for existing active booking for the same user or bike cursor.execute(""" SELECT 1 FROM Bookings WHERE (user_id = %s OR bike_id = %s) AND payment_state = 'Booked' @@ -117,58 +118,48 @@ async def book_bike(booking: Booking): if existing_booking: raise HTTPException(status_code=400, detail="Already Booked") - # Fetch user details + # Fetch user details to check wallet balance cursor.execute("SELECT wallet_balance FROM Users WHERE user_id = %s", (booking.user_id,)) user_details = cursor.fetchone() if not user_details: - raise HTTPException(status_code=404, detail="User not found.") - wallet_balance = user_details[0] + raise HTTPException(status_code=404, detail="User not found") + wallet_balance = Decimal(user_details[0]) - # Fetch bike details and validate the location + # Fetch bike details to check availability and get price per hour cursor.execute("SELECT current_location, price_per_hour FROM Bikes WHERE bike_id = %s", (booking.bike_id,)) bike_details = cursor.fetchone() - if not bike_details: - raise HTTPException(status_code=404, detail="Bike not found.") - db_location, price_per_hour = bike_details + if not bike_details or booking.location != bike_details[0]: + raise HTTPException(status_code=404, detail="Bike not available at the specified location") + price_per_hour = Decimal(bike_details[1]) - if booking.current_location != db_location: - raise HTTPException(status_code=400, detail="Bike not available at the specified location.") + # Calculate blocked amount based on the booking duration + blocked_amount = price_per_hour * Decimal(booking.duration_hours) - # Calculate the amount to be blocked on the user's wallet - blocked_amount = price_per_hour * 5 - - # Validate wallet balance + # Ensure user's wallet balance can cover the blocked amount if wallet_balance < blocked_amount: - raise HTTPException(status_code=400, detail="Insufficient wallet balance.") + raise HTTPException(status_code=400, detail="Insufficient wallet balance") - # Calculate new wallet balance + # Deduct the blocked amount from user's wallet balance new_wallet_balance = wallet_balance - blocked_amount - # Generate lock code and current time - lock_code = generate_lock_code() + # Insert the booking record + lock_code = ''.join(random.choices('0123456789', k=6)) # Generate a 6-digit lock code start_time = datetime.now() - - # Update bike status and user's wallet balance - cursor.execute("UPDATE Bikes SET status = 'booked' WHERE bike_id = %s", (booking.bike_id,)) - cursor.execute("UPDATE Users SET wallet_balance = %s WHERE user_id = %s", (new_wallet_balance, booking.user_id)) - - # Insert a new booking record with lock code cursor.execute(""" INSERT INTO Bookings (user_id, bike_id, location, blocked_amount, payment_state, start_time, lock_code) VALUES (%s, %s, %s, %s, 'Booked', %s, %s) """, (booking.user_id, booking.bike_id, booking.location, blocked_amount, start_time, lock_code)) + # Update the user's wallet balance + cursor.execute("UPDATE Users SET wallet_balance = %s WHERE user_id = %s", (new_wallet_balance, booking.user_id)) + db_connection.commit() return { - "user_id": booking.user_id, - "bike_id": booking.bike_id, - "location": booking.location, - "blocked_amount": blocked_amount, - "new_wallet_balance": new_wallet_balance, + "message": "Bike booked successfully", + "new_wallet_balance": float(new_wallet_balance), # Return as float for JSON serialization "lock_code": lock_code, - "start_time": start_time.strftime('%Y-%m-%d %H:%M:%S'), - "message": "Bike booked successfully" + "start_time": start_time.isoformat() }