{-# LANGUAGE CPP #-}

{-|
Module      : System.Linux.Netlink
Description : The base module for the netlink package
Maintainer  : ongy
Stability   : testing
Portability : Linux

This is the base module for the netlink package.
It contains functions and datatype used by every netlink module.
All definitions are (supposed to be) generic enough to be used
by implementations of more specific netlink interfaces.
-}
module System.Linux.Netlink
(   Header(..)
  , Attributes
  , Packet(..)
  , Convertable(..)
  , NoData(..)
  , NetlinkSocket

  , getPacket
  , getAttributes
  , getHeader
  , putHeader
  , putAttributes
  , putPacket
  , getPackets

  , makeSocket
  , makeSocketGeneric
  , getNetlinkFd
  , closeSocket
  , joinMulticastGroup
  , leaveMulticastGroup

  , query
  , queryOne
  , recvOne
  , showNLAttrs
  , showAttrs
  , showAttr
  , showPacket
)
where

#if MIN_VERSION_base(4,8,0)
#else
import Control.Applicative ((<$>))
#endif


import Data.List (intersperse)
import Hexdump (prettyHex)
import Control.Monad (when, replicateM_, unless)
import Control.Monad.Loops (whileM)
import Data.Bits (Bits, (.&.))
import qualified Data.ByteString as BS (length)
import Data.ByteString (ByteString)
import Data.Map (Map, fromList, toList)
import Data.Serialize.Get
import Data.Serialize.Put
import Data.Word (Word16, Word32)
import Foreign.C.Types (CInt)

import System.Posix.Types (Fd(Fd))
import qualified System.Linux.Netlink.C as C
import System.Linux.Netlink.Helpers
import System.Linux.Netlink.Constants

--Generic protocol stuff

{- |Typeclase used by the system. Basically 'Storable' for 'Get' and 'Put'


getGet Returns a 'Get' function for the convertable. 

The MessageType is passed so that the function can parse different data structures
based on the message type.
-}
class Convertable a where
  getGet :: MessageType -> Get a -- ^get a 'Get' function for the static data
  getPut :: a -> Put -- ^get a 'Put' function for the static data


-- |Datatype to be used when there is no additional static header
data NoData = NoData deriving (Int -> NoData -> ShowS
[NoData] -> ShowS
NoData -> String
(Int -> NoData -> ShowS)
-> (NoData -> String) -> ([NoData] -> ShowS) -> Show NoData
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> NoData -> ShowS
showsPrec :: Int -> NoData -> ShowS
$cshow :: NoData -> String
show :: NoData -> String
$cshowList :: [NoData] -> ShowS
showList :: [NoData] -> ShowS
Show, NoData -> NoData -> Bool
(NoData -> NoData -> Bool)
-> (NoData -> NoData -> Bool) -> Eq NoData
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: NoData -> NoData -> Bool
== :: NoData -> NoData -> Bool
$c/= :: NoData -> NoData -> Bool
/= :: NoData -> NoData -> Bool
Eq)

instance Convertable NoData where
  getPut :: NoData -> Put
getPut NoData
_ = () -> Put
forall a. a -> PutM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  getGet :: MessageType -> Get NoData
getGet MessageType
_ = NoData -> Get NoData
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return NoData
NoData

-- |Data type for the netlink header
data Header = Header
    {
      Header -> MessageType
messageType   :: MessageType -- ^The message type
    , Header -> Word16
messageFlags  :: Word16 -- ^The message flags
    , Header -> Word32
messageSeqNum :: Word32 -- ^The sequence message number
    , Header -> Word32
messagePID    :: Word32 -- ^The pid of the sending process (0 is from kernel for receiving or "let the kernel set it" for sending)
    } deriving (Header -> Header -> Bool
(Header -> Header -> Bool)
-> (Header -> Header -> Bool) -> Eq Header
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Header -> Header -> Bool
== :: Header -> Header -> Bool
$c/= :: Header -> Header -> Bool
/= :: Header -> Header -> Bool
Eq)

instance Show Header where
  show :: Header -> String
show (Header MessageType
t Word16
f Word32
s Word32
p) = 
    String
"Type: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ MessageType -> String
forall a. Show a => a -> String
show MessageType
t String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", Flags: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Word16 -> String
forall a. Show a => a -> String
show Word16
f) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", Seq: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Word32 -> String
forall a. Show a => a -> String
show Word32
s String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", Pid: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Word32 -> String
forall a. Show a => a -> String
show Word32
p

