111 lines
4.7 KiB
Swift
111 lines
4.7 KiB
Swift
import Vapor
|
|
import SQLKit
|
|
import Crypto
|
|
|
|
public protocol ManagedUser: Content, SessionAuthenticatable {
|
|
var id: UUID { get }
|
|
var email: String { get }
|
|
var fullName: String { get }
|
|
var password: String? { get }
|
|
var roles: [String] { get }
|
|
var isActive: Bool { get }
|
|
|
|
// MARK: - Authentication
|
|
static func find (_ id: UUID, on connection: any SQLDatabase) async throws -> Self?
|
|
static func find (email: String, password: String, on connection: any SQLDatabase) async throws -> Self?
|
|
static func find (email: String, on connection: any SQLDatabase) async throws -> Self?
|
|
static func update (userId: UUID, password: String, on connection: any SQLDatabase) async throws
|
|
|
|
// MARK: - Administration
|
|
static func all (on connection: any SQLDatabase) async throws -> [Self]
|
|
static func create (email: String, fullname: String, roles: [String], token: String, on connection: any SQLDatabase) async throws
|
|
static func save (id: UUID, email: String, fullname: String, roles: [String], isActive: Bool, on connection: any SQLDatabase) async throws
|
|
static func store (token: String, userId: UUID, on connection: any SQLDatabase) async throws
|
|
}
|
|
|
|
extension ManagedUser where SessionID == ExpiringUserId {
|
|
|
|
public var sessionID: ExpiringUserId {
|
|
return ExpiringUserId (user: self)
|
|
}
|
|
|
|
internal static func encrypt (password: String) throws -> String {
|
|
return try Bcrypt.hash (password)
|
|
}
|
|
|
|
public static func authenticate (sessionID: SessionID, on connection: any SQLDatabase) async throws -> Self? {
|
|
return try await find (sessionID.id, on: connection)
|
|
}
|
|
|
|
private func verify (password: String) throws -> Bool {
|
|
if let userPassword = self.password {
|
|
return try Bcrypt.verify (password, created: userPassword)
|
|
} else {
|
|
return false
|
|
}
|
|
}
|
|
|
|
public static func sessionAuthenticator() -> UserSessionAuthenticator<Self> {
|
|
return UserSessionAuthenticator()
|
|
}
|
|
|
|
static func basicAuthenticator() -> UserAuthenticator<Self> {
|
|
return UserAuthenticator()
|
|
}
|
|
}
|
|
|
|
// MARK: Defaults for Authentication
|
|
extension ManagedUser where SessionID == ExpiringUserId {
|
|
public static func find (_ id: UUID, on connection: any SQLDatabase) async throws -> Self? {
|
|
return try await connection.raw("""
|
|
SELECT "id",
|
|
"email",
|
|
"full_name",
|
|
"active",
|
|
"password",
|
|
ARRAY (SELECT "role_name"
|
|
FROM "user_roles"
|
|
WHERE "user_roles"."user_id" = "users"."id") AS "roles"
|
|
FROM "users"
|
|
WHERE "id" = \(bind: id)
|
|
AND "active"
|
|
""")
|
|
.first (decoding: Self.self)
|
|
}
|
|
|
|
public static func find (email: String, password: String, on connection: any SQLDatabase) async throws -> Self? {
|
|
if let user = try await find (email: email, on: connection),
|
|
try user.verify (password: password) {
|
|
return user
|
|
} else {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
public static func find (email: String, on connection: any SQLDatabase) async throws -> Self? {
|
|
return try await connection.raw ("""
|
|
SELECT "id",
|
|
"email",
|
|
"full_name",
|
|
"active",
|
|
"password",
|
|
ARRAY (SELECT "role_name"
|
|
FROM "user_roles"
|
|
WHERE "user_roles"."user_id" = "users"."id") AS "roles"
|
|
FROM "users"
|
|
WHERE "email" = \(bind: email)
|
|
AND "active"
|
|
""")
|
|
.first (decoding: Self.self)
|
|
}
|
|
|
|
public static func update (userId: UUID, password: String, on connection: any SQLDatabase) async throws {
|
|
try await connection.raw ("""
|
|
UPDATE "users"
|
|
SET "password" = \(bind: Self.encrypt (password: password))
|
|
WHERE "id" = \(bind: userId)
|
|
""")
|
|
.run()
|
|
}
|
|
}
|