#include "conntrack_api.h"

/************************************************************
*  函数功能:设置ct扩展字段中,大小是16位的字段
*  输入:sk_buff, value(具体设置的值), type(被设置字段的类型)
*  输出:无
*  返回值: 设置是否成功标志
************************************************************/
int cmhi_set_conntrack_u16(const struct sk_buff *skb, uint16_t value, cmhi_ext_type type)
{
    enum ip_conntrack_info ctinfo = {0};
    struct nf_conn *ct = NULL;

    if(!skb){
        printk(KERN_ERR"[CT_API]set_conntrack_u16: skb is null\n");
        return CMHI_EXT_ERR;
    }
    
    ct = nf_ct_get(skb, &ctinfo);
    if(!ct){
        printk(KERN_ERR"[CT_API]set-conntrak_u16: ct is not existed or ct is untracked\n");
        return CMHI_EXT_ERR;
    }
    switch (type)
    {
        case USER_VERSION: 
            ct->cmhi.user_version = value;
            break;
        case APP_ID: 
            ct->cmhi.app_id = value;
            break;
        default: {
            printk(KERN_INFO"[CT_API]set-conntrak_u16: value is not u16\n");
            return CMHI_EXT_ERR;
        }
    }
    return CMHI_EXT_OK;
}

/***********************************************************
*  函数功能:设置ct扩展字段中,大小是32位的字段
*  输入:sk_buff, value(具体设置的值), type(被设置字段的类型)
*  输出:无
*  返回值: 设置是否成功标志
***********************************************************/
int cmhi_set_conntrack_u32(const struct sk_buff *skb, uint32_t value, cmhi_ext_type type)
{
    enum ip_conntrack_info ctinfo = {0};
    struct nf_conn *ct = NULL;

    if(!skb){
        printk(KERN_ERR"[CT_API]set-conntrack_u32: skb is null\n");
        return CMHI_EXT_ERR;
    }
    ct = nf_ct_get(skb, &ctinfo);
    if(!ct){
        printk(KERN_ERR"[CT_API]set-conntrak_u32: ct is not existed or ct is untracked\n");
        return CMHI_EXT_ERR;
    }
    switch (type)
    {
        case USER_ID:
            ct->cmhi.user_id = value;
            break;
        case NODE_INDEX: 
            ct->cmhi.node_index = value;
            break;
        case POLICY_VERSION: 
            ct->cmhi.policy_version = value;
            break;
        case ACTION: 
            ct->cmhi.action = value;
            break;
        default: {
            printk(KERN_INFO"[CT_API]set-conntrak_u32: value is not u32\n");
            return CMHI_EXT_ERR;
        }
    }
    return CMHI_EXT_OK;
}

/**********************************************************
*  函数功能:按照比特位设置ct扩展字段中action字段
*  输入:sk_buff, abit(需要被设置的比特位)
*  输出:无
*  返回值: 设置是否成功标志
**********************************************************/
int cmhi_set_conntrack_action_by_bit(const struct sk_buff *skb, cmhi_ext_action_bit abit)
{
    enum ip_conntrack_info ctinfo = {0};
    struct nf_conn *ct = NULL;

    if(!skb){
        printk(KERN_ERR"[CT_API]set-conntrack_action_by_bit: skb is null\n");
        return CMHI_EXT_ERR;
    }
    
    ct = nf_ct_get(skb, &ctinfo);
    if(!ct){
        printk(KERN_ERR"[CT_API]set-conntrak_action_by_bit: ct is not existed or ct is untracked\n");
        return CMHI_EXT_ERR;
    }
    
    switch (abit)
    {
        case CMHI_EXT_PASS:         
        case CMHI_EXT_GOTO_DPI: 
            ct->cmhi.action |= abit;
            break;
        default: {
            printk(KERN_INFO"[CT_API]set-conntrak_action_by_bit: wrong action bit\n");
            return CMHI_EXT_ERR;
        }
    }
    return CMHI_EXT_OK;
}