-- |Type used for netlink attributes
type Attributes = Map Int ByteString

-- |The generic netlink message type
data Packet a
        = Packet -- The "normal" message
    {
      forall a. Packet a -> Header
packetHeader     :: Header -- ^The netlink message header
    , forall a. Packet a -> a
packetCustom     :: a -- ^The datatype for additional static data for the interface
    , forall a. Packet a -> Attributes
packetAttributes :: Attributes -- ^The netlink attributes
    }
        | ErrorMsg -- The error message
    {
      packetHeader     :: Header -- ^The netlink message header
    , forall a. Packet a -> CInt
packetError      :: CInt -- ^The error ID for this error message
    , forall a. Packet a -> Packet a
errorPacket      :: Packet a -- ^The offending message
    }
        | DoneMsg -- The done message, this should usually not be seen by a user
    {
      packetHeader     :: Header -- ^The header of the done message
    }
    deriving (Packet a -> Packet a -> Bool
(Packet a -> Packet a -> Bool)
-> (Packet a -> Packet a -> Bool) -> Eq (Packet a)
forall a. Eq a => Packet a -> Packet a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => Packet a -> Packet a -> Bool
== :: Packet a -> Packet a -> Bool
$c/= :: forall a. Eq a => Packet a -> Packet a -> Bool
/= :: Packet a -> Packet a -> Bool
Eq)

-- |Helperfunction for show instance of 'Packet' and further specializations
showPacket :: Show a => Packet a -> String
showPacket :: forall a. Show a => Packet a -> String
showPacket (ErrorMsg Header
hdr CInt
code Packet a
pack) = 
  String
"Error packet: \n" String -> ShowS
forall a. [a] -> [a] -> [a]
++
  Header -> String
forall a. Show a => a -> String
show Header
hdr String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++
  String
"Error code: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (CInt -> String
forall a. Show a => a -> String
show CInt
code) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++
  (Packet a -> String
forall a. Show a => a -> String
show Packet a
pack)
showPacket (DoneMsg Header
hdr) = String
"Done: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Header -> String
forall a. Show a => a -> String
show Header
hdr
showPacket (Packet Header
hdr a
cus Attributes
attrs) =
  String
"NetlinkPacket: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Header -> String
forall a. Show a => a -> String
show Header
hdr String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++
  String
"Custom data: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
cus String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++
  String
"Attrs: \n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Attributes -> String
showNLAttrs Attributes
attrs

