diff --git a/libs/wire-api/default.nix b/libs/wire-api/default.nix index 8564a5767a..816c540079 100644 --- a/libs/wire-api/default.nix +++ b/libs/wire-api/default.nix @@ -51,6 +51,7 @@ , iproute , iso3166-country-codes , iso639 +, jose , lens , lib , memory @@ -152,6 +153,7 @@ mkDerivation { iproute iso3166-country-codes iso639 + jose lens memory metrics-wai diff --git a/libs/wire-api/src/Wire/API/OAuth.hs b/libs/wire-api/src/Wire/API/OAuth.hs new file mode 100644 index 0000000000..a6365f5ba2 --- /dev/null +++ b/libs/wire-api/src/Wire/API/OAuth.hs @@ -0,0 +1,490 @@ +-- This file is part of the Wire Server implementation. +-- +-- Copyright (C) 2022 Wire Swiss GmbH +-- +-- This program is free software: you can redistribute it and/or modify it under +-- the terms of the GNU Affero General Public License as published by the Free +-- Software Foundation, either version 3 of the License, or (at your option) any +-- later version. +-- +-- This program is distributed in the hope that it will be useful, but WITHOUT +-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more +-- details. +-- +-- You should have received a copy of the GNU Affero General Public License along +-- with this program. If not, see . + +module Wire.API.OAuth where + +import Cassandra hiding (Set) +import Control.Lens (preview, view) +import Control.Monad.Except +import Crypto.JWT hiding (params, uri) +import qualified Data.Aeson as A +import qualified Data.Aeson.KeyMap as M +import qualified Data.Aeson.Types as A +import Data.ByteString.Conversion +import Data.ByteString.Lazy (toStrict) +import qualified Data.HashMap.Strict as HM +import Data.Id +import Data.Range +import Data.Schema +import qualified Data.Set as Set +import Data.String.Conversions (cs) +import qualified Data.Swagger as S +import qualified Data.Text as T +import Data.Text.Ascii +import qualified Data.Text.Encoding as TE +import Data.Text.Encoding.Error as TErr +import Data.Time (NominalDiffTime) +import Imports hiding (exp) +import Servant hiding (Handler, Tagged, Unauthorized) +import URI.ByteString +import Web.FormUrlEncoded (Form (..), FromForm (..), ToForm (..), parseUnique) +import Wire.API.Error +import Wire.API.Routes.MultiVerb +import Wire.API.Routes.Named (Named (..)) +import Wire.API.Routes.Public (ZUser) + +-------------------------------------------------------------------------------- +-- Types + +newtype RedirectUrl = RedirectUrl {unRedirectUrl :: URIRef Absolute} + deriving (Eq, Show, Generic) + deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema RedirectUrl) + +instance ToByteString RedirectUrl where + builder = serializeURIRef . unRedirectUrl + +instance FromByteString RedirectUrl where + parser = RedirectUrl <$> uriParser strictURIParserOptions + +instance ToSchema RedirectUrl where + schema = + (TE.decodeUtf8 . serializeURIRef' . unRedirectUrl) + .= (RedirectUrl <$> parsedText "RedirectUrl" (runParser (uriParser strictURIParserOptions) . TE.encodeUtf8)) + +instance ToHttpApiData RedirectUrl where + toUrlPiece = TE.decodeUtf8With TErr.lenientDecode . toHeader + toHeader = serializeURIRef' . unRedirectUrl + +instance FromHttpApiData RedirectUrl where + parseUrlPiece = parseHeader . TE.encodeUtf8 + parseHeader = bimap (T.pack . show) RedirectUrl . parseURI strictURIParserOptions + +newtype OAuthApplicationName = OAuthApplicationName {unOAuthApplicationName :: Range 1 256 Text} + deriving (Eq, Show, Generic) + deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthApplicationName) + +instance ToSchema OAuthApplicationName where + schema = OAuthApplicationName <$> unOAuthApplicationName .= schema + +data NewOAuthClient = NewOAuthClient + { nocApplicationName :: OAuthApplicationName, + nocRedirectUrl :: RedirectUrl + } + deriving (Eq, Show, Generic) + deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema NewOAuthClient) + +instance ToSchema NewOAuthClient where + schema = + object "NewOAuthClient" $ + NewOAuthClient + <$> nocApplicationName .= field "applicationName" schema + <*> nocRedirectUrl .= field "redirectUrl" schema + +newtype OAuthClientPlainTextSecret = OAuthClientPlainTextSecret {unOAuthClientPlainTextSecret :: AsciiBase16} + deriving (Eq, Generic) + deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthClientPlainTextSecret) + +instance Show OAuthClientPlainTextSecret where + show _ = "" + +instance ToSchema OAuthClientPlainTextSecret where + schema = (toText . unOAuthClientPlainTextSecret) .= parsedText "OAuthClientPlainTextSecret" (fmap OAuthClientPlainTextSecret . validateBase16) + +instance FromHttpApiData OAuthClientPlainTextSecret where + parseQueryParam = bimap cs OAuthClientPlainTextSecret . validateBase16 . cs + +instance ToHttpApiData OAuthClientPlainTextSecret where + toQueryParam = toText . unOAuthClientPlainTextSecret + +data OAuthClientCredentials = OAuthClientCredentials + { occClientId :: OAuthClientId, + occClientSecret :: OAuthClientPlainTextSecret + } + deriving (Eq, Show, Generic) + deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthClientCredentials) + +instance ToSchema OAuthClientCredentials where + schema = + object "OAuthClientCredentials" $ + OAuthClientCredentials + <$> occClientId .= field "clientId" schema + <*> occClientSecret .= field "clientSecret" schema + +data OAuthClient = OAuthClient + { ocId :: OAuthClientId, + ocName :: OAuthApplicationName, + ocRedirectUrl :: RedirectUrl + } + deriving (Eq, Show, Generic) + deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthClient) + +instance ToSchema OAuthClient where + schema = + object "OAuthClient" $ + OAuthClient + <$> ocId .= field "clientId" schema + <*> ocName .= field "applicationName" schema + <*> ocRedirectUrl .= field "redirectUrl" schema + +data OAuthResponseType = OAuthResponseTypeCode + deriving (Eq, Show, Generic) + deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthResponseType) + +instance ToSchema OAuthResponseType where + schema :: ValueSchema NamedSwaggerDoc OAuthResponseType + schema = + enum @Text "OAuthResponseType" $ + mconcat + [ element "code" OAuthResponseTypeCode + ] + +data OAuthScope + = ConversationCreate + | ConversationCodeCreate + deriving (Eq, Show, Generic, Ord) + +instance ToByteString OAuthScope where + builder = \case + ConversationCreate -> "conversation:create" + ConversationCodeCreate -> "conversation-code:create" + +instance FromByteString OAuthScope where + parser = do + s <- parser + case s & T.toLower of + "conversation:create" -> pure ConversationCreate + "conversation-code:create" -> pure ConversationCodeCreate + _ -> fail "invalid scope" + +newtype OAuthScopes = OAuthScopes {unOAuthScopes :: Set OAuthScope} + deriving (Eq, Show, Generic) + deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthScopes) + +instance ToSchema OAuthScopes where + schema = OAuthScopes <$> (oauthScopesToText . unOAuthScopes) .= withParser schema oauthScopeParser + +oauthScopesToText :: Set OAuthScope -> Text +oauthScopesToText = T.intercalate " " . fmap (cs . toByteString') . Set.toList + +oauthScopeParser :: Text -> A.Parser (Set OAuthScope) +oauthScopeParser "" = pure Set.empty +oauthScopeParser scope = + pure $ (not . T.null) `filter` T.splitOn " " scope & maybe Set.empty Set.fromList . mapM (fromByteString' . cs) + +data NewOAuthAuthCode = NewOAuthAuthCode + { noacClientId :: OAuthClientId, + noacScope :: OAuthScopes, + noacResponseType :: OAuthResponseType, + noacRedirectUri :: RedirectUrl, + noacState :: Text + } + deriving (Eq, Show, Generic) + deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema NewOAuthAuthCode) + +instance ToSchema NewOAuthAuthCode where + schema = + object "NewOAuthAuthCode" $ + NewOAuthAuthCode + <$> noacClientId .= field "clientId" schema + <*> noacScope .= field "scope" schema + <*> noacResponseType .= field "responseType" schema + <*> noacRedirectUri .= field "redirectUri" schema + <*> noacState .= field "state" schema + +newtype OAuthAuthCode = OAuthAuthCode {unOAuthAuthCode :: AsciiBase16} + deriving (Show, Eq, Generic) + +instance ToSchema OAuthAuthCode where + schema = (toText . unOAuthAuthCode) .= parsedText "OAuthAuthCode" (fmap OAuthAuthCode . validateBase16) + +instance ToByteString OAuthAuthCode where + builder = builder . unOAuthAuthCode + +instance FromByteString OAuthAuthCode where + parser = OAuthAuthCode <$> parser + +instance FromHttpApiData OAuthAuthCode where + parseQueryParam = bimap cs OAuthAuthCode . validateBase16 . cs + +instance ToHttpApiData OAuthAuthCode where + toQueryParam = toText . unOAuthAuthCode + +data OAuthGrantType = OAuthGrantTypeAuthorizationCode + deriving (Eq, Show, Generic) + deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthGrantType) + +instance ToSchema OAuthGrantType where + schema = + enum @Text "OAuthGrantType" $ + mconcat + [ element "authorization_code" OAuthGrantTypeAuthorizationCode + ] + +instance FromByteString OAuthGrantType where + parser = do + s <- parser + case s & T.toLower of + "authorization_code" -> pure OAuthGrantTypeAuthorizationCode + _ -> fail "invalid OAuthGrantType" + +instance ToByteString OAuthGrantType where + builder = \case + OAuthGrantTypeAuthorizationCode -> "authorization_code" + +instance FromHttpApiData OAuthGrantType where + parseQueryParam = maybe (Left "invalid OAuthGrantType") pure . fromByteString . cs + +instance ToHttpApiData OAuthGrantType where + toQueryParam = cs . toByteString + +data OAuthAccessTokenRequest = OAuthAccessTokenRequest + { oatGrantType :: OAuthGrantType, + oatClientId :: OAuthClientId, + oatClientSecret :: OAuthClientPlainTextSecret, + oatCode :: OAuthAuthCode, + oatRedirectUri :: RedirectUrl + } + deriving (Eq, Show, Generic) + deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthAccessTokenRequest) + +instance ToSchema OAuthAccessTokenRequest where + schema = + object "OAuthAccessTokenRequest" $ + OAuthAccessTokenRequest + <$> oatGrantType .= field "grantType" schema + <*> oatClientId .= field "clientId" schema + <*> oatClientSecret .= field "clientSecret" schema + <*> oatCode .= field "code" schema + <*> oatRedirectUri .= field "redirectUri" schema + +instance FromForm OAuthAccessTokenRequest where + fromForm f = + OAuthAccessTokenRequest + <$> parseUnique "grant_type" f + <*> parseUnique "client_id" f + <*> parseUnique "client_secret" f + <*> parseUnique "code" f + <*> parseUnique "redirect_uri" f + +instance ToForm OAuthAccessTokenRequest where + toForm req = + Form $ + mempty + & HM.insert "grant_type" [toQueryParam (oatGrantType req)] + & HM.insert "client_id" [toQueryParam (oatClientId req)] + & HM.insert "client_secret" [toQueryParam (oatClientSecret req)] + & HM.insert "code" [toQueryParam (oatCode req)] + & HM.insert "redirect_uri" [toQueryParam (oatRedirectUri req)] + +data OAuthAccessTokenType = OAuthAccessTokenTypeBearer + deriving (Eq, Show, Generic) + deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthAccessTokenType) + +instance ToSchema OAuthAccessTokenType where + schema = + enum @Text "OAuthAccessTokenType" $ + mconcat + [ element "Bearer" OAuthAccessTokenTypeBearer + ] + +newtype OAuthAccessToken = OAuthAccessToken {unOAuthAccessToken :: SignedJWT} + deriving (Show, Eq, Generic) + deriving (A.ToJSON, A.FromJSON, S.ToSchema) via Schema OAuthAccessToken + +instance ToByteString OAuthAccessToken where + builder = builder . encodeCompact . unOAuthAccessToken + +instance FromByteString OAuthAccessToken where + parser = do + t <- parser @Text + case decodeCompact (cs (TE.encodeUtf8 t)) of + Left (err :: JWTError) -> fail $ show err + Right jwt -> pure $ OAuthAccessToken jwt + +instance ToHttpApiData OAuthAccessToken where + toHeader = toByteString' + toUrlPiece = cs . toHeader + +instance FromHttpApiData OAuthAccessToken where + parseHeader = either (Left . cs) pure . runParser parser . cs + parseUrlPiece = parseHeader . cs + +instance ToSchema OAuthAccessToken where + schema = (TE.decodeUtf8 . toByteString') .= withParser schema (either fail pure . runParser parser . cs) + +data OAuthAccessTokenResponse = OAuthAccessTokenResponse + { oatAccessToken :: OAuthAccessToken, + oatTokenType :: OAuthAccessTokenType, + oatExpiresIn :: NominalDiffTime + } + deriving (Eq, Show, Generic) + deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthAccessTokenResponse) + +instance ToSchema OAuthAccessTokenResponse where + schema = + object "OAuthAccessTokenResponse" $ + OAuthAccessTokenResponse + <$> oatAccessToken .= field "accessToken" schema + <*> oatTokenType .= field "tokenType" schema + <*> oatExpiresIn .= field "expiresIn" (fromIntegral <$> roundDiffTime .= schema) + where + roundDiffTime :: NominalDiffTime -> Int32 + roundDiffTime = round + +data OAuthClaimSet = OAuthClaimSet {jwtClaims :: ClaimsSet, scope :: OAuthScopes} + deriving (Eq, Show, Generic) + +instance HasClaimsSet OAuthClaimSet where + claimsSet f s = fmap (\a' -> s {jwtClaims = a'}) (f (jwtClaims s)) + +instance A.FromJSON OAuthClaimSet where + parseJSON = A.withObject "OAuthClaimSet" $ \o -> + OAuthClaimSet + <$> A.parseJSON (A.Object o) + <*> o A..: "scope" + +instance A.ToJSON OAuthClaimSet where + toJSON s = + ins "scope" (scope s) (A.toJSON (jwtClaims s)) + where + ins k v (A.Object o) = A.Object $ M.insert k (A.toJSON v) o + ins _ _ a = a + +csUserId :: OAuthClaimSet -> Maybe UserId +csUserId = + view claimSub + >=> preview string + >=> either (const Nothing) pure . parseIdFromText + +-------------------------------------------------------------------------------- +-- API Internal + +type IOAuthAPI = + Named + "create-oauth-client" + ( Summary "Register an OAuth client" + :> CanThrow 'OAuthFeatureDisabled + :> "i" + :> "oauth" + :> "clients" + :> ReqBody '[JSON] NewOAuthClient + :> Post '[JSON] OAuthClientCredentials + ) + +-------------------------------------------------------------------------------- +-- API Public + +type OAuthAPI = + Named + "get-oauth-client" + ( Summary "Get OAuth client information" + :> CanThrow 'OAuthFeatureDisabled + :> ZUser + :> "oauth" + :> "clients" + :> Capture "ClientId" OAuthClientId + :> MultiVerb + 'GET + '[JSON] + '[ ErrorResponse 'OAuthClientNotFound, + Respond 200 "OAuth client found" OAuthClient + ] + (Maybe OAuthClient) + ) + :<|> Named + "create-oauth-auth-code" + ( Summary "" + :> CanThrow 'UnsupportedResponseType + :> CanThrow 'RedirectUrlMissMatch + :> CanThrow 'OAuthClientNotFound + :> CanThrow 'OAuthFeatureDisabled + :> ZUser + :> "oauth" + :> "authorization" + :> "codes" + :> ReqBody '[JSON] NewOAuthAuthCode + :> MultiVerb + 'POST + '[JSON] + '[WithHeaders '[Header "Location" RedirectUrl] RedirectUrl (RespondEmpty 302 "Found")] + RedirectUrl + ) + :<|> Named + "create-oauth-access-token" + ( Summary "Create an OAuth access token" + :> CanThrow 'JwtError + :> CanThrow 'OAuthAuthCodeNotFound + :> CanThrow 'OAuthClientNotFound + :> CanThrow 'OAuthFeatureDisabled + :> "oauth" + :> "token" + :> ReqBody '[FormUrlEncoded] OAuthAccessTokenRequest + :> Post '[JSON] OAuthAccessTokenResponse + ) + +-------------------------------------------------------------------------------- +-- Errors + +data OAuthError + = OAuthClientNotFound + | RedirectUrlMissMatch + | UnsupportedResponseType + | JwtError + | OAuthAuthCodeNotFound + | OAuthFeatureDisabled + | Unauthorized + +type instance MapError 'OAuthClientNotFound = 'StaticError 404 "not-found" "OAuth client not found" + +type instance MapError 'RedirectUrlMissMatch = 'StaticError 400 "redirect-url-miss-match" "Redirect URL miss match" + +type instance MapError 'UnsupportedResponseType = 'StaticError 400 "unsupported-response-type" "Unsupported response type" + +type instance MapError 'JwtError = 'StaticError 500 "jwt-error" "Internal error while creating JWT" + +type instance MapError 'OAuthAuthCodeNotFound = 'StaticError 404 "not-found" "OAuth authorization code not found" + +type instance MapError 'OAuthFeatureDisabled = 'StaticError 403 "forbidden" "OAuth is disabled" + +type instance MapError 'Unauthorized = 'StaticError 401 "unauthorized" "Unauthorized" + +-------------------------------------------------------------------------------- +-- CQL instances + +instance Cql OAuthApplicationName where + ctype = Tagged TextColumn + toCql = CqlText . fromRange . unOAuthApplicationName + fromCql (CqlText t) = checkedEither t <&> OAuthApplicationName + fromCql _ = Left "OAuthApplicationName: Text expected" + +instance Cql RedirectUrl where + ctype = Tagged BlobColumn + toCql = CqlBlob . toByteString + fromCql (CqlBlob t) = runParser parser (toStrict t) + fromCql _ = Left "RedirectUrl: Blob expected" + +instance Cql OAuthAuthCode where + ctype = Tagged AsciiColumn + toCql = CqlAscii . toText . unOAuthAuthCode + fromCql (CqlAscii t) = OAuthAuthCode <$> validateBase16 t + fromCql _ = Left "OAuthAuthCode: Ascii expected" + +instance Cql OAuthScope where + ctype = Tagged TextColumn + toCql = CqlText . cs . toByteString' + fromCql (CqlText t) = maybe (Left "invalid oauth scope") Right $ fromByteString' (cs t) + fromCql _ = Left "OAuthScope: Text expected" diff --git a/libs/wire-api/src/Wire/API/RawJson.hs b/libs/wire-api/src/Wire/API/RawJson.hs index 295202c1ed..f2806c972f 100644 --- a/libs/wire-api/src/Wire/API/RawJson.hs +++ b/libs/wire-api/src/Wire/API/RawJson.hs @@ -21,7 +21,7 @@ import Imports import Servant -- | Wrap json content as plain 'LByteString' --- This type is intented to be used to receive json content as 'LByteString'. +-- This type is intended to be used to receive json content as 'LByteString'. -- Warning: There is no validation of the json content. It may be any string. newtype RawJson = RawJson {rawJsonBytes :: LByteString} diff --git a/libs/wire-api/src/Wire/API/Routes/Bearer.hs b/libs/wire-api/src/Wire/API/Routes/Bearer.hs index ca88c1c5e4..345e966098 100644 --- a/libs/wire-api/src/Wire/API/Routes/Bearer.hs +++ b/libs/wire-api/src/Wire/API/Routes/Bearer.hs @@ -35,6 +35,10 @@ instance FromHttpApiData a => FromHttpApiData (Bearer a) where _ -> Left "Invalid authorization scheme" parseUrlPiece = parseHeader . T.encodeUtf8 +instance ToHttpApiData a => ToHttpApiData (Bearer a) where + toHeader = (<>) "Bearer " . toHeader . unBearer + toUrlPiece = T.decodeUtf8 . toHeader + type BearerHeader a = Header' '[Lenient] "Authorization" (Bearer a) type BearerQueryParam = @@ -42,6 +46,9 @@ type BearerQueryParam = [Lenient, Description "Access token"] "access_token" +instance ToParamSchema (Bearer a) where + toParamSchema _ = toParamSchema (Proxy @Text) + instance HasSwagger api => HasSwagger (Bearer a :> api) where toSwagger _ = toSwagger (Proxy @api) diff --git a/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs b/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs index e7f9a2d7a7..c518c3b03d 100644 --- a/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs +++ b/libs/wire-api/src/Wire/API/Routes/Public/Brig.hs @@ -48,6 +48,7 @@ import Wire.API.Error.Brig import Wire.API.Error.Empty import Wire.API.MLS.KeyPackage import Wire.API.MLS.Servant +import Wire.API.OAuth (OAuthAccessToken) import Wire.API.Properties import Wire.API.Routes.Bearer import Wire.API.Routes.Cookies @@ -244,11 +245,14 @@ type UserAPI = RichInfoAssocList ) +type OptOAuth = Header' '[Optional, Strict] "Authorization" (Bearer OAuthAccessToken) + type SelfAPI = Named "get-self" ( Summary "Get your own profile" - :> ZUser + :> ZOptUser + :> OptOAuth :> "self" :> Get '[JSON] SelfProfile ) diff --git a/libs/wire-api/wire-api.cabal b/libs/wire-api/wire-api.cabal index bf841e4d75..8dd0159303 100644 --- a/libs/wire-api/wire-api.cabal +++ b/libs/wire-api/wire-api.cabal @@ -60,6 +60,7 @@ library Wire.API.MLS.Servant Wire.API.MLS.Welcome Wire.API.Notification + Wire.API.OAuth Wire.API.Properties Wire.API.Provider Wire.API.Provider.Bot @@ -238,6 +239,7 @@ library , iproute >=1.5 , iso3166-country-codes >=0.2 , iso639 >=0.1 + , jose , lens >=4.12 , memory , metrics-wai diff --git a/services/brig/src/Brig/API/Internal.hs b/services/brig/src/Brig/API/Internal.hs index 9051b52fe2..dcc9cb951d 100644 --- a/services/brig/src/Brig/API/Internal.hs +++ b/services/brig/src/Brig/API/Internal.hs @@ -30,7 +30,7 @@ import qualified Brig.API.Connection as API import Brig.API.Error import Brig.API.Handler import Brig.API.MLS.KeyPackages.Validation -import Brig.API.OAuth (IOAuthAPI, internalOauthAPI) +import Brig.API.OAuth (internalOauthAPI) import Brig.API.Types import qualified Brig.API.User as API import qualified Brig.API.User as Api @@ -91,6 +91,7 @@ import qualified Wire.API.Error.Brig as E import Wire.API.MLS.Credential import Wire.API.MLS.KeyPackage import Wire.API.MLS.Serialisation +import Wire.API.OAuth import Wire.API.Routes.Internal.Brig import qualified Wire.API.Routes.Internal.Brig as BrigIRoutes import Wire.API.Routes.Internal.Brig.Connection diff --git a/services/brig/src/Brig/API/OAuth.hs b/services/brig/src/Brig/API/OAuth.hs index ca41d2c1a6..17919d83dc 100644 --- a/services/brig/src/Brig/API/OAuth.hs +++ b/services/brig/src/Brig/API/OAuth.hs @@ -17,7 +17,7 @@ module Brig.API.OAuth where -import Brig.API.Error (throwStd) +import Brig.API.Error (badRequest, throwStd) import Brig.API.Handler (Handler) import Brig.App import Brig.Effects.Jwk @@ -29,352 +29,29 @@ import qualified Cassandra as C import Control.Lens (view, (.~), (?~), (^?)) import Control.Monad.Except import Crypto.JWT hiding (params, uri) -import qualified Data.Aeson as A -import qualified Data.Aeson.KeyMap as M -import qualified Data.Aeson.Types as A import Data.ByteString.Conversion -import Data.ByteString.Lazy (toStrict) import Data.Domain -import qualified Data.HashMap.Strict as HM import Data.Id (OAuthClientId, UserId, idToText, randomId) import Data.Misc (PlainTextPassword (PlainTextPassword)) -import Data.Range -import Data.Schema import qualified Data.Set as Set import Data.String.Conversions (cs) -import qualified Data.Swagger as S -import qualified Data.Text as T import Data.Text.Ascii -import qualified Data.Text.Encoding as TE -import Data.Text.Encoding.Error as TErr import Data.Time (NominalDiffTime, addUTCTime) import Imports hiding (exp) import OpenSSL.Random (randBytes) import Polysemy (Member) -import Servant hiding (Handler, Tagged) +import Servant hiding (Handler, Tagged, Unauthorized) import URI.ByteString -import Web.FormUrlEncoded (Form (..), FromForm (..), ToForm (..), parseUnique) import Wire.API.Error -import Wire.API.Routes.MultiVerb +import Wire.API.OAuth +import Wire.API.Routes.Bearer (Bearer (Bearer)) import Wire.API.Routes.Named (Named (..)) -import Wire.API.Routes.Public (ZUser) import Wire.Sem.Now (Now) import qualified Wire.Sem.Now as Now --------------------------------------------------------------------------------- --- Types - -newtype RedirectUrl = RedirectUrl {unRedirectUrl :: URIRef Absolute} - deriving (Eq, Show, Generic) - deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema RedirectUrl) - -instance ToByteString RedirectUrl where - builder = serializeURIRef . unRedirectUrl - -instance FromByteString RedirectUrl where - parser = RedirectUrl <$> uriParser strictURIParserOptions - -instance ToSchema RedirectUrl where - schema = - (TE.decodeUtf8 . serializeURIRef' . unRedirectUrl) - .= (RedirectUrl <$> parsedText "RedirectUrl" (runParser (uriParser strictURIParserOptions) . TE.encodeUtf8)) - -instance ToHttpApiData RedirectUrl where - toUrlPiece = TE.decodeUtf8With TErr.lenientDecode . toHeader - toHeader = serializeURIRef' . unRedirectUrl - -instance FromHttpApiData RedirectUrl where - parseUrlPiece = parseHeader . TE.encodeUtf8 - parseHeader = bimap (T.pack . show) RedirectUrl . parseURI strictURIParserOptions - -newtype OAuthApplicationName = OAuthApplicationName {unOAuthApplicationName :: Range 1 256 Text} - deriving (Eq, Show, Generic) - deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthApplicationName) - -instance ToSchema OAuthApplicationName where - schema = OAuthApplicationName <$> unOAuthApplicationName .= schema - -data NewOAuthClient = NewOAuthClient - { nocApplicationName :: OAuthApplicationName, - nocRedirectUrl :: RedirectUrl - } - deriving (Eq, Show, Generic) - deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema NewOAuthClient) - -instance ToSchema NewOAuthClient where - schema = - object "NewOAuthClient" $ - NewOAuthClient - <$> nocApplicationName .= field "applicationName" schema - <*> nocRedirectUrl .= field "redirectUrl" schema - -newtype OAuthClientPlainTextSecret = OAuthClientPlainTextSecret {unOAuthClientPlainTextSecret :: AsciiBase16} - deriving (Eq, Generic) - deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthClientPlainTextSecret) - -instance Show OAuthClientPlainTextSecret where - show _ = "" - -instance ToSchema OAuthClientPlainTextSecret where - schema = (toText . unOAuthClientPlainTextSecret) .= parsedText "OAuthClientPlainTextSecret" (fmap OAuthClientPlainTextSecret . validateBase16) - -instance FromHttpApiData OAuthClientPlainTextSecret where - parseQueryParam = bimap cs OAuthClientPlainTextSecret . validateBase16 . cs - -instance ToHttpApiData OAuthClientPlainTextSecret where - toQueryParam = toText . unOAuthClientPlainTextSecret - -data OAuthClientCredentials = OAuthClientCredentials - { occClientId :: OAuthClientId, - occClientSecret :: OAuthClientPlainTextSecret - } - deriving (Eq, Show, Generic) - deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthClientCredentials) - -instance ToSchema OAuthClientCredentials where - schema = - object "OAuthClientCredentials" $ - OAuthClientCredentials - <$> occClientId .= field "clientId" schema - <*> occClientSecret .= field "clientSecret" schema - -data OAuthClient = OAuthClient - { ocId :: OAuthClientId, - ocName :: OAuthApplicationName, - ocRedirectUrl :: RedirectUrl - } - deriving (Eq, Show, Generic) - deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthClient) - -instance ToSchema OAuthClient where - schema = - object "OAuthClient" $ - OAuthClient - <$> ocId .= field "clientId" schema - <*> ocName .= field "applicationName" schema - <*> ocRedirectUrl .= field "redirectUrl" schema - -data OAuthResponseType = OAuthResponseTypeCode - deriving (Eq, Show, Generic) - deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthResponseType) - -instance ToSchema OAuthResponseType where - schema :: ValueSchema NamedSwaggerDoc OAuthResponseType - schema = - enum @Text "OAuthResponseType" $ - mconcat - [ element "code" OAuthResponseTypeCode - ] - -data OAuthScope - = ConversationCreate - | ConversationCodeCreate - deriving (Eq, Show, Generic, Ord) - -instance ToByteString OAuthScope where - builder = \case - ConversationCreate -> "conversation:create" - ConversationCodeCreate -> "conversation-code:create" - -instance FromByteString OAuthScope where - parser = do - s <- parser - case s & T.toLower of - "conversation:create" -> pure ConversationCreate - "conversation-code:create" -> pure ConversationCodeCreate - _ -> fail "invalid scope" - -newtype OAuthScopes = OAuthScopes {unOAuthScopes :: Set OAuthScope} - deriving (Eq, Show, Generic) - deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthScopes) - -instance ToSchema OAuthScopes where - schema = OAuthScopes <$> (oauthScopesToText . unOAuthScopes) .= withParser schema oauthScopeParser - -oauthScopesToText :: Set OAuthScope -> Text -oauthScopesToText = T.intercalate " " . fmap (cs . toByteString') . Set.toList - -oauthScopeParser :: Text -> A.Parser (Set OAuthScope) -oauthScopeParser "" = pure Set.empty -oauthScopeParser scope = - pure $ (not . T.null) `filter` T.splitOn " " scope & maybe Set.empty Set.fromList . mapM (fromByteString' . cs) - -data NewOAuthAuthCode = NewOAuthAuthCode - { noacClientId :: OAuthClientId, - noacScope :: OAuthScopes, - noacResponseType :: OAuthResponseType, - noacRedirectUri :: RedirectUrl, - noacState :: Text - } - deriving (Eq, Show, Generic) - deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema NewOAuthAuthCode) - -instance ToSchema NewOAuthAuthCode where - schema = - object "NewOAuthAuthCode" $ - NewOAuthAuthCode - <$> noacClientId .= field "clientId" schema - <*> noacScope .= field "scope" schema - <*> noacResponseType .= field "responseType" schema - <*> noacRedirectUri .= field "redirectUri" schema - <*> noacState .= field "state" schema - -newtype OAuthAuthCode = OAuthAuthCode {unOAuthAuthCode :: AsciiBase16} - deriving (Show, Eq, Generic) - -instance ToSchema OAuthAuthCode where - schema = (toText . unOAuthAuthCode) .= parsedText "OAuthAuthCode" (fmap OAuthAuthCode . validateBase16) - -instance ToByteString OAuthAuthCode where - builder = builder . unOAuthAuthCode - -instance FromByteString OAuthAuthCode where - parser = OAuthAuthCode <$> parser - -instance FromHttpApiData OAuthAuthCode where - parseQueryParam = bimap cs OAuthAuthCode . validateBase16 . cs - -instance ToHttpApiData OAuthAuthCode where - toQueryParam = toText . unOAuthAuthCode - -data OAuthGrantType = OAuthGrantTypeAuthorizationCode - deriving (Eq, Show, Generic) - deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthGrantType) - -instance ToSchema OAuthGrantType where - schema = - enum @Text "OAuthGrantType" $ - mconcat - [ element "authorization_code" OAuthGrantTypeAuthorizationCode - ] - -instance FromByteString OAuthGrantType where - parser = do - s <- parser - case s & T.toLower of - "authorization_code" -> pure OAuthGrantTypeAuthorizationCode - _ -> fail "invalid OAuthGrantType" - -instance ToByteString OAuthGrantType where - builder = \case - OAuthGrantTypeAuthorizationCode -> "authorization_code" - -instance FromHttpApiData OAuthGrantType where - parseQueryParam = maybe (Left "invalid OAuthGrantType") pure . fromByteString . cs - -instance ToHttpApiData OAuthGrantType where - toQueryParam = cs . toByteString - -data OAuthAccessTokenRequest = OAuthAccessTokenRequest - { oatGrantType :: OAuthGrantType, - oatClientId :: OAuthClientId, - oatClientSecret :: OAuthClientPlainTextSecret, - oatCode :: OAuthAuthCode, - oatRedirectUri :: RedirectUrl - } - deriving (Eq, Show, Generic) - deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthAccessTokenRequest) - -instance ToSchema OAuthAccessTokenRequest where - schema = - object "OAuthAccessTokenRequest" $ - OAuthAccessTokenRequest - <$> oatGrantType .= field "grantType" schema - <*> oatClientId .= field "clientId" schema - <*> oatClientSecret .= field "clientSecret" schema - <*> oatCode .= field "code" schema - <*> oatRedirectUri .= field "redirectUri" schema - -instance FromForm OAuthAccessTokenRequest where - fromForm f = - OAuthAccessTokenRequest - <$> parseUnique "grant_type" f - <*> parseUnique "client_id" f - <*> parseUnique "client_secret" f - <*> parseUnique "code" f - <*> parseUnique "redirect_uri" f - -instance ToForm OAuthAccessTokenRequest where - toForm req = - Form $ - mempty - & HM.insert "grant_type" [toQueryParam (oatGrantType req)] - & HM.insert "client_id" [toQueryParam (oatClientId req)] - & HM.insert "client_secret" [toQueryParam (oatClientSecret req)] - & HM.insert "code" [toQueryParam (oatCode req)] - & HM.insert "redirect_uri" [toQueryParam (oatRedirectUri req)] - -data OAuthAccessTokenType = OAuthAccessTokenTypeBearer - deriving (Eq, Show, Generic) - deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthAccessTokenType) - -instance ToSchema OAuthAccessTokenType where - schema = - enum @Text "OAuthAccessTokenType" $ - mconcat - [ element "Bearer" OAuthAccessTokenTypeBearer - ] - -newtype OauthAccessToken = OauthAccessToken {unOauthAccessToken :: ByteString} - deriving (Show, Eq, Generic) - deriving (A.ToJSON, A.FromJSON, S.ToSchema) via Schema OauthAccessToken - -instance ToSchema OauthAccessToken where - schema = (TE.decodeUtf8 . unOauthAccessToken) .= fmap (OauthAccessToken . TE.encodeUtf8) schema - -data OAuthAccessTokenResponse = OAuthAccessTokenResponse - { oatAccessToken :: OauthAccessToken, - oatTokenType :: OAuthAccessTokenType, - oatExpiresIn :: NominalDiffTime - } - deriving (Eq, Show, Generic) - deriving (A.ToJSON, A.FromJSON, S.ToSchema) via (Schema OAuthAccessTokenResponse) - -instance ToSchema OAuthAccessTokenResponse where - schema = - object "OAuthAccessTokenResponse" $ - OAuthAccessTokenResponse - <$> oatAccessToken .= field "accessToken" schema - <*> oatTokenType .= field "tokenType" schema - <*> oatExpiresIn .= field "expiresIn" (fromIntegral <$> roundDiffTime .= schema) - where - roundDiffTime :: NominalDiffTime -> Int32 - roundDiffTime = round - -data OAuthClaimSet = OAuthClaimSet {jwtClaims :: ClaimsSet, scope :: OAuthScopes} - deriving (Eq, Show, Generic) - -instance HasClaimsSet OAuthClaimSet where - claimsSet f s = fmap (\a' -> s {jwtClaims = a'}) (f (jwtClaims s)) - -instance A.FromJSON OAuthClaimSet where - parseJSON = A.withObject "OAuthClaimSet" $ \o -> - OAuthClaimSet - <$> A.parseJSON (A.Object o) - <*> o A..: "scope" - -instance A.ToJSON OAuthClaimSet where - toJSON s = - ins "scope" (scope s) (A.toJSON (jwtClaims s)) - where - ins k v (A.Object o) = A.Object $ M.insert k (A.toJSON v) o - ins _ _ a = a - -------------------------------------------------------------------------------- -- API Internal -type IOAuthAPI = - Named - "create-oauth-client" - ( Summary "Register an OAuth client" - :> CanThrow 'OAuthFeatureDisabled - :> "i" - :> "oauth" - :> "clients" - :> ReqBody '[JSON] NewOAuthClient - :> Post '[JSON] OAuthClientCredentials - ) - internalOauthAPI :: ServerT IOAuthAPI (Handler r) internalOauthAPI = Named @"create-oauth-client" createNewOAuthClient @@ -382,83 +59,12 @@ internalOauthAPI = -------------------------------------------------------------------------------- -- API Public -type OAuthAPI = - Named - "get-oauth-client" - ( Summary "Get OAuth client information" - :> CanThrow 'OAuthFeatureDisabled - :> ZUser - :> "oauth" - :> "clients" - :> Capture "ClientId" OAuthClientId - :> MultiVerb - 'GET - '[JSON] - '[ ErrorResponse 'OAuthClientNotFound, - Respond 200 "OAuth client found" OAuthClient - ] - (Maybe OAuthClient) - ) - :<|> Named - "create-oauth-auth-code" - ( Summary "" - :> CanThrow 'UnsupportedResponseType - :> CanThrow 'RedirectUrlMissMatch - :> CanThrow 'OAuthClientNotFound - :> CanThrow 'OAuthFeatureDisabled - :> ZUser - :> "oauth" - :> "authorization" - :> "codes" - :> ReqBody '[JSON] NewOAuthAuthCode - :> MultiVerb - 'POST - '[JSON] - '[WithHeaders '[Header "Location" RedirectUrl] RedirectUrl (RespondEmpty 302 "Found")] - RedirectUrl - ) - :<|> Named - "create-oauth-access-token" - ( Summary "Create an OAuth access token" - :> CanThrow 'JwtError - :> CanThrow 'OAuthAuthCodeNotFound - :> CanThrow 'OAuthClientNotFound - :> CanThrow 'OAuthFeatureDisabled - :> "oauth" - :> "token" - :> ReqBody '[FormUrlEncoded] OAuthAccessTokenRequest - :> Post '[JSON] OAuthAccessTokenResponse - ) - oauthAPI :: (Member Now r, Member Jwk r) => ServerT OAuthAPI (Handler r) oauthAPI = Named @"get-oauth-client" getOAuthClient :<|> Named @"create-oauth-auth-code" createNewOAuthAuthCode :<|> Named @"create-oauth-access-token" createAccessToken --------------------------------------------------------------------------------- --- Errors - -data OAuthError - = OAuthClientNotFound - | RedirectUrlMissMatch - | UnsupportedResponseType - | JwtError - | OAuthAuthCodeNotFound - | OAuthFeatureDisabled - -type instance MapError 'OAuthClientNotFound = 'StaticError 404 "not-found" "OAuth client not found" - -type instance MapError 'RedirectUrlMissMatch = 'StaticError 400 "redirect-url-miss-match" "Redirect URL miss match" - -type instance MapError 'UnsupportedResponseType = 'StaticError 400 "unsupported-response-type" "Unsupported response type" - -type instance MapError 'JwtError = 'StaticError 500 "jwt-error" "Internal error while creating JWT" - -type instance MapError 'OAuthAuthCodeNotFound = 'StaticError 404 "not-found" "OAuth authorization code not found" - -type instance MapError 'OAuthFeatureDisabled = 'StaticError 403 "forbidden" "OAuth is disabled" - -------------------------------------------------------------------------------- -- Handlers @@ -512,7 +118,7 @@ createAccessToken req = do claims <- mkClaims authCodeUserId domain authCodeScopes exp fp <- view settings >>= maybe (throwStd $ errorToWai @'JwtError) pure . Opt.setOAuthJwkKeyPair key <- lift (liftSem $ Jwk.get fp) >>= maybe (throwStd $ errorToWai @'JwtError) pure - token <- OauthAccessToken . cs . encodeCompact <$> signJwtToken key claims + token <- OAuthAccessToken <$> signJwtToken key claims pure $ OAuthAccessTokenResponse token OAuthAccessTokenTypeBearer exp where mkClaims :: (Member Now r) => UserId -> Domain -> OAuthScopes -> NominalDiffTime -> (Handler r) OAuthClaimSet @@ -552,10 +158,23 @@ createAccessToken req = do rand32Bytes :: MonadIO m => m AsciiBase16 rand32Bytes = liftIO . fmap encodeBase16 $ randBytes 32 -verify :: JWK -> ByteString -> IO (Either JWTError OAuthClaimSet) -verify k s = runJOSE $ do +handleZUserOrOAuth :: (Member Jwk r) => Maybe UserId -> Maybe (Bearer OAuthAccessToken) -> (Handler r) UserId +handleZUserOrOAuth (Just u) Nothing = pure u +handleZUserOrOAuth (Just _) (Just _) = throwStd $ badRequest "Authorization header and ZAuth header are mutually exclusive." +handleZUserOrOAuth Nothing Nothing = throwStd $ errorToWai @'Unauthorized +handleZUserOrOAuth Nothing (Just (Bearer token)) = verifyOAuthAccessToken token >>= maybe (throwStd $ errorToWai @'Unauthorized) pure . csUserId + +-- todo(leif): verify other claims as well +verifyOAuthAccessToken :: (Member Jwk r) => OAuthAccessToken -> (Handler r) OAuthClaimSet +verifyOAuthAccessToken token = do + fp <- view settings >>= maybe (throwStd $ errorToWai @'JwtError) pure . Opt.setOAuthJwkKeyPair + key <- lift (liftSem $ Jwk.get fp) >>= maybe (throwStd $ errorToWai @'JwtError) pure + verifiedOrError <- liftIO $ verify' key (unOAuthAccessToken token) + either (const $ throwStd $ errorToWai @'Unauthorized) pure verifiedOrError + +verify' :: JWK -> SignedJWT -> IO (Either JWTError OAuthClaimSet) +verify' k jwt = runJOSE $ do let audCheck = const True - jwt <- decodeCompact (cs s) verifyJWT (defaultJWTValidationSettings audCheck) k jwt -------------------------------------------------------------------------------- @@ -606,30 +225,3 @@ deleteOAuthAuthCode code = retry x5 . write q $ params LocalQuorum (Identity cod lookupAndDeleteOAuthAuthCode :: (MonadClient m, MonadReader Env m) => OAuthAuthCode -> m (Maybe (OAuthClientId, UserId, OAuthScopes, RedirectUrl)) lookupAndDeleteOAuthAuthCode code = lookupOAuthAuthCode code <* deleteOAuthAuthCode code - --------------------------------------------------------------------------------- --- CQL instances - -instance Cql OAuthApplicationName where - ctype = Tagged TextColumn - toCql = CqlText . fromRange . unOAuthApplicationName - fromCql (CqlText t) = checkedEither t <&> OAuthApplicationName - fromCql _ = Left "OAuthApplicationName: Text expected" - -instance Cql RedirectUrl where - ctype = Tagged BlobColumn - toCql = CqlBlob . toByteString - fromCql (CqlBlob t) = runParser parser (toStrict t) - fromCql _ = Left "RedirectUrl: Blob expected" - -instance Cql OAuthAuthCode where - ctype = Tagged AsciiColumn - toCql = CqlAscii . toText . unOAuthAuthCode - fromCql (CqlAscii t) = OAuthAuthCode <$> validateBase16 t - fromCql _ = Left "OAuthAuthCode: Ascii expected" - -instance Cql OAuthScope where - ctype = Tagged TextColumn - toCql = CqlText . cs . toByteString' - fromCql (CqlText t) = maybe (Left "invalid oauth scope") Right $ fromByteString' (cs t) - fromCql _ = Left "OAuthScope: Text expected" diff --git a/services/brig/src/Brig/API/Public.hs b/services/brig/src/Brig/API/Public.hs index 830cf25ad1..4444471909 100644 --- a/services/brig/src/Brig/API/Public.hs +++ b/services/brig/src/Brig/API/Public.hs @@ -33,7 +33,7 @@ import qualified Brig.API.Connection as API import Brig.API.Error import Brig.API.Handler import Brig.API.MLS.KeyPackages -import Brig.API.OAuth (OAuthAPI, oauthAPI) +import Brig.API.OAuth (handleZUserOrOAuth, oauthAPI) import qualified Brig.API.Properties as API import Brig.API.Public.Swagger import Brig.API.Types @@ -75,6 +75,7 @@ import qualified Cassandra as Data import Control.Error hiding (bool) import Control.Lens (view, (.~), (?~), (^.)) import Control.Monad.Catch (throwM) +import Control.Monad.Except import Data.Aeson hiding (json) import Data.Bifunctor import qualified Data.ByteString.Lazy as Lazy @@ -104,7 +105,7 @@ import Network.Wai.Routing import Network.Wai.Utilities as Utilities import Network.Wai.Utilities.Swagger (mkSwaggerApi) import Polysemy -import Servant hiding (Handler, JSON, addHeader, respond) +import Servant hiding (Handler, JSON, Unauthorized, addHeader, respond) import qualified Servant import Servant.Swagger.Internal.Orphans () import Servant.Swagger.UI @@ -113,6 +114,7 @@ import Util.Logging (logFunction, logHandle, logTeam, logUser) import qualified Wire.API.Connection as Public import Wire.API.Error import qualified Wire.API.Error.Brig as E +import Wire.API.OAuth import qualified Wire.API.Properties as Public import qualified Wire.API.Routes.MultiTablePaging as Public import Wire.API.Routes.Named (Named (Named)) @@ -217,7 +219,7 @@ servantSitemap = brigAPI :<|> oauthAPI selfAPI :: ServerT SelfAPI (Handler r) selfAPI = - Named @"get-self" getSelf + Named @"get-self" (\muid mbearer -> handleZUserOrOAuth muid mbearer >>= getSelf) :<|> Named @"delete-self" deleteSelfUser :<|> Named @"put-self" updateUser :<|> Named @"change-phone" changePhone diff --git a/services/brig/src/Brig/Run.hs b/services/brig/src/Brig/Run.hs index 6656a22816..caa00a4f8f 100644 --- a/services/brig/src/Brig/Run.hs +++ b/services/brig/src/Brig/Run.hs @@ -28,7 +28,6 @@ import Brig.API (sitemap) import Brig.API.Federation import Brig.API.Handler import qualified Brig.API.Internal as IAPI -import Brig.API.OAuth (IOAuthAPI, OAuthAPI) import Brig.API.Public (DocsAPI, docsAPI, servantSitemap) import qualified Brig.API.User as API import Brig.AWS (amazonkaEnv, sesQueue) @@ -74,6 +73,7 @@ import qualified Servant import System.Logger (msg, val, (.=), (~~)) import System.Logger.Class (MonadLogger, err) import Util.Options +import Wire.API.OAuth (IOAuthAPI, OAuthAPI) import Wire.API.Routes.API import Wire.API.Routes.Public.Brig import Wire.API.Routes.Version diff --git a/services/brig/test/integration/API/OAuth.hs b/services/brig/test/integration/API/OAuth.hs index aa6f3f028a..e604be87e1 100644 --- a/services/brig/test/integration/API/OAuth.hs +++ b/services/brig/test/integration/API/OAuth.hs @@ -38,11 +38,15 @@ import Data.Text.Ascii (encodeBase16) import Data.Time import Imports import qualified Network.Wai.Utilities as Error +import Servant.API (ToHttpApiData (toHeader)) import Test.Tasty import Test.Tasty.HUnit import URI.ByteString import Util import Web.FormUrlEncoded +import Wire.API.OAuth +import Wire.API.Routes.Bearer (Bearer (Bearer)) +import Wire.API.User tests :: Manager -> Brig -> Opts -> TestTree tests m b o = do @@ -125,7 +129,8 @@ testCreateOAuthCodeClientNotFound brig = do testCreateAccessTokenSuccess :: Opt.Opts -> Brig -> Http () testCreateAccessTokenSuccess opts brig = do now <- liftIO getCurrentTime - uid <- randomId + user <- createUser "alice" brig + let uid = userId user let redirectUrl = fromMaybe (error "invalid url") $ fromByteString' "https://example.com" let scopes = OAuthScopes $ Set.fromList [ConversationCreate, ConversationCodeCreate] (cid, secret, code) <- generateOAuthClientAndAuthCode brig uid scopes redirectUrl @@ -136,8 +141,8 @@ testCreateAccessTokenSuccess opts brig = do const 404 === statusCode const (Just "not-found") === fmap Error.label . responseJsonMaybe k <- liftIO $ readJwk (fromMaybe "" (Opt.setOAuthJwkKeyPair $ Opt.optSettings opts)) <&> fromMaybe (error "invalid key") - verifiedOrError <- liftIO $ verify k (cs $ unOauthAccessToken $ oatAccessToken accessToken) - verifiedOrErrorWithWrongKey <- liftIO $ verify wrongKey (cs $ unOauthAccessToken $ oatAccessToken accessToken) + verifiedOrError <- liftIO $ verify' k (unOAuthAccessToken $ oatAccessToken accessToken) + verifiedOrErrorWithWrongKey <- liftIO $ verify' wrongKey (unOAuthAccessToken $ oatAccessToken accessToken) let expectedDomain = domainText $ Opt.setFederationDomain $ Opt.optSettings opts liftIO $ do isRight verifiedOrError @?= True @@ -151,6 +156,13 @@ testCreateAccessTokenSuccess opts brig = do diffUTCTime expTime now > 0 @?= True let issuingTime = (\(NumericDate x) -> x) . fromMaybe (error "iat claim missing") . view claimIat $ claims abs (diffUTCTime issuingTime now) < 5 @?= True -- allow for some generous clock skew + get (brig . paths ["self"]) !!! const 401 === statusCode + response :: SelfProfile <- responseJsonError =<< get (brig . paths ["self"] . zUser uid) + response' :: SelfProfile <- responseJsonError =<< get (brig . paths ["self"] . oauth (oatAccessToken accessToken)) + liftIO $ response @?= response' + +oauth :: OAuthAccessToken -> Request -> Request +oauth = header "Authorization" . toHeader . Bearer testCreateAccessTokenWrongClientId :: Brig -> Http () testCreateAccessTokenWrongClientId brig = do