/**********************************************************
*  函数功能:获取ct扩展字段中,大小是16位的字段
*  输入:sk_buff, value(被获取字段的指针), type(被获取字段的类型)
*  输出:被获取字段
*  返回值: 获取是否成功标志
**********************************************************/
int cmhi_get_conntrack_u16(const struct sk_buff *skb, uint16_t *value, cmhi_ext_type type)
{
    enum ip_conntrack_info ctinfo = {0};
    struct nf_conn *ct = NULL;

    if(!value){
        printk(KERN_ERR"[CT_API]get-conntrak_u16: value is null\n");
        return CMHI_EXT_ERR;
    }

    if(!skb){
        printk(KERN_ERR"[CT_API]get-conntrack_u16: skb is null\n");
        return CMHI_EXT_ERR;
    }
    
    ct = nf_ct_get(skb, &ctinfo);
    if(!ct){
        printk(KERN_ERR"[CT_API]get-conntrak_u16: ct is not existed or ct is untracked\n");
        return CMHI_EXT_ERR;
    }
    switch (type)
    {
        case USER_VERSION: 
            *value = ct->cmhi.user_version;
            break;
        case APP_ID: 
            *value = ct->cmhi.app_id;
            break;
        default: {
            printk(KERN_INFO"[CT_API]get-conntrak_u16: value is not u16\n");
            return CMHI_EXT_ERR;
        }
    }
    return CMHI_EXT_OK;
}

/**********************************************************
*  函数功能:获取ct扩展字段中,大小是32位的字段
*  输入:sk_buff, value(被获取字段的指针), type(被获取字段的类型)
*  输出:被获取字段
*  返回值: 获取是否成功标志
**********************************************************/
int cmhi_get_conntrack_u32(const struct sk_buff *skb, uint32_t *value, cmhi_ext_type type)
{
    enum ip_conntrack_info ctinfo = {0};
    struct nf_conn *ct = NULL;

    if(!value){
        printk(KERN_ERR"[CT_API]get-conntrak_u32: value is null\n");
        return CMHI_EXT_ERR;
    }

    if(!skb){
        printk(KERN_ERR"[CT_API]get-conntrack_u32: skb is null\n");
        return CMHI_EXT_ERR;
    }
    
    ct = nf_ct_get(skb, &ctinfo);
    if(!ct){
        printk(KERN_ERR"[CT_API]get-conntrak_u32: ct is not existed or ct is untracked\n");
        return CMHI_EXT_ERR;
    }
    switch (type)
    {
        case USER_ID: 
            *value = ct->cmhi.user_id;
            break;
        case NODE_INDEX: 
            *value = ct->cmhi.node_index;
            break;
        case POLICY_VERSION: 
            *value = ct->cmhi.policy_version;
            break;
        case ACTION: 
            *value = ct->cmhi.action;
            break;
        default: {
            printk(KERN_INFO"[CT_API]get-conntrak_u32: value is not u32\n");
            return CMHI_EXT_ERR;
        }
    }
    return CMHI_EXT_OK;
}


/**********************************************************
*  函数功能:删除ct扩展字段中的某个字段
*  输入:sk_buff, type(被删除字段的类型)
*  输出:无
*  返回值: 删除是否成功标志
**********************************************************/
int cmhi_del_conntrack(const struct sk_buff *skb, cmhi_ext_type type)
{
    enum ip_conntrack_info ctinfo = {0};
    struct nf_conn *ct = NULL;

    if(!skb){
        printk(KERN_ERR"[CT_API]del-conntrack: skb is null\n");
        return CMHI_EXT_ERR;
    }
    
    ct = nf_ct_get(skb, &ctinfo);
    if(!ct){
        printk(KERN_ERR"[CT_API]del-conntrak: ct is not existed or ct is untracked\n");
        return CMHI_EXT_ERR;
    }
    switch (type)
    {
        case USER_ID: 
            ct->cmhi.user_id = 0;
            break;
        case USER_VERSION: 
            ct->cmhi.user_version = 0;
            break;
        case NODE_INDEX: 
            ct->cmhi.node_index = 0;
            break;
        case APP_ID: 
            ct->cmhi.app_id = 0;
            break;
        case POLICY_VERSION: 
            ct->cmhi.policy_version = 0;
            break;
        case ACTION: 
            ct->cmhi.action = 0;
            break;
        default: {
            printk(KERN_INFO"[CT_API]del-conntrak: wrong type\n");
            return CMHI_EXT_ERR;
        }
    }
    return CMHI_EXT_OK;
}

