diff --git a/parser/pom.xml b/parser/pom.xml index e32e4418..5b935e85 100644 --- a/parser/pom.xml +++ b/parser/pom.xml @@ -26,27 +26,13 @@ 2.12.16 2.12 0.1.5 - 0.23.13 + 0.23.14 4.0.6 1.12.772 4.16.1 - - - - fi.vm.sade - scala-cas_2.12 - 3.0.2-SNAPSHOT - - - org.http4s - * - - - - org.http4s @@ -88,6 +74,12 @@ ${scala.version} + + org.scala-lang.modules + scala-xml_2.12 + 2.0.1 + + io.symphonia diff --git a/parser/src/main/scala/fi/vm/sade/utils/cas/CasAuthenticatingClient.scala b/parser/src/main/scala/fi/vm/sade/utils/cas/CasAuthenticatingClient.scala new file mode 100644 index 00000000..31bc6712 --- /dev/null +++ b/parser/src/main/scala/fi/vm/sade/utils/cas/CasAuthenticatingClient.scala @@ -0,0 +1,82 @@ +package fi.vm.sade.utils.cas + +import cats.effect.IO +import cats.effect.kernel.Resource +import cats.effect.std.Hotswap +import fi.vm.sade.utils.cas.CasClient.SessionCookie +import org.http4s.client.Client +import org.http4s.{Request, Response, Status} +import org.typelevel.ci.CIString + + +/** + * Middleware that handles CAS authentication automatically. Sessions are maintained by keeping + * a central cache of session cookies per service url. If a session cookie is not found for requested service, it is obtained using + * CasClient. Stale sessions are detected and refreshed automatically. + */ +object CasAuthenticatingClient extends Logging { + val DefaultSessionCookieName = "JSESSIONID" + + private val sessions: collection.mutable.Map[CasParams, SessionCookie] = collection.mutable.Map.empty + + def apply( + casClient: CasClient, + casParams: CasParams, + serviceClient: Client[IO], + clientCallerId: String, + sessionCookieName: String = DefaultSessionCookieName + ): Client[IO] = { + def openWithCasSession(request: Request[IO], hotswap: Hotswap[IO, Response[IO]]): IO[Response[IO]] = { + getCasSession(casParams).flatMap(requestWithCasSession(request, hotswap, retry = true)) + } + + def requestWithCasSession + (request: Request[IO], hotswap: Hotswap[IO, Response[IO]], retry: Boolean) + (sessionCookie: SessionCookie) + : IO[Response[IO]] = { + val fullRequest = FetchHelper.addDefaultHeaders( + request.addCookie(sessionCookieName, sessionCookie), + clientCallerId + ) + // Hotswap use inspired by http4s Retry middleware: + hotswap.swap(serviceClient.run(fullRequest)).flatMap { + case r: Response[IO] if sessionExpired(r) && retry => + logger.info("Session for " + casParams + " expired") + refreshSession(casParams).flatMap(requestWithCasSession(request, hotswap, retry = false)) + case r: Response[IO] => IO.pure(r) + } + } + + def isRedirectToLogin(resp: Response[IO]): Boolean = + resp.headers.get(CIString("Location")).exists(_.exists(header => + header.value.contains("/cas/login") || header.value.contains("/cas-oppija/login") + )) + + def sessionExpired(resp: Response[IO]): Boolean = + isRedirectToLogin(resp) || resp.status == Status.Unauthorized + + def getCasSession(params: CasParams): IO[SessionCookie] = { + synchronized(sessions.get(params)) match { + case None => + logger.debug(s"No existing $sessionCookieName found for " + params + ", creating new") + refreshSession(params) + case Some(session) => + IO.pure(session) + } + } + + def refreshSession(params: CasParams): IO[SessionCookie] = { + casClient.fetchCasSession(params, sessionCookieName).map { session => + logger.debug("Storing new session for " + params) + synchronized(sessions.put(params, session)) + session + } + } + + Client { req => + Hotswap.create[IO, Response[IO]].flatMap { hotswap => + Resource.eval(openWithCasSession(req, hotswap)) + } + } + } +} diff --git a/parser/src/main/scala/fi/vm/sade/utils/cas/CasClient.scala b/parser/src/main/scala/fi/vm/sade/utils/cas/CasClient.scala new file mode 100644 index 00000000..521a343c --- /dev/null +++ b/parser/src/main/scala/fi/vm/sade/utils/cas/CasClient.scala @@ -0,0 +1,294 @@ +package fi.vm.sade.utils.cas + +import cats.data.EitherT +import cats.effect.IO +import org.http4s.EntityDecoder.collectBinary +import org.http4s.Status.{Created, Locked} +import org.http4s.client._ +import org.http4s._ +import org.typelevel.ci.CIString + +import scala.util.{Failure, Success, Try} +import scala.xml._ + + +object CasClient { + type SessionCookie = String + type Username = String + type OppijaAttributes = Map[String, String] + type TGTUrl = Uri + type ServiceTicket = String + + val textOrXmlDecoder: EntityDecoder[IO, String] = + EntityDecoder.decodeBy(MediaRange.`text/*`, MediaType.application.xml)(msg => + collectBinary(msg).map(bs => new String( + bs.toArray, + msg.charset.getOrElse(DefaultCharset).nioCharset + )) + ) +} + +/** + * Facade for establishing sessions with services protected by CAS, and also validating CAS service tickets. + */ +class CasClient(casBaseUrl: Uri, client: Client[IO], callerId: String) extends Logging { + import CasClient._ + + def this(casServer: String, client: Client[IO], callerId: String) = this(Uri.fromString(casServer).right.get, client, callerId) + + def validateServiceTicketWithOppijaAttributes(service: String)(serviceTicket: ServiceTicket): IO[OppijaAttributes] = { + validateServiceTicket[OppijaAttributes](casBaseUrl, client, service, decodeOppijaAttributes)(serviceTicket) + } + + def validateServiceTicketWithVirkailijaUsername(service: String)(serviceTicket: ServiceTicket): IO[Username] = { + validateServiceTicket[Username](casBaseUrl, client, service, decodeVirkailijaUsername)(serviceTicket) + } + + def validateServiceTicket[R](service: String)(serviceTicket: ServiceTicket, responseHandler: Response[IO] => IO[R]): IO[R] = { + validateServiceTicket[R](casBaseUrl, client, service, responseHandler)(serviceTicket) + } + + private def validateServiceTicket[R](casBaseUrl: Uri, client: Client[IO], service: String, responseHandler: Response[IO] => IO[R])(serviceTicket: ServiceTicket): IO[R] = { + val pUri: Uri = casBaseUrl.addPath("serviceValidate") + .withQueryParam("ticket", serviceTicket) + .withQueryParam("service", service) + + val request = Request[IO](Method.GET, pUri) + FetchHelper.fetch[R](client, callerId, request, responseHandler) + } + + def authenticateVirkailija(user: CasUser): IO[Boolean] = { + TicketGrantingTicketClient.getTicketGrantingTicket(casBaseUrl, client, user, callerId) + .map(_tgtUrl => true) // Authentication succeeded if we received a tgtUrl + } + + /** + * Establishes session with the requested service by + * + * 1) getting a CAS ticket granting ticket (TGT) + * 2) getting a CAS service ticket + * 3) getting a session cookie from the service. + * + * Returns the session that can be used for communications later. + */ + def fetchCasSession(params: CasParams, sessionCookieName: String): IO[SessionCookie] = { + val serviceUri = Uri.resolve(casBaseUrl, params.service.securityUri) + + for ( + st <- getServiceTicketWithRetryOnce(params, serviceUri); + session <- SessionCookieClient.getSessionCookieValue(client, serviceUri, sessionCookieName, callerId)(st) + ) yield { + session + } + } + + private def getServiceTicketWithRetryOnce(params: CasParams, serviceUri: TGTUrl): IO[ServiceTicket] = { + getServiceTicket(params, serviceUri).attempt.flatMap { + case Right(success) => + IO(success) + case Left(throwable) => + logger.warn("Fetching TGT or ST failed. Retrying once (and only once) in case the error was ephemeral.", throwable) + retryServiceTicket(params, serviceUri) + } + } + + private def retryServiceTicket(params: CasParams, serviceUri: TGTUrl): IO[ServiceTicket] = { + getServiceTicket(params, serviceUri).attempt.map { + case Right(retrySuccess) => + logger.info("Fetching TGT and ST was successful after one retry.") + retrySuccess + case Left(retryThrowable) => + logger.error("Fetching TGT or ST failed also after one retry.", retryThrowable) + throw retryThrowable + } + } + + private def getServiceTicket(params: CasParams, serviceUri: TGTUrl): IO[ServiceTicket] = { + for { + tgt <- TicketGrantingTicketClient.getTicketGrantingTicket(casBaseUrl, client, params.user, callerId) + st <- ServiceTicketClient.getServiceTicketFromTgt(client, serviceUri, callerId)(tgt) + } yield { + st + } + } + + private val oppijaServiceTicketDecoder: EntityDecoder[IO, OppijaAttributes] = + textOrXmlDecoder + .map(s => Utility.trim(scala.xml.XML.loadString(s))) + .flatMapR[OppijaAttributes] { serviceResponse => + Try { + val attributes: NodeSeq = (serviceResponse \ "authenticationSuccess" \ "attributes") + + List("mail", "clientName", "displayName", "givenName", "personOid", "personName", "firstName", "nationalIdentificationNumber", + "impersonatorNationalIdentificationNumber", "impersonatorDisplayName") + .map(key => (key, (attributes \ key).text)) + .toMap + } match { + case Success(decoded) => DecodeResult.successT(decoded) + case Failure(ex) => + DecodeResult.failureT(InvalidMessageBodyFailure( + "Oppija Service Ticket validation response decoding failed: Failed to parse required values from response body", + Some(ex)) + ) + } + } + + private val virkailijaServiceTicketDecoder: EntityDecoder[IO, Username] = + textOrXmlDecoder + .map(s => Utility.trim(scala.xml.XML.loadString(s))) + .flatMapR[Username] { serviceResponse => { + val user = (serviceResponse \ "authenticationSuccess" \ "user") + user.length match { + case 1 => DecodeResult.successT(user.text) + case _ => + DecodeResult.failureT(InvalidMessageBodyFailure( + s"Virkailija Service Ticket validation response decoding failed: response body is of wrong form ($serviceResponse)" + )) + } + } + } + + private def casFailure[R](debugLabel: String, resp: Response[IO]): EitherT[IO, DecodeFailure, R] = { + textOrXmlDecoder + .decode(resp, strict = false) + .flatMap(body => DecodeResult.failureT[IO, R](InvalidMessageBodyFailure( + s"Decoding $debugLabel failed: CAS returned non-ok status code ${resp.status.code}: $body" + ))) + .leftFlatMap(failure => DecodeResult.failureT[IO, R](InvalidMessageBodyFailure( + s"Decoding $debugLabel failed: CAS returned non-ok status code ${resp.status.code}: ${failure.message}" + ))) + } + + /** + * Decode CAS Oppija's service ticket validation response to various oppija attributes. + */ + def decodeOppijaAttributes: Response[IO] => IO[OppijaAttributes] = { response => + decodeCASResponse[OppijaAttributes](response, "oppija attributes", oppijaServiceTicketDecoder) + } + + /** + * Decode CAS Virkailija's service ticket validation response to username. + */ + def decodeVirkailijaUsername: Response[IO] => IO[Username] = { response => + decodeCASResponse[Username](response, "username", virkailijaServiceTicketDecoder) + } + + private def decodeCASResponse[R](response: Response[IO], debugLabel: String, decoder: EntityDecoder[IO, R]): IO[R] = { + val decodeResult = if (response.status.isSuccess) { + decoder + .decode(response, strict = false) + .leftMap(decodeFailure => new CasClientException(s"Decoding $debugLabel failed: " + decodeFailure.message)) + } else { + casFailure(debugLabel, response) + } + decodeResult.rethrowT + } +} + +private[cas] object ServiceTicketClient { + import CasClient._ + + private val stPattern = "(ST-.*)".r + + def getServiceTicketFromTgt(client: Client[IO], service: Uri, callerId: String)(tgtUrl: TGTUrl): IO[ServiceTicket] = { + val urlForm = UrlForm("service" -> service.toString()) + val request = Request[IO](Method.POST, tgtUrl).withEntity(urlForm) + + def handler(response: Response[IO]): IO[ServiceTicket] = { + response match { + case r: Response[IO] if r.status.isSuccess => r.as[String].map { + case stPattern(st) => st + case nonSt: Any => throw new CasClientException(s"Service Ticket decoding failed at ${tgtUrl}: response body is of wrong form ($nonSt)") + } + case r: Response[IO] => r.as[String].map(body => + throw new CasClientException(s"Service Ticket decoding failed at ${tgtUrl}: unexpected status ${r.status.code}: $body") + ) + } + } + FetchHelper.fetch(client, callerId, request, handler) + } +} + +private[cas] object TicketGrantingTicketClient extends Logging { + import CasClient.TGTUrl + + private val tgtPattern = "(.*TGT-.*)".r + + def getTicketGrantingTicket(casBaseUrl: Uri, client: Client[IO], user: CasUser, callerId: String): IO[TGTUrl] = { + val tgtUri: TGTUrl = casBaseUrl.addPath("v1/tickets") + val urlForm = UrlForm("username" -> user.username, "password" -> user.password) + val request = Request[IO](Method.POST, tgtUri).withEntity(urlForm) + + def handler(response: Response[IO]): IO[TGTUrl] = { + response match { + case Created(resp) => + val found: TGTUrl = resp.headers.get(CIString("Location")).map(_.head.value) match { + case Some(tgtPattern(tgtUrl)) => + Uri.fromString(tgtUrl).fold( + (pf: ParseFailure) => throw new CasClientException(pf.message), + (tgt) => tgt + ) + case Some(nonTgtUrl) => + throw new CasClientException(s"TGT decoding failed at ${tgtUri}: location header has wrong format $nonTgtUrl") + case None => + throw new CasClientException(s"TGT decoding failed at ${tgtUri}: no location header") + } + IO.pure(found) + case Locked(_) => + throw new CasAuthenticationException(s"Access denied: username ${user.username} is locked") + case resp: Response[IO] => resp.as[String].map(body => + if (body.contains("authentication_exceptions") || body.contains("error.authentication.credentials.bad")) { + throw new CasAuthenticationException(s"Access denied: bad credentials") + } else { + throw new CasClientException(s"TGT decoding failed at ${tgtUri}: invalid TGT creation status: ${resp.status.code}: $body") + } + ) + } + } + FetchHelper.fetch(client, callerId, request, handler) + } +} + +private[cas] object SessionCookieClient { + import CasClient._ + + def getSessionCookieValue + (client: Client[IO], service: Uri, sessionCookieName: String, callerId: String) + (serviceTicket: ServiceTicket) + : IO[SessionCookie] = { + val sessionIdUri: Uri = service.withQueryParam("ticket", serviceTicket) + val request = Request[IO](Method.GET, sessionIdUri) + + def handler(response: Response[IO]): IO[SessionCookie] = { + response match { + case resp: Response[IO] if resp.status.isSuccess => + IO.pure( + resp.cookies.find(_.name == sessionCookieName).map(_.content) + .getOrElse(throw new CasClientException(s"Decoding $sessionCookieName failed at ${sessionIdUri}: no cookie found")) + ) + case resp: Response[IO] => + resp.as[String].map(body => + throw new CasClientException(s"Decoding $sessionCookieName failed at ${sessionIdUri}: service returned non-ok status code ${resp.status.code}: $body") + ) + } + } + + FetchHelper.fetch(client, callerId, request, handler) + } +} + +object FetchHelper { + def addDefaultHeaders(request: Request[IO], callerId: String): Request[IO] = { + request.putHeaders( + Header.Raw(CIString("Caller-Id"), callerId), + Header.Raw(CIString("CSRF"), callerId) + ).addCookie("CSRF", callerId) + } + + def fetch[A](client: Client[IO], callerId: String, request: Request[IO], handler: Response[IO] => IO[A]): IO[A] = + client.run(addDefaultHeaders(request, callerId)).use(handler) +} + +class CasClientException(message: String) extends RuntimeException(message) + +class CasAuthenticationException(message: String) extends CasClientException(message) diff --git a/parser/src/main/scala/fi/vm/sade/utils/cas/CasLogout.scala b/parser/src/main/scala/fi/vm/sade/utils/cas/CasLogout.scala new file mode 100644 index 00000000..cdabf224 --- /dev/null +++ b/parser/src/main/scala/fi/vm/sade/utils/cas/CasLogout.scala @@ -0,0 +1,13 @@ +package fi.vm.sade.utils.cas + +import scala.xml.{Utility, XML} + +object CasLogout { + def parseTicketFromLogoutRequest(logoutRequest: String): Option[String] = { + Utility.trim(XML.loadString(logoutRequest)) match { + case {nameID}{ticket} => + Some(ticket.text) + case _ => None + } + } +} diff --git a/parser/src/main/scala/fi/vm/sade/utils/cas/CasParams.scala b/parser/src/main/scala/fi/vm/sade/utils/cas/CasParams.scala new file mode 100644 index 00000000..71e46746 --- /dev/null +++ b/parser/src/main/scala/fi/vm/sade/utils/cas/CasParams.scala @@ -0,0 +1,43 @@ +package fi.vm.sade.utils.cas + +import org.http4s.{ParseFailure, Uri} + + +case class CasUser(username: String, password: String) + +case class CasService(securityUri: Uri) + +case class CasParams(service: CasService, user: CasUser) { + override def toString: String = service.securityUri.toString +} + +object CasParams { + + + def apply(servicePath: String, securityUriSuffix: String, username: String, password: String): CasParams = { + Uri.fromString(ensureTrailingSlash(ensureLeadingSlash(servicePath))).fold( + (e: ParseFailure) => throw new IllegalArgumentException(e), + (service: Uri) => CasParams(CasService(Uri.resolve(service, Uri.fromString(removeTrailingAndLeadingSlash(securityUriSuffix)).getOrElse { + throw new IllegalArgumentException(s"Could not parse securityUriSuffix $securityUriSuffix") + })), CasUser(username, password))) + } + + def apply(servicePath: String, username: String, password: String): CasParams = apply( + servicePath = servicePath, + securityUriSuffix = "j_spring_cas_security_check", + username = username, + password = password + ) + + private def ensureTrailingSlash(servicePath: String): String = servicePath.last match { + case '/' => servicePath + case _ => servicePath + "/" + } + + private def ensureLeadingSlash(servicePath: String): String = servicePath.head match { + case '/' => servicePath + case _ => "/" + servicePath + } + + def removeTrailingAndLeadingSlash(value: String): String = value.stripPrefix("/").stripSuffix("/") +} diff --git a/parser/src/main/scala/fi/vm/sade/utils/cas/Logging.scala b/parser/src/main/scala/fi/vm/sade/utils/cas/Logging.scala new file mode 100644 index 00000000..a18980b0 --- /dev/null +++ b/parser/src/main/scala/fi/vm/sade/utils/cas/Logging.scala @@ -0,0 +1,7 @@ +package fi.vm.sade.utils.cas + +import org.slf4j.{LoggerFactory, Logger} + +trait Logging { + protected lazy val logger: Logger = LoggerFactory.getLogger(getClass()) +} diff --git a/parser/src/test/scala/fi/oph/omaopintopolkuloki/LambdaLogParserHandlerTest.scala b/parser/src/test/scala/fi/oph/omaopintopolkuloki/LambdaLogParserHandlerTest.scala index cb8816ab..4e913bcf 100644 --- a/parser/src/test/scala/fi/oph/omaopintopolkuloki/LambdaLogParserHandlerTest.scala +++ b/parser/src/test/scala/fi/oph/omaopintopolkuloki/LambdaLogParserHandlerTest.scala @@ -5,12 +5,14 @@ import com.amazonaws.services.lambda.runtime.events.SQSEvent import fi.oph.omaopintopolkuloki.db.DB import fi.oph.omaopintopolkuloki.repository.{OrganizationPermission, Permission, RemoteOrganizationRepository, RemoteSQSRepository} import org.scalamock.scalatest.MockFactory -import org.scalatest.{BeforeAndAfter, FunSpec, Matchers, PrivateMethodTester} +import org.scalatest.funspec.AnyFunSpec +import org.scalatest.matchers.should.Matchers +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import scalacache.Flags import scala.io.Source -class LambdaLogParserHandlerTest extends FunSpec with Matchers with MockFactory with PrivateMethodTester with BeforeAndAfter { +class LambdaLogParserHandlerTest extends AnyFunSpec with Matchers with MockFactory with PrivateMethodTester with BeforeAndAfter { private val sendMessage = PrivateMethod[Unit]('sendMessage) private val purgeQueue = PrivateMethod[Unit]('purgeQueue) diff --git a/parser/src/test/scala/vm/sade/utils/cas/CasClientTest.scala b/parser/src/test/scala/vm/sade/utils/cas/CasClientTest.scala new file mode 100644 index 00000000..cfaded58 --- /dev/null +++ b/parser/src/test/scala/vm/sade/utils/cas/CasClientTest.scala @@ -0,0 +1,54 @@ +package vm.sade.utils.cas + +import fi.vm.sade.utils.cas.CasParams +import org.scalatest.funspec.AnyFunSpec +import org.scalatest.matchers.should.Matchers +import org.scalatest.{BeforeAndAfter, PrivateMethodTester} + +class CasClientTest extends AnyFunSpec with Matchers with PrivateMethodTester with BeforeAndAfter { + describe("CasParams") { + it("Removes trailing and leading edge slashes") { + CasParams.removeTrailingAndLeadingSlash("/foo/bar/baz/") should be ("foo/bar/baz") + } + it("Should format securityUri path properly, with leading slash") { + val params = CasParams("/kayttooikeus-service", "example", "passwd") + params.service.securityUri.path.renderString should be("/kayttooikeus-service/j_spring_cas_security_check") + } + it("Should format securityUri path properly, with trailing and leading slash") { + val params = CasParams("/kayttooikeus-service/", "example", "passwd") + params.service.securityUri.path.renderString should be("/kayttooikeus-service/j_spring_cas_security_check") + } + it("Should format securityUri path properly, with trailing slash") { + val params = CasParams("kayttooikeus-service/", "example", "passwd") + params.service.securityUri.path.renderString should be("/kayttooikeus-service/j_spring_cas_security_check") + } + it("Should format securityUri path properly, without trailing or leading slash") { + val params = CasParams("kayttooikeus-service", "example", "passwd") + params.service.securityUri.path.renderString should be("/kayttooikeus-service/j_spring_cas_security_check") + } + it("Should format securityUri path properly, with leading slash and custom security URI suffix") { + val params = CasParams("/kayttooikeus-service", "security_uri_suffix", "example", "passwd") + params.service.securityUri.path.renderString should be("/kayttooikeus-service/security_uri_suffix") + } + it("Should format securityUri path properly, with trailing and leading slash and custom security URI suffix") { + val params = CasParams("/kayttooikeus-service/", "security_uri_suffix","example", "passwd") + params.service.securityUri.path.renderString should be("/kayttooikeus-service/security_uri_suffix") + } + it("Should format securityUri path properly, without trailing edge and custom security URI suffix") { + val params = CasParams("kayttooikeus-service/", "security_uri_suffix", "example", "passwd") + params.service.securityUri.path.renderString should be("/kayttooikeus-service/security_uri_suffix") + } + it("Should format securityUri path properly, without trailing or leading edge and custom security URI suffix (1)") { + val params = CasParams("kayttooikeus-service", "/security_uri_suffix/bar/baz", "example", "passwd") + params.service.securityUri.path.renderString should be("/kayttooikeus-service/security_uri_suffix/bar/baz") + } + it("Should format securityUri path properly, without trailing or leading edge and custom security URI suffix (2)") { + val params = CasParams("kayttooikeus-service", "/security_uri_suffix/bar", "example", "passwd") + params.service.securityUri.path.renderString should be("/kayttooikeus-service/security_uri_suffix/bar") + } + it("Should format securityUri path properly, without trailing or leading edge and custom security URI suffix (3)") { + val params = CasParams("kayttooikeus-service", "/security_uri_suffix/bar/baz/", "example", "passwd") + params.service.securityUri.path.renderString should be("/kayttooikeus-service/security_uri_suffix/bar/baz") + } + } +}