instance {-# OVERLAPPABLE #-} Show a => Show (Packet a) where
  showList :: [Packet a] -> ShowS
showList [Packet a]
xs = (([String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([String] -> String)
-> ([Packet a] -> [String]) -> [Packet a] -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> [String] -> [String]
forall a. a -> [a] -> [a]
intersperse String
"===\n" ([String] -> [String])
-> ([Packet a] -> [String]) -> [Packet a] -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Packet a -> String) -> [Packet a] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Packet a -> String
forall a. Show a => a -> String
show ([Packet a] -> String) -> [Packet a] -> String
forall a b. (a -> b) -> a -> b
$[Packet a]
xs) String -> ShowS
forall a. [a] -> [a] -> [a]
++)
  show :: Packet a -> String
show = Packet a -> String
forall a. Show a => Packet a -> String
showPacket

-- |Convert generic NLAttrs into a string (# and hexdump)
showNLAttrs :: Attributes -> String
showNLAttrs :: Attributes -> String
showNLAttrs = (Int -> String) -> Attributes -> String
showAttrs Int -> String
forall a. Show a => a -> String
show 

-- |Helper function to convert attributes into a string
showAttrs 
  :: (Int -> String) -- ^A function from element id to its name
  -> Attributes -- ^The attributes
  -> String -- ^A string with Element name and hexdump of element
showAttrs :: (Int -> String) -> Attributes -> String
showAttrs Int -> String
sh = [(Int, ByteString)] -> String
showAttrs' ([(Int, ByteString)] -> String)
-> (Attributes -> [(Int, ByteString)]) -> Attributes -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Attributes -> [(Int, ByteString)]
forall k a. Map k a -> [(k, a)]
toList
  where
    showAttrs' :: [(Int, ByteString)] -> String
showAttrs' [] = []
    showAttrs' ((Int, ByteString)
x:[(Int, ByteString)]
xs) = (Int -> String) -> (Int, ByteString) -> String
showAttr Int -> String
sh (Int, ByteString)
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ [(Int, ByteString)] -> String
showAttrs' [(Int, ByteString)]
xs

-- |Helper function to generically show a single attribute
showAttr :: (Int -> String) -> (Int, ByteString) -> String
showAttr :: (Int -> String) -> (Int, ByteString) -> String
showAttr Int -> String
sh (Int
i,ByteString
v) = Int -> String
sh Int
i String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ByteString -> String
prettyHex ByteString
v

-- | Read packets from the buffer
getPacket 
  :: ByteString  -- ^The buffer to read from
  -> Get a -- ^The function to read a single message
  -> Either String [a] -- ^Either an error message or a list of messages read
getPacket :: forall a. ByteString -> Get a -> Either String [a]
getPacket ByteString
bytes Get a
f = (Get [a] -> ByteString -> Either String [a])
-> ByteString -> Get [a] -> Either String [a]
forall a b c. (a -> b -> c) -> b -> a -> c
flip Get [a] -> ByteString -> Either String [a]
forall a. Get a -> ByteString -> Either String a
runGet ByteString
bytes (Get [a] -> Either String [a]) -> Get [a] -> Either String [a]
forall a b. (a -> b) -> a -> b
$ do
    pkts <- Get Bool -> Get a -> Get [a]
forall (m :: * -> *) a. Monad m => m Bool -> m a -> m [a]
whileM (Bool -> Bool
not (Bool -> Bool) -> Get Bool -> Get Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Bool
isEmpty) Get a
f
    isEmpty >>= \Bool
e -> Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
e (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$ String -> Get ()
forall a. String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Incomplete message parse"
    return pkts

-- |'Get' Attributes
getAttributes :: Get Attributes
getAttributes :: Get Attributes
getAttributes = [(Int, ByteString)] -> Attributes
forall k a. Ord k => [(k, a)] -> Map k a
fromList ([(Int, ByteString)] -> Attributes)
-> Get [(Int, ByteString)] -> Get Attributes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Bool -> Get (Int, ByteString) -> Get [(Int, ByteString)]
forall (m :: * -> *) a. Monad m => m Bool -> m a -> m [a]
whileM (Bool -> Bool
not (Bool -> Bool) -> Get Bool -> Get Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Bool
isEmpty) Get (Int, ByteString)
getSingleAttribute

-- |'Get' a single 'Attribute'
getSingleAttribute :: Get (Int, ByteString)
getSingleAttribute :: Get (Int, ByteString)
getSingleAttribute = do
    len <- Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> Int) -> Get Word16 -> Get Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word16
g16
    ty <- fromIntegral <$> g16
    val <- getByteString (len - 4)
    isEmpty >>= \Bool
e -> Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not Bool
e Bool -> Bool -> Bool
&& Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0) (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$ Int -> Get ()
skip (Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
- (Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4))
    return (ty, val)

-- |'Get' the netlink 'Header'
getHeader :: Get (Int, Header)
getHeader :: Get (Int, Header)
getHeader = Int -> Get (Int, Header) -> Get (Int, Header)
forall a. Int -> Get a -> Get a
isolate Int
16 (Get (Int, Header) -> Get (Int, Header))
-> Get (Int, Header) -> Get (Int, Header)
forall a b. (a -> b) -> a -> b
$ do
      len <- Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Int) -> Get Word32 -> Get Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word32
