Skip to content

Commit

Permalink
add SecurityContextParser.ensureParse api
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahoo-Wang committed Dec 31, 2022
1 parent 435a86f commit 31b8893
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,26 @@
package me.ahoo.cosec.context

import me.ahoo.cosec.api.context.SecurityContext
import org.slf4j.LoggerFactory

/**
* Security Context Parser .
*
* @author ahoo wang
*/
private val LOG = LoggerFactory.getLogger(SecurityContextParser::class.java)

fun interface SecurityContextParser<R> {
fun parse(request: R): SecurityContext

fun ensureParse(request: R): SecurityContext {
return try {
parse(request)
} catch (ignored: Throwable) {
if (LOG.isDebugEnabled) {
LOG.debug(ignored.message, ignored)
}
SimpleSecurityContext.ANONYMOUS
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,16 @@ class ReactiveInjectSecurityContextWebFilter(
private val securityContextParser: SecurityContextParser<ServerWebExchange>
) :
WebFilter, Ordered {

override fun filter(exchange: ServerWebExchange, chain: WebFilterChain): Mono<Void> {
try {
val securityContext = securityContextParser.parse(exchange)
exchange.mutate()
.principal(securityContext.principal.toMono())
.build().let {
exchange.setSecurityContext(securityContext)
return chain.filter(it)
.writeSecurityContext(securityContext)
}
} catch (ignored: Throwable) {
// ignored
}
return chain.filter(exchange)
val securityContext = securityContextParser.ensureParse(exchange)
exchange.mutate()
.principal(securityContext.principal.toMono())
.build().let {
exchange.setSecurityContext(securityContext)
return chain.filter(it)
.writeSecurityContext(securityContext)
}
}

override fun getOrder(): Int {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,10 @@ package me.ahoo.cosec.webflux
import me.ahoo.cosec.api.authorization.Authorization
import me.ahoo.cosec.api.authorization.AuthorizeResult
import me.ahoo.cosec.context.SecurityContextParser
import me.ahoo.cosec.context.SimpleSecurityContext
import me.ahoo.cosec.context.request.RequestParser
import me.ahoo.cosec.policy.serialization.CoSecJsonSerializer
import me.ahoo.cosec.token.TokenExpiredException
import me.ahoo.cosec.webflux.ReactiveSecurityContexts.writeSecurityContext
import me.ahoo.cosec.webflux.ServerWebExchanges.setSecurityContext
import org.slf4j.LoggerFactory
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType
import org.springframework.http.server.reactive.ServerHttpResponse
Expand All @@ -35,19 +32,9 @@ abstract class ReactiveSecurityFilter(
val requestParser: RequestParser<ServerWebExchange>,
val authorization: Authorization
) {
companion object {
private val log = LoggerFactory.getLogger(ReactiveSecurityFilter::class.java)
}

fun filterInternal(exchange: ServerWebExchange, chain: (ServerWebExchange) -> Mono<Void>): Mono<Void> {
val securityContext = try {
securityContextParser.parse(exchange)
} catch (tokenExpiredException: TokenExpiredException) {
if (log.isDebugEnabled) {
log.debug("Token Expired!", tokenExpiredException)
}
SimpleSecurityContext.ANONYMOUS
}
val securityContext = securityContextParser.ensureParse(exchange)
val request = requestParser.parse(exchange)
return authorization.authorize(request, securityContext)
.flatMap { authorizeResult ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@ package me.ahoo.cosec.servlet
import me.ahoo.cosec.api.authorization.Authorization
import me.ahoo.cosec.context.SecurityContextHolder
import me.ahoo.cosec.context.SecurityContextParser
import me.ahoo.cosec.context.SimpleSecurityContext
import me.ahoo.cosec.context.request.RequestParser
import me.ahoo.cosec.policy.serialization.CoSecJsonSerializer
import me.ahoo.cosec.servlet.ServletRequests.setSecurityContext
import me.ahoo.cosec.token.TokenExpiredException
import org.slf4j.LoggerFactory
import org.springframework.http.HttpStatus
import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse
Expand All @@ -35,22 +32,12 @@ abstract class AbstractAuthorizationInterceptor(
private val securityContextParser: SecurityContextParser<HttpServletRequest>,
private val authorization: Authorization
) {
companion object {
private val log = LoggerFactory.getLogger(AbstractAuthorizationInterceptor::class.java)
}

protected fun authorize(
request: HttpServletRequest,
response: HttpServletResponse
): Boolean {
val securityContext = try {
securityContextParser.parse(request)
} catch (tokenExpiredException: TokenExpiredException) {
if (log.isDebugEnabled) {
log.debug("Token Expired!", tokenExpiredException)
}
SimpleSecurityContext.ANONYMOUS
}
val securityContext = securityContextParser.ensureParse(request)

SecurityContextHolder.setContext(securityContext)
request.setSecurityContext(securityContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import me.ahoo.cosec.api.context.SecurityContext
import me.ahoo.cosec.context.SecurityContextHolder
import me.ahoo.cosec.context.SecurityContextParser
import me.ahoo.cosec.servlet.ServletRequests.setSecurityContext
import org.slf4j.LoggerFactory
import java.io.IOException
import javax.servlet.Filter
import javax.servlet.FilterChain
Expand All @@ -33,9 +32,6 @@ import javax.servlet.http.HttpServletRequest
*/
class InjectSecurityContextFilter(private val securityContextParser: SecurityContextParser<HttpServletRequest>) :
Filter {
companion object {
private val log = LoggerFactory.getLogger(InjectSecurityContextFilter::class.java)
}

@Throws(IOException::class, ServletException::class)
override fun doFilter(servletRequest: ServletRequest, servletResponse: ServletResponse, filterChain: FilterChain) {
Expand All @@ -44,15 +40,9 @@ class InjectSecurityContextFilter(private val securityContextParser: SecurityCon
}

private fun tryInjectSecurityContext(servletRequest: ServletRequest) {
try {
val httpServletRequest = servletRequest as HttpServletRequest
val securityContext: SecurityContext = securityContextParser.parse(httpServletRequest)
SecurityContextHolder.setContext(securityContext)
httpServletRequest.setSecurityContext(securityContext)
} catch (throwable: Throwable) {
if (log.isInfoEnabled) {
log.info(throwable.message, throwable)
}
}
val httpServletRequest = servletRequest as HttpServletRequest
val securityContext: SecurityContext = securityContextParser.ensureParse(httpServletRequest)
SecurityContextHolder.setContext(securityContext)
httpServletRequest.setSecurityContext(securityContext)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import io.mockk.mockk
import me.ahoo.cosec.context.SecurityContextHolder
import me.ahoo.cosec.context.SimpleSecurityContext
import me.ahoo.cosec.jwt.Jwts
import me.ahoo.cosec.servlet.ServletRequests.setSecurityContext
import org.hamcrest.MatcherAssert.assertThat
import org.hamcrest.Matchers.equalTo
import org.junit.jupiter.api.Test
Expand All @@ -31,6 +32,7 @@ internal class InjectSecurityContextFilterTest {
val filter = InjectSecurityContextFilter(InjectSecurityContextParser)
val request = mockk<HttpServletRequest>() {
every { getHeader(Jwts.AUTHORIZATION_KEY) } returns null
every { setSecurityContext(any()) } returns Unit
}
val filterChain = mockk<FilterChain>() {
every { doFilter(request, any()) } returns Unit
Expand All @@ -42,7 +44,10 @@ internal class InjectSecurityContextFilterTest {
@Test
fun doFilterThrow() {
val filter = InjectSecurityContextFilter(InjectSecurityContextParser)
val request = mockk<HttpServletRequest>()
val request = mockk<HttpServletRequest>() {
every { getHeader(Jwts.AUTHORIZATION_KEY) } returns null
every { setSecurityContext(any()) } returns Unit
}
val filterChain = mockk<FilterChain>() {
every { doFilter(request, any()) } returns Unit
}
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# limitations under the License.
#
group=me.ahoo.cosec
version=1.6.0
version=1.6.2
description=RBAC-based And Policy-based Multi-Tenant Reactive Security Framework
website=https://github.com/Ahoo-Wang/CoSec
issues=https://github.com/Ahoo-Wang/CoSec/issues
Expand Down

0 comments on commit 31b8893

Please sign in to comment.