/**********************************************************
*  函数功能: 删除ct扩展字段中的action字段的某个比特位
*  输入: sk_buff, abit(被删除的位)
*  输出: 无
*  返回值: 删除是否成功标志
**********************************************************/
int cmhi_del_conntrack_action_by_bit(const struct sk_buff *skb, cmhi_ext_action_bit abit)
{
    enum ip_conntrack_info ctinfo = {0};
    struct nf_conn *ct = NULL;

    if(!skb){
        printk(KERN_ERR"[CT_API]del-conntrack_action_by_bit: skb is null\n");
        return CMHI_EXT_ERR;
    }
    
    ct = nf_ct_get(skb, &ctinfo);
    if(!ct){
        printk(KERN_ERR"[CT_API]del-conntrack_action_by_bit: ct is not existed or ct is untracked\n");
        return CMHI_EXT_ERR;
    }
    switch (abit)
    {
        case CMHI_EXT_PASS:         
        case CMHI_EXT_GOTO_DPI: 
            ct->cmhi.action &= (~abit);
            break;
        default: {
            printk(KERN_INFO"[CT_API]del-conntrack_action_by_bit: wrong action bit\n");
            return CMHI_EXT_ERR;
        }
    }
    return CMHI_EXT_OK;
}

/**********************************************************
*  函数功能: 根据tuple获取hash值
*  输入: const struct nf_conntrack_tuple * 五元组指针
*  输出: u32 hash值
*  返回值: u32 hash值
**********************************************************/
static u32 cmhi_hash_conntrack_raw(const struct nf_conntrack_tuple *tuple)
{
	unsigned int n;
	u32 seed;
	seed = cmhi_seed;
	n = (sizeof(tuple->src) + sizeof(tuple->dst.u3)) / sizeof(u32);
	return jhash2((u32 *)tuple, n, seed ^
		      (((__force __u16)tuple->dst.u.all << 16) |
		      tuple->dst.protonum));
}

/*
static 
int cmhi_net_eq(unsigned long ulnet1, unsigned long ulnet2)
{
	return ulnet1 == ulnet2;
}
*/

/**********************************************************
*  函数功能: 判断传入的tuple和查询到的ct中的tuple是否相同
*  判断ct中的bit位是否置为IPS_CONFIRMED_BIT
*  输入: ct中存的tuple的全局hash链表:struct nf_conntrack_tuple_hash *
*  DPI传入的tuple: const struct nf_conntrack_tuple *
*  输出: 无
*  返回值: bool类型 true or false
**********************************************************/
static bool
cmhi_nf_ct_key_equal(struct nf_conntrack_tuple_hash *h,
		const struct nf_conntrack_tuple *tuple)
{
	struct nf_conn *ct = nf_ct_tuplehash_to_ctrack(h);

	/* A conntrack can be recreated with the equal tuple,
	 * so we need to check that the conntrack is confirmed
	 */
	return nf_ct_tuple_equal(tuple, &h->tuple) &&
	       //nf_ct_zone_equal(ct, zone, NF_CT_DIRECTION(h)) &&
	       //nf_ct_is_confirmed(ct) &&
	       //cmhi_net_eq((unsigned long)(nf_ct_net(ct)));
	       nf_ct_is_confirmed(ct);
}

/**********************************************************
*  函数功能:判断此ct是否是expired并设置ct_general
*  输入:链接跟踪指针struct nf_conn *
*  输出:无
*  返回值: 无
**********************************************************/
static void cmhi_nf_ct_gc_expired(struct nf_conn *ct)
{
	if (!atomic_inc_not_zero(&ct->ct_general.use))
		return;

	if (nf_ct_should_gc(ct))
		nf_ct_kill(ct);

	nf_ct_put(ct);
}

/*
 * Warning :
 * - Caller must take a reference on returned object
 *   and recheck nf_ct_tuple_equal(tuple, &h->tuple)
 */