g32
      ty     <- fromIntegral <$> g16
      flags  <- fromIntegral <$> g16
      seqnum <- g32
      pid    <- g32
      return (len - 16, Header ty flags seqnum pid)

-- |'Put' the netlink 'Header'
putHeader
  :: Int -- ^The length of the message
  -> Header -- ^The header itself
  -> Put
putHeader :: Int -> Header -> Put
putHeader Int
len (Header MessageType
ty Word16
flags Word32
seqnum Word32
pid) = do
    Word32 -> Put
p32 (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
    Word16 -> Put
p16 (MessageType -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral MessageType
ty)
    Word16 -> Put
p16 (Word16 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
flags)
    Word32 -> Put
p32 Word32
seqnum
    Word32 -> Put
p32 Word32
pid


-- |'Put' a 'Map' of 'Attributes'
putAttributes :: Attributes -> Put
putAttributes :: Attributes -> Put
putAttributes = ((Int, ByteString) -> Put) -> [(Int, ByteString)] -> Put
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Int, ByteString) -> Put
forall {a}. Integral a => (a, ByteString) -> Put
putAttr ([(Int, ByteString)] -> Put)
-> (Attributes -> [(Int, ByteString)]) -> Attributes -> Put
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Attributes -> [(Int, ByteString)]
forall k a. Map k a -> [(k, a)]
toList
  where
    putAttr :: (a, ByteString) -> Put
putAttr (a
ty, ByteString
value) = do
        Word16 -> Put
p16 (Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word16) -> Int -> Word16
forall a b. (a -> b) -> a -> b
$ByteString -> Int
BS.length ByteString
value Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
4)
        Word16 -> Put
p16 (a -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
ty)
        Putter ByteString
putByteString ByteString
value
        Bool -> Put -> Put
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> Int
BS.length ByteString
value Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0) (Put -> Put) -> Put -> Put
forall a b. (a -> b) -> a -> b
$Int -> Put -> Put
forall (m :: * -> *) a. Applicative m => Int -> m a -> m ()
replicateM_ (Int
4 Int -> Int -> Int
forall a. Num a => a -> a -> a
- (ByteString -> Int
BS.length ByteString
value Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
4)) (Word8 -> Put
p8 Word8
0)

-- |'Put' a 'Packet' so it can e sent
putPacket :: (Convertable a, Eq a, Show a) => Packet a -> [ByteString]
putPacket :: forall a. (Convertable a, Eq a, Show a) => Packet a -> [ByteString]
putPacket (Packet Header
header a
custom Attributes
attributes) =
  let attrs :: ByteString
attrs = Put -> ByteString
runPut (Put -> ByteString) -> Put -> ByteString
forall a b. (a -> b) -> a -> b
$Attributes -> Put
putAttributes Attributes
attributes
      cus :: ByteString
cus   = Put -> ByteString
runPut (Put -> ByteString) -> Put -> ByteString
forall a b. (a -> b) -> a -> b
$a -> Put
forall a. Convertable a => a -> Put
getPut a
custom
      hdr :: ByteString
hdr   = Put -> ByteString
runPut (Put -> ByteString) -> Put -> ByteString
forall a b. (a -> b) -> a -> b
$Int -> Header -> Put
putHeader (ByteString -> Int
BS.length ByteString
attrs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ByteString -> Int
BS.length ByteString
cus Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
16) Header
header
  in [ByteString
hdr, ByteString
cus, ByteString
attrs]
putPacket Packet a
_ = String -> [ByteString]
forall a. HasCallStack => String -> a
error String
"Cannot convert this for transmission"


-- |'Get' an error message
getError :: (Convertable a, Eq a, Show a) => Header -> Get (Packet a)
getError :: forall a. (Convertable a, Eq a, Show a) => Header -> Get (Packet a)
getError Header
hdr = do
  code <- Word32 -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> CInt) -> Get Word32 -> Get CInt
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word32
getWord32host
  packet <- getGenPacket
  return $ErrorMsg hdr code packet


