Skip to content

Commit

Permalink
Replace overlappable HasServer instance with a tyfam
Browse files Browse the repository at this point in the history
  • Loading branch information
theophile-scrive committed Jul 11, 2024
1 parent a7f18b4 commit 4636a97
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 21 deletions.
2 changes: 1 addition & 1 deletion cabal.project
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ optimization: False

-- Development flags
package *
ghc-options: -fshow-hole-constraints -fhide-source-paths
ghc-options: -fshow-hole-constraints -fhide-source-paths -fprint-potential-instances

-- reorder-goals: True

Expand Down
1 change: 1 addition & 0 deletions servant-client/test/Servant/ClientTestUtils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
{-# OPTIONS_GHC -freduction-depth=100 #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# OPTIONS_GHC -fno-warn-name-shadowing #-}
{-# OPTIONS_GHC -fprint-potential-instances #-}

module Servant.ClientTestUtils where

Expand Down
29 changes: 9 additions & 20 deletions servant-server/src/Servant/Server/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ import Servant.API.Modifiers
unfoldRequestArgument)
import Servant.API.QueryString (FromDeepQuery(..))
import Servant.API.ResponseHeaders
(GetHeaders, Headers, getHeaders, getResponse)
(GetHeaders, Headers, getHeaders, getResponse, Wrap, ExtractHeadersResponse, ExtractedValue, HandlerResponse, extractHeadersResponse)
import Servant.API.Status
(statusFromNat)
import qualified Servant.Types.SourceT as S
Expand All @@ -93,7 +93,7 @@ import Servant.Server.Internal.ServerError

import Servant.API.TypeLevel (AtMostOneFragment, FragmentUnique)
import Servant.API.MultiVerb (MultiVerb, Respond)
import Network.HTTP.Types (Header)
import Network.HTTP.Types (Header)

class HasServer api context where
-- | The type of a server for this API, given a monad to run effects in.
Expand Down Expand Up @@ -318,27 +318,16 @@ noContentRouter method status action = leafRouter route'
env request respond $ \ _output ->
Route $ responseLBS status [] ""

instance {-# OVERLAPPABLE #-}
( AllCTRender contentTypes a, ReflectMethod method, KnownNat status
) => HasServer (MultiVerb method contentTypes '[Respond status (desc :: Symbol) a] a) context where

type ServerT (MultiVerb method contentTypes '[Respond status (desc :: Symbol) a] a) m = m a
hoistServerWithContext _ _ nt s = nt s

route Proxy _ = methodRouter ([],) method (Proxy :: Proxy contentTypes) status
where method = reflectMethod (Proxy :: Proxy method)
status = statusFromNat (Proxy :: Proxy status)

instance {-# OVERLAPPING #-}
( AllCTRender contentTypes a, ReflectMethod method, KnownNat status
, GetHeaders (Headers h a)
) => HasServer (MultiVerb method contentTypes '[Respond status (desc :: Symbol) (Headers h a)] (Headers h a)) context where

type ServerT (MultiVerb method contentTypes '[Respond status (desc :: Symbol) (Headers h a)] (Headers h a)) m = m (Headers h a)
instance ( AllCTRender ctypes (ExtractedValue a (Wrap a))
, ReflectMethod method, KnownNat status
, ExtractHeadersResponse a (Wrap a)
, a ~ HandlerResponse a (Wrap a)
) => HasServer (MultiVerb method ctypes '[Respond status (desc :: Symbol) a]) context where

type ServerT (MultiVerb method ctypes '[Respond status (desc :: Symbol) a]) m = m a
hoistServerWithContext _ _ nt s = nt s

route Proxy _ = methodRouter (\x -> (getHeaders x, getResponse x)) method (Proxy :: Proxy contentTypes) status
route Proxy _ = methodRouter (extractHeadersResponse @a @(Wrap a)) method (Proxy :: Proxy ctypes) status
where method = reflectMethod (Proxy :: Proxy method)
status = statusFromNat (Proxy :: Proxy status)

Expand Down
30 changes: 30 additions & 0 deletions servant/src/Servant/API/ResponseHeaders.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ module Servant.API.ResponseHeaders
, GetHeaders'
, HeaderValMap
, HList(..)
, Wrap
, ExtractHeadersResponse(..)
) where

import Control.DeepSeq
Expand All @@ -48,6 +50,8 @@ import Servant.API.Modifiers
import Servant.API.UVerb.Union
import qualified Data.SOP.BasicFunctors as SOP
import qualified Data.SOP.NS as SOP
import qualified Data.ByteString as B
import Network.HTTP.Types (HeaderName)

-- | Response Header objects. You should never need to construct one directly.
-- Instead, use 'addOptionalHeader'.
Expand Down Expand Up @@ -257,6 +261,32 @@ lookupResponseHeader :: (HasResponseHeader h a headers)
=> Headers headers r -> ResponseHeader h a
lookupResponseHeader = hlistLookupHeader . getHeadersHList

newtype Naked a = Naked a

type family Wrap a where
Wrap (Headers x a) = Headers x a
Wrap a = Naked a

class ExtractHeadersResponse orig wrapped where
type HandlerResponse orig wrapped :: Type
type ExtractedValue orig wrapped :: Type

extractHeadersResponse :: HandlerResponse orig wrapped -> ([(HeaderName, B.ByteString)], ExtractedValue orig wrapped)

instance ExtractHeadersResponse a (Naked a) where
type HandlerResponse a (Naked a) = a
type ExtractedValue a (Naked a) = a

extractHeadersResponse :: a -> ([(HeaderName, B.ByteString)], a)
extractHeadersResponse x = ([], x)

instance GetHeaders (Headers x a) => ExtractHeadersResponse (Headers x a) (Headers x a) where
type HandlerResponse (Headers x a) (Headers x a) = Headers x a
type ExtractedValue (Headers x a) (Headers x a) = a

extractHeadersResponse :: Headers x a -> ([(HeaderName, B.ByteString)], a)
extractHeadersResponse x = (getHeaders x, getResponse x)

-- $setup
-- >>> :set -XFlexibleContexts
-- >>> import Servant.API
Expand Down

0 comments on commit 4636a97

Please sign in to comment.