From fbc3d0e58f143e296526bc0b00b4299b2f91f026 Mon Sep 17 00:00:00 2001 From: Nick Peng Date: Sun, 7 Jan 2024 17:49:43 +0800 Subject: [PATCH] feature: add client-rules option. --- etc/smartdns/smartdns.conf | 4 + src/dns_conf.c | 418 ++++++++++++++++++++++++++++++++- src/dns_conf.h | 45 +++- src/dns_server.c | 70 +++++- test/cases/test-client-rule.cc | 65 +++++ 5 files changed, 587 insertions(+), 15 deletions(-) create mode 100644 test/cases/test-client-rule.cc diff --git a/etc/smartdns/smartdns.conf b/etc/smartdns/smartdns.conf index f0a70faa59..fb48b854e5 100644 --- a/etc/smartdns/smartdns.conf +++ b/etc/smartdns/smartdns.conf @@ -374,3 +374,7 @@ log-level info # bogus-nxdomain ip-set:ip-list # ip-alias ip-set:ip-list 1.2.3.4 # ip-alias ip-set:ip-list ip-set:ip-map-list + +# set client rules +# client-rules ip-cidr [-group [group]] [-no-rule-addr] [-no-rule-nameserver] [-no-rule-ipset] [-no-speed-check] [-no-cache] [-no-rule-soa] [-no-dualstack-selection] +# client-rules option is same as bind option, please see bind option for detail. diff --git a/src/dns_conf.c b/src/dns_conf.c index 3566aad66e..6d0a40341a 100644 --- a/src/dns_conf.c +++ b/src/dns_conf.c @@ -143,8 +143,9 @@ int dns_conf_audit_console; int dns_conf_audit_syslog; /* address rules */ -art_tree dns_conf_domain_rule; +struct dns_conf_domain_rule dns_conf_domain_rule; struct dns_conf_address_rule dns_conf_address_rule; +struct dns_conf_client_rule dns_conf_client_rule; /* dual-stack selection */ int dns_conf_dualstack_ip_selection = 1; @@ -196,6 +197,8 @@ static int _conf_client_subnet(char *subnet, struct dns_edns_client_subnet *ipv4 struct dns_edns_client_subnet *ipv6_ecs); static int _conf_domain_rule_address(char *domain, const char *domain_address); static struct dns_domain_rule *_config_domain_rule_get(const char *domain); +typedef int (*set_rule_add_func)(const char *value, void *priv); +static int _config_ip_rule_set_each(const char *ip_set, set_rule_add_func callback, void *priv); static void *_new_dns_rule_ext(enum domain_rule domain_rule, int ext_size) { @@ -851,11 +854,22 @@ static int _config_domain_iter_free(void *data, const unsigned char *key, uint32 static void _config_domain_destroy(void) { - art_iter(&dns_conf_domain_rule, _config_domain_iter_free, NULL); - art_tree_destroy(&dns_conf_domain_rule); + struct dns_conf_doamin_rule_group *group; + struct hlist_node *tmp = NULL; + unsigned long i = 0; + + hash_for_each_safe(dns_conf_domain_rule.group, i, tmp, group, node) + { + hlist_del_init(&group->node); + art_iter(&group->rule, _config_domain_iter_free, NULL); + art_tree_destroy(&group->rule); + free(group); + } + + art_iter(&dns_conf_domain_rule.default_rule, _config_domain_iter_free, NULL); + art_tree_destroy(&dns_conf_domain_rule.default_rule); } -typedef int (*set_rule_add_func)(const char *value, void *priv); static int _config_set_rule_each_from_list(const char *file, set_rule_add_func callback, void *priv) { FILE *fp = NULL; @@ -1004,7 +1018,7 @@ static __attribute__((unused)) struct dns_domain_rule *_config_domain_rule_get(c return NULL; } - return art_search(&dns_conf_domain_rule, (unsigned char *)domain_key, len); + return art_search(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len); } static int _config_domain_rule_add(const char *domain, enum domain_rule type, void *rule) @@ -1036,7 +1050,7 @@ static int _config_domain_rule_add(const char *domain, enum domain_rule type, vo } /* Get existing or create domain rule */ - domain_rule = art_search(&dns_conf_domain_rule, (unsigned char *)domain_key, len); + domain_rule = art_search(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len); if (domain_rule == NULL) { add_domain_rule = malloc(sizeof(*add_domain_rule)); if (add_domain_rule == NULL) { @@ -1059,7 +1073,8 @@ static int _config_domain_rule_add(const char *domain, enum domain_rule type, vo /* update domain rule */ if (add_domain_rule) { - old_domain_rule = art_insert(&dns_conf_domain_rule, (unsigned char *)domain_key, len, add_domain_rule); + old_domain_rule = + art_insert(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len, add_domain_rule); if (old_domain_rule) { _config_domain_rule_free(old_domain_rule); } @@ -1097,7 +1112,7 @@ static int _config_domain_rule_delete(const char *domain) } /* delete existing rules */ - void *rule = art_delete(&dns_conf_domain_rule, (unsigned char *)domain_key, len); + void *rule = art_delete(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len); if (rule) { _config_domain_rule_free(rule); } @@ -1140,7 +1155,7 @@ static int _config_domain_rule_flag_set(const char *domain, unsigned int flag, u } /* Get existing or create domain rule */ - domain_rule = art_search(&dns_conf_domain_rule, (unsigned char *)domain_key, len); + domain_rule = art_search(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len); if (domain_rule == NULL) { add_domain_rule = malloc(sizeof(*add_domain_rule)); if (add_domain_rule == NULL) { @@ -1170,7 +1185,8 @@ static int _config_domain_rule_flag_set(const char *domain, unsigned int flag, u /* update domain rule */ if (add_domain_rule) { - old_domain_rule = art_insert(&dns_conf_domain_rule, (unsigned char *)domain_key, len, add_domain_rule); + old_domain_rule = + art_insert(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, len, add_domain_rule); if (old_domain_rule) { _config_domain_rule_free(old_domain_rule); } @@ -2805,6 +2821,224 @@ static void _dns_ip_rule_put(struct dns_ip_rule *rule) } } +static radix_node_t *_create_client_rules_node(const char *addr) +{ + radix_node_t *node = NULL; + void *p = NULL; + prefix_t prefix; + const char *errmsg = NULL; + + p = prefix_pton(addr, -1, &prefix, &errmsg); + if (p == NULL) { + return NULL; + } + + node = radix_lookup(dns_conf_client_rule.rule, &prefix); + return node; +} + +static void *_new_dns_client_rule_ext(enum client_rule client_rule, int ext_size) +{ + struct dns_client_rule *rule; + int size = 0; + + if (client_rule >= CLIENT_RULE_MAX) { + return NULL; + } + + switch (client_rule) { + case CLIENT_RULE_FLAGS: + size = sizeof(struct client_rule_flags); + break; + case CLIENT_RULE_GROUP: + size = sizeof(struct client_rule_group); + break; + default: + return NULL; + } + + size += ext_size; + rule = malloc(size); + if (!rule) { + return NULL; + } + memset(rule, 0, size); + rule->rule = client_rule; + atomic_set(&rule->refcnt, 1); + return rule; +} + +static void *_new_dns_client_rule(enum client_rule client_rule) +{ + return _new_dns_client_rule_ext(client_rule, 0); +} + +static void _dns_client_rule_get(struct dns_client_rule *rule) +{ + atomic_inc(&rule->refcnt); +} + +static void _dns_client_rule_put(struct dns_client_rule *rule) +{ + int refcount = atomic_dec_return(&rule->refcnt); + if (refcount > 0) { + return; + } + + free(rule); +} + +static int _config_client_rules_free(struct dns_client_rules *client_rules) +{ + int i = 0; + + if (client_rules == NULL) { + return 0; + } + + for (i = 0; i < CLIENT_RULE_MAX; i++) { + if (client_rules->rules[i] == NULL) { + continue; + } + + _dns_client_rule_put(client_rules->rules[i]); + client_rules->rules[i] = NULL; + } + + free(client_rules); + return 0; +} + +static int _config_client_rule_flag_set(const char *ip_cidr, unsigned int flag, unsigned int is_clear); +static int _config_client_rule_flag_callback(const char *ip_cidr, void *priv) +{ + struct dns_set_rule_flags_callback_args *args = (struct dns_set_rule_flags_callback_args *)priv; + return _config_client_rule_flag_set(ip_cidr, args->flags, args->is_clear_flag); +} + +static int _config_client_rule_flag_set(const char *ip_cidr, unsigned int flag, unsigned int is_clear) +{ + struct dns_client_rules *client_rules = NULL; + struct dns_client_rules *add_client_rules = NULL; + struct client_rule_flags *client_rule_flags = NULL; + radix_node_t *node = NULL; + + if (strncmp(ip_cidr, "ip-set:", sizeof("ip-set:") - 1) == 0) { + struct dns_set_rule_flags_callback_args args; + args.flags = flag; + args.is_clear_flag = is_clear; + return _config_ip_rule_set_each(ip_cidr + sizeof("ip-set:") - 1, _config_client_rule_flag_callback, &args); + } + + /* Get existing or create domain rule */ + node = _create_client_rules_node(ip_cidr); + if (node == NULL) { + tlog(TLOG_ERROR, "create addr node failed."); + goto errout; + } + + client_rules = node->data; + if (client_rules == NULL) { + add_client_rules = malloc(sizeof(*add_client_rules)); + if (add_client_rules == NULL) { + goto errout; + } + memset(add_client_rules, 0, sizeof(*add_client_rules)); + client_rules = add_client_rules; + node->data = client_rules; + } + + /* add new rule to domain */ + if (client_rules->rules[CLIENT_RULE_FLAGS] == NULL) { + client_rule_flags = _new_dns_client_rule(CLIENT_RULE_FLAGS); + client_rule_flags->flags = 0; + client_rules->rules[CLIENT_RULE_FLAGS] = &client_rule_flags->head; + } + + client_rule_flags = container_of(client_rules->rules[CLIENT_RULE_FLAGS], struct client_rule_flags, head); + if (is_clear == false) { + client_rule_flags->flags |= flag; + } else { + client_rule_flags->flags &= ~flag; + } + client_rule_flags->is_flag_set |= flag; + + return 0; +errout: + if (add_client_rules) { + free(add_client_rules); + } + + tlog(TLOG_ERROR, "set ip %s flags failed", ip_cidr); + + return 0; +} + +static int _config_client_rule_add(const char *ip_cidr, enum client_rule type, void *rule); +static int _config_client_rule_add_callback(const char *ip_cidr, void *priv) +{ + struct dns_set_rule_add_callback_args *args = (struct dns_set_rule_add_callback_args *)priv; + return _config_client_rule_add(ip_cidr, args->type, args->rule); +} + +static int _config_client_rule_add(const char *ip_cidr, enum client_rule type, void *rule) +{ + struct dns_client_rules *client_rules = NULL; + struct dns_client_rules *add_client_rules = NULL; + radix_node_t *node = NULL; + + if (ip_cidr == NULL) { + goto errout; + } + + if (type >= CLIENT_RULE_MAX) { + goto errout; + } + + if (strncmp(ip_cidr, "ip-set:", sizeof("ip-set:") - 1) == 0) { + struct dns_set_rule_add_callback_args args; + args.type = type; + args.rule = rule; + return _config_ip_rule_set_each(ip_cidr + sizeof("ip-set:") - 1, _config_client_rule_add_callback, &args); + } + + /* Get existing or create domain rule */ + node = _create_client_rules_node(ip_cidr); + if (node == NULL) { + tlog(TLOG_ERROR, "create addr node failed."); + goto errout; + } + + client_rules = node->data; + if (client_rules == NULL) { + add_client_rules = malloc(sizeof(*add_client_rules)); + if (add_client_rules == NULL) { + goto errout; + } + memset(add_client_rules, 0, sizeof(*add_client_rules)); + client_rules = add_client_rules; + node->data = client_rules; + } + + /* add new rule to domain */ + if (client_rules->rules[type]) { + _dns_client_rule_put(client_rules->rules[type]); + client_rules->rules[type] = NULL; + } + + client_rules->rules[type] = rule; + _dns_client_rule_get(rule); + + return 0; +errout: + if (add_client_rules) { + free(add_client_rules); + } + + tlog(TLOG_ERROR, "add client %s rule failed", ip_cidr); + return -1; +} + static int _config_qtype_soa(void *data, int argc, char *argv[]) { int i = 0; @@ -3638,6 +3872,22 @@ static void _config_ip_iter_free(radix_node_t *node, void *cbctx) node->data = NULL; } +static void _config_client_rule_iter_free_cb(radix_node_t *node, void *cbctx) +{ + struct dns_client_rules *client_rules = NULL; + if (node == NULL) { + return; + } + + if (node->data == NULL) { + return; + } + + client_rules = node->data; + _config_client_rules_free(client_rules); + node->data = NULL; +} + static void _config_ip_set_name_table_destroy(void) { struct dns_ip_set_name_list *set_name_list = NULL; @@ -4408,6 +4658,146 @@ static void _config_host_table_destroy(int only_dynamic) dns_hosts_record_num = 0; } +static int _config_client_rule_group_add(const char *client, const char *group_name) +{ + struct client_rule_group *client_rule = NULL; + const char *group = NULL; + + client_rule = _new_dns_client_rule(CLIENT_RULE_GROUP); + if (client_rule == NULL) { + goto errout; + } + + group = _dns_conf_get_group_name(group_name); + if (group == NULL) { + goto errout; + } + + client_rule->group_name = group; + if (_config_client_rule_add(client, CLIENT_RULE_GROUP, client_rule) != 0) { + goto errout; + } + + _dns_client_rule_put(&client_rule->head); + + return 0; +errout: + if (client_rule != NULL) { + _dns_client_rule_put(&client_rule->head); + } + return -1; +} + +static int _config_client_rules(void *data, int argc, char *argv[]) +{ + int opt = 0; + const char *client = argv[1]; + unsigned int server_flag = 0; + + /* clang-format off */ + static struct option long_options[] = { + {"group", required_argument, NULL, 'g'}, + {"no-rule-addr", no_argument, NULL, 'A'}, + {"no-rule-nameserver", no_argument, NULL, 'N'}, + {"no-rule-ipset", no_argument, NULL, 'I'}, + {"no-rule-sni-proxy", no_argument, NULL, 'P'}, + {"no-rule-soa", no_argument, NULL, 'O'}, + {"no-speed-check", no_argument, NULL, 'S'}, + {"no-cache", no_argument, NULL, 'C'}, + {"no-dualstack-selection", no_argument, NULL, 'D'}, + {"no-ip-alias", no_argument, NULL, 'a'}, + {"force-aaaa-soa", no_argument, NULL, 'F'}, + {"no-serve-expired", no_argument, NULL, 253}, + {"force-https-soa", no_argument, NULL, 254}, + {NULL, no_argument, NULL, 0} + }; + /* clang-format on */ + + if (argc <= 1) { + tlog(TLOG_ERROR, "invalid parameter."); + goto errout; + } + + /* process extra options */ + optind = 1; + while (1) { + opt = getopt_long_only(argc, argv, "g:", long_options, NULL); + if (opt == -1) { + break; + } + + switch (opt) { + case 'g': { + const char *group = optarg; + if (_config_client_rule_group_add(client, group) != 0) { + tlog(TLOG_ERROR, "add group rule failed."); + goto errout; + } + break; + } + case 'A': { + server_flag |= BIND_FLAG_NO_RULE_ADDR; + break; + } + case 'a': { + server_flag |= BIND_FLAG_NO_IP_ALIAS; + break; + } + case 'N': { + server_flag |= BIND_FLAG_NO_RULE_NAMESERVER; + break; + } + case 'I': { + server_flag |= BIND_FLAG_NO_RULE_IPSET; + break; + } + case 'P': { + server_flag |= BIND_FLAG_NO_RULE_SNIPROXY; + break; + } + case 'S': { + server_flag |= BIND_FLAG_NO_SPEED_CHECK; + break; + } + case 'C': { + server_flag |= BIND_FLAG_NO_CACHE; + break; + } + case 'O': { + server_flag |= BIND_FLAG_NO_RULE_SOA; + break; + } + case 'D': { + server_flag |= BIND_FLAG_NO_DUALSTACK_SELECTION; + break; + } + case 'F': { + server_flag |= BIND_FLAG_FORCE_AAAA_SOA; + break; + } + case 253: { + server_flag |= BIND_FLAG_NO_SERVE_EXPIRED; + break; + } + case 254: { + server_flag |= BIND_FLAG_FORCE_HTTPS_SOA; + break; + } + } + } + + if (server_flag != 0) { + if (_config_client_rule_flag_set(client, server_flag, 0) != 0) { + tlog(TLOG_ERROR, "set client rule flags failed."); + goto errout; + } + } + + return 0; +errout: + return -1; +} + int dns_server_check_update_hosts(void) { struct stat statbuf; @@ -4605,6 +4995,7 @@ static struct config_item _config_item[] = { CONF_CUSTOM("ddns-domain", _conf_ddns_domain, NULL), CONF_CUSTOM("dnsmasq-lease-file", _conf_dhcp_lease_dnsmasq_file, NULL), CONF_CUSTOM("hosts-file", _conf_hosts_file, NULL), + CONF_CUSTOM("client-rules", _config_client_rules, NULL), CONF_STRING("ca-file", (char *)&dns_conf_ca_file, DNS_MAX_PATH), CONF_STRING("ca-path", (char *)&dns_conf_ca_path, DNS_MAX_PATH), CONF_STRING("user", (char *)&dns_conf_user, sizeof(dns_conf_user)), @@ -4732,12 +5123,14 @@ static int _dns_server_load_conf_init(void) { dns_conf_address_rule.ipv4 = New_Radix(); dns_conf_address_rule.ipv6 = New_Radix(); - if (dns_conf_address_rule.ipv4 == NULL || dns_conf_address_rule.ipv6 == NULL) { + dns_conf_client_rule.rule = New_Radix(); + if (dns_conf_address_rule.ipv4 == NULL || dns_conf_address_rule.ipv6 == NULL || dns_conf_client_rule.rule == NULL) { tlog(TLOG_WARN, "init radix tree failed."); return -1; } - art_tree_init(&dns_conf_domain_rule); + art_tree_init(&dns_conf_domain_rule.default_rule); + hash_init(dns_conf_domain_rule.group); hash_init(dns_ipset_table.ipset); hash_init(dns_nftset_table.nftset); @@ -4790,6 +5183,7 @@ static void _config_ip_rules_destroy(void) { Destroy_Radix(dns_conf_address_rule.ipv4, _config_ip_iter_free, NULL); Destroy_Radix(dns_conf_address_rule.ipv6, _config_ip_iter_free, NULL); + Destroy_Radix(dns_conf_client_rule.rule, _config_client_rule_iter_free_cb, NULL); } void dns_server_load_exit(void) diff --git a/src/dns_conf.h b/src/dns_conf.h index 0700986cf7..44b6217ddd 100644 --- a/src/dns_conf.h +++ b/src/dns_conf.h @@ -94,6 +94,12 @@ enum ip_rule { IP_RULE_MAX, }; +enum client_rule { + CLIENT_RULE_FLAGS = 0, + CLIENT_RULE_GROUP, + CLIENT_RULE_MAX, +}; + typedef enum { DNS_BIND_TYPE_UDP, DNS_BIND_TYPE_TCP, @@ -240,7 +246,6 @@ extern struct dns_nftset_names dns_conf_nftset_no_speed; extern struct dns_nftset_names dns_conf_nftset; struct dns_domain_rule { - struct dns_rule head; unsigned char sub_rule_only : 1; unsigned char root_rule_only : 1; struct dns_rule *rules[DOMAIN_RULE_MAX]; @@ -273,6 +278,17 @@ struct dns_response_mode_rule { enum response_mode_type mode; }; +struct dns_conf_doamin_rule_group { + struct hlist_node node; + art_tree rule; + const char *group_name; +}; + +struct dns_conf_domain_rule { + art_tree default_rule; + DECLARE_HASHTABLE(group, 8); +}; + struct dns_group_table { DECLARE_HASHTABLE(group, 8); }; @@ -393,6 +409,30 @@ struct dns_conf_address_rule { radix_tree_t *ipv6; }; +struct dns_client_rule { + atomic_t refcnt; + enum client_rule rule; +}; + +struct client_rule_flags { + struct dns_client_rule head; + unsigned int flags; + unsigned int is_flag_set; +}; + +struct client_rule_group { + struct dns_client_rule head; + const char *group_name; +}; + +struct dns_client_rules { + struct dns_client_rule *rules[CLIENT_RULE_MAX]; +}; + +struct dns_conf_client_rule { + radix_tree_t *rule; +}; + struct nftset_ipset_rules { struct dns_ipset_rule *ipset; struct dns_ipset_rule *ipset_ip; @@ -572,8 +612,9 @@ extern int dns_conf_audit_console; extern int dns_conf_audit_syslog; extern char dns_conf_server_name[DNS_MAX_SERVER_NAME_LEN]; -extern art_tree dns_conf_domain_rule; +extern struct dns_conf_domain_rule dns_conf_domain_rule; extern struct dns_conf_address_rule dns_conf_address_rule; +extern struct dns_conf_client_rule dns_conf_client_rule; extern int dns_conf_dualstack_ip_selection; extern int dns_conf_dualstack_ip_allow_force_AAAA; diff --git a/src/dns_server.c b/src/dns_server.c index 4571428842..030cc5bb10 100644 --- a/src/dns_server.c +++ b/src/dns_server.c @@ -3027,6 +3027,48 @@ static int _dns_server_check_speed(struct dns_request *request, char *ip) return -1; } +static struct dns_client_rules *_dns_server_get_client_rules(struct sockaddr_storage *addr, socklen_t addr_len) +{ + prefix_t prefix; + radix_node_t *node = NULL; + uint8_t *netaddr = NULL; + int netaddr_len = 0; + + switch (addr->ss_family) { + case AF_INET: { + struct sockaddr_in *addr_in = NULL; + addr_in = (struct sockaddr_in *)addr; + netaddr = (unsigned char *)&(addr_in->sin_addr.s_addr); + netaddr_len = 4; + } break; + case AF_INET6: { + struct sockaddr_in6 *addr_in6 = NULL; + addr_in6 = (struct sockaddr_in6 *)addr; + if (IN6_IS_ADDR_V4MAPPED(&addr_in6->sin6_addr)) { + netaddr = addr_in6->sin6_addr.s6_addr + 12; + netaddr_len = 4; + } else { + netaddr = addr_in6->sin6_addr.s6_addr; + netaddr_len = 16; + } + } break; + default: + return NULL; + break; + } + + if (prefix_from_blob(netaddr, netaddr_len, netaddr_len * 8, &prefix) == NULL) { + return NULL; + } + + node = radix_search_best(dns_conf_client_rule.rule, &prefix); + if (node == NULL) { + return NULL; + } + + return node->data; +} + static struct dns_ip_rules *_dns_server_ip_rule_get(struct dns_request *request, unsigned char *addr, int addr_len, dns_type_t addr_type) { @@ -4609,7 +4651,7 @@ static void _dns_server_get_domain_rule_by_domain(struct dns_request *request, c domain_key[domain_len] = 0; /* find domain rule */ - art_substring_walk(&dns_conf_domain_rule, (unsigned char *)domain_key, domain_len, _dns_server_get_rules, + art_substring_walk(&dns_conf_domain_rule.default_rule, (unsigned char *)domain_key, domain_len, _dns_server_get_rules, &walk_args); if (likely(dns_conf_log_level > TLOG_DEBUG)) { return; @@ -5543,6 +5585,29 @@ static void _dns_server_request_set_client(struct dns_request *request, struct d _dns_server_conn_get(conn); } +static void _dns_server_request_set_client_rules(struct dns_request *request, struct dns_client_rules *client_rule) +{ + if (client_rule == NULL) { + return; + } + + tlog(TLOG_DEBUG, "match client rule.\n"); + + if (client_rule->rules[CLIENT_RULE_GROUP]) { + struct client_rule_group *group = (struct client_rule_group *)client_rule->rules[CLIENT_RULE_GROUP]; + if (group && group->group_name[0] != '\0') { + safe_strncpy(request->dns_group_name, group->group_name, sizeof(request->dns_group_name)); + } + } + + if (client_rule->rules[CLIENT_RULE_FLAGS]) { + struct client_rule_flags *flags = (struct client_rule_flags *)client_rule->rules[CLIENT_RULE_FLAGS]; + if (flags) { + request->server_flags = flags->flags; + } + } +} + static void _dns_server_request_set_id(struct dns_request *request, unsigned short id) { request->id = id; @@ -6096,6 +6161,7 @@ static int _dns_server_recv(struct dns_server_conn_head *conn, unsigned char *in char name[DNS_MAX_CNAME_LEN]; struct dns_packet *packet = (struct dns_packet *)packet_buff; struct dns_request *request = NULL; + struct dns_client_rules *client_rules = NULL; /* decode packet */ tlog(TLOG_DEBUG, "recv query packet from %s, len = %d, type = %d", @@ -6116,6 +6182,7 @@ static int _dns_server_recv(struct dns_server_conn_head *conn, unsigned char *in packet->head.qdcount, packet->head.ancount, packet->head.nscount, packet->head.nrcount, inpacket_len, packet->head.id, packet->head.tc, packet->head.rd, packet->head.ra, packet->head.rcode); + client_rules = _dns_server_get_client_rules(from, from_len); request = _dns_server_new_request(); if (request == NULL) { tlog(TLOG_ERROR, "malloc failed.\n"); @@ -6124,6 +6191,7 @@ static int _dns_server_recv(struct dns_server_conn_head *conn, unsigned char *in memcpy(&request->localaddr, local, local_len); _dns_server_request_set_client(request, conn); + _dns_server_request_set_client_rules(request, client_rules); _dns_server_request_set_client_addr(request, from, from_len); _dns_server_request_set_id(request, packet->head.id); diff --git a/test/cases/test-client-rule.cc b/test/cases/test-client-rule.cc new file mode 100644 index 0000000000..15a265cc7e --- /dev/null +++ b/test/cases/test-client-rule.cc @@ -0,0 +1,65 @@ +/************************************************************************* + * + * Copyright (C) 2018-2023 Ruilin Peng (Nick) . + * + * smartdns is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * smartdns is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +#include "client.h" +#include "dns.h" +#include "include/utils.h" +#include "server.h" +#include "gtest/gtest.h" + +class ClientRule : public ::testing::Test +{ + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(ClientRule, bogus_nxdomain) +{ + smartdns::MockServer server_upstream; + smartdns::MockServer server_upstream2; + smartdns::Server server; + + server_upstream.Start("udp://0.0.0.0:61053", [](struct smartdns::ServerRequestContext *request) { + if (request->qtype != DNS_T_A) { + return smartdns::SERVER_REQUEST_SOA; + } + + smartdns::MockServer::AddIP(request, request->domain.c_str(), "1.2.3.4", 611); + return smartdns::SERVER_REQUEST_OK; + }); + + server_upstream2.Start("udp://0.0.0.0:62053", + [](struct smartdns::ServerRequestContext *request) { return smartdns::SERVER_REQUEST_SOA; }); + + /* this ip will be discard, but is reachable */ + server.MockPing(PING_TYPE_ICMP, "1.2.3.4", 60, 10); + + server.Start(R"""(bind [::]:60053 +server udp://127.0.0.1:61053 -g client -e +server udp://127.0.0.1:62053 +client-rules 127.0.0.1 -g client +)"""); + smartdns::Client client; + ASSERT_TRUE(client.Query("b.com", 60053)); + std::cout << client.GetResult() << std::endl; + ASSERT_EQ(client.GetAnswerNum(), 1); + EXPECT_EQ(client.GetStatus(), "NOERROR"); + EXPECT_EQ(client.GetAnswer()[0].GetName(), "b.com"); + EXPECT_EQ(client.GetAnswer()[0].GetData(), "1.2.3.4"); +}