#include "conntrack_api.h" /************************************************************ * 函数功能:设置ct扩展字段中,大小是16位的字段 * 输入:sk_buff, value(具体设置的值), type(被设置字段的类型) * 输出:无 * 返回值: 设置是否成功标志 ************************************************************/ int cmhi_set_conntrack_u16(const struct sk_buff *skb, uint16_t value, cmhi_ext_type type) { enum ip_conntrack_info ctinfo = {0}; struct nf_conn *ct = NULL; if(!skb){ printk(KERN_ERR"[CT_API]set_conntrack_u16: skb is null\n"); return CMHI_EXT_ERR; } ct = nf_ct_get(skb, &ctinfo); if(!ct){ printk(KERN_ERR"[CT_API]set-conntrak_u16: ct is not existed or ct is untracked\n"); return CMHI_EXT_ERR; } switch (type) { case USER_VERSION: ct->cmhi.user_version = value; break; case APP_ID: ct->cmhi.app_id = value; break; default: { printk(KERN_INFO"[CT_API]set-conntrak_u16: value is not u16\n"); return CMHI_EXT_ERR; } } return CMHI_EXT_OK; } /*********************************************************** * 函数功能:设置ct扩展字段中,大小是32位的字段 * 输入:sk_buff, value(具体设置的值), type(被设置字段的类型) * 输出:无 * 返回值: 设置是否成功标志 ***********************************************************/ int cmhi_set_conntrack_u32(const struct sk_buff *skb, uint32_t value, cmhi_ext_type type) { enum ip_conntrack_info ctinfo = {0}; struct nf_conn *ct = NULL; if(!skb){ printk(KERN_ERR"[CT_API]set-conntrack_u32: skb is null\n"); return CMHI_EXT_ERR; } ct = nf_ct_get(skb, &ctinfo); if(!ct){ printk(KERN_ERR"[CT_API]set-conntrak_u32: ct is not existed or ct is untracked\n"); return CMHI_EXT_ERR; } switch (type) { case USER_ID: ct->cmhi.user_id = value; break; case NODE_INDEX: ct->cmhi.node_index = value; break; case POLICY_VERSION: ct->cmhi.policy_version = value; break; case ACTION: ct->cmhi.action = value; break; default: { printk(KERN_INFO"[CT_API]set-conntrak_u32: value is not u32\n"); return CMHI_EXT_ERR; } } return CMHI_EXT_OK; } /********************************************************** * 函数功能:按照比特位设置ct扩展字段中action字段 * 输入:sk_buff, abit(需要被设置的比特位) * 输出:无 * 返回值: 设置是否成功标志 **********************************************************/ int cmhi_set_conntrack_action_by_bit(const struct sk_buff *skb, cmhi_ext_action_bit abit) { enum ip_conntrack_info ctinfo = {0}; struct nf_conn *ct = NULL; if(!skb){ printk(KERN_ERR"[CT_API]set-conntrack_action_by_bit: skb is null\n"); return CMHI_EXT_ERR; } ct = nf_ct_get(skb, &ctinfo); if(!ct){ printk(KERN_ERR"[CT_API]set-conntrak_action_by_bit: ct is not existed or ct is untracked\n"); return CMHI_EXT_ERR; } switch (abit) { case CMHI_EXT_PASS: case CMHI_EXT_GOTO_DPI: ct->cmhi.action |= abit; break; default: { printk(KERN_INFO"[CT_API]set-conntrak_action_by_bit: wrong action bit\n"); return CMHI_EXT_ERR; } } return CMHI_EXT_OK; } /********************************************************** * 函数功能:获取ct扩展字段中,大小是16位的字段 * 输入:sk_buff, value(被获取字段的指针), type(被获取字段的类型) * 输出:被获取字段 * 返回值: 获取是否成功标志 **********************************************************/ int cmhi_get_conntrack_u16(const struct sk_buff *skb, uint16_t *value, cmhi_ext_type type) { enum ip_conntrack_info ctinfo = {0}; struct nf_conn *ct = NULL; if(!value){ printk(KERN_ERR"[CT_API]get-conntrak_u16: value is null\n"); return CMHI_EXT_ERR; } if(!skb){ printk(KERN_ERR"[CT_API]get-conntrack_u16: skb is null\n"); return CMHI_EXT_ERR; } ct = nf_ct_get(skb, &ctinfo); if(!ct){ printk(KERN_ERR"[CT_API]get-conntrak_u16: ct is not existed or ct is untracked\n"); return CMHI_EXT_ERR; } switch (type) { case USER_VERSION: *value = ct->cmhi.user_version; break; case APP_ID: *value = ct->cmhi.app_id; break; default: { printk(KERN_INFO"[CT_API]get-conntrak_u16: value is not u16\n"); return CMHI_EXT_ERR; } } return CMHI_EXT_OK; } /********************************************************** * 函数功能:获取ct扩展字段中,大小是32位的字段 * 输入:sk_buff, value(被获取字段的指针), type(被获取字段的类型) * 输出:被获取字段 * 返回值: 获取是否成功标志 **********************************************************/ int cmhi_get_conntrack_u32(const struct sk_buff *skb, uint32_t *value, cmhi_ext_type type) { enum ip_conntrack_info ctinfo = {0}; struct nf_conn *ct = NULL; if(!value){ printk(KERN_ERR"[CT_API]get-conntrak_u32: value is null\n"); return CMHI_EXT_ERR; } if(!skb){ printk(KERN_ERR"[CT_API]get-conntrack_u32: skb is null\n"); return CMHI_EXT_ERR; } ct = nf_ct_get(skb, &ctinfo); if(!ct){ printk(KERN_ERR"[CT_API]get-conntrak_u32: ct is not existed or ct is untracked\n"); return CMHI_EXT_ERR; } switch (type) { case USER_ID: *value = ct->cmhi.user_id; break; case NODE_INDEX: *value = ct->cmhi.node_index; break; case POLICY_VERSION: *value = ct->cmhi.policy_version; break; case ACTION: *value = ct->cmhi.action; break; default: { printk(KERN_INFO"[CT_API]get-conntrak_u32: value is not u32\n"); return CMHI_EXT_ERR; } } return CMHI_EXT_OK; } /********************************************************** * 函数功能:删除ct扩展字段中的某个字段 * 输入:sk_buff, type(被删除字段的类型) * 输出:无 * 返回值: 删除是否成功标志 **********************************************************/ int cmhi_del_conntrack(const struct sk_buff *skb, cmhi_ext_type type) { enum ip_conntrack_info ctinfo = {0}; struct nf_conn *ct = NULL; if(!skb){ printk(KERN_ERR"[CT_API]del-conntrack: skb is null\n"); return CMHI_EXT_ERR; } ct = nf_ct_get(skb, &ctinfo); if(!ct){ printk(KERN_ERR"[CT_API]del-conntrak: ct is not existed or ct is untracked\n"); return CMHI_EXT_ERR; } switch (type) { case USER_ID: ct->cmhi.user_id = 0; break; case USER_VERSION: ct->cmhi.user_version = 0; break; case NODE_INDEX: ct->cmhi.node_index = 0; break; case APP_ID: ct->cmhi.app_id = 0; break; case POLICY_VERSION: ct->cmhi.policy_version = 0; break; case ACTION: ct->cmhi.action = 0; break; default: { printk(KERN_INFO"[CT_API]del-conntrak: wrong type\n"); return CMHI_EXT_ERR; } } return CMHI_EXT_OK; } /********************************************************** * 函数功能: 删除ct扩展字段中的action字段的某个比特位 * 输入: sk_buff, abit(被删除的位) * 输出: 无 * 返回值: 删除是否成功标志 **********************************************************/ int cmhi_del_conntrack_action_by_bit(const struct sk_buff *skb, cmhi_ext_action_bit abit) { enum ip_conntrack_info ctinfo = {0}; struct nf_conn *ct = NULL; if(!skb){ printk(KERN_ERR"[CT_API]del-conntrack_action_by_bit: skb is null\n"); return CMHI_EXT_ERR; } ct = nf_ct_get(skb, &ctinfo); if(!ct){ printk(KERN_ERR"[CT_API]del-conntrack_action_by_bit: ct is not existed or ct is untracked\n"); return CMHI_EXT_ERR; } switch (abit) { case CMHI_EXT_PASS: case CMHI_EXT_GOTO_DPI: ct->cmhi.action &= (~abit); break; default: { printk(KERN_INFO"[CT_API]del-conntrack_action_by_bit: wrong action bit\n"); return CMHI_EXT_ERR; } } return CMHI_EXT_OK; } /********************************************************** * 函数功能: 根据tuple获取hash值 * 输入: const struct nf_conntrack_tuple * 五元组指针 * 输出: u32 hash值 * 返回值: u32 hash值 **********************************************************/ static u32 cmhi_hash_conntrack_raw(const struct nf_conntrack_tuple *tuple) { unsigned int n; u32 seed; seed = cmhi_seed; n = (sizeof(tuple->src) + sizeof(tuple->dst.u3)) / sizeof(u32); return jhash2((u32 *)tuple, n, seed ^ (((__force __u16)tuple->dst.u.all << 16) | tuple->dst.protonum)); } /* static int cmhi_net_eq(unsigned long ulnet1, unsigned long ulnet2) { return ulnet1 == ulnet2; } */ /********************************************************** * 函数功能: 判断传入的tuple和查询到的ct中的tuple是否相同 * 判断ct中的bit位是否置为IPS_CONFIRMED_BIT * 输入: ct中存的tuple的全局hash链表:struct nf_conntrack_tuple_hash * * DPI传入的tuple: const struct nf_conntrack_tuple * * 输出: 无 * 返回值: bool类型 true or false **********************************************************/ static bool cmhi_nf_ct_key_equal(struct nf_conntrack_tuple_hash *h, const struct nf_conntrack_tuple *tuple) { struct nf_conn *ct = nf_ct_tuplehash_to_ctrack(h); /* A conntrack can be recreated with the equal tuple, * so we need to check that the conntrack is confirmed */ return nf_ct_tuple_equal(tuple, &h->tuple) && //nf_ct_zone_equal(ct, zone, NF_CT_DIRECTION(h)) && //nf_ct_is_confirmed(ct) && //cmhi_net_eq((unsigned long)(nf_ct_net(ct))); nf_ct_is_confirmed(ct); } /********************************************************** * 函数功能:判断此ct是否是expired并设置ct_general * 输入:链接跟踪指针struct nf_conn * * 输出:无 * 返回值: 无 **********************************************************/ static void cmhi_nf_ct_gc_expired(struct nf_conn *ct) { if (!atomic_inc_not_zero(&ct->ct_general.use)) return; if (nf_ct_should_gc(ct)) nf_ct_kill(ct); nf_ct_put(ct); } /* * Warning : * - Caller must take a reference on returned object * and recheck nf_ct_tuple_equal(tuple, &h->tuple) */ static struct nf_conntrack_tuple_hash * cmhi_nf_conntrack_find( const struct nf_conntrack_tuple *tuple, u32 hash) { struct nf_conntrack_tuple_hash *h; struct hlist_nulls_head *ct_hash; struct hlist_nulls_node *n; unsigned int bucket, hsize; begin: nf_conntrack_get_ht(&ct_hash, &hsize); bucket = reciprocal_scale(hash, hsize); //printk(KERN_INFO"[CT_API]###bucket=%d\n", bucket); hlist_nulls_for_each_entry_rcu(h, n, &ct_hash[bucket], hnnode) { struct nf_conn *ct; ct = nf_ct_tuplehash_to_ctrack(h); if (nf_ct_is_expired(ct)) { cmhi_nf_ct_gc_expired(ct); continue; } if (nf_ct_is_dying(ct)) continue; if (cmhi_nf_ct_key_equal(h, tuple)) return h; } /* * if the nulls value we got at the end of this lookup is * not the expected one, we must restart lookup. * We probably met an item that was moved to another chain. */ if (get_nulls_value(n) != bucket) { //printk(KERN_INFO"[CT_API]nulls_value != bucket, bucket=%d\n\n\n\n", bucket); goto begin; } return NULL; } /********************************************************** * 函数功能: 根据DPI信息构造的tuple和hash值查找tuple所在hash * 桶的指针,并做相应判断 * 输入: 构造的tuple:const struct nf_conntrack_tuple *、根据 * 五元组构造的tuple生成的hash值:u32 * 输出: tuple所在的hashmap的首指针 * 返回值: tuple所在的hashmap的首指针 **********************************************************/ static struct nf_conntrack_tuple_hash * cmhi_nf_conntrack_find_get( const struct nf_conntrack_tuple *tuple, u32 hash) { struct nf_conntrack_tuple_hash *h; struct nf_conn *ct; rcu_read_lock(); begin: h = cmhi_nf_conntrack_find(tuple, hash); if (h) { ct = nf_ct_tuplehash_to_ctrack(h); if (unlikely(nf_ct_is_dying(ct) || !atomic_inc_not_zero(&ct->ct_general.use))) h = NULL; else { if (unlikely(!cmhi_nf_ct_key_equal(h, tuple))) { nf_ct_put(ct); goto begin; } } } rcu_read_unlock(); return h; } /********************************************************** * 函数功能: 根据自定义DPI tuple获取链接跟踪指针ct * 输入: 自定义DPI五元组tuple地址: struct dpi_tuple * * 输出: 链接跟踪指针struct nf_conn * * 返回值: 链接跟踪指针struct nf_conn * **********************************************************/ struct nf_conn *get_conntrack_from_tuple(struct dpi_tuple *dpi_tuple) { u32 hash; struct nf_conntrack_tuple tuple = {0}; struct nf_conntrack_tuple_hash *h= NULL; struct nf_conn *ct = NULL; if(!dpi_tuple){ printk(KERN_ERR"[CT_API]get-conntrack-from-tuple: dpi_tuple is null.\n"); return NULL; } tuple.src.l3num = 2; tuple.src.u3.ip = dpi_tuple->sip; tuple.dst.u3.ip = dpi_tuple->dip; tuple.dst.protonum = dpi_tuple->protonum; tuple.dst.dir = IP_CT_DIR_ORIGINAL; /* UDP */ if(dpi_tuple->protonum == IPPROTO_UDP){ tuple.src.u.udp.port = dpi_tuple->sport; tuple.dst.u.udp.port = dpi_tuple->dport; } /* TCP */ if(dpi_tuple->protonum == IPPROTO_TCP){ tuple.src.u.tcp.port = dpi_tuple->sport; tuple.dst.u.tcp.port = dpi_tuple->dport; } printk(KERN_INFO"[CT_API]:src.l3num=%d, src.u3.ip=%x, dst.u3.ip=%x, dst.protonum=%d, dst.dir=%d,dst.u.all=%d\n", tuple.src.l3num, tuple.src.u3.ip, tuple.dst.u3.ip, tuple.dst.protonum, tuple.dst.dir, tuple.dst.u.all); hash = cmhi_hash_conntrack_raw(&tuple); printk(KERN_INFO"[CT_API]get-conntrack-from-tuple:hash=%d\n", hash); h = cmhi_nf_conntrack_find_get(&tuple, hash); if (!h) { printk(KERN_ERR"[CT_API]get-conntrack-from-tuple: h is null.\n"); return NULL; } ct = nf_ct_tuplehash_to_ctrack(h); return ct; } /********************************************************** * 函数功能: 将指定应用id aid设置到相应的链接跟踪ct中 * 输入: 链接跟踪表指针struct nf_conn *、应用id aid * 输出: 无 * 返回值: 0或-1 **********************************************************/ int __set_aid_by_dpi_tuple(struct nf_conn *ct, uint16_t aid) { if(!ct){ printk(KERN_ERR"[CT_API]__set-appid-by-dpituple:input ct is null.\n"); return CMHI_EXT_ERR; } if(!ct){ printk(KERN_ERR"[CT_API]__set-appid-by-dpituple: ct is not existed or ct is untracked\n"); return CMHI_EXT_ERR; } ct->cmhi.app_id = aid; return CMHI_EXT_OK; } /********************************************************** * 函数功能: 将指定应用id aid设置到相应的链接跟踪ct中 * 输入: 应用id aid * 输出: 无 * 返回值: 0或-1 **********************************************************/ int set_aid_by_dpi_tuple(struct dpi *dpi) { int ret; struct nf_conn *ct; if(!dpi){ printk(KERN_ERR"[CT_API]set-appid-by-dpituple: dpi is null.\n"); return CMHI_EXT_ERR; } ct = get_conntrack_from_tuple(&dpi->tuple); if(!ct){ printk(KERN_ERR"[CT_API]set-appid-by-dpituple:find ct failed, now return.\n"); return CMHI_EXT_ERR; } printk(KERN_INFO"[CT_API]set-appid-by-dpituple: Find ct !!!\n"); ret = __set_aid_by_dpi_tuple(ct, dpi->aid); if(CMHI_EXT_OK != ret){ printk(KERN_ERR"[CT_API]set-appid-by-dpituple:set aid failed.\n"); return CMHI_EXT_ERR; } printk(KERN_INFO"[CT_API]set-appid-by-dpituple: set aid ok\n"); return CMHI_EXT_OK; } EXPORT_SYMBOL(cmhi_set_conntrack_u16); EXPORT_SYMBOL(cmhi_set_conntrack_u32); EXPORT_SYMBOL(cmhi_set_conntrack_action_by_bit); EXPORT_SYMBOL(cmhi_get_conntrack_u16); EXPORT_SYMBOL(cmhi_get_conntrack_u32); EXPORT_SYMBOL(cmhi_del_conntrack); EXPORT_SYMBOL(cmhi_del_conntrack_action_by_bit); EXPORT_SYMBOL(get_conntrack_from_tuple); EXPORT_SYMBOL(set_aid_by_dpi_tuple); static int __init API_init(void) { printk(KERN_INFO"[API]module init.\n"); return 0; } static void __exit API_exit(void) { printk(KERN_INFO"[API]module exit.\n"); return; } module_init(API_init); module_exit(API_exit); MODULE_DESCRIPTION("CMHI API"); MODULE_VERSION("0.1"); MODULE_LICENSE("GPL");