Skip to content

Commit fa102a8

Browse files
committed
Support deriving with mutual datatypes
1 parent 0a5b086 commit fa102a8

File tree

3 files changed

+28
-20
lines changed

3 files changed

+28
-20
lines changed

examples/deriv/Deriv.hs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ extendA "A" [] [t|T|] $ defaultExtA {
77
typeA = Ann $ \a -> [t| [$a] |]
88
}
99

10+
extendB "B" [] [t|T|] $ defaultExtB
11+
1012
main = print $
11-
A "" 5 ["a", "b"] ==
12-
(A "" 5 ["a", "b", "c"] :: A String)
13+
(A (B "") 5 ["a", "b"]) ==
14+
(A (B "") 5 ["a", "b", "c"] :: A String)
1315
-- annotation needed until pattern synonyms have
1416
-- type signatures :/

examples/deriv/DerivBase.hs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,7 @@ module DerivBase where
22

33
import Extensible
44

5-
extensible [d| data A a = A a Int deriving Eq |]
5+
extensible [d|
6+
data A a = A (B a) Int deriving Eq
7+
data B a = B a deriving Eq
8+
|]

src/Extensible.hs

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,9 @@ makeExtensible1 conf home nameMap (SimpleData name tvs cs derivs) = do
404404
let cx = extensionCon conf name ext tvs
405405
efs <- traverse (extendFam conf tvs) cs
406406
efx <- extensionFam conf name tvs
407-
(bname, bnd) <- constraintBundle conf name ext tvs cs
408-
insts <- fmap concat $ traverse (makeInstances name' bname ext tvs) derivs
407+
bnd <- constraintBundle conf name ext tvs cs
408+
insts <- fmap concat $
409+
traverse (makeInstances conf name' (map fst nameMap) ext tvs) derivs
409410
(rname, fcnames, fname, rec) <- extRecord conf name tvs cs
410411
(_dname, defRec) <- extRecDefault conf rname fcnames fname
411412
(_ename, extFun) <- makeExtender conf home name rname tvs cs
@@ -461,39 +462,41 @@ extensionFam conf name tvs =
461462
constraintBundle :: Config
462463
-> Name -- ^ datatype name
463464
-> Name -- ^ extension type variable name
464-
-> [TyVarBndr] -> [SimpleCon] -> Q (Name, Dec)
465+
-> [TyVarBndr] -> [SimpleCon] -> DecQ
465466
constraintBundle conf name ext tvs cs = do
466467
c <- newName "c"
467468
ckind <- [t|K.Type -> Constraint|]
468469
let cnames = map scName cs
469-
aname = applyAffix (bundleName conf) name
470+
bname = applyAffix (bundleName conf) name
470471
tvs' = kindedTV c ckind : plainTV ext : tvs
471472
con1 n = varT c `appT`
472473
foldl appT (conT n) (varT ext : map (varT . tyvarName) tvs)
473474
tupled ts = foldl appT (tupleT (length ts)) ts
474-
d <- tySynD aname tvs' $ tupled $ map con1 $
475+
tySynD bname tvs' $ tupled $ map con1 $
475476
map (applyAffix $ annotationName conf) cnames ++
476477
[applyAffix (extensionName conf) name]
477-
pure (aname, d)
478478

479-
makeInstances :: Name -- ^ name of the __output__ datatype
480-
-> Name -- ^ name of the constraint bundle
481-
-> Name -- ^ extension type variable name
479+
makeInstances :: Config
480+
-> Name -- ^ name of the __output__ datatype
481+
-> [Name] -- ^ names of all datatypes in this group
482+
-> Name -- ^ extension type variable name
482483
-> [TyVarBndr]
483484
-> SimpleDeriv
484485
-> DecsQ
485-
makeInstances name bname ext tvs (SimpleDeriv strat prds) =
486+
makeInstances conf name names ext tvs (SimpleDeriv strat prds) =
486487
pure $ map make1 prds
487488
where
488-
make1 :: Pred -> Dec
489489
make1 prd = StandaloneDerivD strat'
490-
(map (AppT prd . VarT . tyvarName) tvs
491-
++ [appExtTvs (ConT bname `AppT` prd) ext tvs])
490+
(map tvPred tvs ++ map allPred names)
492491
(prd `AppT` appExtTvs (ConT name) ext tvs)
493-
strat' = case strat of
494-
SBlank -> Nothing
495-
SStock -> Just StockStrategy
496-
SAnyclass -> Just AnyclassStrategy
492+
where
493+
tvPred = AppT prd . VarT . tyvarName
494+
allPred name' = appExtTvs (ConT bname `AppT` prd) ext tvs
495+
where bname = applyAffix (bundleName conf) name'
496+
strat' = case strat of
497+
SBlank -> Nothing
498+
SStock -> Just StockStrategy
499+
SAnyclass -> Just AnyclassStrategy
497500

498501
extendFam' :: Name -> [TyVarBndr] -> DecQ
499502
extendFam' name tvs = do

0 commit comments

Comments
 (0)