Optionally set a user's full name in a PostgreSQL session variable for connections.
parent
401d5ef679
commit
cfb6ac8e5f
|
|
@ -3,9 +3,9 @@ import Vapor
|
|||
struct UserAuthenticator<User: ManagedUser>: AsyncBasicAuthenticator where User.SessionID == ExpiringUserId {
|
||||
|
||||
func authenticate (basic: BasicAuthorization, for request: Request) async throws {
|
||||
if let user = try await request.db.withSQLConnection { connection in
|
||||
if let user = try await request.db.withSQLConnection ({ connection in
|
||||
return try await User.find (email: basic.username, password: basic.password, on: connection)
|
||||
} {
|
||||
}) {
|
||||
request.auth.login (user)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@ public struct UserSessionAuthenticator<SessionUser: ManagedUser>: AsyncSessionAu
|
|||
public typealias User = SessionUser
|
||||
|
||||
public func authenticate (sessionID: SessionUser.SessionID, for request: Request) async throws {
|
||||
if let user = try await request.db.withSQLConnection { connection in
|
||||
if let user = try await request.db.withSQLConnection ({ connection in
|
||||
return try await SessionUser.authenticate (sessionID: sessionID, on: connection)
|
||||
} {
|
||||
}) {
|
||||
request.logger.info ("Seeing user \(user)")
|
||||
request.auth.login (user)
|
||||
request.logger.info ("Saw user \(user)")
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ public struct AdminController<User: ManagedUser>: Sendable where User.SessionID
|
|||
}
|
||||
let save = try request.content.decode(Save.self)
|
||||
|
||||
return try await request.db.withSQLConnection { connection in
|
||||
return try await request.db.withSQLConnection (user: try request.auth.require (BasicUser.self)) { connection in
|
||||
guard (try await User.find (userId, on: connection)) != nil else {
|
||||
throw Abort (.notFound)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,13 +2,24 @@ import Fluent
|
|||
import SQLKit
|
||||
|
||||
extension Database {
|
||||
public func withSQLConnection<T: Sendable>(_ closure: @escaping @Sendable (SQLDatabase) async throws -> T) async throws -> T {
|
||||
public func withSQLConnection<T: Sendable>(user: BasicUser? = nil, _ closure: @escaping @Sendable (any SQLDatabase) async throws -> T) async throws -> T {
|
||||
return try await self.withConnection { database in
|
||||
guard let connection = database as? SQLDatabase else {
|
||||
guard let connection = database as? (any SQLDatabase) else {
|
||||
throw NoSQLDatabaseError (database: database)
|
||||
}
|
||||
try await connection.raw ("""
|
||||
SELECT pg_catalog.set_config ('manageable_users.active_user', \(bind: user?.fullName), FALSE)
|
||||
""")
|
||||
.run()
|
||||
|
||||
return try await closure (connection)
|
||||
let result = try await closure (connection)
|
||||
|
||||
try await connection.raw ("""
|
||||
SELECT pg_catalog.set_config ('manageable_users.active_user', NULL, FALSE)
|
||||
""")
|
||||
.run()
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue