Skip to content

Commit c41b9f8

Browse files
committed
Added generic FromField and ToField classes
1 parent a8f6a90 commit c41b9f8

4 files changed

Lines changed: 197 additions & 40 deletions

File tree

postgresql-simple.cabal

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ Library
7474

7575
if !impl(ghc >= 7.6)
7676
Build-depends:
77-
ghc-prim
77+
ghc-prim,
78+
tagged >= 0.8
7879

7980
extensions: DoAndIfThenElse, OverloadedStrings, BangPatterns, ViewPatterns
8081
TypeOperators

src/Database/PostgreSQL/Simple/FromField.hs

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
{-# LANGUAGE FlexibleInstances, TypeSynonymInstances #-}
33
{-# LANGUAGE PatternGuards, ScopedTypeVariables #-}
44
{-# LANGUAGE RecordWildCards, TemplateHaskell #-}
5+
{-# LANGUAGE MultiWayIf, DefaultSignatures #-}
6+
{-# LANGUAGE FlexibleContexts #-}
57

68
{- |
79
Module: Database.PostgreSQL.Simple.FromField
@@ -83,6 +85,7 @@ instances use 'typename' instead.
8385
module Database.PostgreSQL.Simple.FromField
8486
(
8587
FromField(..)
88+
, genericFromField
8689
, FieldParser
8790
, Conversion()
8891

@@ -113,16 +116,19 @@ module Database.PostgreSQL.Simple.FromField
113116

114117
#include "MachDeps.h"
115118

116-
import Control.Applicative ( (<|>), (<$>), pure, (*>), (<*) )
119+
import Control.Applicative ( Alternative(..), (<|>), (<$>), pure, (*>), (<*), liftA2 )
117120
import Control.Concurrent.MVar (MVar, newMVar)
118121
import Control.Exception (Exception)
119122
import qualified Data.Aeson as JSON
120123
import qualified Data.Aeson.Parser as JSON (value')
121124
import Data.Attoparsec.ByteString.Char8 hiding (Result)
122125
import Data.ByteString (ByteString)
126+
import Data.ByteString.Builder (Builder, toLazyByteString, byteString)
123127
import qualified Data.ByteString.Char8 as B
128+
import Data.Char (toLower)
124129
import Data.Int (Int16, Int32, Int64)
125130
import Data.IORef (IORef, newIORef)
131+
import Data.Proxy (Proxy(..))
126132
import Data.Ratio (Ratio)
127133
import Data.Time ( UTCTime, ZonedTime, LocalTime, Day, TimeOfDay )
128134
import Data.Typeable (Typeable, typeOf)
@@ -150,6 +156,7 @@ import qualified Data.CaseInsensitive as CI
150156
import Data.UUID.Types (UUID)
151157
import qualified Data.UUID.Types as UUID
152158
import Data.Scientific (Scientific)
159+
import GHC.Generics (Generic, Rep, M1(..), K1(..), D1, C1, S1, Rec0, Constructor, (:*:)(..), to, conName)
153160
import GHC.Real (infinity, notANumber)
154161

155162
-- | Exception thrown if conversion from a SQL value to a Haskell
@@ -188,6 +195,8 @@ type FieldParser a = Field -> Maybe ByteString -> Conversion a
188195
-- | A type that may be converted from a SQL type.
189196
class FromField a where
190197
fromField :: FieldParser a
198+
default fromField :: (Generic a, Typeable a, GFromField (Rep a)) => FieldParser a
199+
fromField = genericFromField (map toLower)
191200
-- ^ Convert a SQL value to a Haskell value.
192201
--
193202
-- Returns a list of exceptions if the conversion fails. In the case of
@@ -292,7 +301,8 @@ instance FromField Null where
292301
-- | bool
293302
instance FromField Bool where
294303
fromField f bs
295-
| typeOid f /= $(inlineTypoid TI.bool) = returnError Incompatible f ""
304+
| typeOid f /= $(inlineTypoid TI.bool)
305+
&& typeOid f /= $(inlineTypoid TI.unknown) = returnError Incompatible f ""
296306
| bs == Nothing = returnError UnexpectedNull f ""
297307
| bs == Just "t" = pure True
298308
| bs == Just "f" = pure False
@@ -404,9 +414,9 @@ instance FromField (Binary SB.ByteString) where
404414
instance FromField (Binary LB.ByteString) where
405415
fromField f dat = Binary . LB.fromChunks . (:[]) . unBinary <$> fromField f dat
406416

407-
-- | name, text, \"char\", bpchar, varchar
417+
-- | name, text, \"char\", bpchar, varchar, unknown
408418
instance FromField ST.Text where
409-
fromField f = doFromField f okText $ (either left pure . ST.decodeUtf8')
419+
fromField f = doFromField f okText' $ (either left pure . ST.decodeUtf8')
410420
-- FIXME: check character encoding
411421

412422
-- | name, text, \"char\", bpchar, varchar
@@ -645,10 +655,93 @@ returnError mkErr f msg = do
645655
atto :: forall a. (Typeable a)
646656
=> Compat -> Parser a -> Field -> Maybe ByteString
647657
-> Conversion a
648-
atto types p0 f dat = doFromField f types (go p0) dat
658+
atto types p0 f dat = doFromField f (\t -> types t || (t == $(inlineTypoid TI.unknown))) (go p0) dat
649659
where
650660
go :: Parser a -> ByteString -> Conversion a
651661
go p s =
652662
case parseOnly p s of
653663
Left err -> returnError ConversionFailed f err
654664
Right v -> pure v
665+
666+
667+
-- | Type class for default implementation of FromField using generics.
668+
class GFromField f where
669+
gfromField :: (Typeable p)
670+
=> Proxy p
671+
-> (String -> String)
672+
-> Field
673+
-> [Maybe ByteString]
674+
-> Conversion (f p)
675+
676+
instance (GFromField f) => GFromField (D1 i f) where
677+
gfromField w t f v = M1 <$> gfromField w t f v
678+
679+
instance (GFromField f, Typeable f, Constructor i) => GFromField (C1 i f) where
680+
gfromField w t f (v:[]) = let
681+
tname = B8.pack . t . conName $ (undefined::(C1 i f t))
682+
tcheck = (\t -> t /= "record" && t /= tname)
683+
in tcheck <$> typename f >>= \b -> M1 <$> case b of
684+
True -> returnError Incompatible f ""
685+
False -> maybe
686+
(returnError UnexpectedNull f "")
687+
(either
688+
(returnError ConversionFailed f)
689+
(gfromField w t f)
690+
. (parseOnly record)) v
691+
gfromField _ _ f _ = M1 <$> returnError ConversionFailed f errUnexpectedArgs
692+
693+
instance (GFromField f, Typeable f, GFromField g) => GFromField (f :*: g) where
694+
gfromField _ _ f [] = liftA2 (:*:) (returnError ConversionFailed f errTooFewValues) empty
695+
gfromField w t f (v:vs) = liftA2 (:*:) (gfromField w t f [v]) (gfromField w t f vs)
696+
697+
instance (GFromField f, Typeable f) => GFromField (S1 i f) where
698+
gfromField _ _ f [] = M1 <$> returnError ConversionFailed f errTooFewValues
699+
gfromField w t f (v:[]) = M1 <$> gfromField w t f [v]
700+
gfromField _ _ f _ = M1 <$> returnError ConversionFailed f errTooManyValues
701+
702+
instance (FromField f, Typeable f) => GFromField (Rec0 f) where
703+
gfromField _ _ f [v] = K1 <$> fromField (f {typeOid = typoid TI.unknown}) v
704+
gfromField _ _ f _ = K1 <$> returnError ConversionFailed f errUnexpectedArgs
705+
706+
707+
-- | Common error messages for GFromField instances.
708+
errTooFewValues, errTooManyValues, errUnexpectedArgs :: String
709+
errTooFewValues = "too few values"
710+
errTooManyValues = "too many values"
711+
errUnexpectedArgs = "unexpected arguments"
712+
713+
-- | Parser of a postgresql record.
714+
record :: Parser [Maybe ByteString]
715+
record = (char '(') *> (recordField `sepBy` (char ',')) <* (char ')')
716+
717+
-- | Parser of a postgresql record's field.
718+
recordField :: Parser (Maybe ByteString)
719+
recordField = (Just <$> quotedString) <|> (Just <$> unquotedString) <|> (pure Nothing) where
720+
quotedString = unescape <$> (char '"' *> scan False updateState) where
721+
updateState isBalanced c = if
722+
| c == '"' -> Just . not $ isBalanced
723+
| not isBalanced -> Just False
724+
| c == ',' || c == ')' -> Nothing
725+
| otherwise -> fail $ "unexpected symbol: " ++ [c]
726+
727+
unescape = unescape' '\\' . unescape' '"' . B8.init where
728+
unescape' c = halve c (byteString SB.empty) . groupByChar c
729+
730+
groupByChar c = B8.groupBy $ \a b -> (a == c) == (b == c)
731+
732+
halve :: Char -> Builder -> [ByteString] -> ByteString
733+
halve _ b [] = LB.toStrict . toLazyByteString $ b
734+
halve c b (s:ss) = halve c (b <> b') ss where
735+
b' = if
736+
| (/= c) . B8.head $ s -> byteString s
737+
| otherwise -> byteString . SB.take ((SB.length s) `div` 2) $ s
738+
739+
unquotedString = takeWhile1 (\c -> c /= ',' && c /= ')')
740+
741+
-- | Function that creates fromField for a given type.
742+
genericFromField :: forall a. (Generic a, Typeable a, GFromField (Rep a))
743+
=> (String -> String) -- ^ How to transform constructor's name to match
744+
-- postgresql type's name.
745+
-> FieldParser a
746+
genericFromField t f v = (to <$> (gfromField (Proxy :: Proxy a) t f [v]))
747+

src/Database/PostgreSQL/Simple/ToField.hs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{-# LANGUAGE CPP, DeriveDataTypeable, DeriveFunctor #-}
22
{-# LANGUAGE FlexibleInstances, TypeSynonymInstances #-}
3+
{-# LANGUAGE DefaultSignatures, FlexibleContexts #-}
34

45
------------------------------------------------------------------------------
56
-- |
@@ -39,6 +40,7 @@ import Data.Word (Word, Word8, Word16, Word32, Word64)
3940
import {-# SOURCE #-} Database.PostgreSQL.Simple.ToRow
4041
import Database.PostgreSQL.Simple.Types
4142
import Database.PostgreSQL.Simple.Compat (toByteString)
43+
import GHC.Generics (Generic, Rep, D1, C1, S1, (:*:)(..), Rec0, from, unM1, unK1)
4244

4345
import qualified Data.ByteString as SB
4446
import qualified Data.ByteString.Lazy as LB
@@ -92,6 +94,8 @@ instance Show Action where
9294
-- | A type that may be used as a single parameter to a SQL query.
9395
class ToField a where
9496
toField :: a -> Action
97+
default toField :: (Generic a, GToField (Rep a)) => a -> Action
98+
toField = head . gtoField . from
9599
-- ^ Prepare a value for substitution into a query string.
96100

97101
instance ToField Action where
@@ -369,3 +373,26 @@ instance ToRow a => ToField (Values a) where
369373
(litC ',')
370374
rest
371375
vals
376+
377+
-- Type class for default implementation of ToField using generics.
378+
class GToField f where
379+
gtoField :: f p -> [Action]
380+
381+
instance GToField f => GToField (D1 i f) where
382+
gtoField = gtoField . unM1
383+
384+
instance GToField f => GToField (C1 i f) where
385+
gtoField = (:[]) . Many . tupleWrap . gtoField . unM1
386+
387+
instance (GToField f, GToField g) => GToField (f :*: g) where
388+
gtoField (f :*: g) = gtoField f ++ gtoField g
389+
390+
instance (GToField f) => GToField (S1 i f) where
391+
gtoField = gtoField . unM1
392+
393+
instance (ToField f) => GToField (Rec0 f) where
394+
gtoField = (:[]) . toField . unK1
395+
396+
tupleWrap :: [Action] -> [Action]
397+
tupleWrap xs = (Plain "("): (intersperse (Plain ",") xs) ++ [Plain ")"]
398+

test/Main.hs

Lines changed: 70 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
{-# LANGUAGE DeriveDataTypeable #-}
44
{-# LANGUAGE DoAndIfThenElse #-}
55
{-# LANGUAGE ScopedTypeVariables #-}
6+
{-# LANGUAGE QuasiQuotes #-}
7+
68
import Common
79
import Database.PostgreSQL.Simple.FromField (FromField)
8-
import Database.PostgreSQL.Simple.Types(Query(..),Values(..))
10+
import Database.PostgreSQL.Simple.ToField (ToField)
11+
import Database.PostgreSQL.Simple.Types (Query(..), Values(..))
912
import Database.PostgreSQL.Simple.HStore
1013
import Database.PostgreSQL.Simple.Copy
14+
import Database.PostgreSQL.Simple.SqlQQ (sql)
1115
import qualified Database.PostgreSQL.Simple.Transaction as ST
1216

1317
import Control.Applicative
@@ -42,25 +46,28 @@ tests :: TestEnv -> TestTree
4246
tests env = testGroup "tests"
4347
$ map ($ env)
4448
[ testBytea
45-
, testCase "ExecuteMany" . testExecuteMany
46-
, testCase "Fold" . testFold
47-
, testCase "Notify" . testNotify
48-
, testCase "Serializable" . testSerializable
49-
, testCase "Time" . testTime
50-
, testCase "Array" . testArray
51-
, testCase "Array of nullables" . testNullableArray
52-
, testCase "HStore" . testHStore
53-
, testCase "JSON" . testJSON
54-
, testCase "Savepoint" . testSavepoint
55-
, testCase "Unicode" . testUnicode
56-
, testCase "Values" . testValues
57-
, testCase "Copy" . testCopy
49+
, testCase "ExecuteMany" . testExecuteMany
50+
, testCase "Fold" . testFold
51+
, testCase "Notify" . testNotify
52+
, testCase "Serializable" . testSerializable
53+
, testCase "Time" . testTime
54+
, testCase "Array" . testArray
55+
, testCase "Array of nullables" . testNullableArray
56+
, testCase "HStore" . testHStore
57+
, testCase "JSON" . testJSON
58+
, testCase "Savepoint" . testSavepoint
59+
, testCase "Unicode" . testUnicode
60+
, testCase "Values" . testValues
61+
, testCase "Copy" . testCopy
5862
, testCopyFailures
59-
, testCase "Double" . testDouble
60-
, testCase "1-ary generic" . testGeneric1
61-
, testCase "2-ary generic" . testGeneric2
62-
, testCase "3-ary generic" . testGeneric3
63-
, testCase "Timeout" . testTimeout
63+
, testCase "Double" . testDouble
64+
, testCase "1-ary generic row" . testGeneric1Row
65+
, testCase "2-ary generic row" . testGeneric2Row
66+
, testCase "3-ary generic row" . testGeneric3Row
67+
, testCase "1-ary generic field" . testGeneric1Field
68+
, testCase "2-ary generic field" . testGeneric2Field
69+
, testCase "3-ary generic field" . testGeneric3Field
70+
, testCase "Timeout" . testTimeout
6471
]
6572

6673
testBytea :: TestEnv -> TestTree
@@ -406,44 +413,73 @@ testDouble TestEnv{..} = do
406413
x @?= (-1 / 0)
407414

408415

409-
testGeneric1 :: TestEnv -> Assertion
410-
testGeneric1 TestEnv{..} = do
416+
testGeneric1Row :: TestEnv -> Assertion
417+
testGeneric1Row TestEnv{..} = do
411418
roundTrip conn (Gen1 123)
412419
where
413420
roundTrip conn x0 = do
414421
r <- query conn "SELECT ?::int" (x0 :: Gen1)
415422
r @?= [x0]
416423

417-
testGeneric2 :: TestEnv -> Assertion
418-
testGeneric2 TestEnv{..} = do
424+
testGeneric2Row :: TestEnv -> Assertion
425+
testGeneric2Row TestEnv{..} = do
419426
roundTrip conn (Gen2 123 "asdf")
420427
where
421428
roundTrip conn x0 = do
422429
r <- query conn "SELECT ?::int, ?::text" x0
423430
r @?= [x0]
424431

425-
testGeneric3 :: TestEnv -> Assertion
426-
testGeneric3 TestEnv{..} = do
432+
testGeneric3Row :: TestEnv -> Assertion
433+
testGeneric3Row TestEnv{..} = do
427434
roundTrip conn (Gen3 123 "asdf" True)
428435
where
429436
roundTrip conn x0 = do
430437
r <- query conn "SELECT ?::int, ?::text, ?::bool" x0
431438
r @?= [x0]
432439

440+
testGeneric1Field :: TestEnv -> Assertion
441+
testGeneric1Field TestEnv{..} = withTransaction conn $ do
442+
-- It's not possible to simply roundtrip a 1-ary tuple
443+
-- as PostgreSQL will treat it as a scalar value.
444+
-- Therefore we will create a separate type for it.
445+
execute_ conn "CREATE TYPE gen1 AS (x bigint)"
446+
execute_ conn [sql|
447+
CREATE FUNCTION test_gen1() RETURNS SETOF gen1 AS $$
448+
(SELECT 1::bigint) UNION ALL (SELECT 2) UNION ALL (SELECT 3)
449+
$$ LANGUAGE sql
450+
|]
451+
query_ conn "SELECT test_gen1()" >>= (@?= [Only (Gen1 1), Only (Gen1 2), Only (Gen1 3)])
452+
rollback conn
453+
454+
testGeneric2Field :: TestEnv -> Assertion
455+
testGeneric2Field TestEnv{..} = roundTripField conn (Gen2 123 "asdf")
456+
457+
testGeneric3Field :: TestEnv -> Assertion
458+
testGeneric3Field TestEnv{..} = roundTripField conn (Gen3 123 "asdf" True)
459+
460+
roundTripField :: (Show a, Eq a, FromField a, ToField a) => Connection -> a -> Assertion
461+
roundTripField conn x0 = query conn "SELECT ?" (Only x0) >>= (@?= [Only x0])
462+
433463
data Gen1 = Gen1 Int
434-
deriving (Show,Eq,Generic)
435-
instance FromRow Gen1
436-
instance ToRow Gen1
464+
deriving (Show, Eq, Generic, Typeable)
465+
instance FromRow Gen1
466+
instance ToRow Gen1
467+
instance FromField Gen1
468+
instance ToField Gen1
437469

438470
data Gen2 = Gen2 Int Text
439-
deriving (Show,Eq,Generic)
440-
instance FromRow Gen2
441-
instance ToRow Gen2
471+
deriving (Show, Eq, Generic, Typeable)
472+
instance FromRow Gen2
473+
instance ToRow Gen2
474+
instance FromField Gen2
475+
instance ToField Gen2
442476

443477
data Gen3 = Gen3 Int Text Bool
444-
deriving (Show,Eq,Generic)
445-
instance FromRow Gen3
446-
instance ToRow Gen3
478+
deriving (Show, Eq, Generic, Typeable)
479+
instance FromRow Gen3
480+
instance ToRow Gen3
481+
instance FromField Gen3
482+
instance ToField Gen3
447483

448484
data TestException
449485
= TestException

0 commit comments

Comments
 (0)