static struct nf_conntrack_tuple_hash *
cmhi_nf_conntrack_find(
		      const struct nf_conntrack_tuple *tuple, u32 hash)
{
	struct nf_conntrack_tuple_hash *h;
	struct hlist_nulls_head *ct_hash;
	struct hlist_nulls_node *n;
	unsigned int bucket, hsize;

begin:
	nf_conntrack_get_ht(&ct_hash, &hsize);
	bucket = reciprocal_scale(hash, hsize);
	//printk(KERN_INFO"[CT_API]###bucket=%d\n", bucket);
	hlist_nulls_for_each_entry_rcu(h, n, &ct_hash[bucket], hnnode) {
		struct nf_conn *ct;

		ct = nf_ct_tuplehash_to_ctrack(h);
		if (nf_ct_is_expired(ct)) {
			cmhi_nf_ct_gc_expired(ct);
			continue;
		}

		if (nf_ct_is_dying(ct))
			continue;

		if (cmhi_nf_ct_key_equal(h, tuple))
			return h;
	}
	/*
	 * if the nulls value we got at the end of this lookup is
	 * not the expected one, we must restart lookup.
	 * We probably met an item that was moved to another chain.
	 */
	if (get_nulls_value(n) != bucket) {
		//printk(KERN_INFO"[CT_API]nulls_value != bucket, bucket=%d\n\n\n\n", bucket);		
		goto begin;
	}

	return NULL;
}



/**********************************************************
*  函数功能: 根据DPI信息构造的tuple和hash值查找tuple所在hash
*  桶的指针,并做相应判断
*  输入: 构造的tuple:const struct nf_conntrack_tuple *、根据
*  五元组构造的tuple生成的hash值:u32
*  输出: tuple所在的hashmap的首指针
*  返回值: tuple所在的hashmap的首指针
**********************************************************/
static struct nf_conntrack_tuple_hash *
cmhi_nf_conntrack_find_get(
			const struct nf_conntrack_tuple *tuple, u32 hash)
{
	struct nf_conntrack_tuple_hash *h;
	struct nf_conn *ct;

	rcu_read_lock();
begin:
	h = cmhi_nf_conntrack_find(tuple, hash);
	if (h) {
		ct = nf_ct_tuplehash_to_ctrack(h);
		if (unlikely(nf_ct_is_dying(ct) ||
			     !atomic_inc_not_zero(&ct->ct_general.use)))
			h = NULL;
		else {
			if (unlikely(!cmhi_nf_ct_key_equal(h, tuple))) {
				nf_ct_put(ct);
				goto begin;
			}
		}
	}
	rcu_read_unlock();

	return h;
}

/**********************************************************
*  函数功能: 根据自定义DPI tuple获取链接跟踪指针ct
*  输入: 自定义DPI五元组tuple地址: struct dpi_tuple *
*  输出: 链接跟踪指针struct nf_conn *
*  返回值: 链接跟踪指针struct nf_conn *
**********************************************************/
struct nf_conn *get_conntrack_from_tuple(struct dpi_tuple *dpi_tuple)
{
	u32 hash;
	struct nf_conntrack_tuple tuple = {0};
	struct nf_conntrack_tuple_hash *h= NULL;
	struct nf_conn *ct = NULL;

	if(!dpi_tuple){
		printk(KERN_ERR"[CT_API]get-conntrack-from-tuple: dpi_tuple is null.\n");
		return NULL;
	}
	
	tuple.src.l3num = 2;
	tuple.src.u3.ip = dpi_tuple->sip;
	tuple.dst.u3.ip = dpi_tuple->dip;
	tuple.dst.protonum = dpi_tuple->protonum;
	tuple.dst.dir = IP_CT_DIR_ORIGINAL;

	/* UDP */
	if(dpi_tuple->protonum == IPPROTO_UDP){
		tuple.src.u.udp.port = dpi_tuple->sport;
		tuple.dst.u.udp.port = dpi_tuple->dport;
	}

