Skip to content
Snippets Groups Projects
AuthenticationRequest.scala 2.06 KiB
Newer Older
package models.actions

import com.typesafe.config.ConfigFactory

import play.api.mvc.{ActionTransformer, Request, WrappedRequest}

import scala.concurrent.{Future, ExecutionContext}
import scala.util.Try
import org.bson.types.ObjectId

import pdi.jwt.{JwtJson, JwtAlgorithm, JwtClaim}
import play.api.libs.json.Json

import javax.inject.Inject


class AuthenticationRequest[A](val userId: Option[ObjectId], request: Request[A]) extends WrappedRequest[A](request)


/**
 * The authentication action transformer that transforms the incoming base request to a user request.
 */
class AuthenticationTransformer @Inject() (implicit val executionContext: ExecutionContext) extends ActionTransformer[Request, AuthenticationRequest] {

    /**
     * Transforms the existing request from a Request to UserRequest.
     *
     * @param request The incoming request.
     * @return The new parameter to pass to the Action block.
     */
    override def transform[A](request: Request[A]) = Future.successful {
        println(request)
        val userId: Option[ObjectId] = processJWT(request)
        new AuthenticationRequest(userId, request)
    }

    /**
     * Processes the JWT token by decoding and validating it.
     *
     * @param request The incoming request.
     * @return The user ID specified in the JWT's payload.
     */
    def processJWT[A](request: Request[A]): Option[ObjectId] = {
        val privateKey = ConfigFactory.load().getString("jwt.privateKey")

        try {
            val authHeader = request.headers.get("Authorization").get
            val token = authHeader.substring(7)
            println(s"JWT Token Received: $token")

            val payload: Try[JwtClaim] = JwtJson.decode(token, privateKey, Seq(JwtAlgorithm.HS256))

            val content = payload.get.content
            val jsonContent = Json.parse(content)
            val userId = (jsonContent \ "userId").as[String]
            
            Some(new ObjectId(userId))
        }
        catch {
            case ex: Throwable => {
                println(ex)
                None
            }
        }
    }
}