-- | 'Get' the body of a packet (the 'Header' is already read from the buffer
getGenPacketContent :: (Convertable a, Eq a, Show a) => Header -> Get (Packet a)
getGenPacketContent :: forall a. (Convertable a, Eq a, Show a) => Header -> Get (Packet a)
getGenPacketContent Header
hdr
  | Header -> MessageType
messageType Header
hdr MessageType -> MessageType -> Bool
forall a. Eq a => a -> a -> Bool
== MessageType
forall a. Num a => a
eNLMSG_DONE  = Int -> Get ()
skip Int
4 Get () -> Get (Packet a) -> Get (Packet a)
forall a b. Get a -> Get b -> Get b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Packet a -> Get (Packet a)
forall a. a -> Get a
forall (m :: * -> *) a. Monad m => a -> m a
return (Header -> Packet a
forall a. Header -> Packet a
DoneMsg Header
hdr)
  | Header -> MessageType
messageType Header
hdr MessageType -> MessageType -> Bool
forall a. Eq a => a -> a -> Bool
== MessageType
forall a. Num a => a
eNLMSG_ERROR = Header -> Get (Packet a)
forall a. (Convertable a, Eq a, Show a) => Header -> Get (Packet a)
getError Header
hdr
  | Bool
otherwise  = do
      msg    <- MessageType -> Get a
forall a. Convertable a => MessageType -> Get a
getGet (Header -> MessageType
messageType Header
hdr)
      attrs  <- getAttributes
      return $ Packet hdr msg attrs

{- | 'Get' a packet

This returns a 'Get' function for a netlink message.
The message may have additional static data defined by the protocol.
-}
getGenPacket :: (Convertable a, Eq a, Show a) => Get (Packet a)
getGenPacket :: forall a. (Convertable a, Eq a, Show a) => Get (Packet a)
getGenPacket = do
    (len, header) <- Get (Int, Header)
getHeader
    isolate len $ getGenPacketContent header


{- | Read all 'Packet's from a buffer

The packets may have additional static data defined by the protocol.
-}
getPackets :: (Convertable a, Eq a, Show a) => ByteString -> Either String [Packet a]
getPackets :: forall a.
(Convertable a, Eq a, Show a) =>
ByteString -> Either String [Packet a]
getPackets ByteString
bytes = (Get [Packet a] -> ByteString -> Either String [Packet a])
-> ByteString -> Get [Packet a] -> Either String [Packet a]
forall a b c. (a -> b -> c) -> b -> a -> c
flip Get [Packet a] -> ByteString -> Either String [Packet a]
forall a. Get a -> ByteString -> Either String a
runGet ByteString
bytes (Get [Packet a] -> Either String [Packet a])
-> Get [Packet a] -> Either String [Packet a]
forall a b. (a -> b) -> a -> b
$ do
    pkts <- Get Bool -> Get (Packet a) -> Get [Packet a]
forall (m :: * -> *) a. Monad m => m Bool -> m a -> m [a]
whileM (Bool -> Bool
not (Bool -> Bool) -> Get Bool -> Get Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Bool
isEmpty) Get (Packet a)
forall a. (Convertable a, Eq a, Show a) => Get (Packet a)
getGenPacket
    isEmpty >>= \Bool
e -> Bool -> Get () -> Get ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
e (Get () -> Get ()) -> Get () -> Get ()
forall a b. (a -> b) -> a -> b
$ String -> Get ()
forall a. String -> Get a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Incomplete message parse"
    return pkts

-- | Typesafe wrapper around a 'CInt' (fd)
newtype NetlinkSocket = NS CInt

-- |Open and return a 'NetlinkSocket', for legacy reasons this opens a route socket
makeSocket :: IO NetlinkSocket
makeSocket :: IO NetlinkSocket
makeSocket = CInt -> NetlinkSocket
NS (CInt -> NetlinkSocket) -> IO CInt -> IO NetlinkSocket
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO CInt
C.makeSocket

-- |Open a 'NetlinkSocket'. This is the generic function
makeSocketGeneric 
  :: Int -- ^The netlink family to use
  -> IO NetlinkSocket
makeSocketGeneric :: Int -> IO NetlinkSocket
makeSocketGeneric = (CInt -> NetlinkSocket) -> IO CInt -> IO NetlinkSocket
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CInt -> NetlinkSocket
NS (IO CInt -> IO NetlinkSocket)
-> (Int -> IO CInt) -> Int -> IO NetlinkSocket
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IO CInt
C.makeSocketGeneric

-- |Get the raw 'Fd' used for netlink communcation (this can be plugged into eventing)
getNetlinkFd :: NetlinkSocket -> Fd
getNetlinkFd :: NetlinkSocket -> Fd
getNetlinkFd (NS CInt
f) = CInt -> Fd
Fd CInt
f

{- |Send a Message over netlink.

This is an internal function.
The prototype directly reflects the interface of the C functions.
-}
sendmsg :: NetlinkSocket -> [ByteString] -> IO ()
sendmsg :: NetlinkSocket -> [ByteString] -> IO ()
sendmsg (NS CInt
fd) = CInt -> [ByteString] -> IO ()
C.sendmsg CInt
fd

{- |Receive a Message over netlink.

This is an internal function.
The prototype directly reflects the interface of the C functions.
-}
recvmsg :: NetlinkSocket -> Int -> IO ByteString
recvmsg :: NetlinkSocket -> Int -> IO ByteString
recvmsg (NS CInt
fd) = CInt -> Int -> IO ByteString
C.recvmsg CInt
fd

-- |Close a 'NetlinkSocket' when it is no longer used
closeSocket :: NetlinkSocket -> IO ()
closeSocket :: NetlinkSocket -> IO ()
closeSocket (NS CInt
fd) = CInt -> IO ()
C.closeSocket CInt
fd

-- |Join a netlink multicast group
joinMulticastGroup
  :: NetlinkSocket -- ^The socket to join with
  -> Word32  -- ^The id of the group to join, values of System.Linux.Netlink.Constants.eRTNLGRP_*
  -> IO ()
joinMulticastGroup :: NetlinkSocket -> Word32 -> IO ()
joinMulticastGroup (NS CInt
fd) = CInt -> Word32 -> IO ()
C.joinMulticastGroup CInt
fd

-- |Leave a netlink multicast group
leaveMulticastGroup
  :: NetlinkSocket -- ^The socket to leave
  -> Word32  -- ^The id of the group to leave, values of System.Linux.Netlink.Constants.eRTNLGRP_*
  -> IO ()
leaveMulticastGroup :: NetlinkSocket -> Word32 -> IO ()
leaveMulticastGroup (NS CInt
fd) = CInt -> Word32 -> IO ()
C.leaveMulticastGroup CInt
fd



-- generic query functions
{- |Query data over netlink.

This sends a 'Packet' over netlink and returns the answer.
This blocks in a safe foregin function until the other side replies.
-}
query :: (Convertable a, Eq a, Show a) => NetlinkSocket -> Packet a -> IO [Packet a]
query :: forall a.
(Convertable a, Eq a, Show a) =>
NetlinkSocket -> Packet a -> IO [Packet a]
query NetlinkSocket
sock Packet a
req = do
    NetlinkSocket -> [ByteString] -> IO ()
sendmsg NetlinkSocket
sock (Packet a -> [ByteString]
forall a. (Convertable a, Eq a, Show a) => Packet a -> [ByteString]
putPacket Packet a
req)
    NetlinkSocket -> IO [Packet a]
forall a.
(Convertable a, Eq a, Show a) =>
NetlinkSocket -> IO [Packet a]
recvMulti NetlinkSocket
sock

-- |The same as 'query' but requires the answer to be a single message
queryOne :: (Convertable a, Eq a, Show a) => NetlinkSocket -> Packet a -> IO (Packet a)
queryOne :: forall a.
(Convertable a, Eq a, Show a) =>
NetlinkSocket -> Packet a -> IO (Packet a)
queryOne NetlinkSocket
sock Packet a
req = do
    NetlinkSocket -> [ByteString] -> IO ()
sendmsg NetlinkSocket
sock (Packet a -> [ByteString]
forall a. (Convertable a, Eq a, Show a) => Packet a -> [ByteString]
putPacket Packet a
req)
    pkts <- NetlinkSocket -> IO [Packet a]
forall a.
(Convertable a, Eq a, Show a) =>
NetlinkSocket -> IO [Packet a]
recvMulti NetlinkSocket
sock
    case pkts of
      [Packet a
x] -> Packet a -> IO (Packet a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Packet a
x
      [Packet a]
_ -> String -> IO (Packet a)
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String
"Expected one packet, received " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (Int -> String
forall a. Show a => a -> String
show (Int -> String) -> ([Packet a] -> Int) -> [Packet a] -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Packet a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Packet a] -> String) -> [Packet a] -> String
forall a b. (a -> b) -> a -> b
$[Packet a]
pkts))