	/* TCP */
	if(dpi_tuple->protonum == IPPROTO_TCP){
		tuple.src.u.tcp.port = dpi_tuple->sport;
		tuple.dst.u.tcp.port = dpi_tuple->dport;
	}
	
	printk(KERN_INFO"[CT_API]<L3XXX>:src.l3num=%d, src.u3.ip=%x, dst.u3.ip=%x, dst.protonum=%d, dst.dir=%d,dst.u.all=%d\n",
		tuple.src.l3num, tuple.src.u3.ip, tuple.dst.u3.ip, tuple.dst.protonum, tuple.dst.dir, tuple.dst.u.all);
	hash = cmhi_hash_conntrack_raw(&tuple);
	printk(KERN_INFO"[CT_API]get-conntrack-from-tuple:hash=%d\n", hash);
	h = cmhi_nf_conntrack_find_get(&tuple, hash);
	if (!h) {
		printk(KERN_ERR"[CT_API]get-conntrack-from-tuple: h is null.\n");
		return NULL;
	}
	ct = nf_ct_tuplehash_to_ctrack(h);
	return ct;
}

/**********************************************************
*  函数功能: 将指定应用id aid设置到相应的链接跟踪ct中
*  输入: 链接跟踪表指针struct nf_conn *、应用id aid
*  输出: 无
*  返回值: 0或-1
**********************************************************/
int __set_aid_by_dpi_tuple(struct nf_conn *ct, uint16_t aid)
{
	if(!ct){
		printk(KERN_ERR"[CT_API]__set-appid-by-dpituple:input ct is null.\n");
		return CMHI_EXT_ERR;
	}
	if(!ct){
		printk(KERN_ERR"[CT_API]__set-appid-by-dpituple: ct is not existed or ct is untracked\n");
		return CMHI_EXT_ERR;
	}
	ct->cmhi.app_id = aid;
	return CMHI_EXT_OK;
}

/**********************************************************
*  函数功能: 将指定应用id aid设置到相应的链接跟踪ct中
*  输入: 应用id aid
*  输出: 无
*  返回值: 0或-1
**********************************************************/
int set_aid_by_dpi_tuple(struct dpi *dpi)
{
	int ret;
	struct nf_conn *ct;
	if(!dpi){
		printk(KERN_ERR"[CT_API]set-appid-by-dpituple: dpi is null.\n");
		return CMHI_EXT_ERR;
	}
	ct = get_conntrack_from_tuple(&dpi->tuple);
	if(!ct){
		printk(KERN_ERR"[CT_API]set-appid-by-dpituple:find ct failed, now return.\n");
		return CMHI_EXT_ERR;
	}
	printk(KERN_INFO"[CT_API]set-appid-by-dpituple: Find ct !!!\n");
	ret = __set_aid_by_dpi_tuple(ct, dpi->aid);
	if(CMHI_EXT_OK != ret){
		printk(KERN_ERR"[CT_API]set-appid-by-dpituple:set aid failed.\n");
		return CMHI_EXT_ERR;
	}
	printk(KERN_INFO"[CT_API]set-appid-by-dpituple: set aid ok\n");
	return CMHI_EXT_OK;
}

EXPORT_SYMBOL(cmhi_set_conntrack_u16);
EXPORT_SYMBOL(cmhi_set_conntrack_u32);
EXPORT_SYMBOL(cmhi_set_conntrack_action_by_bit);

EXPORT_SYMBOL(cmhi_get_conntrack_u16);
EXPORT_SYMBOL(cmhi_get_conntrack_u32);

EXPORT_SYMBOL(cmhi_del_conntrack);
EXPORT_SYMBOL(cmhi_del_conntrack_action_by_bit);
EXPORT_SYMBOL(get_conntrack_from_tuple);
EXPORT_SYMBOL(set_aid_by_dpi_tuple);
static int __init API_init(void)
{
    printk(KERN_INFO"[API]module init.\n");
    return 0;
}

static void __exit API_exit(void)
{
    printk(KERN_INFO"[API]module exit.\n");
    return;
}

module_init(API_init);
module_exit(API_exit);
MODULE_DESCRIPTION("CMHI API");
MODULE_VERSION("0.1");
MODULE_LICENSE("GPL");