#include #include #include #include #include #include #include #include #include #include #include #include #include #include "lkh_hook.h" struct lkh_hook_handle g_lkh_hook_handle; static unsigned int lkh_accept_all(void *priv, struct sk_buff *skb, const struct lkh_hook_state *state) { return LKH_ACCEPT; /* ACCEPT makes nf_hook_slow call next hook */ } static const struct lkh_hook_ops dummy_ops = { .hook = lkh_accept_all, .priority = INT_MIN, }; static inline struct lkh_hook_ops **lkh_hook_entries_get_hook_ops(const struct lkh_hook_entries *e) { unsigned int n = e->num_hook_entries; const void *hook_end; hook_end = &e->hooks[n]; /* this is *past* ->hooks[]! */ return (struct lkh_hook_ops **)hook_end; } /* 注册流程 */ static void __lkh_hook_entries_free(struct rcu_head *h) { struct lkh_hook_entries_rcu_head *head; head = container_of(h, struct lkh_hook_entries_rcu_head, head); kvfree(head->allocation); return; } static void lkh_hook_entries_free(struct lkh_hook_entries *e) { struct lkh_hook_entries_rcu_head *head; struct lkh_hook_ops **ops; unsigned int num; if (NULL == e) { return; } num = e->num_hook_entries; ops = lkh_hook_entries_get_hook_ops(e); head = (void *)&ops[num]; head->allocation = e; call_rcu(&head->head, __lkh_hook_entries_free); return; } static void lkh_hooks_validate(const struct lkh_hook_entries *hooks) { #ifdef CONFIG_DEBUG_KERNEL struct lkh_hook_ops **orig_ops; int prio = INT_MIN; size_t i = 0; orig_ops = lkh_hook_entries_get_hook_ops(hooks); for (i = 0; i < hooks->num_hook_entries; i++) { if (orig_ops[i] == &dummy_ops) { continue; } if (orig_ops[i]->priority > prio) { prio = orig_ops[i]->priority; } } #endif } /* 申请新的存储结构 */ static struct lkh_hook_entries * lkh_allocate_hook_entries_size(u16 num) { struct lkh_hook_entries *e; size_t alloc = sizeof(*e) + sizeof(struct lkh_hook_entry) * num + sizeof(struct lkh_hook_ops *) * num + sizeof(struct lkh_hook_entries_rcu_head); if (num == 0) { return NULL; } e = kvzalloc(alloc, GFP_KERNEL); if (e) { e->num_hook_entries = num; } return e; } static struct lkh_hook_entries * lkh_hook_entries_grow(const struct lkh_hook_entries *old, const struct lkh_hook_ops *reg) { unsigned int i, alloc_entries, nhooks, old_entries; struct lkh_hook_ops **orig_ops = NULL; struct lkh_hook_ops **new_ops; struct lkh_hook_entries *new; bool inserted = false; alloc_entries = 1; old_entries = old ? old->num_hook_entries : 0; if (old != NULL) { orig_ops = lkh_hook_entries_get_hook_ops(old); for (i = 0; i < old_entries; i++) { if (orig_ops[i] != &dummy_ops) { alloc_entries++; } } } if (alloc_entries > MAX_HOOK_COUNT) { return ERR_PTR(-E2BIG); } new = lkh_allocate_hook_entries_size(alloc_entries); if (NULL == new) { return ERR_PTR(-ENOMEM); } new_ops = lkh_hook_entries_get_hook_ops(new); i = 0; nhooks = 0; while (i < old_entries) { if (orig_ops[i] == &dummy_ops) { ++i; continue; } if (inserted || (reg->priority > orig_ops[i]->priority)) { new_ops[nhooks] = (void *)orig_ops[i]; new->hooks[nhooks] = old->hooks[i]; i++; } else { new_ops[nhooks] = (void *)reg; new->hooks[nhooks].hook = reg->hook; new->hooks[nhooks].priv = reg->priv; inserted = true; } nhooks++; } if (!inserted) { new_ops[nhooks] = (void *)reg; new->hooks[nhooks].hook = reg->hook; new->hooks[nhooks].priv = reg->priv; } return new; } static struct lkh_hook_entries __rcu ** lkh_hook_entry_head(struct net *net, int pf, unsigned int hook_stage) { struct lkh_hook * hook; list_for_each_entry(hook, &g_lkh_hook_handle.list, list) { if (hook->net_ptr == net) { switch (pf) { case LKH_PROTO_IPV4: if ((ARRAY_SIZE(hook->hooks_ipv4) <= hook_stage)) { return NULL; } return hook->hooks_ipv4 + hook_stage; case LKH_PROTO_IPV6: if ((ARRAY_SIZE(hook->hooks_ipv6) <= hook_stage)) { return NULL; } return hook->hooks_ipv6 + hook_stage; default: return NULL; } break; } } return NULL; } static int __lkh_register_net_hook(struct net *net, int pf, const struct lkh_hook_ops *reg) { struct lkh_hook_entries *p, *new_hooks; struct lkh_hook_entries __rcu **pp; pp = lkh_hook_entry_head(net, pf, reg->hook_stage); if (NULL == pp) { return -EINVAL; } mutex_lock(&lkh_hook_mutex); /* 多核信号量保护处理 RCU */ p = lkh_entry_dereference(*pp); new_hooks = lkh_hook_entries_grow(p, reg); if (!IS_ERR(new_hooks)) { /* * 该接口被writer用来进行removal的操作,在witer完成新版本数据分配和更新之后, * 调用这个接口可以让RCU protected pointer指向RCU protected data */ rcu_assign_pointer(*pp, new_hooks); } mutex_unlock(&lkh_hook_mutex); if (IS_ERR(new_hooks)) { return PTR_ERR(new_hooks); } lkh_hooks_validate(new_hooks); /* 释放旧的指针内存 */ lkh_hook_entries_free(p); return 0; } int lkh_register_net_hook(struct net *net, const struct lkh_hook_ops *reg) { int err; err = __lkh_register_net_hook(net, reg->pf, reg); if (err < 0) { return err; } return 0; } static void *__lkh_hook_entries_try_shrink(struct lkh_hook_entries *old, struct lkh_hook_entries __rcu **pp) { unsigned int i, j, skip = 0, hook_entries; struct lkh_hook_entries *new = NULL; struct lkh_hook_ops **orig_ops; struct lkh_hook_ops **new_ops; if (NULL == old) { return NULL; } orig_ops = lkh_hook_entries_get_hook_ops(old); for (i = 0; i < old->num_hook_entries; i++) { if (orig_ops[i] == &dummy_ops) { skip++; } } hook_entries = old->num_hook_entries; if (skip == hook_entries) { goto out_assign; } if (skip == 0) { return NULL; } hook_entries -= skip; new = lkh_allocate_hook_entries_size(hook_entries); if (NULL == new) { return NULL; } new_ops = lkh_hook_entries_get_hook_ops(new); for (i = 0, j = 0; i < old->num_hook_entries; i++) { if (orig_ops[i] == &dummy_ops) { continue; } new->hooks[j] = old->hooks[i]; new_ops[j] = (void *)orig_ops[i]; j++; } lkh_hooks_validate(new); out_assign: rcu_assign_pointer(*pp, new); return old; } static bool lkh_remove_net_hook(struct lkh_hook_entries *old, const struct lkh_hook_ops *unreg) { struct lkh_hook_ops **orig_ops; unsigned int i; orig_ops = lkh_hook_entries_get_hook_ops(old); for (i = 0; i < old->num_hook_entries; i++) { if (orig_ops[i] != unreg) { continue; } WRITE_ONCE(old->hooks[i].hook, lkh_accept_all); WRITE_ONCE(orig_ops[i], &dummy_ops); return true; } return false; } static void __lkh_unregister_net_hook(struct net *net, int pf, const struct lkh_hook_ops *reg) { struct lkh_hook_entries __rcu **pp; struct lkh_hook_entries *p; pp = lkh_hook_entry_head(net, pf, reg->hook_stage); if (NULL == pp) { return; } mutex_lock(&lkh_hook_mutex); p = lkh_entry_dereference(*pp); if (NULL == p) { mutex_unlock(&lkh_hook_mutex); return; } lkh_remove_net_hook(p, reg); p = __lkh_hook_entries_try_shrink(p, pp); mutex_unlock(&lkh_hook_mutex); if (NULL == p) { return; } lkh_hook_entries_free(p); return; } void lkh_unregister_net_hook(struct net *net, const struct lkh_hook_ops *reg) { __lkh_unregister_net_hook(net, reg->pf, reg); return; } /********************************************************************************* * Description:   * 指定协议、阶段、优先级解除注册钩子处理函数 * Input:   * net - 网络命名空间指针 * reg - 协议、阶段、优先级、钩子函数指针等信息保存结构 * n - 一次注册的钩子函数数量 * Output: *  无 * Return: * 无 * Others: * 无 **********************************************************************************/ void lkh_unregister_net_hooks(struct net *net, const struct lkh_hook_ops *reg, unsigned int hookcount) { unsigned int i; for (i = 0; i < hookcount; i++) { lkh_unregister_net_hook(net, ®[i]); } return; } /********************************************************************************* * Description:   * 指定协议、阶段、优先级注册钩子处理函数 * Input:   * net - 网络命名空间指针 * reg - 协议、阶段、优先级、钩子函数指针等信息保存结构 * hookcount - 一次注册的钩子函数数量 * Output: *  无 * Return: * 0 - 成功 * 非0 - 失败 * Others: * 入参的reg结构需要调用者确认是否需要保存,因为HOOK管理结构不会记录reg的内容,而是记录reg的指针, netfilter就是这样实现 * 当前按照netfilter的方式实现。 **********************************************************************************/ int lkh_register_net_hooks(struct net *net, const struct lkh_hook_ops *reg, unsigned int hookcount) { unsigned int i; int err = 0; for (i = 0; i < hookcount; i++) { err = lkh_register_net_hook(net, ®[i]); if (err != 0) { goto err; } } return err; err: if (i > 0) { lkh_unregister_net_hooks(net, reg, i); } return err; } void lkh_hash_struct_show(void) { struct lkh_hook * hook; struct lkh_hook_entries * hook_entries= NULL; struct lkh_hook_ops **new_ops; struct lkh_hook_ops * ops = NULL; struct lkh_hook_entry * entry = NULL; list_for_each_entry(hook, &g_lkh_hook_handle.list, list) { if (hook->net_ptr != NULL) { printk(KERN_EMERG "------------------------------ net hook begin-----------------------------"); printk(KERN_EMERG "hook->net_ptr %p", hook->net_ptr); printk(KERN_EMERG "------------------------------ net hook ipv4-----------------------------"); printk(KERN_EMERG "hook->hooks_ipv4 %p", hook->hooks_ipv4); printk(KERN_EMERG "------------------------- LKH_INET_PRE_FORWARD------------------------"); printk(KERN_EMERG "hook_entries %p", hook->hooks_ipv4[LKH_INET_PRE_FORWARD]); hook_entries = hook->hooks_ipv4[LKH_INET_PRE_FORWARD]; if (hook_entries != NULL) { int i = 0; new_ops = lkh_hook_entries_get_hook_ops(hook_entries); printk(KERN_EMERG "hook_entries num_hook_entries %d new_ops: %p", hook_entries->num_hook_entries, new_ops); for (; i < hook_entries->num_hook_entries; i++) { entry = &(hook_entries->hooks[i]); ops = new_ops[i]; printk(KERN_EMERG "----------- hook [%d]------------", i); printk(KERN_EMERG "hook entry [%d] entry ptr: %p new_ops: %p", i, entry, ops); printk(KERN_EMERG "entry hook: %p", entry->hook); printk(KERN_EMERG "ops pf: %d hook_stage: %d priority: %d hook: %p", ops->pf, ops->hook_stage, ops->priority, ops->hook); } } printk(KERN_EMERG "------------------------------ net hook end-----------------------------"); } else { printk(KERN_EMERG "------------------------------ net hook begin-----------------------------"); printk(KERN_EMERG "hook->net_ptr is NULL"); printk(KERN_EMERG "------------------------------ net hook end-----------------------------"); } } return; } EXPORT_SYMBOL(lkh_register_net_hooks); EXPORT_SYMBOL(lkh_unregister_net_hooks);