-- |Internal function to receive multiple netlink messages
recvMulti :: (Convertable a, Eq a, Show a) => NetlinkSocket -> IO [Packet a]
recvMulti :: forall a.
(Convertable a, Eq a, Show a) =>
NetlinkSocket -> IO [Packet a]
recvMulti NetlinkSocket
sock = do
    pkts <- NetlinkSocket -> IO [Packet a]
forall a.
(Convertable a, Eq a, Show a) =>
NetlinkSocket -> IO [Packet a]
recvOne NetlinkSocket
sock
    if isMulti (first pkts)
        then if isDone (last pkts)
             -- This is fine because first would have complained before
             then return $ init pkts
             else (pkts ++) <$> recvMulti sock
        else return pkts
  where
    isMulti :: Packet a -> Bool
isMulti = Word16 -> Word16 -> Bool
forall a. Bits a => a -> a -> Bool
isFlagSet Word16
forall a. (Num a, Bits a) => a
fNLM_F_MULTI (Word16 -> Bool) -> (Packet a -> Word16) -> Packet a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Header -> Word16
messageFlags (Header -> Word16) -> (Packet a -> Header) -> Packet a -> Word16
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Packet a -> Header
forall a. Packet a -> Header
packetHeader
    isDone :: Packet a -> Bool
