secgateway/Platform/user/kernel_hook/lkh_hook.c

531 lines
13 KiB
C
Raw Normal View History

#include <linux/kernel.h>
#include <linux/if.h>
#include <linux/netdevice.h>
#include <linux/netfilter_ipv6.h>
#include <linux/inetdevice.h>
#include <linux/mutex.h>
#include <linux/mm.h>
#include <linux/rcupdate.h>
#include <net/net_namespace.h>
#include <net/sock.h>
#include <linux/list.h>
#include <linux/skbuff.h>
#include <linux/netdevice.h>
#include "lkh_hook.h"
struct lkh_hook_handle g_lkh_hook_handle;
static unsigned int lkh_accept_all(void *priv, struct sk_buff *skb, const struct lkh_hook_state *state)
{
return LKH_ACCEPT; /* ACCEPT makes nf_hook_slow call next hook */
}
static const struct lkh_hook_ops dummy_ops = {
.hook = lkh_accept_all,
.priority = INT_MIN,
};
static inline struct lkh_hook_ops **lkh_hook_entries_get_hook_ops(const struct lkh_hook_entries *e)
{
unsigned int n = e->num_hook_entries;
const void *hook_end;
hook_end = &e->hooks[n]; /* this is *past* ->hooks[]! */
return (struct lkh_hook_ops **)hook_end;
}
/* 注册流程 */
static void __lkh_hook_entries_free(struct rcu_head *h)
{
struct lkh_hook_entries_rcu_head *head;
head = container_of(h, struct lkh_hook_entries_rcu_head, head);
kvfree(head->allocation);
return;
}
static void lkh_hook_entries_free(struct lkh_hook_entries *e)
{
struct lkh_hook_entries_rcu_head *head;
struct lkh_hook_ops **ops;
unsigned int num;
if (NULL == e)
{
return;
}
num = e->num_hook_entries;
ops = lkh_hook_entries_get_hook_ops(e);
head = (void *)&ops[num];
head->allocation = e;
call_rcu(&head->head, __lkh_hook_entries_free);
return;
}
static void lkh_hooks_validate(const struct lkh_hook_entries *hooks)
{
#ifdef CONFIG_DEBUG_KERNEL
struct lkh_hook_ops **orig_ops;
int prio = INT_MIN;
size_t i = 0;
orig_ops = lkh_hook_entries_get_hook_ops(hooks);
for (i = 0; i < hooks->num_hook_entries; i++)
{
if (orig_ops[i] == &dummy_ops)
{
continue;
}
if (orig_ops[i]->priority > prio)
{
prio = orig_ops[i]->priority;
}
}
#endif
}
/* 申请新的存储结构 */
static struct lkh_hook_entries * lkh_allocate_hook_entries_size(u16 num)
{
struct lkh_hook_entries *e;
size_t alloc = sizeof(*e) +
sizeof(struct lkh_hook_entry) * num +
sizeof(struct lkh_hook_ops *) * num +
sizeof(struct lkh_hook_entries_rcu_head);
if (num == 0)
{
return NULL;
}
e = kvzalloc(alloc, GFP_KERNEL);
if (e)
{
e->num_hook_entries = num;
}
return e;
}
static struct lkh_hook_entries * lkh_hook_entries_grow(const struct lkh_hook_entries *old, const struct lkh_hook_ops *reg)
{
unsigned int i, alloc_entries, nhooks, old_entries;
struct lkh_hook_ops **orig_ops = NULL;
struct lkh_hook_ops **new_ops;
struct lkh_hook_entries *new;
bool inserted = false;
alloc_entries = 1;
old_entries = old ? old->num_hook_entries : 0;
if (old != NULL)
{
orig_ops = lkh_hook_entries_get_hook_ops(old);
for (i = 0; i < old_entries; i++)
{
if (orig_ops[i] != &dummy_ops)
{
alloc_entries++;
}
}
}
if (alloc_entries > MAX_HOOK_COUNT)
{
return ERR_PTR(-E2BIG);
}
new = lkh_allocate_hook_entries_size(alloc_entries);
if (NULL == new)
{
return ERR_PTR(-ENOMEM);
}
new_ops = lkh_hook_entries_get_hook_ops(new);
i = 0;
nhooks = 0;
while (i < old_entries)
{
if (orig_ops[i] == &dummy_ops)
{
++i;
continue;
}
if (inserted || (reg->priority > orig_ops[i]->priority))
{
new_ops[nhooks] = (void *)orig_ops[i];
new->hooks[nhooks] = old->hooks[i];
i++;
}
else
{
new_ops[nhooks] = (void *)reg;
new->hooks[nhooks].hook = reg->hook;
new->hooks[nhooks].priv = reg->priv;
inserted = true;
}
nhooks++;
}
if (!inserted)
{
new_ops[nhooks] = (void *)reg;
new->hooks[nhooks].hook = reg->hook;
new->hooks[nhooks].priv = reg->priv;
}
return new;
}
static struct lkh_hook_entries __rcu ** lkh_hook_entry_head(struct net *net, int pf, unsigned int hook_stage)
{
struct lkh_hook * hook;
list_for_each_entry(hook, &g_lkh_hook_handle.list, list)
{
if (hook->net_ptr == net)
{
switch (pf)
{
case LKH_PROTO_IPV4:
if ((ARRAY_SIZE(hook->hooks_ipv4) <= hook_stage))
{
return NULL;
}
return hook->hooks_ipv4 + hook_stage;
case LKH_PROTO_IPV6:
if ((ARRAY_SIZE(hook->hooks_ipv6) <= hook_stage))
{
return NULL;
}
return hook->hooks_ipv6 + hook_stage;
default:
return NULL;
}
break;
}
}
return NULL;
}
static int __lkh_register_net_hook(struct net *net, int pf, const struct lkh_hook_ops *reg)
{
struct lkh_hook_entries *p, *new_hooks;
struct lkh_hook_entries __rcu **pp;
pp = lkh_hook_entry_head(net, pf, reg->hook_stage);
if (NULL == pp)
{
return -EINVAL;
}
mutex_lock(&lkh_hook_mutex);
/* 多核信号量保护处理 RCU */
p = lkh_entry_dereference(*pp);
new_hooks = lkh_hook_entries_grow(p, reg);
if (!IS_ERR(new_hooks))
{
/*
* writer用来进行removal的操作witer完成新版本数据分配和更新之后
* RCU protected pointer指向RCU protected data
*/
rcu_assign_pointer(*pp, new_hooks);
}
mutex_unlock(&lkh_hook_mutex);
if (IS_ERR(new_hooks))
{
return PTR_ERR(new_hooks);
}
lkh_hooks_validate(new_hooks);
/* 释放旧的指针内存 */
lkh_hook_entries_free(p);
return 0;
}
int lkh_register_net_hook(struct net *net, const struct lkh_hook_ops *reg)
{
int err;
err = __lkh_register_net_hook(net, reg->pf, reg);
if (err < 0)
{
return err;
}
return 0;
}
static void *__lkh_hook_entries_try_shrink(struct lkh_hook_entries *old, struct lkh_hook_entries __rcu **pp)
{
unsigned int i, j, skip = 0, hook_entries;
struct lkh_hook_entries *new = NULL;
struct lkh_hook_ops **orig_ops;
struct lkh_hook_ops **new_ops;
if (NULL == old)
{
return NULL;
}
orig_ops = lkh_hook_entries_get_hook_ops(old);
for (i = 0; i < old->num_hook_entries; i++)
{
if (orig_ops[i] == &dummy_ops)
{
skip++;
}
}
hook_entries = old->num_hook_entries;
if (skip == hook_entries)
{
goto out_assign;
}
if (skip == 0)
{
return NULL;
}
hook_entries -= skip;
new = lkh_allocate_hook_entries_size(hook_entries);
if (NULL == new)
{
return NULL;
}
new_ops = lkh_hook_entries_get_hook_ops(new);
for (i = 0, j = 0; i < old->num_hook_entries; i++)
{
if (orig_ops[i] == &dummy_ops)
{
continue;
}
new->hooks[j] = old->hooks[i];
new_ops[j] = (void *)orig_ops[i];
j++;
}
lkh_hooks_validate(new);
out_assign:
rcu_assign_pointer(*pp, new);
return old;
}
static bool lkh_remove_net_hook(struct lkh_hook_entries *old, const struct lkh_hook_ops *unreg)
{
struct lkh_hook_ops **orig_ops;
unsigned int i;
orig_ops = lkh_hook_entries_get_hook_ops(old);
for (i = 0; i < old->num_hook_entries; i++)
{
if (orig_ops[i] != unreg)
{
continue;
}
WRITE_ONCE(old->hooks[i].hook, lkh_accept_all);
WRITE_ONCE(orig_ops[i], &dummy_ops);
return true;
}
return false;
}
static void __lkh_unregister_net_hook(struct net *net, int pf, const struct lkh_hook_ops *reg)
{
struct lkh_hook_entries __rcu **pp;
struct lkh_hook_entries *p;
pp = lkh_hook_entry_head(net, pf, reg->hook_stage);
if (NULL == pp)
{
return;
}
mutex_lock(&lkh_hook_mutex);
p = lkh_entry_dereference(*pp);
if (NULL == p)
{
mutex_unlock(&lkh_hook_mutex);
return;
}
lkh_remove_net_hook(p, reg);
p = __lkh_hook_entries_try_shrink(p, pp);
mutex_unlock(&lkh_hook_mutex);
if (NULL == p)
{
return;
}
lkh_hook_entries_free(p);
return;
}
void lkh_unregister_net_hook(struct net *net, const struct lkh_hook_ops *reg)
{
__lkh_unregister_net_hook(net, reg->pf, reg);
return;
}
/*********************************************************************************
* Description  
*
* Input:  
* net -
* reg -
* n -
* Output:
* 
* Return:
*
* Others:
*
**********************************************************************************/
void lkh_unregister_net_hooks(struct net *net, const struct lkh_hook_ops *reg, unsigned int hookcount)
{
unsigned int i;
for (i = 0; i < hookcount; i++)
{
lkh_unregister_net_hook(net, &reg[i]);
}
return;
}
/*********************************************************************************
* Description  
*
* Input:  
* net -
* reg -
* hookcount -
* Output:
* 
* Return:
* 0 -
* 0 -
* Others:
* reg结构需要调用者确认是否需要保存HOOK管理结构不会记录reg的内容reg的指针 netfilter就是这样实现
* netfilter的方式实现
**********************************************************************************/
int lkh_register_net_hooks(struct net *net, const struct lkh_hook_ops *reg, unsigned int hookcount)
{
unsigned int i;
int err = 0;
for (i = 0; i < hookcount; i++)
{
err = lkh_register_net_hook(net, &reg[i]);
if (err != 0)
{
goto err;
}
}
return err;
err:
if (i > 0)
{
lkh_unregister_net_hooks(net, reg, i);
}
return err;
}
void lkh_hash_struct_show(void)
{
struct lkh_hook * hook;
struct lkh_hook_entries * hook_entries= NULL;
struct lkh_hook_ops **new_ops;
struct lkh_hook_ops * ops = NULL;
struct lkh_hook_entry * entry = NULL;
list_for_each_entry(hook, &g_lkh_hook_handle.list, list)
{
if (hook->net_ptr != NULL)
{
printk(KERN_EMERG "------------------------------ net hook begin-----------------------------");
printk(KERN_EMERG "hook->net_ptr %p", hook->net_ptr);
printk(KERN_EMERG "------------------------------ net hook ipv4-----------------------------");
printk(KERN_EMERG "hook->hooks_ipv4 %p", hook->hooks_ipv4);
printk(KERN_EMERG "------------------------- LKH_INET_PRE_FORWARD------------------------");
printk(KERN_EMERG "hook_entries %p", hook->hooks_ipv4[LKH_INET_PRE_FORWARD]);
hook_entries = hook->hooks_ipv4[LKH_INET_PRE_FORWARD];
if (hook_entries != NULL)
{
int i = 0;
new_ops = lkh_hook_entries_get_hook_ops(hook_entries);
printk(KERN_EMERG "hook_entries num_hook_entries %d new_ops: %p", hook_entries->num_hook_entries, new_ops);
for (; i < hook_entries->num_hook_entries; i++)
{
entry = &(hook_entries->hooks[i]);
ops = new_ops[i];
printk(KERN_EMERG "----------- hook [%d]------------", i);
printk(KERN_EMERG "hook entry [%d] entry ptr: %p new_ops: %p", i, entry, ops);
printk(KERN_EMERG "entry hook: %p", entry->hook);
printk(KERN_EMERG "ops pf: %d hook_stage: %d priority: %d hook: %p", ops->pf, ops->hook_stage, ops->priority, ops->hook);
}
}
printk(KERN_EMERG "------------------------------ net hook end-----------------------------");
}
else
{
printk(KERN_EMERG "------------------------------ net hook begin-----------------------------");
printk(KERN_EMERG "hook->net_ptr is NULL");
printk(KERN_EMERG "------------------------------ net hook end-----------------------------");
}
}
return;
}
EXPORT_SYMBOL(lkh_register_net_hooks);
EXPORT_SYMBOL(lkh_unregister_net_hooks);