secgateway/Platform/user/kernel_hook/lkh_hook.c

531 lines
13 KiB
C
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include <linux/kernel.h>
#include <linux/if.h>
#include <linux/netdevice.h>
#include <linux/netfilter_ipv6.h>
#include <linux/inetdevice.h>
#include <linux/mutex.h>
#include <linux/mm.h>
#include <linux/rcupdate.h>
#include <net/net_namespace.h>
#include <net/sock.h>
#include <linux/list.h>
#include <linux/skbuff.h>
#include <linux/netdevice.h>
#include "lkh_hook.h"
struct lkh_hook_handle g_lkh_hook_handle;
static unsigned int lkh_accept_all(void *priv, struct sk_buff *skb, const struct lkh_hook_state *state)
{
return LKH_ACCEPT; /* ACCEPT makes nf_hook_slow call next hook */
}
static const struct lkh_hook_ops dummy_ops = {
.hook = lkh_accept_all,
.priority = INT_MIN,
};
static inline struct lkh_hook_ops **lkh_hook_entries_get_hook_ops(const struct lkh_hook_entries *e)
{
unsigned int n = e->num_hook_entries;
const void *hook_end;
hook_end = &e->hooks[n]; /* this is *past* ->hooks[]! */
return (struct lkh_hook_ops **)hook_end;
}
/* 注册流程 */
static void __lkh_hook_entries_free(struct rcu_head *h)
{
struct lkh_hook_entries_rcu_head *head;
head = container_of(h, struct lkh_hook_entries_rcu_head, head);
kvfree(head->allocation);
return;
}
static void lkh_hook_entries_free(struct lkh_hook_entries *e)
{
struct lkh_hook_entries_rcu_head *head;
struct lkh_hook_ops **ops;
unsigned int num;
if (NULL == e)
{
return;
}
num = e->num_hook_entries;
ops = lkh_hook_entries_get_hook_ops(e);
head = (void *)&ops[num];
head->allocation = e;
call_rcu(&head->head, __lkh_hook_entries_free);
return;
}
static void lkh_hooks_validate(const struct lkh_hook_entries *hooks)
{
#ifdef CONFIG_DEBUG_KERNEL
struct lkh_hook_ops **orig_ops;
int prio = INT_MIN;
size_t i = 0;
orig_ops = lkh_hook_entries_get_hook_ops(hooks);
for (i = 0; i < hooks->num_hook_entries; i++)
{
if (orig_ops[i] == &dummy_ops)
{
continue;
}
if (orig_ops[i]->priority > prio)
{
prio = orig_ops[i]->priority;
}
}
#endif
}
/* 申请新的存储结构 */
static struct lkh_hook_entries * lkh_allocate_hook_entries_size(u16 num)
{
struct lkh_hook_entries *e;
size_t alloc = sizeof(*e) +
sizeof(struct lkh_hook_entry) * num +
sizeof(struct lkh_hook_ops *) * num +
sizeof(struct lkh_hook_entries_rcu_head);
if (num == 0)
{
return NULL;
}
e = kvzalloc(alloc, GFP_KERNEL);
if (e)
{
e->num_hook_entries = num;
}
return e;
}
static struct lkh_hook_entries * lkh_hook_entries_grow(const struct lkh_hook_entries *old, const struct lkh_hook_ops *reg)
{
unsigned int i, alloc_entries, nhooks, old_entries;
struct lkh_hook_ops **orig_ops = NULL;
struct lkh_hook_ops **new_ops;
struct lkh_hook_entries *new;
bool inserted = false;
alloc_entries = 1;
old_entries = old ? old->num_hook_entries : 0;
if (old != NULL)
{
orig_ops = lkh_hook_entries_get_hook_ops(old);
for (i = 0; i < old_entries; i++)
{
if (orig_ops[i] != &dummy_ops)
{
alloc_entries++;
}
}
}
if (alloc_entries > MAX_HOOK_COUNT)
{
return ERR_PTR(-E2BIG);
}
new = lkh_allocate_hook_entries_size(alloc_entries);
if (NULL == new)
{
return ERR_PTR(-ENOMEM);
}
new_ops = lkh_hook_entries_get_hook_ops(new);
i = 0;
nhooks = 0;
while (i < old_entries)
{
if (orig_ops[i] == &dummy_ops)
{
++i;
continue;
}
if (inserted || (reg->priority > orig_ops[i]->priority))
{
new_ops[nhooks] = (void *)orig_ops[i];
new->hooks[nhooks] = old->hooks[i];
i++;
}
else
{
new_ops[nhooks] = (void *)reg;
new->hooks[nhooks].hook = reg->hook;
new->hooks[nhooks].priv = reg->priv;
inserted = true;
}
nhooks++;
}
if (!inserted)
{
new_ops[nhooks] = (void *)reg;
new->hooks[nhooks].hook = reg->hook;
new->hooks[nhooks].priv = reg->priv;
}
return new;
}
static struct lkh_hook_entries __rcu ** lkh_hook_entry_head(struct net *net, int pf, unsigned int hook_stage)
{
struct lkh_hook * hook;
list_for_each_entry(hook, &g_lkh_hook_handle.list, list)
{
if (hook->net_ptr == net)
{
switch (pf)
{
case LKH_PROTO_IPV4:
if ((ARRAY_SIZE(hook->hooks_ipv4) <= hook_stage))
{
return NULL;
}
return hook->hooks_ipv4 + hook_stage;
case LKH_PROTO_IPV6:
if ((ARRAY_SIZE(hook->hooks_ipv6) <= hook_stage))
{
return NULL;
}
return hook->hooks_ipv6 + hook_stage;
default:
return NULL;
}
break;
}
}
return NULL;
}
static int __lkh_register_net_hook(struct net *net, int pf, const struct lkh_hook_ops *reg)
{
struct lkh_hook_entries *p, *new_hooks;
struct lkh_hook_entries __rcu **pp;
pp = lkh_hook_entry_head(net, pf, reg->hook_stage);
if (NULL == pp)
{
return -EINVAL;
}
mutex_lock(&lkh_hook_mutex);
/* 多核信号量保护处理 RCU */
p = lkh_entry_dereference(*pp);
new_hooks = lkh_hook_entries_grow(p, reg);
if (!IS_ERR(new_hooks))
{
/*
* 该接口被writer用来进行removal的操作在witer完成新版本数据分配和更新之后
* 调用这个接口可以让RCU protected pointer指向RCU protected data
*/
rcu_assign_pointer(*pp, new_hooks);
}
mutex_unlock(&lkh_hook_mutex);
if (IS_ERR(new_hooks))
{
return PTR_ERR(new_hooks);
}
lkh_hooks_validate(new_hooks);
/* 释放旧的指针内存 */
lkh_hook_entries_free(p);
return 0;
}
int lkh_register_net_hook(struct net *net, const struct lkh_hook_ops *reg)
{
int err;
err = __lkh_register_net_hook(net, reg->pf, reg);
if (err < 0)
{
return err;
}
return 0;
}
static void *__lkh_hook_entries_try_shrink(struct lkh_hook_entries *old, struct lkh_hook_entries __rcu **pp)
{
unsigned int i, j, skip = 0, hook_entries;
struct lkh_hook_entries *new = NULL;
struct lkh_hook_ops **orig_ops;
struct lkh_hook_ops **new_ops;
if (NULL == old)
{
return NULL;
}
orig_ops = lkh_hook_entries_get_hook_ops(old);
for (i = 0; i < old->num_hook_entries; i++)
{
if (orig_ops[i] == &dummy_ops)
{
skip++;
}
}
hook_entries = old->num_hook_entries;
if (skip == hook_entries)
{
goto out_assign;
}
if (skip == 0)
{
return NULL;
}
hook_entries -= skip;
new = lkh_allocate_hook_entries_size(hook_entries);
if (NULL == new)
{
return NULL;
}
new_ops = lkh_hook_entries_get_hook_ops(new);
for (i = 0, j = 0; i < old->num_hook_entries; i++)
{
if (orig_ops[i] == &dummy_ops)
{
continue;
}
new->hooks[j] = old->hooks[i];
new_ops[j] = (void *)orig_ops[i];
j++;
}
lkh_hooks_validate(new);
out_assign:
rcu_assign_pointer(*pp, new);
return old;
}
static bool lkh_remove_net_hook(struct lkh_hook_entries *old, const struct lkh_hook_ops *unreg)
{
struct lkh_hook_ops **orig_ops;
unsigned int i;
orig_ops = lkh_hook_entries_get_hook_ops(old);
for (i = 0; i < old->num_hook_entries; i++)
{
if (orig_ops[i] != unreg)
{
continue;
}
WRITE_ONCE(old->hooks[i].hook, lkh_accept_all);
WRITE_ONCE(orig_ops[i], &dummy_ops);
return true;
}
return false;
}
static void __lkh_unregister_net_hook(struct net *net, int pf, const struct lkh_hook_ops *reg)
{
struct lkh_hook_entries __rcu **pp;
struct lkh_hook_entries *p;
pp = lkh_hook_entry_head(net, pf, reg->hook_stage);
if (NULL == pp)
{
return;
}
mutex_lock(&lkh_hook_mutex);
p = lkh_entry_dereference(*pp);
if (NULL == p)
{
mutex_unlock(&lkh_hook_mutex);
return;
}
lkh_remove_net_hook(p, reg);
p = __lkh_hook_entries_try_shrink(p, pp);
mutex_unlock(&lkh_hook_mutex);
if (NULL == p)
{
return;
}
lkh_hook_entries_free(p);
return;
}
void lkh_unregister_net_hook(struct net *net, const struct lkh_hook_ops *reg)
{
__lkh_unregister_net_hook(net, reg->pf, reg);
return;
}
/*********************************************************************************
* Description  
* 指定协议、阶段、优先级解除注册钩子处理函数
* Input:  
* net - 网络命名空间指针
* reg - 协议、阶段、优先级、钩子函数指针等信息保存结构
* n - 一次注册的钩子函数数量
* Output:
* 
* Return:
* 无
* Others:
* 无
**********************************************************************************/
void lkh_unregister_net_hooks(struct net *net, const struct lkh_hook_ops *reg, unsigned int hookcount)
{
unsigned int i;
for (i = 0; i < hookcount; i++)
{
lkh_unregister_net_hook(net, &reg[i]);
}
return;
}
/*********************************************************************************
* Description  
* 指定协议、阶段、优先级注册钩子处理函数
* Input:  
* net - 网络命名空间指针
* reg - 协议、阶段、优先级、钩子函数指针等信息保存结构
* hookcount - 一次注册的钩子函数数量
* Output:
* 
* Return:
* 0 - 成功
* 非0 - 失败
* Others:
* 入参的reg结构需要调用者确认是否需要保存因为HOOK管理结构不会记录reg的内容而是记录reg的指针 netfilter就是这样实现
* 当前按照netfilter的方式实现。
**********************************************************************************/
int lkh_register_net_hooks(struct net *net, const struct lkh_hook_ops *reg, unsigned int hookcount)
{
unsigned int i;
int err = 0;
for (i = 0; i < hookcount; i++)
{
err = lkh_register_net_hook(net, &reg[i]);
if (err != 0)
{
goto err;
}
}
return err;
err:
if (i > 0)
{
lkh_unregister_net_hooks(net, reg, i);
}
return err;
}
void lkh_hash_struct_show(void)
{
struct lkh_hook * hook;
struct lkh_hook_entries * hook_entries= NULL;
struct lkh_hook_ops **new_ops;
struct lkh_hook_ops * ops = NULL;
struct lkh_hook_entry * entry = NULL;
list_for_each_entry(hook, &g_lkh_hook_handle.list, list)
{
if (hook->net_ptr != NULL)
{
printk(KERN_EMERG "------------------------------ net hook begin-----------------------------");
printk(KERN_EMERG "hook->net_ptr %p", hook->net_ptr);
printk(KERN_EMERG "------------------------------ net hook ipv4-----------------------------");
printk(KERN_EMERG "hook->hooks_ipv4 %p", hook->hooks_ipv4);
printk(KERN_EMERG "------------------------- LKH_INET_PRE_FORWARD------------------------");
printk(KERN_EMERG "hook_entries %p", hook->hooks_ipv4[LKH_INET_PRE_FORWARD]);
hook_entries = hook->hooks_ipv4[LKH_INET_PRE_FORWARD];
if (hook_entries != NULL)
{
int i = 0;
new_ops = lkh_hook_entries_get_hook_ops(hook_entries);
printk(KERN_EMERG "hook_entries num_hook_entries %d new_ops: %p", hook_entries->num_hook_entries, new_ops);
for (; i < hook_entries->num_hook_entries; i++)
{
entry = &(hook_entries->hooks[i]);
ops = new_ops[i];
printk(KERN_EMERG "----------- hook [%d]------------", i);
printk(KERN_EMERG "hook entry [%d] entry ptr: %p new_ops: %p", i, entry, ops);
printk(KERN_EMERG "entry hook: %p", entry->hook);
printk(KERN_EMERG "ops pf: %d hook_stage: %d priority: %d hook: %p", ops->pf, ops->hook_stage, ops->priority, ops->hook);
}
}
printk(KERN_EMERG "------------------------------ net hook end-----------------------------");
}
else
{
printk(KERN_EMERG "------------------------------ net hook begin-----------------------------");
printk(KERN_EMERG "hook->net_ptr is NULL");
printk(KERN_EMERG "------------------------------ net hook end-----------------------------");
}
}
return;
}
EXPORT_SYMBOL(lkh_register_net_hooks);
EXPORT_SYMBOL(lkh_unregister_net_hooks);