isDone  = (MessageType -> MessageType -> Bool
forall a. Eq a => a -> a -> Bool
== MessageType
forall a. Num a => a
eNLMSG_DONE) (MessageType -> Bool)
-> (Packet a -> MessageType) -> Packet a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Header -> MessageType
messageType (Header -> MessageType)
-> (Packet a -> Header) -> Packet a -> MessageType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Packet a -> Header
forall a. Packet a -> Header
packetHeader
    first :: [a] -> a
first (a
x:[a]
_) = a
x
    first [] = String -> a
forall a. HasCallStack => String -> a
error String
"Got empty list from recvOne in recvMulti, this shouldn't happen"

{- | Calls recvmsg once and returns all received messages

This should only be used outside of the package when reading multicast messages.

The prototype of this function is unintuitive, but this cannot be avoided without
buffering in userspace with the netlink api.
-}
recvOne :: (Convertable a, Eq a, Show a) => NetlinkSocket -> IO [Packet a]
recvOne :: forall a.
(Convertable a, Eq a, Show a) =>
NetlinkSocket -> IO [Packet a]
recvOne NetlinkSocket
sock = NetlinkSocket -> Int -> IO ByteString
recvmsg NetlinkSocket
sock Int
forall a. Num a => a
bufferSize IO ByteString -> (ByteString -> IO [Packet a]) -> IO [Packet a]
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ByteString
b -> case ByteString -> Either String [Packet a]
forall a.
(Convertable a, Eq a, Show a) =>
ByteString -> Either String [Packet a]
getPackets ByteString
b of
    Left String
err   -> String -> IO [Packet a]
forall a. String -> IO a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
err
    Right [Packet a]
pkts -> [Packet a] -> IO [Packet a]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [Packet a]
pkts


isFlagSet :: Bits a => a -> a -> Bool
isFlagSet :: forall a. Bits a => a -> a -> Bool
isFlagSet a
f a
v = (a
f a -> a -> a
forall a. Bits a => a -> a -> a
.&. a
v) a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
f

bufferSize :: Num a => a
bufferSize :: forall a. Num a => a
bufferSize = a
4096