I've been trying to get better at Haskell for a while, and have recently been working on a lot of small projects with it. This constructs a binary decision tree.
The command to run it is:
stack exec decision-tree-exe <threshold> <training file> <testing file>
where threshold is in the range (0,1].
I think I've gotten a lot better, but I'm still having problems, especially with performance and readability. For this project, I took a more top down approach, implementing functions after using them. src/DecisionTree.hs is where the bulk of the logic is, and the file is pretty much in order of writing. I would love to get some feedback from some more experienced people on where I might improve.
module DecisionTree where
import Data.List (genericLength, maximumBy, nub)
import Data.Map (elemAt, foldlWithKey', fromListWith)
import Data.Ord
data DecisionTree a b
= Node ([a] -> Bool) (DecisionTree a b) (DecisionTree a b)
| Leaf b
type Dataset cat attrs = [(cat, [attrs])]
type Threshold = Double
type Splitter c a = ([a] -> Bool, Dataset c a, Dataset c a)
apply :: DecisionTree a b -> [a] -> b
apply (Leaf b) _ = b
apply (Node f l r) a =
case f a of
False -> apply l a
True -> apply r a
train ::
(Ord c)
=> (Dataset c a -> Maybe (Splitter c a))
-> Dataset c a
-> DecisionTree a c
train splitter dataset =
case splitter dataset of
Just (partitioner, left, right) ->
Node partitioner (train splitter left) (train splitter right)
Nothing -> Leaf majority
where
classCounts = fromListWith (+) $ map (\k -> (fst k, 1)) dataset
majority = fst $ foldlWithKey' max (elemAt 0 classCounts) classCounts
max acc k v
| v > snd acc = (k, v)
| otherwise = acc
giniSplitter ::
(Ord a, Ord c) => Threshold -> Dataset c a -> Maybe (Splitter c a)
giniSplitter threshold dataset =
case fst maxDelta > threshold of
True -> Just $ snd maxDelta
False -> Nothing
where
attrs = nub . concat . snd . unzip $ dataset
partitioner a = (a `elem`)
delta a = giniDelta (partitioner a) dataset
maxDelta = maximumBy (comparing fst) $ map delta attrs
giniDelta :: (Eq c) => ([a] -> Bool) -> Dataset c a -> (Double, Splitter c a)
giniDelta partitioner dataset =
( gini dataset - (d1 / d * gini left + d2 / d * gini right)
, (partitioner, left, right))
where
left = filter (not . partitioner . snd) dataset
right = filter (partitioner . snd) dataset
d1 = genericLength left
d2 = genericLength right
d = genericLength dataset
gini :: (Eq c) => Dataset c a -> Double
gini d = 1 - sum [(pj c) ** 2 | c <- nub . fst . unzip $ d]
where
pj c = genericLength (filter ((== c) . fst) d) / genericLength d