diff options
Diffstat (limited to 'net')
102 files changed, 8419 insertions, 348 deletions
diff --git a/net/Kconfig b/net/Kconfig index 129b9fcbf1d0..f73e64ea067f 100644 --- a/net/Kconfig +++ b/net/Kconfig @@ -86,6 +86,12 @@ source "net/netlabel/Kconfig" endif # if INET +config ANDROID_PARANOID_NETWORK + bool "Only allow certain groups to create sockets" + default y + help + none + config NETWORK_SECMARK bool "Security Marking" help diff --git a/net/bluetooth/af_bluetooth.c b/net/bluetooth/af_bluetooth.c index 70306cc9d814..709ce9fb15f3 100644 --- a/net/bluetooth/af_bluetooth.c +++ b/net/bluetooth/af_bluetooth.c @@ -106,11 +106,40 @@ void bt_sock_unregister(int proto) } EXPORT_SYMBOL(bt_sock_unregister); +#ifdef CONFIG_PARANOID_NETWORK +static inline int current_has_bt_admin(void) +{ + return !current_euid(); +} + +static inline int current_has_bt(void) +{ + return current_has_bt_admin(); +} +# else +static inline int current_has_bt_admin(void) +{ + return 1; +} + +static inline int current_has_bt(void) +{ + return 1; +} +#endif + static int bt_sock_create(struct net *net, struct socket *sock, int proto, int kern) { int err; + if (proto == BTPROTO_RFCOMM || proto == BTPROTO_SCO || + proto == BTPROTO_L2CAP) { + if (!current_has_bt()) + return -EPERM; + } else if (!current_has_bt_admin()) + return -EPERM; + if (net != &init_net) return -EAFNOSUPPORT; diff --git a/net/bridge/br_device.c b/net/bridge/br_device.c index 06e88c6bb511..49c29fa7fd30 100644 --- a/net/bridge/br_device.c +++ b/net/bridge/br_device.c @@ -48,16 +48,17 @@ netdev_tx_t br_dev_xmit(struct sk_buff *skb, struct net_device *dev) return NETDEV_TX_OK; } - u64_stats_update_begin(&brstats->syncp); - brstats->tx_packets++; - brstats->tx_bytes += skb->len; - u64_stats_update_end(&brstats->syncp); - BR_INPUT_SKB_CB(skb)->brdev = dev; skb_reset_mac_header(skb); skb_pull(skb, ETH_HLEN); + u64_stats_update_begin(&brstats->syncp); + brstats->tx_packets++; + /* Exclude ETH_HLEN from byte stats for consistency with Rx chain */ + brstats->tx_bytes += skb->len; + u64_stats_update_end(&brstats->syncp); + if (!br_allowed_ingress(br, br_vlan_group_rcu(br), skb, &vid)) goto out; diff --git a/net/bridge/netfilter/ebt_limit.c b/net/bridge/netfilter/ebt_limit.c index 517e78befcb2..61a9f1be1263 100644 --- a/net/bridge/netfilter/ebt_limit.c +++ b/net/bridge/netfilter/ebt_limit.c @@ -105,6 +105,7 @@ static struct xt_match ebt_limit_mt_reg __read_mostly = { .match = ebt_limit_mt, .checkentry = ebt_limit_mt_check, .matchsize = sizeof(struct ebt_limit_info), + .usersize = offsetof(struct ebt_limit_info, prev), #ifdef CONFIG_COMPAT .compatsize = sizeof(struct ebt_compat_limit_info), #endif diff --git a/net/core/dev.c b/net/core/dev.c index 9a0c726bd124..be60554db0af 100644 --- a/net/core/dev.c +++ b/net/core/dev.c @@ -184,7 +184,7 @@ EXPORT_SYMBOL(dev_base_lock); static DEFINE_SPINLOCK(napi_hash_lock); static unsigned int napi_gen_id = NR_CPUS; -static DEFINE_HASHTABLE(napi_hash, 8); +static DEFINE_READ_MOSTLY_HASHTABLE(napi_hash, 8); static DECLARE_RWSEM(devnet_rename_sem); diff --git a/net/core/fib_rules.c b/net/core/fib_rules.c index b9cbab73d0de..dcd40a44b93d 100644 --- a/net/core/fib_rules.c +++ b/net/core/fib_rules.c @@ -18,6 +18,11 @@ #include <net/fib_rules.h> #include <net/ip_tunnels.h> +static const struct fib_kuid_range fib_kuid_range_unset = { + KUIDT_INIT(0), + KUIDT_INIT(~0), +}; + int fib_default_rule_add(struct fib_rules_ops *ops, u32 pref, u32 table, u32 flags) { @@ -33,6 +38,7 @@ int fib_default_rule_add(struct fib_rules_ops *ops, r->table = table; r->flags = flags; r->fr_net = ops->fro_net; + r->uid_range = fib_kuid_range_unset; r->suppress_prefixlen = -1; r->suppress_ifgroup = -1; @@ -172,6 +178,34 @@ void fib_rules_unregister(struct fib_rules_ops *ops) } EXPORT_SYMBOL_GPL(fib_rules_unregister); +static int uid_range_set(struct fib_kuid_range *range) +{ + return uid_valid(range->start) && uid_valid(range->end); +} + +static struct fib_kuid_range nla_get_kuid_range(struct nlattr **tb) +{ + struct fib_rule_uid_range *in; + struct fib_kuid_range out; + + in = (struct fib_rule_uid_range *)nla_data(tb[FRA_UID_RANGE]); + + out.start = make_kuid(current_user_ns(), in->start); + out.end = make_kuid(current_user_ns(), in->end); + + return out; +} + +static int nla_put_uid_range(struct sk_buff *skb, struct fib_kuid_range *range) +{ + struct fib_rule_uid_range out = { + from_kuid_munged(current_user_ns(), range->start), + from_kuid_munged(current_user_ns(), range->end) + }; + + return nla_put(skb, FRA_UID_RANGE, sizeof(out), &out); +} + static int fib_rule_match(struct fib_rule *rule, struct fib_rules_ops *ops, struct flowi *fl, int flags) { @@ -189,6 +223,10 @@ static int fib_rule_match(struct fib_rule *rule, struct fib_rules_ops *ops, if (rule->tun_id && (rule->tun_id != fl->flowi_tun_key.tun_id)) goto out; + if (uid_lt(fl->flowi_uid, rule->uid_range.start) || + uid_gt(fl->flowi_uid, rule->uid_range.end)) + goto out; + ret = ops->match(rule, fl, flags); out: return (rule->flags & FIB_RULE_INVERT) ? !ret : ret; @@ -371,6 +409,21 @@ static int fib_nl_newrule(struct sk_buff *skb, struct nlmsghdr* nlh) } else if (rule->action == FR_ACT_GOTO) goto errout_free; + if (tb[FRA_UID_RANGE]) { + if (current_user_ns() != net->user_ns) { + err = -EPERM; + goto errout_free; + } + + rule->uid_range = nla_get_kuid_range(tb); + + if (!uid_range_set(&rule->uid_range) || + !uid_lte(rule->uid_range.start, rule->uid_range.end)) + goto errout_free; + } else { + rule->uid_range = fib_kuid_range_unset; + } + err = ops->configure(rule, skb, frh, tb); if (err < 0) goto errout_free; @@ -432,6 +485,7 @@ static int fib_nl_delrule(struct sk_buff *skb, struct nlmsghdr* nlh) struct fib_rules_ops *ops = NULL; struct fib_rule *rule, *tmp; struct nlattr *tb[FRA_MAX+1]; + struct fib_kuid_range range; int err = -EINVAL; if (nlh->nlmsg_len < nlmsg_msg_size(sizeof(*frh))) @@ -451,6 +505,14 @@ static int fib_nl_delrule(struct sk_buff *skb, struct nlmsghdr* nlh) if (err < 0) goto errout; + if (tb[FRA_UID_RANGE]) { + range = nla_get_kuid_range(tb); + if (!uid_range_set(&range)) + goto errout; + } else { + range = fib_kuid_range_unset; + } + list_for_each_entry(rule, &ops->rules_list, list) { if (frh->action && (frh->action != rule->action)) continue; @@ -483,6 +545,11 @@ static int fib_nl_delrule(struct sk_buff *skb, struct nlmsghdr* nlh) (rule->tun_id != nla_get_be64(tb[FRA_TUN_ID]))) continue; + if (uid_range_set(&range) && + (!uid_eq(rule->uid_range.start, range.start) || + !uid_eq(rule->uid_range.end, range.end))) + continue; + if (!ops->compare(rule, frh, tb)) continue; @@ -550,6 +617,7 @@ static inline size_t fib_rule_nlmsg_size(struct fib_rules_ops *ops, + nla_total_size(4) /* FRA_FWMARK */ + nla_total_size(4) /* FRA_FWMASK */ + nla_total_size(8); /* FRA_TUN_ID */ + + nla_total_size(sizeof(struct fib_kuid_range)); if (ops->nlmsg_payload) payload += ops->nlmsg_payload(rule); @@ -607,7 +675,9 @@ static int fib_nl_fill_rule(struct sk_buff *skb, struct fib_rule *rule, (rule->target && nla_put_u32(skb, FRA_GOTO, rule->target)) || (rule->tun_id && - nla_put_be64(skb, FRA_TUN_ID, rule->tun_id))) + nla_put_be64(skb, FRA_TUN_ID, rule->tun_id)) || + (uid_range_set(&rule->uid_range) && + nla_put_uid_range(skb, &rule->uid_range))) goto nla_put_failure; if (rule->suppress_ifgroup != -1) { diff --git a/net/core/sock.c b/net/core/sock.c index 5e9ff8d9f9e3..2b938d093bf1 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -2426,8 +2426,11 @@ void sock_init_data(struct socket *sock, struct sock *sk) sk->sk_type = sock->type; sk->sk_wq = sock->wq; sock->sk = sk; - } else + sk->sk_uid = SOCK_INODE(sock)->i_uid; + } else { sk->sk_wq = NULL; + sk->sk_uid = make_kuid(sock_net(sk)->user_ns, 0); + } rwlock_init(&sk->sk_callback_lock); lockdep_set_class_and_name(&sk->sk_callback_lock, diff --git a/net/core/sock_diag.c b/net/core/sock_diag.c index a47f693f9f14..9653798da293 100644 --- a/net/core/sock_diag.c +++ b/net/core/sock_diag.c @@ -214,7 +214,7 @@ void sock_diag_unregister(const struct sock_diag_handler *hnld) } EXPORT_SYMBOL_GPL(sock_diag_unregister); -static int __sock_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh) +static int __sock_diag_cmd(struct sk_buff *skb, struct nlmsghdr *nlh) { int err; struct sock_diag_req *req = nlmsg_data(nlh); @@ -234,8 +234,12 @@ static int __sock_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh) hndl = sock_diag_handlers[req->sdiag_family]; if (hndl == NULL) err = -ENOENT; - else + else if (nlh->nlmsg_type == SOCK_DIAG_BY_FAMILY) err = hndl->dump(skb, nlh); + else if (nlh->nlmsg_type == SOCK_DESTROY_BACKPORT && hndl->destroy) + err = hndl->destroy(skb, nlh); + else + err = -EOPNOTSUPP; mutex_unlock(&sock_diag_table_mutex); return err; @@ -261,7 +265,8 @@ static int sock_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh) return ret; case SOCK_DIAG_BY_FAMILY: - return __sock_diag_rcv_msg(skb, nlh); + case SOCK_DESTROY_BACKPORT: + return __sock_diag_cmd(skb, nlh); default: return -EINVAL; } @@ -295,6 +300,18 @@ static int sock_diag_bind(struct net *net, int group) return 0; } +int sock_diag_destroy(struct sock *sk, int err) +{ + if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_ADMIN)) + return -EPERM; + + if (!sk->sk_prot->diag_destroy) + return -EOPNOTSUPP; + + return sk->sk_prot->diag_destroy(sk, err); +} +EXPORT_SYMBOL_GPL(sock_diag_destroy); + static int __net_init diag_net_init(struct net *net) { struct netlink_kernel_cfg cfg = { diff --git a/net/ipv4/Kconfig b/net/ipv4/Kconfig index 0d17c8516589..daad05c6f25b 100644 --- a/net/ipv4/Kconfig +++ b/net/ipv4/Kconfig @@ -439,6 +439,19 @@ config INET_UDP_DIAG Support for UDP socket monitoring interface used by the ss tool. If unsure, say Y. +config INET_DIAG_DESTROY + bool "INET: allow privileged process to administratively close sockets" + depends on INET_DIAG + default n + ---help--- + Provides a SOCK_DESTROY operation that allows privileged processes + (e.g., a connection manager or a network administration tool such as + ss) to close sockets opened by other processes. Closing a socket in + this way interrupts any blocking read/write/connect operations on + the socket and causes future socket calls to behave as if the socket + had been disconnected. + If unsure, say N. + menuconfig TCP_CONG_ADVANCED bool "TCP: advanced congestion control" ---help--- diff --git a/net/ipv4/Makefile b/net/ipv4/Makefile index c29809f765dc..854c4bfd6eea 100644 --- a/net/ipv4/Makefile +++ b/net/ipv4/Makefile @@ -16,6 +16,7 @@ obj-y := route.o inetpeer.o protocol.o \ obj-$(CONFIG_NET_IP_TUNNEL) += ip_tunnel.o obj-$(CONFIG_SYSCTL) += sysctl_net_ipv4.o +obj-$(CONFIG_SYSFS) += sysfs_net_ipv4.o obj-$(CONFIG_PROC_FS) += proc.o obj-$(CONFIG_IP_MULTIPLE_TABLES) += fib_rules.o obj-$(CONFIG_IP_MROUTE) += ipmr.o diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c index 48d2ae83e268..43f9a0ad479c 100644 --- a/net/ipv4/af_inet.c +++ b/net/ipv4/af_inet.c @@ -89,6 +89,7 @@ #include <linux/netfilter_ipv4.h> #include <linux/random.h> #include <linux/slab.h> +#include <linux/netfilter/xt_qtaguid.h> #include <asm/uaccess.h> @@ -121,6 +122,19 @@ #endif #include <net/l3mdev.h> +#ifdef CONFIG_ANDROID_PARANOID_NETWORK +#include <linux/android_aid.h> + +static inline int current_has_network(void) +{ + return in_egroup_p(AID_INET) || capable(CAP_NET_RAW); +} +#else +static inline int current_has_network(void) +{ + return 1; +} +#endif /* The inetsw table contains everything that inet_create needs to * build a new socket. @@ -260,6 +274,9 @@ static int inet_create(struct net *net, struct socket *sock, int protocol, if (protocol < 0 || protocol >= IPPROTO_MAX) return -EINVAL; + if (!current_has_network()) + return -EACCES; + sock->state = SS_UNCONNECTED; /* Look for the requested type/protocol pair. */ @@ -308,8 +325,7 @@ lookup_protocol: } err = -EPERM; - if (sock->type == SOCK_RAW && !kern && - !ns_capable(net->user_ns, CAP_NET_RAW)) + if (sock->type == SOCK_RAW && !kern && !capable(CAP_NET_RAW)) goto out_rcu_unlock; sock->ops = answer->ops; @@ -398,6 +414,9 @@ int inet_release(struct socket *sock) if (sk) { long timeout; +#ifdef CONFIG_NETFILTER_XT_MATCH_QTAGUID + qtaguid_untag(sock, true); +#endif /* Applications forget to leave groups before exiting */ ip_mc_drop_socket(sk); diff --git a/net/ipv4/fib_frontend.c b/net/ipv4/fib_frontend.c index 1a1163cedf17..c01149331f46 100644 --- a/net/ipv4/fib_frontend.c +++ b/net/ipv4/fib_frontend.c @@ -629,6 +629,7 @@ const struct nla_policy rtm_ipv4_policy[RTA_MAX + 1] = { [RTA_FLOW] = { .type = NLA_U32 }, [RTA_ENCAP_TYPE] = { .type = NLA_U16 }, [RTA_ENCAP] = { .type = NLA_NESTED }, + [RTA_UID] = { .type = NLA_U32 }, }; static int rtm_to_fib_config(struct net *net, struct sk_buff *skb, diff --git a/net/ipv4/icmp.c b/net/ipv4/icmp.c index 0a9fb3d2ba90..358d9dabda1d 100644 --- a/net/ipv4/icmp.c +++ b/net/ipv4/icmp.c @@ -429,6 +429,7 @@ static void icmp_reply(struct icmp_bxm *icmp_param, struct sk_buff *skb) fl4.daddr = daddr; fl4.saddr = saddr; fl4.flowi4_mark = mark; + fl4.flowi4_uid = sock_net_uid(net, NULL); fl4.flowi4_tos = RT_TOS(ip_hdr(skb)->tos); fl4.flowi4_proto = IPPROTO_ICMP; fl4.flowi4_oif = l3mdev_master_ifindex(skb->dev); @@ -495,6 +496,7 @@ static struct rtable *icmp_route_lookup(struct net *net, param->replyopts.opt.opt.faddr : iph->saddr); fl4->saddr = saddr; fl4->flowi4_mark = mark; + fl4->flowi4_uid = sock_net_uid(net, NULL); fl4->flowi4_tos = RT_TOS(tos); fl4->flowi4_proto = IPPROTO_ICMP; fl4->fl4_icmp_type = type; diff --git a/net/ipv4/inet_connection_sock.c b/net/ipv4/inet_connection_sock.c index 9678dd8d70c3..336ef741c950 100644 --- a/net/ipv4/inet_connection_sock.c +++ b/net/ipv4/inet_connection_sock.c @@ -432,7 +432,7 @@ struct dst_entry *inet_csk_route_req(const struct sock *sk, sk->sk_protocol, inet_sk_flowi_flags(sk), (opt && opt->opt.srr) ? opt->opt.faddr : ireq->ir_rmt_addr, ireq->ir_loc_addr, ireq->ir_rmt_port, - htons(ireq->ir_num)); + htons(ireq->ir_num), sk->sk_uid); security_req_classify_flow(req, flowi4_to_flowi(fl4)); rt = ip_route_output_flow(net, fl4, sk); if (IS_ERR(rt)) @@ -468,7 +468,7 @@ struct dst_entry *inet_csk_route_child_sock(const struct sock *sk, sk->sk_protocol, inet_sk_flowi_flags(sk), (opt && opt->opt.srr) ? opt->opt.faddr : ireq->ir_rmt_addr, ireq->ir_loc_addr, ireq->ir_rmt_port, - htons(ireq->ir_num)); + htons(ireq->ir_num), sk->sk_uid); security_req_classify_flow(req, flowi4_to_flowi(fl4)); rt = ip_route_output_flow(net, fl4, sk); if (IS_ERR(rt)) diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c index 386443e780da..fcb83b2a61f0 100644 --- a/net/ipv4/inet_diag.c +++ b/net/ipv4/inet_diag.c @@ -44,6 +44,8 @@ struct inet_diag_entry { u16 dport; u16 family; u16 userlocks; + u32 ifindex; + u32 mark; }; static DEFINE_MUTEX(inet_diag_table_mutex); @@ -96,6 +98,7 @@ static size_t inet_sk_attr_size(void) + nla_total_size(1) /* INET_DIAG_SHUTDOWN */ + nla_total_size(1) /* INET_DIAG_TOS */ + nla_total_size(1) /* INET_DIAG_TCLASS */ + + nla_total_size(4) /* INET_DIAG_MARK */ + nla_total_size(sizeof(struct inet_diag_meminfo)) + nla_total_size(sizeof(struct inet_diag_msg)) + nla_total_size(SK_MEMINFO_VARS * sizeof(u32)) @@ -108,7 +111,8 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, struct sk_buff *skb, const struct inet_diag_req_v2 *req, struct user_namespace *user_ns, u32 portid, u32 seq, u16 nlmsg_flags, - const struct nlmsghdr *unlh) + const struct nlmsghdr *unlh, + bool net_admin) { const struct inet_sock *inet = inet_sk(sk); const struct tcp_congestion_ops *ca_ops; @@ -158,6 +162,9 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, } #endif + if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, sk->sk_mark)) + goto errout; + r->idiag_uid = from_kuid_munged(user_ns, sock_i_uid(sk)); r->idiag_inode = sock_i_ino(sk); @@ -256,10 +263,11 @@ static int inet_csk_diag_fill(struct sock *sk, const struct inet_diag_req_v2 *req, struct user_namespace *user_ns, u32 portid, u32 seq, u16 nlmsg_flags, - const struct nlmsghdr *unlh) + const struct nlmsghdr *unlh, + bool net_admin) { - return inet_sk_diag_fill(sk, inet_csk(sk), skb, req, - user_ns, portid, seq, nlmsg_flags, unlh); + return inet_sk_diag_fill(sk, inet_csk(sk), skb, req, user_ns, + portid, seq, nlmsg_flags, unlh, net_admin); } static int inet_twsk_diag_fill(struct sock *sk, @@ -301,8 +309,9 @@ static int inet_twsk_diag_fill(struct sock *sk, static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb, u32 portid, u32 seq, u16 nlmsg_flags, - const struct nlmsghdr *unlh) + const struct nlmsghdr *unlh, bool net_admin) { + struct request_sock *reqsk = inet_reqsk(sk); struct inet_diag_msg *r; struct nlmsghdr *nlh; long tmo; @@ -316,7 +325,7 @@ static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb, inet_diag_msg_common_fill(r, sk); r->idiag_state = TCP_SYN_RECV; r->idiag_timer = 1; - r->idiag_retrans = inet_reqsk(sk)->num_retrans; + r->idiag_retrans = reqsk->num_retrans; BUILD_BUG_ON(offsetof(struct inet_request_sock, ir_cookie) != offsetof(struct sock, sk_cookie)); @@ -328,6 +337,10 @@ static int inet_req_diag_fill(struct sock *sk, struct sk_buff *skb, r->idiag_uid = 0; r->idiag_inode = 0; + if (net_admin && nla_put_u32(skb, INET_DIAG_MARK, + inet_rsk(reqsk)->ir_mark)) + return -EMSGSIZE; + nlmsg_end(skb, nlh); return 0; } @@ -336,7 +349,7 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, const struct inet_diag_req_v2 *r, struct user_namespace *user_ns, u32 portid, u32 seq, u16 nlmsg_flags, - const struct nlmsghdr *unlh) + const struct nlmsghdr *unlh, bool net_admin) { if (sk->sk_state == TCP_TIME_WAIT) return inet_twsk_diag_fill(sk, skb, portid, seq, @@ -344,23 +357,18 @@ static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, if (sk->sk_state == TCP_NEW_SYN_RECV) return inet_req_diag_fill(sk, skb, portid, seq, - nlmsg_flags, unlh); + nlmsg_flags, unlh, net_admin); return inet_csk_diag_fill(sk, skb, r, user_ns, portid, seq, - nlmsg_flags, unlh); + nlmsg_flags, unlh, net_admin); } -int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, - struct sk_buff *in_skb, - const struct nlmsghdr *nlh, - const struct inet_diag_req_v2 *req) +struct sock *inet_diag_find_one_icsk(struct net *net, + struct inet_hashinfo *hashinfo, + const struct inet_diag_req_v2 *req) { - struct net *net = sock_net(in_skb->sk); - struct sk_buff *rep; struct sock *sk; - int err; - err = -EINVAL; if (req->sdiag_family == AF_INET) sk = inet_lookup(net, hashinfo, req->id.idiag_dst[0], req->id.idiag_dport, req->id.idiag_src[0], @@ -382,15 +390,33 @@ int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, } #endif else - goto out_nosk; + return ERR_PTR(-EINVAL); - err = -ENOENT; if (!sk) - goto out_nosk; + return ERR_PTR(-ENOENT); - err = sock_diag_check_cookie(sk, req->id.idiag_cookie); - if (err) - goto out; + if (sock_diag_check_cookie(sk, req->id.idiag_cookie)) { + sock_gen_put(sk); + return ERR_PTR(-ENOENT); + } + + return sk; +} +EXPORT_SYMBOL_GPL(inet_diag_find_one_icsk); + +int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, + struct sk_buff *in_skb, + const struct nlmsghdr *nlh, + const struct inet_diag_req_v2 *req) +{ + struct net *net = sock_net(in_skb->sk); + struct sk_buff *rep; + struct sock *sk; + int err; + + sk = inet_diag_find_one_icsk(net, hashinfo, req); + if (IS_ERR(sk)) + return PTR_ERR(sk); rep = nlmsg_new(inet_sk_attr_size(), GFP_KERNEL); if (!rep) { @@ -401,7 +427,8 @@ int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, err = sk_diag_fill(sk, rep, req, sk_user_ns(NETLINK_CB(in_skb).sk), NETLINK_CB(in_skb).portid, - nlh->nlmsg_seq, 0, nlh); + nlh->nlmsg_seq, 0, nlh, + netlink_net_capable(in_skb, CAP_NET_ADMIN)); if (err < 0) { WARN_ON(err == -EMSGSIZE); nlmsg_free(rep); @@ -416,12 +443,11 @@ out: if (sk) sock_gen_put(sk); -out_nosk: return err; } EXPORT_SYMBOL_GPL(inet_diag_dump_one_icsk); -static int inet_diag_get_exact(struct sk_buff *in_skb, +static int inet_diag_cmd_exact(int cmd, struct sk_buff *in_skb, const struct nlmsghdr *nlh, const struct inet_diag_req_v2 *req) { @@ -431,8 +457,12 @@ static int inet_diag_get_exact(struct sk_buff *in_skb, handler = inet_diag_lock_handler(req->sdiag_protocol); if (IS_ERR(handler)) err = PTR_ERR(handler); - else + else if (cmd == SOCK_DIAG_BY_FAMILY) err = handler->dump_one(in_skb, nlh, req); + else if (cmd == SOCK_DESTROY_BACKPORT && handler->destroy) + err = handler->destroy(in_skb, req); + else + err = -EOPNOTSUPP; inet_diag_unlock_handler(handler); return err; @@ -536,6 +566,22 @@ static int inet_diag_bc_run(const struct nlattr *_bc, yes = 0; break; } + case INET_DIAG_BC_DEV_COND: { + u32 ifindex; + + ifindex = *((const u32 *)(op + 1)); + if (ifindex != entry->ifindex) + yes = 0; + break; + } + case INET_DIAG_BC_MARK_COND: { + struct inet_diag_markcond *cond; + + cond = (struct inet_diag_markcond *)(op + 1); + if ((entry->mark & cond->mask) != cond->mark) + yes = 0; + break; + } } if (yes) { @@ -578,7 +624,14 @@ int inet_diag_bc_sk(const struct nlattr *bc, struct sock *sk) entry_fill_addrs(&entry, sk); entry.sport = inet->inet_num; entry.dport = ntohs(inet->inet_dport); + entry.ifindex = sk->sk_bound_dev_if; entry.userlocks = sk_fullsock(sk) ? sk->sk_userlocks : 0; + if (sk_fullsock(sk)) + entry.mark = sk->sk_mark; + else if (sk->sk_state == TCP_NEW_SYN_RECV) + entry.mark = inet_rsk(inet_reqsk(sk))->ir_mark; + else + entry.mark = 0; return inet_diag_bc_run(bc, &entry); } @@ -601,6 +654,17 @@ static int valid_cc(const void *bc, int len, int cc) return 0; } +/* data is u32 ifindex */ +static bool valid_devcond(const struct inet_diag_bc_op *op, int len, + int *min_len) +{ + /* Check ifindex space. */ + *min_len += sizeof(u32); + if (len < *min_len) + return false; + + return true; +} /* Validate an inet_diag_hostcond. */ static bool valid_hostcond(const struct inet_diag_bc_op *op, int len, int *min_len) @@ -650,10 +714,25 @@ static bool valid_port_comparison(const struct inet_diag_bc_op *op, return true; } -static int inet_diag_bc_audit(const void *bytecode, int bytecode_len) +static bool valid_markcond(const struct inet_diag_bc_op *op, int len, + int *min_len) +{ + *min_len += sizeof(struct inet_diag_markcond); + return len >= *min_len; +} + +static int inet_diag_bc_audit(const struct nlattr *attr, + const struct sk_buff *skb) { - const void *bc = bytecode; - int len = bytecode_len; + bool net_admin = netlink_net_capable(skb, CAP_NET_ADMIN); + const void *bytecode, *bc; + int bytecode_len, len; + + if (!attr || nla_len(attr) < sizeof(struct inet_diag_bc_op)) + return -EINVAL; + + bytecode = bc = nla_data(attr); + len = bytecode_len = nla_len(attr); while (len > 0) { int min_len = sizeof(struct inet_diag_bc_op); @@ -665,6 +744,10 @@ static int inet_diag_bc_audit(const void *bytecode, int bytecode_len) if (!valid_hostcond(bc, len, &min_len)) return -EINVAL; break; + case INET_DIAG_BC_DEV_COND: + if (!valid_devcond(bc, len, &min_len)) + return -EINVAL; + break; case INET_DIAG_BC_S_GE: case INET_DIAG_BC_S_LE: case INET_DIAG_BC_D_GE: @@ -672,6 +755,12 @@ static int inet_diag_bc_audit(const void *bytecode, int bytecode_len) if (!valid_port_comparison(bc, len, &min_len)) return -EINVAL; break; + case INET_DIAG_BC_MARK_COND: + if (!net_admin) + return -EPERM; + if (!valid_markcond(bc, len, &min_len)) + return -EINVAL; + break; case INET_DIAG_BC_AUTO: case INET_DIAG_BC_JMP: case INET_DIAG_BC_NOP: @@ -700,7 +789,8 @@ static int inet_csk_diag_dump(struct sock *sk, struct sk_buff *skb, struct netlink_callback *cb, const struct inet_diag_req_v2 *r, - const struct nlattr *bc) + const struct nlattr *bc, + bool net_admin) { if (!inet_diag_bc_sk(bc, sk)) return 0; @@ -708,7 +798,8 @@ static int inet_csk_diag_dump(struct sock *sk, return inet_csk_diag_fill(sk, skb, r, sk_user_ns(NETLINK_CB(cb->skb).sk), NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh); + cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh, + net_admin); } static void twsk_build_assert(void) @@ -744,6 +835,7 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb, struct net *net = sock_net(skb->sk); int i, num, s_i, s_num; u32 idiag_states = r->idiag_states; + bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN); if (idiag_states & TCPF_SYN_RECV) idiag_states |= TCPF_NEW_SYN_RECV; @@ -785,7 +877,8 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb, cb->args[3] > 0) goto next_listen; - if (inet_csk_diag_dump(sk, skb, cb, r, bc) < 0) { + if (inet_csk_diag_dump(sk, skb, cb, r, + bc, net_admin) < 0) { spin_unlock_bh(&ilb->lock); goto done; } @@ -853,7 +946,7 @@ skip_listen_ht: sk_user_ns(NETLINK_CB(cb->skb).sk), NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, NLM_F_MULTI, - cb->nlh); + cb->nlh, net_admin); if (res < 0) { spin_unlock_bh(lock); goto done; @@ -945,7 +1038,7 @@ static int inet_diag_get_exact_compat(struct sk_buff *in_skb, req.idiag_states = rc->idiag_states; req.id = rc->id; - return inet_diag_get_exact(in_skb, nlh, &req); + return inet_diag_cmd_exact(SOCK_DIAG_BY_FAMILY, in_skb, nlh, &req); } static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh) @@ -960,13 +1053,13 @@ static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh) if (nlh->nlmsg_flags & NLM_F_DUMP) { if (nlmsg_attrlen(nlh, hdrlen)) { struct nlattr *attr; + int err; attr = nlmsg_find_attr(nlh, hdrlen, INET_DIAG_REQ_BYTECODE); - if (!attr || - nla_len(attr) < sizeof(struct inet_diag_bc_op) || - inet_diag_bc_audit(nla_data(attr), nla_len(attr))) - return -EINVAL; + err = inet_diag_bc_audit(attr, skb); + if (err) + return err; } { struct netlink_dump_control c = { @@ -979,7 +1072,7 @@ static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh) return inet_diag_get_exact_compat(skb, nlh); } -static int inet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h) +static int inet_diag_handler_cmd(struct sk_buff *skb, struct nlmsghdr *h) { int hdrlen = sizeof(struct inet_diag_req_v2); struct net *net = sock_net(skb->sk); @@ -987,16 +1080,17 @@ static int inet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h) if (nlmsg_len(h) < hdrlen) return -EINVAL; - if (h->nlmsg_flags & NLM_F_DUMP) { + if (h->nlmsg_type == SOCK_DIAG_BY_FAMILY && + h->nlmsg_flags & NLM_F_DUMP) { if (nlmsg_attrlen(h, hdrlen)) { struct nlattr *attr; + int err; attr = nlmsg_find_attr(h, hdrlen, INET_DIAG_REQ_BYTECODE); - if (!attr || - nla_len(attr) < sizeof(struct inet_diag_bc_op) || - inet_diag_bc_audit(nla_data(attr), nla_len(attr))) - return -EINVAL; + err = inet_diag_bc_audit(attr, skb); + if (err) + return err; } { struct netlink_dump_control c = { @@ -1006,7 +1100,7 @@ static int inet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h) } } - return inet_diag_get_exact(skb, h, nlmsg_data(h)); + return inet_diag_cmd_exact(h->nlmsg_type, skb, h, nlmsg_data(h)); } static @@ -1057,14 +1151,16 @@ int inet_diag_handler_get_info(struct sk_buff *skb, struct sock *sk) static const struct sock_diag_handler inet_diag_handler = { .family = AF_INET, - .dump = inet_diag_handler_dump, + .dump = inet_diag_handler_cmd, .get_info = inet_diag_handler_get_info, + .destroy = inet_diag_handler_cmd, }; static const struct sock_diag_handler inet6_diag_handler = { .family = AF_INET6, - .dump = inet_diag_handler_dump, + .dump = inet_diag_handler_cmd, .get_info = inet_diag_handler_get_info, + .destroy = inet_diag_handler_cmd, }; int inet_diag_register(const struct inet_diag_handler *h) diff --git a/net/ipv4/ip_gre.c b/net/ipv4/ip_gre.c index 900ee28bda99..63f7bacf628a 100644 --- a/net/ipv4/ip_gre.c +++ b/net/ipv4/ip_gre.c @@ -502,6 +502,10 @@ static void __gre_xmit(struct sk_buff *skb, struct net_device *dev, static struct sk_buff *gre_handle_offloads(struct sk_buff *skb, bool csum) { + unsigned char *skb_checksum_start = skb->head + skb->csum_start; + + if (csum && skb_checksum_start < skb->data) + return ERR_PTR(-EINVAL); return iptunnel_handle_offloads(skb, csum, csum ? SKB_GSO_GRE_CSUM : SKB_GSO_GRE); } diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c index 477540b3d320..aad369b767f9 100644 --- a/net/ipv4/ip_output.c +++ b/net/ipv4/ip_output.c @@ -1593,7 +1593,8 @@ void ip_send_unicast_reply(struct sock *sk, struct sk_buff *skb, RT_SCOPE_UNIVERSE, ip_hdr(skb)->protocol, ip_reply_arg_flowi_flags(arg), daddr, saddr, - tcp_hdr(skb)->source, tcp_hdr(skb)->dest); + tcp_hdr(skb)->source, tcp_hdr(skb)->dest, + arg->uid); security_skb_classify_flow(skb, flowi4_to_flowi(&fl4)); rt = ip_route_output_key(net, &fl4); if (IS_ERR(rt)) diff --git a/net/ipv4/netfilter/ip_tables.c b/net/ipv4/netfilter/ip_tables.c index 684003063174..73b1d8e64658 100644 --- a/net/ipv4/netfilter/ip_tables.c +++ b/net/ipv4/netfilter/ip_tables.c @@ -964,10 +964,6 @@ copy_entries_to_user(unsigned int total_size, return PTR_ERR(counters); loc_cpu_entry = private->entries; - if (copy_to_user(userptr, loc_cpu_entry, total_size) != 0) { - ret = -EFAULT; - goto free_counters; - } /* FIXME: use iterator macros --RR */ /* ... then go back and fix counters and names */ @@ -977,6 +973,10 @@ copy_entries_to_user(unsigned int total_size, const struct xt_entry_target *t; e = (struct ipt_entry *)(loc_cpu_entry + off); + if (copy_to_user(userptr + off, e, sizeof(*e))) { + ret = -EFAULT; + goto free_counters; + } if (copy_to_user(userptr + off + offsetof(struct ipt_entry, counters), &counters[num], @@ -990,23 +990,14 @@ copy_entries_to_user(unsigned int total_size, i += m->u.match_size) { m = (void *)e + i; - if (copy_to_user(userptr + off + i - + offsetof(struct xt_entry_match, - u.user.name), - m->u.kernel.match->name, - strlen(m->u.kernel.match->name)+1) - != 0) { + if (xt_match_to_user(m, userptr + off + i)) { ret = -EFAULT; goto free_counters; } } t = ipt_get_target_c(e); - if (copy_to_user(userptr + off + e->target_offset - + offsetof(struct xt_entry_target, - u.user.name), - t->u.kernel.target->name, - strlen(t->u.kernel.target->name)+1) != 0) { + if (xt_target_to_user(t, userptr + off + e->target_offset)) { ret = -EFAULT; goto free_counters; } diff --git a/net/ipv4/netfilter/ipt_CLUSTERIP.c b/net/ipv4/netfilter/ipt_CLUSTERIP.c index 16599bae11dd..28bcde0a2749 100644 --- a/net/ipv4/netfilter/ipt_CLUSTERIP.c +++ b/net/ipv4/netfilter/ipt_CLUSTERIP.c @@ -478,6 +478,7 @@ static struct xt_target clusterip_tg_reg __read_mostly = { .checkentry = clusterip_tg_check, .destroy = clusterip_tg_destroy, .targetsize = sizeof(struct ipt_clusterip_tgt_info), + .usersize = offsetof(struct ipt_clusterip_tgt_info, config), #ifdef CONFIG_COMPAT .compatsize = sizeof(struct compat_ipt_clusterip_tgt_info), #endif /* CONFIG_COMPAT */ diff --git a/net/ipv4/ping.c b/net/ipv4/ping.c index 56d89dfd8831..012a18052c8d 100644 --- a/net/ipv4/ping.c +++ b/net/ipv4/ping.c @@ -800,7 +800,8 @@ static int ping_v4_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) flowi4_init_output(&fl4, ipc.oif, sk->sk_mark, tos, RT_SCOPE_UNIVERSE, sk->sk_protocol, - inet_sk_flowi_flags(sk), faddr, saddr, 0, 0); + inet_sk_flowi_flags(sk), faddr, saddr, 0, 0, + sk->sk_uid); fl4.fl4_icmp_type = user_icmph.type; fl4.fl4_icmp_code = user_icmph.code; diff --git a/net/ipv4/raw.c b/net/ipv4/raw.c index 24ce13a79665..051ed5732a81 100644 --- a/net/ipv4/raw.c +++ b/net/ipv4/raw.c @@ -611,7 +611,7 @@ static int raw_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) hdrincl ? IPPROTO_RAW : sk->sk_protocol, inet_sk_flowi_flags(sk) | (hdrincl ? FLOWI_FLAG_KNOWN_NH : 0), - daddr, saddr, 0, 0); + daddr, saddr, 0, 0, sk->sk_uid); if (!saddr && ipc.oif) { err = l3mdev_get_saddr(net, ipc.oif, &fl4); diff --git a/net/ipv4/route.c b/net/ipv4/route.c index ed6483181676..f79f9a6dd046 100644 --- a/net/ipv4/route.c +++ b/net/ipv4/route.c @@ -516,7 +516,8 @@ void __ip_select_ident(struct net *net, struct iphdr *iph, int segs) } EXPORT_SYMBOL(__ip_select_ident); -static void __build_flow_key(struct flowi4 *fl4, const struct sock *sk, +static void __build_flow_key(const struct net *net, struct flowi4 *fl4, + const struct sock *sk, const struct iphdr *iph, int oif, u8 tos, u8 prot, u32 mark, int flow_flags) @@ -532,19 +533,21 @@ static void __build_flow_key(struct flowi4 *fl4, const struct sock *sk, flowi4_init_output(fl4, oif, mark, tos, RT_SCOPE_UNIVERSE, prot, flow_flags, - iph->daddr, iph->saddr, 0, 0); + iph->daddr, iph->saddr, 0, 0, + sock_net_uid(net, sk)); } static void build_skb_flow_key(struct flowi4 *fl4, const struct sk_buff *skb, const struct sock *sk) { + const struct net *net = dev_net(skb->dev); const struct iphdr *iph = ip_hdr(skb); int oif = skb->dev->ifindex; u8 tos = RT_TOS(iph->tos); u8 prot = iph->protocol; u32 mark = skb->mark; - __build_flow_key(fl4, sk, iph, oif, tos, prot, mark, 0); + __build_flow_key(net, fl4, sk, iph, oif, tos, prot, mark, 0); } static void build_sk_flow_key(struct flowi4 *fl4, const struct sock *sk) @@ -561,7 +564,7 @@ static void build_sk_flow_key(struct flowi4 *fl4, const struct sock *sk) RT_CONN_FLAGS(sk), RT_SCOPE_UNIVERSE, inet->hdrincl ? IPPROTO_RAW : sk->sk_protocol, inet_sk_flowi_flags(sk), - daddr, inet->inet_saddr, 0, 0); + daddr, inet->inet_saddr, 0, 0, sk->sk_uid); rcu_read_unlock(); } @@ -827,6 +830,7 @@ static void ip_do_redirect(struct dst_entry *dst, struct sock *sk, struct sk_buf struct rtable *rt; struct flowi4 fl4; const struct iphdr *iph = (const struct iphdr *) skb->data; + struct net *net = dev_net(skb->dev); int oif = skb->dev->ifindex; u8 tos = RT_TOS(iph->tos); u8 prot = iph->protocol; @@ -834,7 +838,7 @@ static void ip_do_redirect(struct dst_entry *dst, struct sock *sk, struct sk_buf rt = (struct rtable *) dst; - __build_flow_key(&fl4, sk, iph, oif, tos, prot, mark, 0); + __build_flow_key(net, &fl4, sk, iph, oif, tos, prot, mark, 0); __ip_do_redirect(rt, skb, &fl4, true); } @@ -1058,7 +1062,7 @@ void ipv4_update_pmtu(struct sk_buff *skb, struct net *net, u32 mtu, if (!mark) mark = IP4_REPLY_MARK(net, skb->mark); - __build_flow_key(&fl4, NULL, iph, oif, + __build_flow_key(net, &fl4, NULL, iph, oif, RT_TOS(iph->tos), protocol, mark, flow_flags); rt = __ip_route_output_key(net, &fl4); if (!IS_ERR(rt)) { @@ -1074,7 +1078,7 @@ static void __ipv4_sk_update_pmtu(struct sk_buff *skb, struct sock *sk, u32 mtu) struct flowi4 fl4; struct rtable *rt; - __build_flow_key(&fl4, sk, iph, 0, 0, 0, 0, 0); + __build_flow_key(sock_net(sk), &fl4, sk, iph, 0, 0, 0, 0, 0); if (!fl4.flowi4_mark) fl4.flowi4_mark = IP4_REPLY_MARK(sock_net(sk), skb->mark); @@ -1093,6 +1097,7 @@ void ipv4_sk_update_pmtu(struct sk_buff *skb, struct sock *sk, u32 mtu) struct rtable *rt; struct dst_entry *odst = NULL; bool new = false; + struct net *net = sock_net(sk); bh_lock_sock(sk); @@ -1106,7 +1111,7 @@ void ipv4_sk_update_pmtu(struct sk_buff *skb, struct sock *sk, u32 mtu) goto out; } - __build_flow_key(&fl4, sk, iph, 0, 0, 0, 0, 0); + __build_flow_key(net, &fl4, sk, iph, 0, 0, 0, 0, 0); rt = (struct rtable *)odst; if (odst->obsolete && !odst->ops->check(odst, 0)) { @@ -1146,7 +1151,7 @@ void ipv4_redirect(struct sk_buff *skb, struct net *net, struct flowi4 fl4; struct rtable *rt; - __build_flow_key(&fl4, NULL, iph, oif, + __build_flow_key(net, &fl4, NULL, iph, oif, RT_TOS(iph->tos), protocol, mark, flow_flags); rt = __ip_route_output_key(net, &fl4); if (!IS_ERR(rt)) { @@ -1161,9 +1166,10 @@ void ipv4_sk_redirect(struct sk_buff *skb, struct sock *sk) const struct iphdr *iph = (const struct iphdr *) skb->data; struct flowi4 fl4; struct rtable *rt; + struct net *net = sock_net(sk); - __build_flow_key(&fl4, sk, iph, 0, 0, 0, 0, 0); - rt = __ip_route_output_key(sock_net(sk), &fl4); + __build_flow_key(net, &fl4, sk, iph, 0, 0, 0, 0, 0); + rt = __ip_route_output_key(net, &fl4); if (!IS_ERR(rt)) { __ip_do_redirect(rt, skb, &fl4, false); ip_rt_put(rt); @@ -2579,6 +2585,11 @@ static int rt_fill_info(struct net *net, __be32 dst, __be32 src, u32 table_id, nla_put_u32(skb, RTA_MARK, fl4->flowi4_mark)) goto nla_put_failure; + if (!uid_eq(fl4->flowi4_uid, INVALID_UID) && + nla_put_u32(skb, RTA_UID, + from_kuid_munged(current_user_ns(), fl4->flowi4_uid))) + goto nla_put_failure; + error = rt->dst.error; if (rt_is_input_route(rt)) { @@ -2631,6 +2642,7 @@ static int inet_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh) int mark; struct sk_buff *skb; u32 table_id = RT_TABLE_MAIN; + kuid_t uid; err = nlmsg_parse(nlh, sizeof(*rtm), tb, RTA_MAX, rtm_ipv4_policy); if (err < 0) @@ -2658,6 +2670,10 @@ static int inet_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh) dst = tb[RTA_DST] ? nla_get_in_addr(tb[RTA_DST]) : 0; iif = tb[RTA_IIF] ? nla_get_u32(tb[RTA_IIF]) : 0; mark = tb[RTA_MARK] ? nla_get_u32(tb[RTA_MARK]) : 0; + if (tb[RTA_UID]) + uid = make_kuid(current_user_ns(), nla_get_u32(tb[RTA_UID])); + else + uid = (iif ? INVALID_UID : current_uid()); memset(&fl4, 0, sizeof(fl4)); fl4.daddr = dst; @@ -2665,6 +2681,7 @@ static int inet_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh) fl4.flowi4_tos = rtm->rtm_tos; fl4.flowi4_oif = tb[RTA_OIF] ? nla_get_u32(tb[RTA_OIF]) : 0; fl4.flowi4_mark = mark; + fl4.flowi4_uid = uid; if (netif_index_is_l3_master(net, fl4.flowi4_oif)) fl4.flowi4_flags = FLOWI_FLAG_L3MDEV_SRC | FLOWI_FLAG_SKIP_NH_OIF; diff --git a/net/ipv4/syncookies.c b/net/ipv4/syncookies.c index d9e7d61a0197..ba0301860985 100644 --- a/net/ipv4/syncookies.c +++ b/net/ipv4/syncookies.c @@ -381,7 +381,7 @@ struct sock *cookie_v4_check(struct sock *sk, struct sk_buff *skb) RT_CONN_FLAGS(sk), RT_SCOPE_UNIVERSE, IPPROTO_TCP, inet_sk_flowi_flags(sk), opt->srr ? opt->faddr : ireq->ir_rmt_addr, - ireq->ir_loc_addr, th->source, th->dest); + ireq->ir_loc_addr, th->source, th->dest, sk->sk_uid); security_req_classify_flow(req, flowi4_to_flowi(&fl4)); rt = ip_route_output_key(sock_net(sk), &fl4); if (IS_ERR(rt)) { diff --git a/net/ipv4/sysctl_net_ipv4.c b/net/ipv4/sysctl_net_ipv4.c index 6413e36d639d..a433886311e3 100644 --- a/net/ipv4/sysctl_net_ipv4.c +++ b/net/ipv4/sysctl_net_ipv4.c @@ -156,6 +156,21 @@ static int ipv4_ping_group_range(struct ctl_table *table, int write, return ret; } +/* Validate changes from /proc interface. */ +static int proc_tcp_default_init_rwnd(struct ctl_table *ctl, int write, + void __user *buffer, + size_t *lenp, loff_t *ppos) +{ + int old_value = *(int *)ctl->data; + int ret = proc_dointvec(ctl, write, buffer, lenp, ppos); + int new_value = *(int *)ctl->data; + + if (write && ret == 0 && (new_value < 3 || new_value > 100)) + *(int *)ctl->data = old_value; + + return ret; +} + static int proc_tcp_congestion_control(struct ctl_table *ctl, int write, void __user *buffer, size_t *lenp, loff_t *ppos) { @@ -775,6 +790,13 @@ static struct ctl_table ipv4_table[] = { .proc_handler = proc_dointvec_ms_jiffies, }, { + .procname = "tcp_default_init_rwnd", + .data = &sysctl_tcp_default_init_rwnd, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_tcp_default_init_rwnd + }, + { .procname = "icmp_msgs_per_sec", .data = &sysctl_icmp_msgs_per_sec, .maxlen = sizeof(int), diff --git a/net/ipv4/sysfs_net_ipv4.c b/net/ipv4/sysfs_net_ipv4.c new file mode 100644 index 000000000000..0cbbf10026a6 --- /dev/null +++ b/net/ipv4/sysfs_net_ipv4.c @@ -0,0 +1,88 @@ +/* + * net/ipv4/sysfs_net_ipv4.c + * + * sysfs-based networking knobs (so we can, unlike with sysctl, control perms) + * + * Copyright (C) 2008 Google, Inc. + * + * Robert Love <rlove@google.com> + * + * This software is licensed under the terms of the GNU General Public + * License version 2, as published by the Free Software Foundation, and + * may be copied, distributed, and modified under those terms. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + */ + +#include <linux/kobject.h> +#include <linux/string.h> +#include <linux/sysfs.h> +#include <linux/init.h> +#include <net/tcp.h> + +#define CREATE_IPV4_FILE(_name, _var) \ +static ssize_t _name##_show(struct kobject *kobj, \ + struct kobj_attribute *attr, char *buf) \ +{ \ + return sprintf(buf, "%d\n", _var); \ +} \ +static ssize_t _name##_store(struct kobject *kobj, \ + struct kobj_attribute *attr, \ + const char *buf, size_t count) \ +{ \ + int val, ret; \ + ret = sscanf(buf, "%d", &val); \ + if (ret != 1) \ + return -EINVAL; \ + if (val < 0) \ + return -EINVAL; \ + _var = val; \ + return count; \ +} \ +static struct kobj_attribute _name##_attr = \ + __ATTR(_name, 0644, _name##_show, _name##_store) + +CREATE_IPV4_FILE(tcp_wmem_min, sysctl_tcp_wmem[0]); +CREATE_IPV4_FILE(tcp_wmem_def, sysctl_tcp_wmem[1]); +CREATE_IPV4_FILE(tcp_wmem_max, sysctl_tcp_wmem[2]); + +CREATE_IPV4_FILE(tcp_rmem_min, sysctl_tcp_rmem[0]); +CREATE_IPV4_FILE(tcp_rmem_def, sysctl_tcp_rmem[1]); +CREATE_IPV4_FILE(tcp_rmem_max, sysctl_tcp_rmem[2]); + +static struct attribute *ipv4_attrs[] = { + &tcp_wmem_min_attr.attr, + &tcp_wmem_def_attr.attr, + &tcp_wmem_max_attr.attr, + &tcp_rmem_min_attr.attr, + &tcp_rmem_def_attr.attr, + &tcp_rmem_max_attr.attr, + NULL +}; + +static struct attribute_group ipv4_attr_group = { + .attrs = ipv4_attrs, +}; + +static __init int sysfs_ipv4_init(void) +{ + struct kobject *ipv4_kobject; + int ret; + + ipv4_kobject = kobject_create_and_add("ipv4", kernel_kobj); + if (!ipv4_kobject) + return -ENOMEM; + + ret = sysfs_create_group(ipv4_kobject, &ipv4_attr_group); + if (ret) { + kobject_put(ipv4_kobject); + return ret; + } + + return 0; +} + +subsys_initcall(sysfs_ipv4_init); diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index 4080cf1a369d..081537596ea7 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -3118,6 +3118,52 @@ void tcp_done(struct sock *sk) } EXPORT_SYMBOL_GPL(tcp_done); +int tcp_abort(struct sock *sk, int err) +{ + if (!sk_fullsock(sk)) { + if (sk->sk_state == TCP_NEW_SYN_RECV) { + struct request_sock *req = inet_reqsk(sk); + + local_bh_disable(); + inet_csk_reqsk_queue_drop_and_put(req->rsk_listener, + req); + local_bh_enable(); + return 0; + } + sock_gen_put(sk); + return -EOPNOTSUPP; + } + + /* Don't race with userspace socket closes such as tcp_close. */ + lock_sock(sk); + + if (sk->sk_state == TCP_LISTEN) { + tcp_set_state(sk, TCP_CLOSE); + inet_csk_listen_stop(sk); + } + + /* Don't race with BH socket closes such as inet_csk_listen_stop. */ + local_bh_disable(); + bh_lock_sock(sk); + + if (!sock_flag(sk, SOCK_DEAD)) { + sk->sk_err = err; + /* This barrier is coupled with smp_rmb() in tcp_poll() */ + smp_wmb(); + sk->sk_error_report(sk); + if (tcp_need_reset(sk->sk_state)) + tcp_send_active_reset(sk, GFP_ATOMIC); + tcp_done(sk); + } + + bh_unlock_sock(sk); + local_bh_enable(); + release_sock(sk); + sock_put(sk); + return 0; +} +EXPORT_SYMBOL_GPL(tcp_abort); + extern struct tcp_congestion_ops tcp_reno; static __initdata unsigned long thash_entries; diff --git a/net/ipv4/tcp_diag.c b/net/ipv4/tcp_diag.c index b31604086edd..4d610934fb39 100644 --- a/net/ipv4/tcp_diag.c +++ b/net/ipv4/tcp_diag.c @@ -10,6 +10,8 @@ */ #include <linux/module.h> +#include <linux/net.h> +#include <linux/sock_diag.h> #include <linux/inet_diag.h> #include <linux/tcp.h> @@ -46,12 +48,29 @@ static int tcp_diag_dump_one(struct sk_buff *in_skb, const struct nlmsghdr *nlh, return inet_diag_dump_one_icsk(&tcp_hashinfo, in_skb, nlh, req); } +#ifdef CONFIG_INET_DIAG_DESTROY +static int tcp_diag_destroy(struct sk_buff *in_skb, + const struct inet_diag_req_v2 *req) +{ + struct net *net = sock_net(in_skb->sk); + struct sock *sk = inet_diag_find_one_icsk(net, &tcp_hashinfo, req); + + if (IS_ERR(sk)) + return PTR_ERR(sk); + + return sock_diag_destroy(sk, ECONNABORTED); +} +#endif + static const struct inet_diag_handler tcp_diag_handler = { .dump = tcp_diag_dump, .dump_one = tcp_diag_dump_one, .idiag_get_info = tcp_diag_get_info, .idiag_type = IPPROTO_TCP, .idiag_info_size = sizeof(struct tcp_info), +#ifdef CONFIG_INET_DIAG_DESTROY + .destroy = tcp_diag_destroy, +#endif }; static int __init tcp_diag_init(void) diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c index 0919183b003f..7232c68dde09 100644 --- a/net/ipv4/tcp_input.c +++ b/net/ipv4/tcp_input.c @@ -102,6 +102,7 @@ int sysctl_tcp_thin_dupack __read_mostly; int sysctl_tcp_moderate_rcvbuf __read_mostly = 1; int sysctl_tcp_early_retrans __read_mostly = 3; int sysctl_tcp_invalid_ratelimit __read_mostly = HZ/2; +int sysctl_tcp_default_init_rwnd __read_mostly = TCP_INIT_CWND * 2; #define FLAG_DATA 0x01 /* Incoming frame contained data. */ #define FLAG_WIN_UPDATE 0x02 /* Incoming ACK was a window update. */ diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index f9d55dd2dec8..d9722c42d85d 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -692,6 +692,7 @@ static void tcp_v4_send_reset(const struct sock *sk, struct sk_buff *skb) arg.bound_dev_if = sk->sk_bound_dev_if; arg.tos = ip_hdr(skb)->tos; + arg.uid = sock_net_uid(net, sk && sk_fullsock(sk) ? sk : NULL); ip_send_unicast_reply(*this_cpu_ptr(net->ipv4.tcp_sk), skb, &TCP_SKB_CB(skb)->header.h4.opt, ip_hdr(skb)->saddr, ip_hdr(skb)->daddr, @@ -713,8 +714,8 @@ release_sk1: outside socket context is ugly, certainly. What can I do? */ -static void tcp_v4_send_ack(struct net *net, - struct sk_buff *skb, u32 seq, u32 ack, +static void tcp_v4_send_ack(const struct sock *sk, struct sk_buff *skb, + u32 seq, u32 ack, u32 win, u32 tsval, u32 tsecr, int oif, struct tcp_md5sig_key *key, int reply_flags, u8 tos) @@ -729,6 +730,7 @@ static void tcp_v4_send_ack(struct net *net, ]; } rep; struct ip_reply_arg arg; + struct net *net = sock_net(sk); memset(&rep.th, 0, sizeof(struct tcphdr)); memset(&arg, 0, sizeof(arg)); @@ -777,6 +779,7 @@ static void tcp_v4_send_ack(struct net *net, if (oif) arg.bound_dev_if = oif; arg.tos = tos; + arg.uid = sock_net_uid(net, sk_fullsock(sk) ? sk : NULL); ip_send_unicast_reply(*this_cpu_ptr(net->ipv4.tcp_sk), skb, &TCP_SKB_CB(skb)->header.h4.opt, ip_hdr(skb)->saddr, ip_hdr(skb)->daddr, @@ -790,8 +793,7 @@ static void tcp_v4_timewait_ack(struct sock *sk, struct sk_buff *skb) struct inet_timewait_sock *tw = inet_twsk(sk); struct tcp_timewait_sock *tcptw = tcp_twsk(sk); - tcp_v4_send_ack(sock_net(sk), skb, - tcptw->tw_snd_nxt, tcptw->tw_rcv_nxt, + tcp_v4_send_ack(sk, skb, tcptw->tw_snd_nxt, tcptw->tw_rcv_nxt, tcptw->tw_rcv_wnd >> tw->tw_rcv_wscale, tcp_time_stamp + tcptw->tw_ts_offset, tcptw->tw_ts_recent, @@ -810,17 +812,9 @@ static void tcp_v4_reqsk_send_ack(const struct sock *sk, struct sk_buff *skb, /* sk->sk_state == TCP_LISTEN -> for regular TCP_SYN_RECV * sk->sk_state == TCP_SYN_RECV -> for Fast Open. */ - u32 seq = (sk->sk_state == TCP_LISTEN) ? tcp_rsk(req)->snt_isn + 1 : - tcp_sk(sk)->snd_nxt; - - /* RFC 7323 2.3 - * The window field (SEG.WND) of every outgoing segment, with the - * exception of <SYN> segments, MUST be right-shifted by - * Rcv.Wind.Shift bits: - */ - tcp_v4_send_ack(sock_net(sk), skb, seq, - tcp_rsk(req)->rcv_nxt, - req->rsk_rcv_wnd >> inet_rsk(req)->rcv_wscale, + tcp_v4_send_ack(sk, skb, (sk->sk_state == TCP_LISTEN) ? + tcp_rsk(req)->snt_isn + 1 : tcp_sk(sk)->snd_nxt, + tcp_rsk(req)->rcv_nxt, req->rsk_rcv_wnd, tcp_time_stamp, req->ts_recent, 0, @@ -2395,6 +2389,7 @@ struct proto tcp_prot = { .destroy_cgroup = tcp_destroy_cgroup, .proto_cgroup = tcp_proto_cgroup, #endif + .diag_destroy = tcp_abort, }; EXPORT_SYMBOL(tcp_prot); diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c index 3f061bbf842a..17388adedc7d 100644 --- a/net/ipv4/tcp_output.c +++ b/net/ipv4/tcp_output.c @@ -196,7 +196,7 @@ u32 tcp_default_init_rwnd(u32 mss) * (RFC 3517, Section 4, NextSeg() rule (2)). Further place a * limit when mss is larger than 1460. */ - u32 init_rwnd = TCP_INIT_CWND * 2; + u32 init_rwnd = sysctl_tcp_default_init_rwnd; if (mss > 1460) init_rwnd = max((1460 * init_rwnd) / mss, 2U); diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index aba49b23e65f..51e5ab92a5a4 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -1027,7 +1027,8 @@ int udp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) flowi4_init_output(fl4, ipc.oif, sk->sk_mark, tos, RT_SCOPE_UNIVERSE, sk->sk_protocol, flow_flags, - faddr, saddr, dport, inet->inet_sport); + faddr, saddr, dport, inet->inet_sport, + sk->sk_uid); if (!saddr && ipc.oif) { err = l3mdev_get_saddr(net, ipc.oif, fl4); @@ -2274,6 +2275,20 @@ unsigned int udp_poll(struct file *file, struct socket *sock, poll_table *wait) } EXPORT_SYMBOL(udp_poll); +int udp_abort(struct sock *sk, int err) +{ + lock_sock(sk); + + sk->sk_err = err; + sk->sk_error_report(sk); + udp_disconnect(sk, 0); + + release_sock(sk); + + return 0; +} +EXPORT_SYMBOL_GPL(udp_abort); + struct proto udp_prot = { .name = "UDP", .owner = THIS_MODULE, @@ -2305,6 +2320,7 @@ struct proto udp_prot = { .compat_getsockopt = compat_udp_getsockopt, #endif .clear_sk = sk_prot_clear_portaddr_nulls, + .diag_destroy = udp_abort, }; EXPORT_SYMBOL(udp_prot); diff --git a/net/ipv4/udp_diag.c b/net/ipv4/udp_diag.c index 6116604bf6e8..092aa60e8b92 100644 --- a/net/ipv4/udp_diag.c +++ b/net/ipv4/udp_diag.c @@ -20,7 +20,7 @@ static int sk_diag_dump(struct sock *sk, struct sk_buff *skb, struct netlink_callback *cb, const struct inet_diag_req_v2 *req, - struct nlattr *bc) + struct nlattr *bc, bool net_admin) { if (!inet_diag_bc_sk(bc, sk)) return 0; @@ -28,7 +28,7 @@ static int sk_diag_dump(struct sock *sk, struct sk_buff *skb, return inet_sk_diag_fill(sk, NULL, skb, req, sk_user_ns(NETLINK_CB(cb->skb).sk), NETLINK_CB(cb->skb).portid, - cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh); + cb->nlh->nlmsg_seq, NLM_F_MULTI, cb->nlh, net_admin); } static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb, @@ -75,7 +75,8 @@ static int udp_dump_one(struct udp_table *tbl, struct sk_buff *in_skb, err = inet_sk_diag_fill(sk, NULL, rep, req, sk_user_ns(NETLINK_CB(in_skb).sk), NETLINK_CB(in_skb).portid, - nlh->nlmsg_seq, 0, nlh); + nlh->nlmsg_seq, 0, nlh, + netlink_net_capable(in_skb, CAP_NET_ADMIN)); if (err < 0) { WARN_ON(err == -EMSGSIZE); kfree_skb(rep); @@ -98,6 +99,7 @@ static void udp_dump(struct udp_table *table, struct sk_buff *skb, { int num, s_num, slot, s_slot; struct net *net = sock_net(skb->sk); + bool net_admin = netlink_net_capable(cb->skb, CAP_NET_ADMIN); s_slot = cb->args[0]; num = s_num = cb->args[1]; @@ -132,7 +134,7 @@ static void udp_dump(struct udp_table *table, struct sk_buff *skb, r->id.idiag_dport) goto next; - if (sk_diag_dump(sk, skb, cb, r, bc) < 0) { + if (sk_diag_dump(sk, skb, cb, r, bc, net_admin) < 0) { spin_unlock_bh(&hslot->lock); goto done; } @@ -165,12 +167,88 @@ static void udp_diag_get_info(struct sock *sk, struct inet_diag_msg *r, r->idiag_wqueue = sk_wmem_alloc_get(sk); } +#ifdef CONFIG_INET_DIAG_DESTROY +static int __udp_diag_destroy(struct sk_buff *in_skb, + const struct inet_diag_req_v2 *req, + struct udp_table *tbl) +{ + struct net *net = sock_net(in_skb->sk); + struct sock *sk; + int err; + + rcu_read_lock(); + + if (req->sdiag_family == AF_INET) + sk = __udp4_lib_lookup(net, + req->id.idiag_dst[0], req->id.idiag_dport, + req->id.idiag_src[0], req->id.idiag_sport, + req->id.idiag_if, tbl); +#if IS_ENABLED(CONFIG_IPV6) + else if (req->sdiag_family == AF_INET6) { + if (ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_dst) && + ipv6_addr_v4mapped((struct in6_addr *)req->id.idiag_src)) + sk = __udp4_lib_lookup(net, + req->id.idiag_dst[3], req->id.idiag_dport, + req->id.idiag_src[3], req->id.idiag_sport, + req->id.idiag_if, tbl); + + else + sk = __udp6_lib_lookup(net, + (struct in6_addr *)req->id.idiag_dst, + req->id.idiag_dport, + (struct in6_addr *)req->id.idiag_src, + req->id.idiag_sport, + req->id.idiag_if, tbl); + } +#endif + else { + rcu_read_unlock(); + return -EINVAL; + } + + if (sk && !atomic_inc_not_zero(&sk->sk_refcnt)) + sk = NULL; + + rcu_read_unlock(); + + if (!sk) + return -ENOENT; + + if (sock_diag_check_cookie(sk, req->id.idiag_cookie)) { + sock_put(sk); + return -ENOENT; + } + + err = sock_diag_destroy(sk, ECONNABORTED); + + sock_put(sk); + + return err; +} + +static int udp_diag_destroy(struct sk_buff *in_skb, + const struct inet_diag_req_v2 *req) +{ + return __udp_diag_destroy(in_skb, req, &udp_table); +} + +static int udplite_diag_destroy(struct sk_buff *in_skb, + const struct inet_diag_req_v2 *req) +{ + return __udp_diag_destroy(in_skb, req, &udplite_table); +} + +#endif + static const struct inet_diag_handler udp_diag_handler = { .dump = udp_diag_dump, .dump_one = udp_diag_dump_one, .idiag_get_info = udp_diag_get_info, .idiag_type = IPPROTO_UDP, .idiag_info_size = 0, +#ifdef CONFIG_INET_DIAG_DESTROY + .destroy = udp_diag_destroy, +#endif }; static void udplite_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, @@ -192,6 +270,9 @@ static const struct inet_diag_handler udplite_diag_handler = { .idiag_get_info = udp_diag_get_info, .idiag_type = IPPROTO_UDPLITE, .idiag_info_size = 0, +#ifdef CONFIG_INET_DIAG_DESTROY + .destroy = udplite_diag_destroy, +#endif }; static int __init udp_diag_init(void) diff --git a/net/ipv4/xfrm4_policy.c b/net/ipv4/xfrm4_policy.c index d9758ecdcba6..0a6bcf7e7b54 100644 --- a/net/ipv4/xfrm4_policy.c +++ b/net/ipv4/xfrm4_policy.c @@ -22,14 +22,16 @@ static struct xfrm_policy_afinfo xfrm4_policy_afinfo; static struct dst_entry *__xfrm4_dst_lookup(struct net *net, struct flowi4 *fl4, int tos, int oif, const xfrm_address_t *saddr, - const xfrm_address_t *daddr) + const xfrm_address_t *daddr, + u32 mark) { struct rtable *rt; memset(fl4, 0, sizeof(*fl4)); fl4->daddr = daddr->a4; fl4->flowi4_tos = tos; - fl4->flowi4_oif = oif; + fl4->flowi4_oif = l3mdev_master_ifindex_by_index(net, oif); + fl4->flowi4_mark = mark; if (saddr) fl4->saddr = saddr->a4; @@ -44,20 +46,22 @@ static struct dst_entry *__xfrm4_dst_lookup(struct net *net, struct flowi4 *fl4, static struct dst_entry *xfrm4_dst_lookup(struct net *net, int tos, int oif, const xfrm_address_t *saddr, - const xfrm_address_t *daddr) + const xfrm_address_t *daddr, + u32 mark) { struct flowi4 fl4; - return __xfrm4_dst_lookup(net, &fl4, tos, oif, saddr, daddr); + return __xfrm4_dst_lookup(net, &fl4, tos, oif, saddr, daddr, mark); } static int xfrm4_get_saddr(struct net *net, int oif, - xfrm_address_t *saddr, xfrm_address_t *daddr) + xfrm_address_t *saddr, xfrm_address_t *daddr, + u32 mark) { struct dst_entry *dst; struct flowi4 fl4; - dst = __xfrm4_dst_lookup(net, &fl4, 0, oif, NULL, daddr); + dst = __xfrm4_dst_lookup(net, &fl4, 0, oif, NULL, daddr, mark); if (IS_ERR(dst)) return -EHOSTUNREACH; diff --git a/net/ipv6/addrconf.c b/net/ipv6/addrconf.c index 392fc8ac4c6a..25891f661883 100644 --- a/net/ipv6/addrconf.c +++ b/net/ipv6/addrconf.c @@ -112,6 +112,27 @@ static inline u32 cstamp_delta(unsigned long cstamp) return (cstamp - INITIAL_JIFFIES) * 100UL / HZ; } +static inline s32 rfc3315_s14_backoff_init(s32 irt) +{ + /* multiply 'initial retransmission time' by 0.9 .. 1.1 */ + u64 tmp = (900000 + prandom_u32() % 200001) * (u64)irt; + do_div(tmp, 1000000); + return (s32)tmp; +} + +static inline s32 rfc3315_s14_backoff_update(s32 rt, s32 mrt) +{ + /* multiply 'retransmission timeout' by 1.9 .. 2.1 */ + u64 tmp = (1900000 + prandom_u32() % 200001) * (u64)rt; + do_div(tmp, 1000000); + if ((s32)tmp > mrt) { + /* multiply 'maximum retransmission time' by 0.9 .. 1.1 */ + tmp = (900000 + prandom_u32() % 200001) * (u64)mrt; + do_div(tmp, 1000000); + } + return (s32)tmp; +} + #ifdef CONFIG_SYSCTL static int addrconf_sysctl_register(struct inet6_dev *idev); static void addrconf_sysctl_unregister(struct inet6_dev *idev); @@ -187,6 +208,7 @@ static struct ipv6_devconf ipv6_devconf __read_mostly = { .dad_transmits = 1, .rtr_solicits = MAX_RTR_SOLICITATIONS, .rtr_solicit_interval = RTR_SOLICITATION_INTERVAL, + .rtr_solicit_max_interval = RTR_SOLICITATION_MAX_INTERVAL, .rtr_solicit_delay = MAX_RTR_SOLICITATION_DELAY, .use_tempaddr = 0, .temp_valid_lft = TEMP_VALID_LIFETIME, @@ -202,9 +224,11 @@ static struct ipv6_devconf ipv6_devconf __read_mostly = { .accept_ra_rtr_pref = 1, .rtr_probe_interval = 60 * HZ, #ifdef CONFIG_IPV6_ROUTE_INFO + .accept_ra_rt_info_min_plen = 0, .accept_ra_rt_info_max_plen = 0, #endif #endif + .accept_ra_rt_table = 0, .proxy_ndp = 0, .accept_source_route = 0, /* we do not accept RH0 by default. */ .disable_ipv6 = 0, @@ -231,6 +255,7 @@ static struct ipv6_devconf ipv6_devconf_dflt __read_mostly = { .dad_transmits = 1, .rtr_solicits = MAX_RTR_SOLICITATIONS, .rtr_solicit_interval = RTR_SOLICITATION_INTERVAL, + .rtr_solicit_max_interval = RTR_SOLICITATION_MAX_INTERVAL, .rtr_solicit_delay = MAX_RTR_SOLICITATION_DELAY, .use_tempaddr = 0, .temp_valid_lft = TEMP_VALID_LIFETIME, @@ -246,9 +271,11 @@ static struct ipv6_devconf ipv6_devconf_dflt __read_mostly = { .accept_ra_rtr_pref = 1, .rtr_probe_interval = 60 * HZ, #ifdef CONFIG_IPV6_ROUTE_INFO + .accept_ra_rt_info_min_plen = 0, .accept_ra_rt_info_max_plen = 0, #endif #endif + .accept_ra_rt_table = 0, .proxy_ndp = 0, .accept_source_route = 0, /* we do not accept RH0 by default. */ .disable_ipv6 = 0, @@ -2157,6 +2184,31 @@ static void __ipv6_try_regen_rndid(struct inet6_dev *idev, struct in6_addr *tmp __ipv6_regen_rndid(idev); } +u32 addrconf_rt_table(const struct net_device *dev, u32 default_table) { + /* Determines into what table to put autoconf PIO/RIO/default routes + * learned on this device. + * + * - If 0, use the same table for every device. This puts routes into + * one of RT_TABLE_{PREFIX,INFO,DFLT} depending on the type of route + * (but note that these three are currently all equal to + * RT6_TABLE_MAIN). + * - If > 0, use the specified table. + * - If < 0, put routes into table dev->ifindex + (-rt_table). + */ + struct inet6_dev *idev = in6_dev_get(dev); + u32 table; + int sysctl = idev->cnf.accept_ra_rt_table; + if (sysctl == 0) { + table = default_table; + } else if (sysctl > 0) { + table = (u32) sysctl; + } else { + table = (unsigned) dev->ifindex + (-sysctl); + } + in6_dev_put(idev); + return table; +} + /* * Add prefix route. */ @@ -2166,7 +2218,7 @@ addrconf_prefix_route(struct in6_addr *pfx, int plen, struct net_device *dev, unsigned long expires, u32 flags) { struct fib6_config cfg = { - .fc_table = l3mdev_fib_table(dev) ? : RT6_TABLE_PREFIX, + .fc_table = l3mdev_fib_table(dev) ? : addrconf_rt_table(dev, RT6_TABLE_PREFIX), .fc_metric = IP6_RT_PRIO_ADDRCONF, .fc_ifindex = dev->ifindex, .fc_expires = expires, @@ -2199,7 +2251,7 @@ static struct rt6_info *addrconf_get_prefix_route(const struct in6_addr *pfx, struct fib6_node *fn; struct rt6_info *rt = NULL; struct fib6_table *table; - u32 tb_id = l3mdev_fib_table(dev) ? : RT6_TABLE_PREFIX; + u32 tb_id = l3mdev_fib_table(dev) ? : addrconf_rt_table(dev, RT6_TABLE_PREFIX); table = fib6_get_table(dev_net(dev), tb_id); if (!table) @@ -3488,7 +3540,7 @@ static void addrconf_rs_timer(unsigned long data) if (idev->if_flags & IF_RA_RCVD) goto out; - if (idev->rs_probes++ < idev->cnf.rtr_solicits) { + if (idev->rs_probes++ < idev->cnf.rtr_solicits || idev->cnf.rtr_solicits < 0) { write_unlock(&idev->lock); if (!ipv6_get_lladdr(dev, &lladdr, IFA_F_TENTATIVE)) ndisc_send_rs(dev, &lladdr, @@ -3497,11 +3549,13 @@ static void addrconf_rs_timer(unsigned long data) goto put; write_lock(&idev->lock); + idev->rs_interval = rfc3315_s14_backoff_update( + idev->rs_interval, idev->cnf.rtr_solicit_max_interval); /* The wait after the last probe can be shorter */ addrconf_mod_rs_timer(idev, (idev->rs_probes == idev->cnf.rtr_solicits) ? idev->cnf.rtr_solicit_delay : - idev->cnf.rtr_solicit_interval); + idev->rs_interval); } else { /* * Note: we do not support deprecated "all on-link" @@ -3729,7 +3783,7 @@ static void addrconf_dad_completed(struct inet6_ifaddr *ifp) send_mld = ifp->scope == IFA_LINK && ipv6_lonely_lladdr(ifp); send_rs = send_mld && ipv6_accept_ra(ifp->idev) && - ifp->idev->cnf.rtr_solicits > 0 && + ifp->idev->cnf.rtr_solicits != 0 && (dev->flags&IFF_LOOPBACK) == 0; read_unlock_bh(&ifp->idev->lock); @@ -3751,10 +3805,11 @@ static void addrconf_dad_completed(struct inet6_ifaddr *ifp) write_lock_bh(&ifp->idev->lock); spin_lock(&ifp->lock); + ifp->idev->rs_interval = rfc3315_s14_backoff_init( + ifp->idev->cnf.rtr_solicit_interval); ifp->idev->rs_probes = 1; ifp->idev->if_flags |= IF_RS_SENT; - addrconf_mod_rs_timer(ifp->idev, - ifp->idev->cnf.rtr_solicit_interval); + addrconf_mod_rs_timer(ifp->idev, ifp->idev->rs_interval); spin_unlock(&ifp->lock); write_unlock_bh(&ifp->idev->lock); } @@ -4671,6 +4726,8 @@ static inline void ipv6_store_devconf(struct ipv6_devconf *cnf, array[DEVCONF_RTR_SOLICITS] = cnf->rtr_solicits; array[DEVCONF_RTR_SOLICIT_INTERVAL] = jiffies_to_msecs(cnf->rtr_solicit_interval); + array[DEVCONF_RTR_SOLICIT_MAX_INTERVAL] = + jiffies_to_msecs(cnf->rtr_solicit_max_interval); array[DEVCONF_RTR_SOLICIT_DELAY] = jiffies_to_msecs(cnf->rtr_solicit_delay); array[DEVCONF_FORCE_MLD_VERSION] = cnf->force_mld_version; @@ -4692,9 +4749,11 @@ static inline void ipv6_store_devconf(struct ipv6_devconf *cnf, array[DEVCONF_RTR_PROBE_INTERVAL] = jiffies_to_msecs(cnf->rtr_probe_interval); #ifdef CONFIG_IPV6_ROUTE_INFO + array[DEVCONF_ACCEPT_RA_RT_INFO_MIN_PLEN] = cnf->accept_ra_rt_info_min_plen; array[DEVCONF_ACCEPT_RA_RT_INFO_MAX_PLEN] = cnf->accept_ra_rt_info_max_plen; #endif #endif + array[DEVCONF_ACCEPT_RA_RT_TABLE] = cnf->accept_ra_rt_table; array[DEVCONF_PROXY_NDP] = cnf->proxy_ndp; array[DEVCONF_ACCEPT_SOURCE_ROUTE] = cnf->accept_source_route; #ifdef CONFIG_IPV6_OPTIMISTIC_DAD @@ -4878,7 +4937,7 @@ static int inet6_set_iftoken(struct inet6_dev *idev, struct in6_addr *token) return -EINVAL; if (!ipv6_accept_ra(idev)) return -EINVAL; - if (idev->cnf.rtr_solicits <= 0) + if (idev->cnf.rtr_solicits == 0) return -EINVAL; write_lock_bh(&idev->lock); @@ -4903,8 +4962,10 @@ static int inet6_set_iftoken(struct inet6_dev *idev, struct in6_addr *token) if (update_rs) { idev->if_flags |= IF_RS_SENT; + idev->rs_interval = rfc3315_s14_backoff_init( + idev->cnf.rtr_solicit_interval); idev->rs_probes = 1; - addrconf_mod_rs_timer(idev, idev->cnf.rtr_solicit_interval); + addrconf_mod_rs_timer(idev, idev->rs_interval); } /* Well, that's kinda nasty ... */ @@ -5542,6 +5603,13 @@ static struct addrconf_sysctl_table .proc_handler = proc_dointvec_jiffies, }, { + .procname = "router_solicitation_max_interval", + .data = &ipv6_devconf.rtr_solicit_max_interval, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec_jiffies, + }, + { .procname = "router_solicitation_delay", .data = &ipv6_devconf.rtr_solicit_delay, .maxlen = sizeof(int), @@ -5651,6 +5719,13 @@ static struct addrconf_sysctl_table }, #ifdef CONFIG_IPV6_ROUTE_INFO { + .procname = "accept_ra_rt_info_min_plen", + .data = &ipv6_devconf.accept_ra_rt_info_min_plen, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { .procname = "accept_ra_rt_info_max_plen", .data = &ipv6_devconf.accept_ra_rt_info_max_plen, .maxlen = sizeof(int), @@ -5660,6 +5735,13 @@ static struct addrconf_sysctl_table #endif #endif { + .procname = "accept_ra_rt_table", + .data = &ipv6_devconf.accept_ra_rt_table, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = proc_dointvec, + }, + { .procname = "proxy_ndp", .data = &ipv6_devconf.proxy_ndp, .maxlen = sizeof(int), diff --git a/net/ipv6/af_inet6.c b/net/ipv6/af_inet6.c index 37a562fc13d5..8e38e058a66c 100644 --- a/net/ipv6/af_inet6.c +++ b/net/ipv6/af_inet6.c @@ -64,6 +64,20 @@ #include <asm/uaccess.h> #include <linux/mroute6.h> +#ifdef CONFIG_ANDROID_PARANOID_NETWORK +#include <linux/android_aid.h> + +static inline int current_has_network(void) +{ + return in_egroup_p(AID_INET) || capable(CAP_NET_RAW); +} +#else +static inline int current_has_network(void) +{ + return 1; +} +#endif + MODULE_AUTHOR("Cast of dozens"); MODULE_DESCRIPTION("IPv6 protocol stack for Linux"); MODULE_LICENSE("GPL"); @@ -112,6 +126,9 @@ static int inet6_create(struct net *net, struct socket *sock, int protocol, if (protocol < 0 || protocol >= IPPROTO_MAX) return -EINVAL; + if (!current_has_network()) + return -EACCES; + /* Look for the requested type/protocol pair. */ lookup_protocol: err = -ESOCKTNOSUPPORT; @@ -158,8 +175,7 @@ lookup_protocol: } err = -EPERM; - if (sock->type == SOCK_RAW && !kern && - !ns_capable(net->user_ns, CAP_NET_RAW)) + if (sock->type == SOCK_RAW && !kern && !capable(CAP_NET_RAW)) goto out_rcu_unlock; sock->ops = answer->ops; @@ -676,6 +692,7 @@ int inet6_sk_rebuild_header(struct sock *sk) fl6.flowi6_mark = sk->sk_mark; fl6.fl6_dport = inet->inet_dport; fl6.fl6_sport = inet->inet_sport; + fl6.flowi6_uid = sk->sk_uid; security_sk_classify_flow(sk, flowi6_to_flowi(&fl6)); rcu_read_lock(); diff --git a/net/ipv6/ah6.c b/net/ipv6/ah6.c index 519e0730d9f6..98d253d7bed3 100644 --- a/net/ipv6/ah6.c +++ b/net/ipv6/ah6.c @@ -667,9 +667,10 @@ static int ah6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, return 0; if (type == NDISC_REDIRECT) - ip6_redirect(skb, net, skb->dev->ifindex, 0); + ip6_redirect(skb, net, skb->dev->ifindex, 0, + sock_net_uid(net, NULL)); else - ip6_update_pmtu(skb, net, info, 0, 0); + ip6_update_pmtu(skb, net, info, 0, 0, sock_net_uid(net, NULL)); xfrm_state_put(x); return 0; diff --git a/net/ipv6/datagram.c b/net/ipv6/datagram.c index 389b6367a810..60687f60bea1 100644 --- a/net/ipv6/datagram.c +++ b/net/ipv6/datagram.c @@ -166,6 +166,7 @@ ipv4_connected: fl6.flowi6_mark = sk->sk_mark; fl6.fl6_dport = inet->inet_dport; fl6.fl6_sport = inet->inet_sport; + fl6.flowi6_uid = sk->sk_uid; if (!fl6.flowi6_oif) fl6.flowi6_oif = np->sticky_pktinfo.ipi6_ifindex; diff --git a/net/ipv6/esp6.c b/net/ipv6/esp6.c index da158a3acac4..37268a37312c 100644 --- a/net/ipv6/esp6.c +++ b/net/ipv6/esp6.c @@ -480,9 +480,10 @@ static int esp6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, return 0; if (type == NDISC_REDIRECT) - ip6_redirect(skb, net, skb->dev->ifindex, 0); + ip6_redirect(skb, net, skb->dev->ifindex, 0, + sock_net_uid(net, NULL)); else - ip6_update_pmtu(skb, net, info, 0, 0); + ip6_update_pmtu(skb, net, info, 0, 0, sock_net_uid(net, NULL)); xfrm_state_put(x); return 0; diff --git a/net/ipv6/exthdrs_core.c b/net/ipv6/exthdrs_core.c index 9508a20fbf61..840a4388f860 100644 --- a/net/ipv6/exthdrs_core.c +++ b/net/ipv6/exthdrs_core.c @@ -166,15 +166,15 @@ EXPORT_SYMBOL_GPL(ipv6_find_tlv); * to explore inner IPv6 header, eg. ICMPv6 error messages. * * If target header is found, its offset is set in *offset and return protocol - * number. Otherwise, return -1. + * number. Otherwise, return -ENOENT or -EBADMSG. * * If the first fragment doesn't contain the final protocol header or * NEXTHDR_NONE it is considered invalid. * * Note that non-1st fragment is special case that "the protocol number * of last header" is "next header" field in Fragment header. In this case, - * *offset is meaningless and fragment offset is stored in *fragoff if fragoff - * isn't NULL. + * *offset is meaningless. If fragoff is not NULL, the fragment offset is + * stored in *fragoff; if it is NULL, return -EINVAL. * * if flags is not NULL and it's a fragment, then the frag flag * IP6_FH_F_FRAG will be set. If it's an AH header, the @@ -253,9 +253,12 @@ int ipv6_find_hdr(const struct sk_buff *skb, unsigned int *offset, if (target < 0 && ((!ipv6_ext_hdr(hp->nexthdr)) || hp->nexthdr == NEXTHDR_NONE)) { - if (fragoff) + if (fragoff) { *fragoff = _frag_off; - return hp->nexthdr; + return hp->nexthdr; + } else { + return -EINVAL; + } } if (!found) return -ENOENT; diff --git a/net/ipv6/icmp.c b/net/ipv6/icmp.c index fa96e05cf22b..463f533bd289 100644 --- a/net/ipv6/icmp.c +++ b/net/ipv6/icmp.c @@ -92,9 +92,10 @@ static void icmpv6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, struct net *net = dev_net(skb->dev); if (type == ICMPV6_PKT_TOOBIG) - ip6_update_pmtu(skb, net, info, 0, 0); + ip6_update_pmtu(skb, net, info, 0, 0, sock_net_uid(net, NULL)); else if (type == NDISC_REDIRECT) - ip6_redirect(skb, net, skb->dev->ifindex, 0); + ip6_redirect(skb, net, skb->dev->ifindex, 0, + sock_net_uid(net, NULL)); if (!(type & ICMPV6_INFOMSG_MASK)) if (icmp6->icmp6_type == ICMPV6_ECHO_REQUEST) @@ -482,6 +483,7 @@ static void icmp6_send(struct sk_buff *skb, u8 type, u8 code, __u32 info) fl6.flowi6_oif = iif; fl6.fl6_icmp_type = type; fl6.fl6_icmp_code = code; + fl6.flowi6_uid = sock_net_uid(net, NULL); security_skb_classify_flow(skb, flowi6_to_flowi(&fl6)); sk = icmpv6_xmit_lock(net); @@ -586,6 +588,7 @@ static void icmpv6_echo_reply(struct sk_buff *skb) fl6.flowi6_oif = l3mdev_fib_oif(skb->dev); fl6.fl6_icmp_type = ICMPV6_ECHO_REPLY; fl6.flowi6_mark = mark; + fl6.flowi6_uid = sock_net_uid(net, NULL); security_skb_classify_flow(skb, flowi6_to_flowi(&fl6)); sk = icmpv6_xmit_lock(net); diff --git a/net/ipv6/inet6_connection_sock.c b/net/ipv6/inet6_connection_sock.c index b31ab511c767..ab9eda00dde5 100644 --- a/net/ipv6/inet6_connection_sock.c +++ b/net/ipv6/inet6_connection_sock.c @@ -86,6 +86,7 @@ struct dst_entry *inet6_csk_route_req(const struct sock *sk, fl6->flowi6_mark = ireq->ir_mark; fl6->fl6_dport = ireq->ir_rmt_port; fl6->fl6_sport = htons(ireq->ir_num); + fl6->flowi6_uid = sk->sk_uid; security_req_classify_flow(req, flowi6_to_flowi(fl6)); dst = ip6_dst_lookup_flow(sock_net(sk), sk, fl6, final_p); @@ -134,6 +135,7 @@ static struct dst_entry *inet6_csk_route_socket(struct sock *sk, fl6->flowi6_mark = sk->sk_mark; fl6->fl6_sport = inet->inet_sport; fl6->fl6_dport = inet->inet_dport; + fl6->flowi6_uid = sk->sk_uid; security_sk_classify_flow(sk, flowi6_to_flowi(fl6)); rcu_read_lock(); diff --git a/net/ipv6/ip6_fib.c b/net/ipv6/ip6_fib.c index 30eb8bdcdbda..24e91d438f21 100644 --- a/net/ipv6/ip6_fib.c +++ b/net/ipv6/ip6_fib.c @@ -910,6 +910,7 @@ add: fn->fn_flags |= RTN_RTINFO; } nsiblings = iter->rt6i_nsiblings; + iter->rt6i_node = NULL; fib6_purge_rt(iter, fn, info->nl_net); if (fn->rr_ptr == iter) fn->rr_ptr = NULL; @@ -924,6 +925,7 @@ add: break; if (rt6_qualify_for_ecmp(iter)) { *ins = iter->dst.rt6_next; + iter->rt6i_node = NULL; fib6_purge_rt(iter, fn, info->nl_net); if (fn->rr_ptr == iter) fn->rr_ptr = NULL; diff --git a/net/ipv6/ip6_gre.c b/net/ipv6/ip6_gre.c index 23e6a845c7fa..17eee5083548 100644 --- a/net/ipv6/ip6_gre.c +++ b/net/ipv6/ip6_gre.c @@ -804,6 +804,8 @@ static inline int ip6gre_xmit_ipv4(struct sk_buff *skb, struct net_device *dev) if (t->parms.flags & IP6_TNL_F_USE_ORIG_FWMARK) fl6.flowi6_mark = skb->mark; + fl6.flowi6_uid = sock_net_uid(dev_net(dev), NULL); + err = ip6gre_xmit2(skb, dev, dsfield, &fl6, encap_limit, &mtu); if (err != 0) { /* XXX: send ICMP error even if DF is not set. */ @@ -854,6 +856,8 @@ static inline int ip6gre_xmit_ipv6(struct sk_buff *skb, struct net_device *dev) if (t->parms.flags & IP6_TNL_F_USE_ORIG_FWMARK) fl6.flowi6_mark = skb->mark; + fl6.flowi6_uid = sock_net_uid(dev_net(dev), NULL); + err = ip6gre_xmit2(skb, dev, dsfield, &fl6, encap_limit, &mtu); if (err != 0) { if (err == -EMSGSIZE) diff --git a/net/ipv6/ip6_tunnel.c b/net/ipv6/ip6_tunnel.c index 80f88df280d7..f6f3c64b146a 100644 --- a/net/ipv6/ip6_tunnel.c +++ b/net/ipv6/ip6_tunnel.c @@ -1122,6 +1122,8 @@ ip4ip6_tnl_xmit(struct sk_buff *skb, struct net_device *dev) memcpy(&fl6, &t->fl.u.ip6, sizeof(fl6)); fl6.flowi6_proto = IPPROTO_IPIP; + fl6.flowi6_uid = sock_net_uid(dev_net(dev), NULL); + dsfield = ipv4_get_dsfield(iph); if (t->parms.flags & IP6_TNL_F_USE_ORIG_TCLASS) @@ -1179,6 +1181,7 @@ ip6ip6_tnl_xmit(struct sk_buff *skb, struct net_device *dev) memcpy(&fl6, &t->fl.u.ip6, sizeof(fl6)); fl6.flowi6_proto = IPPROTO_IPV6; + fl6.flowi6_uid = sock_net_uid(dev_net(dev), NULL); dsfield = ipv6_get_dsfield(ipv6h); if (t->parms.flags & IP6_TNL_F_USE_ORIG_TCLASS) diff --git a/net/ipv6/ip6_vti.c b/net/ipv6/ip6_vti.c index 13f686253ae4..9df1947e79eb 100644 --- a/net/ipv6/ip6_vti.c +++ b/net/ipv6/ip6_vti.c @@ -630,9 +630,10 @@ static int vti6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, return 0; if (type == NDISC_REDIRECT) - ip6_redirect(skb, net, skb->dev->ifindex, 0); + ip6_redirect(skb, net, skb->dev->ifindex, 0, + sock_net_uid(net, NULL)); else - ip6_update_pmtu(skb, net, info, 0, 0); + ip6_update_pmtu(skb, net, info, 0, 0, sock_net_uid(net, NULL)); xfrm_state_put(x); return 0; diff --git a/net/ipv6/ipcomp6.c b/net/ipv6/ipcomp6.c index 1b9316e1386a..54d165b9845a 100644 --- a/net/ipv6/ipcomp6.c +++ b/net/ipv6/ipcomp6.c @@ -74,9 +74,10 @@ static int ipcomp6_err(struct sk_buff *skb, struct inet6_skb_parm *opt, return 0; if (type == NDISC_REDIRECT) - ip6_redirect(skb, net, skb->dev->ifindex, 0); + ip6_redirect(skb, net, skb->dev->ifindex, 0, + sock_net_uid(net, NULL)); else - ip6_update_pmtu(skb, net, info, 0, 0); + ip6_update_pmtu(skb, net, info, 0, 0, sock_net_uid(net, NULL)); xfrm_state_put(x); return 0; diff --git a/net/ipv6/ndisc.c b/net/ipv6/ndisc.c index 0bf375177a9a..e16a05ca4879 100644 --- a/net/ipv6/ndisc.c +++ b/net/ipv6/ndisc.c @@ -188,7 +188,8 @@ static struct nd_opt_hdr *ndisc_next_option(struct nd_opt_hdr *cur, static inline int ndisc_is_useropt(struct nd_opt_hdr *opt) { return opt->nd_opt_type == ND_OPT_RDNSS || - opt->nd_opt_type == ND_OPT_DNSSL; + opt->nd_opt_type == ND_OPT_DNSSL || + opt->nd_opt_type == ND_OPT_CAPTIVE_PORTAL; } static struct nd_opt_hdr *ndisc_next_useropt(struct nd_opt_hdr *cur, @@ -1358,6 +1359,8 @@ skip_linkparms: if (ri->prefix_len == 0 && !in6_dev->cnf.accept_ra_defrtr) continue; + if (ri->prefix_len < in6_dev->cnf.accept_ra_rt_info_min_plen) + continue; if (ri->prefix_len > in6_dev->cnf.accept_ra_rt_info_max_plen) continue; rt6_route_rcv(skb->dev, (u8 *)p, (p->nd_opt_len) << 3, diff --git a/net/ipv6/netfilter.c b/net/ipv6/netfilter.c index d11c46833d61..39970e212ad5 100644 --- a/net/ipv6/netfilter.c +++ b/net/ipv6/netfilter.c @@ -26,6 +26,7 @@ int ip6_route_me_harder(struct net *net, struct sk_buff *skb) struct flowi6 fl6 = { .flowi6_oif = skb->sk ? skb->sk->sk_bound_dev_if : 0, .flowi6_mark = skb->mark, + .flowi6_uid = sock_net_uid(net, skb->sk), .daddr = iph->daddr, .saddr = iph->saddr, }; diff --git a/net/ipv6/netfilter/ip6_tables.c b/net/ipv6/netfilter/ip6_tables.c index 43d26625b80f..305c60850e7a 100644 --- a/net/ipv6/netfilter/ip6_tables.c +++ b/net/ipv6/netfilter/ip6_tables.c @@ -976,10 +976,6 @@ copy_entries_to_user(unsigned int total_size, return PTR_ERR(counters); loc_cpu_entry = private->entries; - if (copy_to_user(userptr, loc_cpu_entry, total_size) != 0) { - ret = -EFAULT; - goto free_counters; - } /* FIXME: use iterator macros --RR */ /* ... then go back and fix counters and names */ @@ -989,6 +985,10 @@ copy_entries_to_user(unsigned int total_size, const struct xt_entry_target *t; e = (struct ip6t_entry *)(loc_cpu_entry + off); + if (copy_to_user(userptr + off, e, sizeof(*e))) { + ret = -EFAULT; + goto free_counters; + } if (copy_to_user(userptr + off + offsetof(struct ip6t_entry, counters), &counters[num], @@ -1002,23 +1002,14 @@ copy_entries_to_user(unsigned int total_size, i += m->u.match_size) { m = (void *)e + i; - if (copy_to_user(userptr + off + i - + offsetof(struct xt_entry_match, - u.user.name), - m->u.kernel.match->name, - strlen(m->u.kernel.match->name)+1) - != 0) { + if (xt_match_to_user(m, userptr + off + i)) { ret = -EFAULT; goto free_counters; } } t = ip6t_get_target_c(e); - if (copy_to_user(userptr + off + e->target_offset - + offsetof(struct xt_entry_target, - u.user.name), - t->u.kernel.target->name, - strlen(t->u.kernel.target->name)+1) != 0) { + if (xt_target_to_user(t, userptr + off + e->target_offset)) { ret = -EFAULT; goto free_counters; } diff --git a/net/ipv6/netfilter/ip6t_NPT.c b/net/ipv6/netfilter/ip6t_NPT.c index 590f767db5d4..a379d2f79b19 100644 --- a/net/ipv6/netfilter/ip6t_NPT.c +++ b/net/ipv6/netfilter/ip6t_NPT.c @@ -112,6 +112,7 @@ static struct xt_target ip6t_npt_target_reg[] __read_mostly = { .table = "mangle", .target = ip6t_snpt_tg, .targetsize = sizeof(struct ip6t_npt_tginfo), + .usersize = offsetof(struct ip6t_npt_tginfo, adjustment), .checkentry = ip6t_npt_checkentry, .family = NFPROTO_IPV6, .hooks = (1 << NF_INET_LOCAL_IN) | @@ -123,6 +124,7 @@ static struct xt_target ip6t_npt_target_reg[] __read_mostly = { .table = "mangle", .target = ip6t_dnpt_tg, .targetsize = sizeof(struct ip6t_npt_tginfo), + .usersize = offsetof(struct ip6t_npt_tginfo, adjustment), .checkentry = ip6t_npt_checkentry, .family = NFPROTO_IPV6, .hooks = (1 << NF_INET_PRE_ROUTING) | diff --git a/net/ipv6/ping.c b/net/ipv6/ping.c index c846cff26933..40b835720722 100644 --- a/net/ipv6/ping.c +++ b/net/ipv6/ping.c @@ -84,7 +84,7 @@ int ping_v6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) struct icmp6hdr user_icmph; int addr_type; struct in6_addr *daddr; - int iif = 0; + int oif = 0; struct flowi6 fl6; int err; int hlimit; @@ -106,25 +106,30 @@ int ping_v6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) if (u->sin6_family != AF_INET6) { return -EAFNOSUPPORT; } - if (sk->sk_bound_dev_if && - sk->sk_bound_dev_if != u->sin6_scope_id) { - return -EINVAL; - } daddr = &(u->sin6_addr); - iif = u->sin6_scope_id; + if (__ipv6_addr_needs_scope_id(ipv6_addr_type(daddr))) + oif = u->sin6_scope_id; } else { if (sk->sk_state != TCP_ESTABLISHED) return -EDESTADDRREQ; daddr = &sk->sk_v6_daddr; } - if (!iif) - iif = sk->sk_bound_dev_if; + if (!oif) + oif = sk->sk_bound_dev_if; + + if (!oif) + oif = np->sticky_pktinfo.ipi6_ifindex; + + if (!oif && ipv6_addr_is_multicast(daddr)) + oif = np->mcast_oif; + else if (!oif) + oif = np->ucast_oif; addr_type = ipv6_addr_type(daddr); - if (__ipv6_addr_needs_scope_id(addr_type) && !iif) - return -EINVAL; - if (addr_type & IPV6_ADDR_MAPPED) + if ((__ipv6_addr_needs_scope_id(addr_type) && !oif) || + (addr_type & IPV6_ADDR_MAPPED) || + (oif && sk->sk_bound_dev_if && oif != sk->sk_bound_dev_if)) return -EINVAL; /* TODO: use ip6_datagram_send_ctl to get options from cmsg */ @@ -134,16 +139,13 @@ int ping_v6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) fl6.flowi6_proto = IPPROTO_ICMPV6; fl6.saddr = np->saddr; fl6.daddr = *daddr; + fl6.flowi6_oif = oif; fl6.flowi6_mark = sk->sk_mark; + fl6.flowi6_uid = sk->sk_uid; fl6.fl6_icmp_type = user_icmph.icmp6_type; fl6.fl6_icmp_code = user_icmph.icmp6_code; security_sk_classify_flow(sk, flowi6_to_flowi(&fl6)); - if (!fl6.flowi6_oif && ipv6_addr_is_multicast(&fl6.daddr)) - fl6.flowi6_oif = np->mcast_oif; - else if (!fl6.flowi6_oif) - fl6.flowi6_oif = np->ucast_oif; - dst = ip6_sk_dst_lookup_flow(sk, &fl6, daddr); if (IS_ERR(dst)) return PTR_ERR(dst); @@ -155,11 +157,6 @@ int ping_v6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) goto dst_err_out; } - if (!fl6.flowi6_oif && ipv6_addr_is_multicast(&fl6.daddr)) - fl6.flowi6_oif = np->mcast_oif; - else if (!fl6.flowi6_oif) - fl6.flowi6_oif = np->ucast_oif; - pfh.icmph.type = user_icmph.icmp6_type; pfh.icmph.code = user_icmph.icmp6_code; pfh.icmph.checksum = 0; diff --git a/net/ipv6/raw.c b/net/ipv6/raw.c index 67cdcd3d644f..8614321b4c54 100644 --- a/net/ipv6/raw.c +++ b/net/ipv6/raw.c @@ -784,6 +784,7 @@ static int rawv6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) memset(&fl6, 0, sizeof(fl6)); fl6.flowi6_mark = sk->sk_mark; + fl6.flowi6_uid = sk->sk_uid; if (sin6) { if (addr_len < SIN6_LEN_RFC2133) diff --git a/net/ipv6/route.c b/net/ipv6/route.c index f06a76878746..a4b75b2f6781 100644 --- a/net/ipv6/route.c +++ b/net/ipv6/route.c @@ -99,15 +99,12 @@ static void rt6_dst_from_metrics_check(struct rt6_info *rt); static int rt6_score_route(struct rt6_info *rt, int oif, int strict); #ifdef CONFIG_IPV6_ROUTE_INFO -static struct rt6_info *rt6_add_route_info(struct net *net, +static struct rt6_info *rt6_add_route_info(struct net_device *dev, const struct in6_addr *prefix, int prefixlen, - const struct in6_addr *gwaddr, - struct net_device *dev, - unsigned int pref); -static struct rt6_info *rt6_get_route_info(struct net *net, + const struct in6_addr *gwaddr, unsigned int pref); +static struct rt6_info *rt6_get_route_info(struct net_device *dev, const struct in6_addr *prefix, int prefixlen, - const struct in6_addr *gwaddr, - struct net_device *dev); + const struct in6_addr *gwaddr); #endif struct uncached_list { @@ -758,7 +755,6 @@ static bool rt6_is_gw_or_nonexthop(const struct rt6_info *rt) int rt6_route_rcv(struct net_device *dev, u8 *opt, int len, const struct in6_addr *gwaddr) { - struct net *net = dev_net(dev); struct route_info *rinfo = (struct route_info *) opt; struct in6_addr prefix_buf, *prefix; unsigned int pref; @@ -803,8 +799,7 @@ int rt6_route_rcv(struct net_device *dev, u8 *opt, int len, if (rinfo->prefix_len == 0) rt = rt6_get_dflt_router(gwaddr, dev); else - rt = rt6_get_route_info(net, prefix, rinfo->prefix_len, - gwaddr, dev); + rt = rt6_get_route_info(dev, prefix, rinfo->prefix_len, gwaddr); if (rt && !lifetime) { ip6_del_rt(rt); @@ -812,8 +807,7 @@ int rt6_route_rcv(struct net_device *dev, u8 *opt, int len, } if (!rt && lifetime) - rt = rt6_add_route_info(net, prefix, rinfo->prefix_len, gwaddr, - dev, pref); + rt = rt6_add_route_info(dev, prefix, rinfo->prefix_len, gwaddr, pref); else if (rt) rt->rt6i_flags = RTF_ROUTEINFO | (rt->rt6i_flags & ~RTF_PREF_MASK) | RTF_PREF(pref); @@ -1401,7 +1395,7 @@ static void ip6_rt_update_pmtu(struct dst_entry *dst, struct sock *sk, } void ip6_update_pmtu(struct sk_buff *skb, struct net *net, __be32 mtu, - int oif, u32 mark) + int oif, u32 mark, kuid_t uid) { const struct ipv6hdr *iph = (struct ipv6hdr *) skb->data; struct dst_entry *dst; @@ -1413,6 +1407,7 @@ void ip6_update_pmtu(struct sk_buff *skb, struct net *net, __be32 mtu, fl6.daddr = iph->daddr; fl6.saddr = iph->saddr; fl6.flowlabel = ip6_flowinfo(iph); + fl6.flowi6_uid = uid; dst = ip6_route_output(net, NULL, &fl6); if (!dst->error) @@ -1428,7 +1423,7 @@ void ip6_sk_update_pmtu(struct sk_buff *skb, struct sock *sk, __be32 mtu) if (!oif && skb->dev) oif = l3mdev_master_ifindex(skb->dev); - ip6_update_pmtu(skb, sock_net(sk), mtu, oif, sk->sk_mark); + ip6_update_pmtu(skb, sock_net(sk), mtu, oif, sk->sk_mark, sk->sk_uid); } EXPORT_SYMBOL_GPL(ip6_sk_update_pmtu); @@ -1509,7 +1504,8 @@ static struct dst_entry *ip6_route_redirect(struct net *net, flags, __ip6_route_redirect); } -void ip6_redirect(struct sk_buff *skb, struct net *net, int oif, u32 mark) +void ip6_redirect(struct sk_buff *skb, struct net *net, int oif, u32 mark, + kuid_t uid) { const struct ipv6hdr *iph = (struct ipv6hdr *) skb->data; struct dst_entry *dst; @@ -1522,6 +1518,7 @@ void ip6_redirect(struct sk_buff *skb, struct net *net, int oif, u32 mark) fl6.daddr = iph->daddr; fl6.saddr = iph->saddr; fl6.flowlabel = ip6_flowinfo(iph); + fl6.flowi6_uid = uid; dst = ip6_route_redirect(net, &fl6, &ipv6_hdr(skb)->saddr); rt6_do_redirect(dst, NULL, skb); @@ -1543,6 +1540,7 @@ void ip6_redirect_no_header(struct sk_buff *skb, struct net *net, int oif, fl6.flowi6_mark = mark; fl6.daddr = msg->dest; fl6.saddr = iph->daddr; + fl6.flowi6_uid = sock_net_uid(net, NULL); dst = ip6_route_redirect(net, &fl6, &iph->saddr); rt6_do_redirect(dst, NULL, skb); @@ -1551,7 +1549,8 @@ void ip6_redirect_no_header(struct sk_buff *skb, struct net *net, int oif, void ip6_sk_redirect(struct sk_buff *skb, struct sock *sk) { - ip6_redirect(skb, sock_net(sk), sk->sk_bound_dev_if, sk->sk_mark); + ip6_redirect(skb, sock_net(sk), sk->sk_bound_dev_if, sk->sk_mark, + sk->sk_uid); } EXPORT_SYMBOL_GPL(ip6_sk_redirect); @@ -1768,6 +1767,37 @@ static int ip6_convert_metrics(struct mx6_config *mxc, return -EINVAL; } +static struct rt6_info *ip6_nh_lookup_table(struct net *net, + struct fib6_config *cfg, + const struct in6_addr *gw_addr) +{ + struct flowi6 fl6 = { + .flowi6_oif = cfg->fc_ifindex, + .daddr = *gw_addr, + .saddr = cfg->fc_prefsrc, + }; + struct fib6_table *table; + struct rt6_info *rt; + int flags = 0; + + table = fib6_get_table(net, cfg->fc_table); + if (!table) + return NULL; + + if (!ipv6_addr_any(&cfg->fc_prefsrc)) + flags |= RT6_LOOKUP_F_HAS_SADDR; + + rt = ip6_pol_route(net, table, cfg->fc_ifindex, &fl6, flags); + + /* if table lookup failed, fall back to full lookup */ + if (rt == net->ipv6.ip6_null_entry) { + ip6_rt_put(rt); + rt = NULL; + } + + return rt; +} + static struct rt6_info *ip6_route_info_create(struct fib6_config *cfg) { struct net *net = cfg->fc_nlinfo.nl_net; @@ -1943,7 +1973,7 @@ static struct rt6_info *ip6_route_info_create(struct fib6_config *cfg) rt->rt6i_gateway = *gw_addr; if (gwa_type != (IPV6_ADDR_LINKLOCAL|IPV6_ADDR_UNICAST)) { - struct rt6_info *grt; + struct rt6_info *grt = NULL; /* IPv6 strictly inhibits using not link-local addresses as nexthop address. @@ -1955,7 +1985,12 @@ static struct rt6_info *ip6_route_info_create(struct fib6_config *cfg) if (!(gwa_type & IPV6_ADDR_UNICAST)) goto out; - grt = rt6_lookup(net, gw_addr, NULL, cfg->fc_ifindex, 1); + if (cfg->fc_table) + grt = ip6_nh_lookup_table(net, cfg, gw_addr); + + if (!grt) + grt = rt6_lookup(net, gw_addr, NULL, + cfg->fc_ifindex, 1); err = -EHOSTUNREACH; if (!grt) @@ -2274,18 +2309,16 @@ static void ip6_rt_copy_init(struct rt6_info *rt, struct rt6_info *ort) } #ifdef CONFIG_IPV6_ROUTE_INFO -static struct rt6_info *rt6_get_route_info(struct net *net, +static struct rt6_info *rt6_get_route_info(struct net_device *dev, const struct in6_addr *prefix, int prefixlen, - const struct in6_addr *gwaddr, - struct net_device *dev) + const struct in6_addr *gwaddr) { - u32 tb_id = l3mdev_fib_table(dev) ? : RT6_TABLE_INFO; - int ifindex = dev->ifindex; struct fib6_node *fn; struct rt6_info *rt = NULL; struct fib6_table *table; - table = fib6_get_table(net, tb_id); + table = fib6_get_table(dev_net(dev), + addrconf_rt_table(dev, RT6_TABLE_INFO)); if (!table) return NULL; @@ -2295,7 +2328,7 @@ static struct rt6_info *rt6_get_route_info(struct net *net, goto out; for (rt = fn->leaf; rt; rt = rt->dst.rt6_next) { - if (rt->dst.dev->ifindex != ifindex) + if (rt->dst.dev->ifindex != dev->ifindex) continue; if ((rt->rt6i_flags & (RTF_ROUTEINFO|RTF_GATEWAY)) != (RTF_ROUTEINFO|RTF_GATEWAY)) continue; @@ -2309,11 +2342,9 @@ out: return rt; } -static struct rt6_info *rt6_add_route_info(struct net *net, +static struct rt6_info *rt6_add_route_info(struct net_device *dev, const struct in6_addr *prefix, int prefixlen, - const struct in6_addr *gwaddr, - struct net_device *dev, - unsigned int pref) + const struct in6_addr *gwaddr, unsigned int pref) { struct fib6_config cfg = { .fc_metric = IP6_RT_PRIO_USER, @@ -2323,10 +2354,10 @@ static struct rt6_info *rt6_add_route_info(struct net *net, RTF_UP | RTF_PREF(pref), .fc_nlinfo.portid = 0, .fc_nlinfo.nlh = NULL, - .fc_nlinfo.nl_net = net, + .fc_nlinfo.nl_net = dev_net(dev), }; - cfg.fc_table = l3mdev_fib_table(dev) ? : RT6_TABLE_INFO, + cfg.fc_table = l3mdev_fib_table_by_index(dev_net(dev), dev->ifindex) ? : addrconf_rt_table(dev, RT6_TABLE_INFO); cfg.fc_dst = *prefix; cfg.fc_gateway = *gwaddr; @@ -2336,17 +2367,17 @@ static struct rt6_info *rt6_add_route_info(struct net *net, ip6_route_add(&cfg); - return rt6_get_route_info(net, prefix, prefixlen, gwaddr, dev); + return rt6_get_route_info(dev, prefix, prefixlen, gwaddr); } #endif struct rt6_info *rt6_get_dflt_router(const struct in6_addr *addr, struct net_device *dev) { - u32 tb_id = l3mdev_fib_table(dev) ? : RT6_TABLE_DFLT; struct rt6_info *rt; struct fib6_table *table; - table = fib6_get_table(dev_net(dev), tb_id); + table = fib6_get_table(dev_net(dev), + addrconf_rt_table(dev, RT6_TABLE_MAIN)); if (!table) return NULL; @@ -2368,7 +2399,7 @@ struct rt6_info *rt6_add_dflt_router(const struct in6_addr *gwaddr, unsigned int pref) { struct fib6_config cfg = { - .fc_table = l3mdev_fib_table(dev) ? : RT6_TABLE_DFLT, + .fc_table = l3mdev_fib_table(dev) ? : addrconf_rt_table(dev, RT6_TABLE_DFLT), .fc_metric = IP6_RT_PRIO_USER, .fc_ifindex = dev->ifindex, .fc_flags = RTF_GATEWAY | RTF_ADDRCONF | RTF_DEFAULT | @@ -2380,54 +2411,22 @@ struct rt6_info *rt6_add_dflt_router(const struct in6_addr *gwaddr, cfg.fc_gateway = *gwaddr; - if (!ip6_route_add(&cfg)) { - struct fib6_table *table; - - table = fib6_get_table(dev_net(dev), cfg.fc_table); - if (table) - table->flags |= RT6_TABLE_HAS_DFLT_ROUTER; - } + ip6_route_add(&cfg); return rt6_get_dflt_router(gwaddr, dev); } -static void __rt6_purge_dflt_routers(struct fib6_table *table) -{ - struct rt6_info *rt; - -restart: - read_lock_bh(&table->tb6_lock); - for (rt = table->tb6_root.leaf; rt; rt = rt->dst.rt6_next) { - if (rt->rt6i_flags & (RTF_DEFAULT | RTF_ADDRCONF) && - (!rt->rt6i_idev || rt->rt6i_idev->cnf.accept_ra != 2)) { - dst_hold(&rt->dst); - read_unlock_bh(&table->tb6_lock); - ip6_del_rt(rt); - goto restart; - } - } - read_unlock_bh(&table->tb6_lock); - table->flags &= ~RT6_TABLE_HAS_DFLT_ROUTER; +int rt6_addrconf_purge(struct rt6_info *rt, void *arg) { + if (rt->rt6i_flags & (RTF_DEFAULT | RTF_ADDRCONF) && + (!rt->rt6i_idev || rt->rt6i_idev->cnf.accept_ra != 2)) + return -1; + return 0; } void rt6_purge_dflt_routers(struct net *net) { - struct fib6_table *table; - struct hlist_head *head; - unsigned int h; - - rcu_read_lock(); - - for (h = 0; h < FIB6_TABLE_HASHSZ; h++) { - head = &net->ipv6.fib_table_hash[h]; - hlist_for_each_entry_rcu(table, head, tb6_hlist) { - if (table->flags & RT6_TABLE_HAS_DFLT_ROUTER) - __rt6_purge_dflt_routers(table); - } - } - - rcu_read_unlock(); + fib6_clean_all(net, rt6_addrconf_purge, NULL); } static void rtmsg_to_fib6_config(struct net *net, @@ -2753,6 +2752,7 @@ static const struct nla_policy rtm_ipv6_policy[RTA_MAX+1] = { [RTA_PREF] = { .type = NLA_U8 }, [RTA_ENCAP_TYPE] = { .type = NLA_U16 }, [RTA_ENCAP] = { .type = NLA_NESTED }, + [RTA_UID] = { .type = NLA_U32 }, [RTA_TABLE] = { .type = NLA_U32 }, }; @@ -3315,6 +3315,12 @@ static int inet6_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh) if (tb[RTA_MARK]) fl6.flowi6_mark = nla_get_u32(tb[RTA_MARK]); + if (tb[RTA_UID]) + fl6.flowi6_uid = make_kuid(current_user_ns(), + nla_get_u32(tb[RTA_UID])); + else + fl6.flowi6_uid = iif ? INVALID_UID : current_uid(); + if (iif) { struct net_device *dev; int flags = 0; diff --git a/net/ipv6/syncookies.c b/net/ipv6/syncookies.c index d2792580b112..2133cc5e6a74 100644 --- a/net/ipv6/syncookies.c +++ b/net/ipv6/syncookies.c @@ -229,6 +229,7 @@ struct sock *cookie_v6_check(struct sock *sk, struct sk_buff *skb) fl6.flowi6_mark = ireq->ir_mark; fl6.fl6_dport = ireq->ir_rmt_port; fl6.fl6_sport = inet_sk(sk)->inet_sport; + fl6.flowi6_uid = sk->sk_uid; security_req_classify_flow(req, flowi6_to_flowi(&fl6)); dst = ip6_dst_lookup_flow(sock_net(sk), sk, &fl6, final_p); diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c index 68c9d033c718..728c5ac4d2b3 100644 --- a/net/ipv6/tcp_ipv6.c +++ b/net/ipv6/tcp_ipv6.c @@ -239,6 +239,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr, fl6.flowi6_mark = sk->sk_mark; fl6.fl6_dport = usin->sin6_port; fl6.fl6_sport = inet->inet_sport; + fl6.flowi6_uid = sk->sk_uid; opt = rcu_dereference_protected(np->opt, sock_owned_by_user(sk)); final_p = fl6_update_dst(&fl6, opt, &final); @@ -840,6 +841,7 @@ static void tcp_v6_send_response(const struct sock *sk, struct sk_buff *skb, u32 fl6.flowi6_mark = IP6_REPLY_MARK(net, skb->mark); fl6.fl6_dport = t1->dest; fl6.fl6_sport = t1->source; + fl6.flowi6_uid = sock_net_uid(net, sk && sk_fullsock(sk) ? sk : NULL); security_skb_classify_flow(skb, flowi6_to_flowi(&fl6)); /* Pass a socket to ip6_dst_lookup either it is for RST @@ -1945,6 +1947,7 @@ struct proto tcpv6_prot = { .proto_cgroup = tcp_proto_cgroup, #endif .clear_sk = tcp_v6_clear_sk, + .diag_destroy = tcp_abort, }; static const struct inet6_protocol tcpv6_protocol = { diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index be570cd7c9ae..00f262ca4290 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -1247,6 +1247,7 @@ do_udp_sendmsg: fl6.flowi6_oif = np->sticky_pktinfo.ipi6_ifindex; fl6.flowi6_mark = sk->sk_mark; + fl6.flowi6_uid = sk->sk_uid; if (msg->msg_controllen) { opt = &opt_space; @@ -1555,6 +1556,7 @@ struct proto udpv6_prot = { .compat_getsockopt = compat_udpv6_getsockopt, #endif .clear_sk = udp_v6_clear_sk, + .diag_destroy = udp_abort, }; static struct inet_protosw udpv6_protosw = { diff --git a/net/ipv6/xfrm6_policy.c b/net/ipv6/xfrm6_policy.c index 1ca0c2f3d92b..4d0c7115f78e 100644 --- a/net/ipv6/xfrm6_policy.c +++ b/net/ipv6/xfrm6_policy.c @@ -29,15 +29,17 @@ static struct xfrm_policy_afinfo xfrm6_policy_afinfo; static struct dst_entry *xfrm6_dst_lookup(struct net *net, int tos, int oif, const xfrm_address_t *saddr, - const xfrm_address_t *daddr) + const xfrm_address_t *daddr, + u32 mark) { struct flowi6 fl6; struct dst_entry *dst; int err; memset(&fl6, 0, sizeof(fl6)); - fl6.flowi6_oif = oif; + fl6.flowi6_oif = l3mdev_master_ifindex_by_index(net, oif); fl6.flowi6_flags = FLOWI_FLAG_SKIP_NH_OIF; + fl6.flowi6_mark = mark; memcpy(&fl6.daddr, daddr, sizeof(fl6.daddr)); if (saddr) memcpy(&fl6.saddr, saddr, sizeof(fl6.saddr)); @@ -54,12 +56,13 @@ static struct dst_entry *xfrm6_dst_lookup(struct net *net, int tos, int oif, } static int xfrm6_get_saddr(struct net *net, int oif, - xfrm_address_t *saddr, xfrm_address_t *daddr) + xfrm_address_t *saddr, xfrm_address_t *daddr, + u32 mark) { struct dst_entry *dst; struct net_device *dev; - dst = xfrm6_dst_lookup(net, 0, oif, NULL, daddr); + dst = xfrm6_dst_lookup(net, 0, oif, NULL, daddr, mark); if (IS_ERR(dst)) return -EHOSTUNREACH; diff --git a/net/l2tp/l2tp_ip6.c b/net/l2tp/l2tp_ip6.c index bc2c5284d1cb..bcb7bce86d98 100644 --- a/net/l2tp/l2tp_ip6.c +++ b/net/l2tp/l2tp_ip6.c @@ -543,6 +543,7 @@ static int l2tp_ip6_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) memset(&fl6, 0, sizeof(fl6)); fl6.flowi6_mark = sk->sk_mark; + fl6.flowi6_uid = sk->sk_uid; if (lsa) { if (addr_len < SIN6_LEN_RFC2133) diff --git a/net/mac80211/ibss.c b/net/mac80211/ibss.c index 50fa92fe7d24..99723f3fe448 100644 --- a/net/mac80211/ibss.c +++ b/net/mac80211/ibss.c @@ -427,7 +427,7 @@ static void ieee80211_sta_join_ibss(struct ieee80211_sub_if_data *sdata, case NL80211_CHAN_WIDTH_5: case NL80211_CHAN_WIDTH_10: cfg80211_chandef_create(&chandef, cbss->channel, - NL80211_CHAN_WIDTH_20_NOHT); + NL80211_CHAN_NO_HT); chandef.width = sdata->u.ibss.chandef.width; break; case NL80211_CHAN_WIDTH_80: @@ -438,7 +438,7 @@ static void ieee80211_sta_join_ibss(struct ieee80211_sub_if_data *sdata, default: /* fall back to 20 MHz for unsupported modes */ cfg80211_chandef_create(&chandef, cbss->channel, - NL80211_CHAN_WIDTH_20_NOHT); + NL80211_CHAN_NO_HT); break; } diff --git a/net/mac80211/rate.c b/net/mac80211/rate.c index a4d9e9ee06be..bde0572d9c2d 100644 --- a/net/mac80211/rate.c +++ b/net/mac80211/rate.c @@ -173,9 +173,11 @@ ieee80211_rate_control_ops_get(const char *name) /* try default if specific alg requested but not found */ ops = ieee80211_try_rate_control_ops_get(ieee80211_default_rc_algo); - /* try built-in one if specific alg requested but not found */ - if (!ops && strlen(CONFIG_MAC80211_RC_DEFAULT)) + /* Note: check for > 0 is intentional to avoid clang warning */ + if (!ops && (strlen(CONFIG_MAC80211_RC_DEFAULT) > 0)) + /* try built-in one if specific alg requested but not found */ ops = ieee80211_try_rate_control_ops_get(CONFIG_MAC80211_RC_DEFAULT); + kernel_param_unlock(THIS_MODULE); return ops; diff --git a/net/netfilter/Kconfig b/net/netfilter/Kconfig index c244711a0b91..634f60331387 100644 --- a/net/netfilter/Kconfig +++ b/net/netfilter/Kconfig @@ -1278,6 +1278,8 @@ config NETFILTER_XT_MATCH_OWNER based on who created the socket: the user or group. It is also possible to check whether a socket actually exists. + Conflicts with '"quota, tag, uid" match' + config NETFILTER_XT_MATCH_POLICY tristate 'IPsec "policy" match support' depends on XFRM @@ -1311,6 +1313,22 @@ config NETFILTER_XT_MATCH_PKTTYPE To compile it as a module, choose M here. If unsure, say N. +config NETFILTER_XT_MATCH_QTAGUID + bool '"quota, tag, owner" match and stats support' + depends on NETFILTER_XT_MATCH_SOCKET + depends on NETFILTER_XT_MATCH_OWNER=n + help + This option replaces the `owner' match. In addition to matching + on uid, it keeps stats based on a tag assigned to a socket. + The full tag is comprised of a UID and an accounting tag. + The tags are assignable to sockets from user space (e.g. a download + manager can assign the socket to another UID for accounting). + Stats and control are done via /proc/net/xt_qtaguid/. + It replaces owner as it takes the same arguments, but should + really be recognized by the iptables tool. + + If unsure, say `N'. + config NETFILTER_XT_MATCH_QUOTA tristate '"quota" match support' depends on NETFILTER_ADVANCED @@ -1321,6 +1339,29 @@ config NETFILTER_XT_MATCH_QUOTA If you want to compile it as a module, say M here and read <file:Documentation/kbuild/modules.txt>. If unsure, say `N'. +config NETFILTER_XT_MATCH_QUOTA2 + tristate '"quota2" match support' + depends on NETFILTER_ADVANCED + help + This option adds a `quota2' match, which allows to match on a + byte counter correctly and not per CPU. + It allows naming the quotas. + This is based on http://xtables-addons.git.sourceforge.net + + If you want to compile it as a module, say M here and read + <file:Documentation/kbuild/modules.txt>. If unsure, say `N'. + +config NETFILTER_XT_MATCH_QUOTA2_LOG + bool '"quota2" Netfilter LOG support' + depends on NETFILTER_XT_MATCH_QUOTA2 + default n + help + This option allows `quota2' to log ONCE when a quota limit + is passed. It logs via NETLINK using the NETLINK_NFLOG family. + It logs similarly to how ipt_ULOG would without data. + + If unsure, say `N'. + config NETFILTER_XT_MATCH_RATEEST tristate '"rateest" match support' depends on NETFILTER_ADVANCED diff --git a/net/netfilter/Makefile b/net/netfilter/Makefile index 7638c36b498c..ad6a8aa63b1f 100644 --- a/net/netfilter/Makefile +++ b/net/netfilter/Makefile @@ -156,7 +156,9 @@ obj-$(CONFIG_NETFILTER_XT_MATCH_CGROUP) += xt_cgroup.o obj-$(CONFIG_NETFILTER_XT_MATCH_PHYSDEV) += xt_physdev.o obj-$(CONFIG_NETFILTER_XT_MATCH_PKTTYPE) += xt_pkttype.o obj-$(CONFIG_NETFILTER_XT_MATCH_POLICY) += xt_policy.o +obj-$(CONFIG_NETFILTER_XT_MATCH_QTAGUID) += xt_qtaguid_print.o xt_qtaguid.o obj-$(CONFIG_NETFILTER_XT_MATCH_QUOTA) += xt_quota.o +obj-$(CONFIG_NETFILTER_XT_MATCH_QUOTA2) += xt_quota2.o obj-$(CONFIG_NETFILTER_XT_MATCH_RATEEST) += xt_rateest.o obj-$(CONFIG_NETFILTER_XT_MATCH_REALM) += xt_realm.o obj-$(CONFIG_NETFILTER_XT_MATCH_RECENT) += xt_recent.o diff --git a/net/netfilter/x_tables.c b/net/netfilter/x_tables.c index 480ccd52a73f..2909c34dda7a 100644 --- a/net/netfilter/x_tables.c +++ b/net/netfilter/x_tables.c @@ -266,6 +266,60 @@ struct xt_target *xt_request_find_target(u8 af, const char *name, u8 revision) } EXPORT_SYMBOL_GPL(xt_request_find_target); + +static int xt_obj_to_user(u16 __user *psize, u16 size, + void __user *pname, const char *name, + u8 __user *prev, u8 rev) +{ + if (put_user(size, psize)) + return -EFAULT; + if (copy_to_user(pname, name, strlen(name) + 1)) + return -EFAULT; + if (put_user(rev, prev)) + return -EFAULT; + + return 0; +} + +#define XT_OBJ_TO_USER(U, K, TYPE, C_SIZE) \ + xt_obj_to_user(&U->u.TYPE##_size, C_SIZE ? : K->u.TYPE##_size, \ + U->u.user.name, K->u.kernel.TYPE->name, \ + &U->u.user.revision, K->u.kernel.TYPE->revision) + +int xt_data_to_user(void __user *dst, const void *src, + int usersize, int size) +{ + usersize = usersize ? : size; + if (copy_to_user(dst, src, usersize)) + return -EFAULT; + if (usersize != size && clear_user(dst + usersize, size - usersize)) + return -EFAULT; + + return 0; +} +EXPORT_SYMBOL_GPL(xt_data_to_user); + +#define XT_DATA_TO_USER(U, K, TYPE, C_SIZE) \ + xt_data_to_user(U->data, K->data, \ + K->u.kernel.TYPE->usersize, \ + C_SIZE ? : K->u.kernel.TYPE->TYPE##size) + +int xt_match_to_user(const struct xt_entry_match *m, + struct xt_entry_match __user *u) +{ + return XT_OBJ_TO_USER(u, m, match, 0) || + XT_DATA_TO_USER(u, m, match, 0); +} +EXPORT_SYMBOL_GPL(xt_match_to_user); + +int xt_target_to_user(const struct xt_entry_target *t, + struct xt_entry_target __user *u) +{ + return XT_OBJ_TO_USER(u, t, target, 0) || + XT_DATA_TO_USER(u, t, target, 0); +} +EXPORT_SYMBOL_GPL(xt_target_to_user); + static int match_revfn(u8 af, const char *name, u8 revision, int *bestp) { const struct xt_match *m; diff --git a/net/netfilter/xt_CT.c b/net/netfilter/xt_CT.c index febcfac7e3df..f2fe60ce286f 100644 --- a/net/netfilter/xt_CT.c +++ b/net/netfilter/xt_CT.c @@ -380,6 +380,7 @@ static struct xt_target xt_ct_tg_reg[] __read_mostly = { .name = "CT", .family = NFPROTO_UNSPEC, .targetsize = sizeof(struct xt_ct_target_info), + .usersize = offsetof(struct xt_ct_target_info, ct), .checkentry = xt_ct_tg_check_v0, .destroy = xt_ct_tg_destroy_v0, .target = xt_ct_target_v0, @@ -391,6 +392,7 @@ static struct xt_target xt_ct_tg_reg[] __read_mostly = { .family = NFPROTO_UNSPEC, .revision = 1, .targetsize = sizeof(struct xt_ct_target_info_v1), + .usersize = offsetof(struct xt_ct_target_info, ct), .checkentry = xt_ct_tg_check_v1, .destroy = xt_ct_tg_destroy_v1, .target = xt_ct_target_v1, @@ -402,6 +404,7 @@ static struct xt_target xt_ct_tg_reg[] __read_mostly = { .family = NFPROTO_UNSPEC, .revision = 2, .targetsize = sizeof(struct xt_ct_target_info_v1), + .usersize = offsetof(struct xt_ct_target_info, ct), .checkentry = xt_ct_tg_check_v2, .destroy = xt_ct_tg_destroy_v1, .target = xt_ct_target_v1, diff --git a/net/netfilter/xt_IDLETIMER.c b/net/netfilter/xt_IDLETIMER.c index 8a1d2af3eed0..ff490161d107 100644 --- a/net/netfilter/xt_IDLETIMER.c +++ b/net/netfilter/xt_IDLETIMER.c @@ -5,6 +5,7 @@ * After timer expires a kevent will be sent. * * Copyright (C) 2004, 2010 Nokia Corporation + * * Written by Timo Teras <ext-timo.teras@nokia.com> * * Converted to x_tables and reworked for upstream inclusion @@ -38,8 +39,17 @@ #include <linux/netfilter/xt_IDLETIMER.h> #include <linux/kdev_t.h> #include <linux/kobject.h> +#include <linux/skbuff.h> #include <linux/workqueue.h> #include <linux/sysfs.h> +#include <linux/rtc.h> +#include <linux/time.h> +#include <linux/math64.h> +#include <linux/suspend.h> +#include <linux/notifier.h> +#include <net/net_namespace.h> +#include <net/sock.h> +#include <net/inet_sock.h> struct idletimer_tg_attr { struct attribute attr; @@ -55,14 +65,110 @@ struct idletimer_tg { struct kobject *kobj; struct idletimer_tg_attr attr; + struct timespec delayed_timer_trigger; + struct timespec last_modified_timer; + struct timespec last_suspend_time; + struct notifier_block pm_nb; + + int timeout; unsigned int refcnt; + bool work_pending; + bool send_nl_msg; + bool active; + uid_t uid; }; static LIST_HEAD(idletimer_tg_list); static DEFINE_MUTEX(list_mutex); +static DEFINE_SPINLOCK(timestamp_lock); static struct kobject *idletimer_tg_kobj; +static bool check_for_delayed_trigger(struct idletimer_tg *timer, + struct timespec *ts) +{ + bool state; + struct timespec temp; + spin_lock_bh(×tamp_lock); + timer->work_pending = false; + if ((ts->tv_sec - timer->last_modified_timer.tv_sec) > timer->timeout || + timer->delayed_timer_trigger.tv_sec != 0) { + state = false; + temp.tv_sec = timer->timeout; + temp.tv_nsec = 0; + if (timer->delayed_timer_trigger.tv_sec != 0) { + temp = timespec_add(timer->delayed_timer_trigger, temp); + ts->tv_sec = temp.tv_sec; + ts->tv_nsec = temp.tv_nsec; + timer->delayed_timer_trigger.tv_sec = 0; + timer->work_pending = true; + schedule_work(&timer->work); + } else { + temp = timespec_add(timer->last_modified_timer, temp); + ts->tv_sec = temp.tv_sec; + ts->tv_nsec = temp.tv_nsec; + } + } else { + state = timer->active; + } + spin_unlock_bh(×tamp_lock); + return state; +} + +static void notify_netlink_uevent(const char *iface, struct idletimer_tg *timer) +{ + char iface_msg[NLMSG_MAX_SIZE]; + char state_msg[NLMSG_MAX_SIZE]; + char timestamp_msg[NLMSG_MAX_SIZE]; + char uid_msg[NLMSG_MAX_SIZE]; + char *envp[] = { iface_msg, state_msg, timestamp_msg, uid_msg, NULL }; + int res; + struct timespec ts; + uint64_t time_ns; + bool state; + + res = snprintf(iface_msg, NLMSG_MAX_SIZE, "INTERFACE=%s", + iface); + if (NLMSG_MAX_SIZE <= res) { + pr_err("message too long (%d)", res); + return; + } + + get_monotonic_boottime(&ts); + state = check_for_delayed_trigger(timer, &ts); + res = snprintf(state_msg, NLMSG_MAX_SIZE, "STATE=%s", + state ? "active" : "inactive"); + + if (NLMSG_MAX_SIZE <= res) { + pr_err("message too long (%d)", res); + return; + } + + if (state) { + res = snprintf(uid_msg, NLMSG_MAX_SIZE, "UID=%u", timer->uid); + if (NLMSG_MAX_SIZE <= res) + pr_err("message too long (%d)", res); + } else { + res = snprintf(uid_msg, NLMSG_MAX_SIZE, "UID="); + if (NLMSG_MAX_SIZE <= res) + pr_err("message too long (%d)", res); + } + + time_ns = timespec_to_ns(&ts); + res = snprintf(timestamp_msg, NLMSG_MAX_SIZE, "TIME_NS=%llu", time_ns); + if (NLMSG_MAX_SIZE <= res) { + timestamp_msg[0] = '\0'; + pr_err("message too long (%d)", res); + } + + pr_debug("putting nlmsg: <%s> <%s> <%s> <%s>\n", iface_msg, state_msg, + timestamp_msg, uid_msg); + kobject_uevent_env(idletimer_tg_kobj, KOBJ_CHANGE, envp); + return; + + +} + static struct idletimer_tg *__idletimer_tg_find_by_label(const char *label) { @@ -83,6 +189,7 @@ static ssize_t idletimer_tg_show(struct kobject *kobj, struct attribute *attr, { struct idletimer_tg *timer; unsigned long expires = 0; + unsigned long now = jiffies; mutex_lock(&list_mutex); @@ -92,11 +199,15 @@ static ssize_t idletimer_tg_show(struct kobject *kobj, struct attribute *attr, mutex_unlock(&list_mutex); - if (time_after(expires, jiffies)) + if (time_after(expires, now)) return sprintf(buf, "%u\n", - jiffies_to_msecs(expires - jiffies) / 1000); + jiffies_to_msecs(expires - now) / 1000); - return sprintf(buf, "0\n"); + if (timer->send_nl_msg) + return sprintf(buf, "0 %d\n", + jiffies_to_msecs(now - expires) / 1000); + else + return sprintf(buf, "0\n"); } static void idletimer_tg_work(struct work_struct *work) @@ -105,6 +216,9 @@ static void idletimer_tg_work(struct work_struct *work) work); sysfs_notify(idletimer_tg_kobj, NULL, timer->attr.attr.name); + + if (timer->send_nl_msg) + notify_netlink_uevent(timer->attr.attr.name, timer); } static void idletimer_tg_expired(unsigned long data) @@ -112,8 +226,55 @@ static void idletimer_tg_expired(unsigned long data) struct idletimer_tg *timer = (struct idletimer_tg *) data; pr_debug("timer %s expired\n", timer->attr.attr.name); - + spin_lock_bh(×tamp_lock); + timer->active = false; + timer->work_pending = true; schedule_work(&timer->work); + spin_unlock_bh(×tamp_lock); +} + +static int idletimer_resume(struct notifier_block *notifier, + unsigned long pm_event, void *unused) +{ + struct timespec ts; + unsigned long time_diff, now = jiffies; + struct idletimer_tg *timer = container_of(notifier, + struct idletimer_tg, pm_nb); + if (!timer) + return NOTIFY_DONE; + switch (pm_event) { + case PM_SUSPEND_PREPARE: + get_monotonic_boottime(&timer->last_suspend_time); + break; + case PM_POST_SUSPEND: + spin_lock_bh(×tamp_lock); + if (!timer->active) { + spin_unlock_bh(×tamp_lock); + break; + } + /* since jiffies are not updated when suspended now represents + * the time it would have suspended */ + if (time_after(timer->timer.expires, now)) { + get_monotonic_boottime(&ts); + ts = timespec_sub(ts, timer->last_suspend_time); + time_diff = timespec_to_jiffies(&ts); + if (timer->timer.expires > (time_diff + now)) { + mod_timer_pending(&timer->timer, + (timer->timer.expires - time_diff)); + } else { + del_timer(&timer->timer); + timer->timer.expires = 0; + timer->active = false; + timer->work_pending = true; + schedule_work(&timer->work); + } + } + spin_unlock_bh(×tamp_lock); + break; + default: + break; + } + return NOTIFY_DONE; } static int idletimer_check_sysfs_name(const char *name, unsigned int size) @@ -166,6 +327,21 @@ static int idletimer_tg_create(struct idletimer_tg_info *info) setup_timer(&info->timer->timer, idletimer_tg_expired, (unsigned long) info->timer); info->timer->refcnt = 1; + info->timer->send_nl_msg = (info->send_nl_msg == 0) ? false : true; + info->timer->active = true; + info->timer->timeout = info->timeout; + + info->timer->delayed_timer_trigger.tv_sec = 0; + info->timer->delayed_timer_trigger.tv_nsec = 0; + info->timer->work_pending = false; + info->timer->uid = 0; + get_monotonic_boottime(&info->timer->last_modified_timer); + + info->timer->pm_nb.notifier_call = idletimer_resume; + ret = register_pm_notifier(&info->timer->pm_nb); + if (ret) + printk(KERN_WARNING "[%s] Failed to register pm notifier %d\n", + __func__, ret); INIT_WORK(&info->timer->work, idletimer_tg_work); @@ -182,6 +358,42 @@ out: return ret; } +static void reset_timer(const struct idletimer_tg_info *info, + struct sk_buff *skb) +{ + unsigned long now = jiffies; + struct idletimer_tg *timer = info->timer; + bool timer_prev; + + spin_lock_bh(×tamp_lock); + timer_prev = timer->active; + timer->active = true; + /* timer_prev is used to guard overflow problem in time_before*/ + if (!timer_prev || time_before(timer->timer.expires, now)) { + pr_debug("Starting Checkentry timer (Expired, Jiffies): %lu, %lu\n", + timer->timer.expires, now); + + /* Stores the uid resposible for waking up the radio */ + if (skb && (skb->sk)) { + timer->uid = from_kuid_munged(current_user_ns(), + sock_i_uid(skb_to_full_sk(skb))); + } + + /* checks if there is a pending inactive notification*/ + if (timer->work_pending) + timer->delayed_timer_trigger = timer->last_modified_timer; + else { + timer->work_pending = true; + schedule_work(&timer->work); + } + } + + get_monotonic_boottime(&timer->last_modified_timer); + mod_timer(&timer->timer, + msecs_to_jiffies(info->timeout * 1000) + now); + spin_unlock_bh(×tamp_lock); +} + /* * The actual xt_tables plugin. */ @@ -189,15 +401,23 @@ static unsigned int idletimer_tg_target(struct sk_buff *skb, const struct xt_action_param *par) { const struct idletimer_tg_info *info = par->targinfo; + unsigned long now = jiffies; pr_debug("resetting timer %s, timeout period %u\n", info->label, info->timeout); BUG_ON(!info->timer); - mod_timer(&info->timer->timer, - msecs_to_jiffies(info->timeout * 1000) + jiffies); + info->timer->active = true; + if (time_before(info->timer->timer.expires, now)) { + schedule_work(&info->timer->work); + pr_debug("Starting timer %s (Expired, Jiffies): %lu, %lu\n", + info->label, info->timer->timer.expires, now); + } + + /* TODO: Avoid modifying timers on each packet */ + reset_timer(info, skb); return XT_CONTINUE; } @@ -206,7 +426,7 @@ static int idletimer_tg_checkentry(const struct xt_tgchk_param *par) struct idletimer_tg_info *info = par->targinfo; int ret; - pr_debug("checkentry targinfo%s\n", info->label); + pr_debug("checkentry targinfo %s\n", info->label); if (info->timeout == 0) { pr_debug("timeout value is zero\n"); @@ -228,9 +448,7 @@ static int idletimer_tg_checkentry(const struct xt_tgchk_param *par) info->timer = __idletimer_tg_find_by_label(info->label); if (info->timer) { info->timer->refcnt++; - mod_timer(&info->timer->timer, - msecs_to_jiffies(info->timeout * 1000) + jiffies); - + reset_timer(info, NULL); pr_debug("increased refcnt of timer %s to %u\n", info->label, info->timer->refcnt); } else { @@ -243,6 +461,7 @@ static int idletimer_tg_checkentry(const struct xt_tgchk_param *par) } mutex_unlock(&list_mutex); + return 0; } @@ -260,11 +479,13 @@ static void idletimer_tg_destroy(const struct xt_tgdtor_param *par) list_del(&info->timer->entry); del_timer_sync(&info->timer->timer); sysfs_remove_file(idletimer_tg_kobj, &info->timer->attr.attr); + unregister_pm_notifier(&info->timer->pm_nb); + cancel_work_sync(&info->timer->work); kfree(info->timer->attr.attr.name); kfree(info->timer); } else { pr_debug("decreased refcnt of timer %s to %u\n", - info->label, info->timer->refcnt); + info->label, info->timer->refcnt); } mutex_unlock(&list_mutex); @@ -272,9 +493,11 @@ static void idletimer_tg_destroy(const struct xt_tgdtor_param *par) static struct xt_target idletimer_tg __read_mostly = { .name = "IDLETIMER", + .revision = 1, .family = NFPROTO_UNSPEC, .target = idletimer_tg_target, .targetsize = sizeof(struct idletimer_tg_info), + .usersize = offsetof(struct idletimer_tg_info, timer), .checkentry = idletimer_tg_checkentry, .destroy = idletimer_tg_destroy, .me = THIS_MODULE, @@ -337,3 +560,4 @@ MODULE_DESCRIPTION("Xtables: idle time monitor"); MODULE_LICENSE("GPL v2"); MODULE_ALIAS("ipt_IDLETIMER"); MODULE_ALIAS("ip6t_IDLETIMER"); +MODULE_ALIAS("arpt_IDLETIMER"); diff --git a/net/netfilter/xt_LED.c b/net/netfilter/xt_LED.c index 0858fe17e14a..2d1c5c169a26 100644 --- a/net/netfilter/xt_LED.c +++ b/net/netfilter/xt_LED.c @@ -198,6 +198,7 @@ static struct xt_target led_tg_reg __read_mostly = { .family = NFPROTO_UNSPEC, .target = led_tg, .targetsize = sizeof(struct xt_led_info), + .usersize = offsetof(struct xt_led_info, internal_data), .checkentry = led_tg_check, .destroy = led_tg_destroy, .me = THIS_MODULE, diff --git a/net/netfilter/xt_RATEEST.c b/net/netfilter/xt_RATEEST.c index 3b5c76ae4375..6768d4d2ffd0 100644 --- a/net/netfilter/xt_RATEEST.c +++ b/net/netfilter/xt_RATEEST.c @@ -181,6 +181,7 @@ static struct xt_target xt_rateest_tg_reg __read_mostly = { .checkentry = xt_rateest_tg_checkentry, .destroy = xt_rateest_tg_destroy, .targetsize = sizeof(struct xt_rateest_target_info), + .usersize = offsetof(struct xt_rateest_target_info, est), .me = THIS_MODULE, }; diff --git a/net/netfilter/xt_TEE.c b/net/netfilter/xt_TEE.c index 3eff7b67cdf2..d597b504a82e 100644 --- a/net/netfilter/xt_TEE.c +++ b/net/netfilter/xt_TEE.c @@ -127,6 +127,7 @@ static struct xt_target tee_tg_reg[] __read_mostly = { .family = NFPROTO_IPV4, .target = tee_tg4, .targetsize = sizeof(struct xt_tee_tginfo), + .usersize = offsetof(struct xt_tee_tginfo, priv), .checkentry = tee_tg_check, .destroy = tee_tg_destroy, .me = THIS_MODULE, @@ -138,6 +139,7 @@ static struct xt_target tee_tg_reg[] __read_mostly = { .family = NFPROTO_IPV6, .target = tee_tg6, .targetsize = sizeof(struct xt_tee_tginfo), + .usersize = offsetof(struct xt_tee_tginfo, priv), .checkentry = tee_tg_check, .destroy = tee_tg_destroy, .me = THIS_MODULE, diff --git a/net/netfilter/xt_bpf.c b/net/netfilter/xt_bpf.c index 7b993f25aab9..2e26df862d38 100644 --- a/net/netfilter/xt_bpf.c +++ b/net/netfilter/xt_bpf.c @@ -60,6 +60,7 @@ static struct xt_match bpf_mt_reg __read_mostly = { .match = bpf_mt, .destroy = bpf_mt_destroy, .matchsize = sizeof(struct xt_bpf_info), + .usersize = offsetof(struct xt_bpf_info, filter), .me = THIS_MODULE, }; diff --git a/net/netfilter/xt_connlimit.c b/net/netfilter/xt_connlimit.c index 99bbc829868d..aab11f7ab547 100644 --- a/net/netfilter/xt_connlimit.c +++ b/net/netfilter/xt_connlimit.c @@ -437,6 +437,7 @@ static struct xt_match connlimit_mt_reg __read_mostly = { .checkentry = connlimit_mt_check, .match = connlimit_mt, .matchsize = sizeof(struct xt_connlimit_info), + .usersize = offsetof(struct xt_connlimit_info, data), .destroy = connlimit_mt_destroy, .me = THIS_MODULE, }; diff --git a/net/netfilter/xt_hashlimit.c b/net/netfilter/xt_hashlimit.c index 7381be0cdcdf..d893cc133de4 100644 --- a/net/netfilter/xt_hashlimit.c +++ b/net/netfilter/xt_hashlimit.c @@ -726,6 +726,7 @@ static struct xt_match hashlimit_mt_reg[] __read_mostly = { .family = NFPROTO_IPV4, .match = hashlimit_mt, .matchsize = sizeof(struct xt_hashlimit_mtinfo1), + .usersize = offsetof(struct xt_hashlimit_mtinfo1, hinfo), .checkentry = hashlimit_mt_check, .destroy = hashlimit_mt_destroy, .me = THIS_MODULE, @@ -737,6 +738,7 @@ static struct xt_match hashlimit_mt_reg[] __read_mostly = { .family = NFPROTO_IPV6, .match = hashlimit_mt, .matchsize = sizeof(struct xt_hashlimit_mtinfo1), + .usersize = offsetof(struct xt_hashlimit_mtinfo1, hinfo), .checkentry = hashlimit_mt_check, .destroy = hashlimit_mt_destroy, .me = THIS_MODULE, diff --git a/net/netfilter/xt_limit.c b/net/netfilter/xt_limit.c index bef850596558..e84de7656289 100644 --- a/net/netfilter/xt_limit.c +++ b/net/netfilter/xt_limit.c @@ -193,6 +193,7 @@ static struct xt_match limit_mt_reg __read_mostly = { .compat_from_user = limit_mt_compat_from_user, .compat_to_user = limit_mt_compat_to_user, #endif + .usersize = offsetof(struct xt_rateinfo, prev), .me = THIS_MODULE, }; diff --git a/net/netfilter/xt_nfacct.c b/net/netfilter/xt_nfacct.c index 3048a7e3a90a..e9adf6ebca30 100644 --- a/net/netfilter/xt_nfacct.c +++ b/net/netfilter/xt_nfacct.c @@ -62,6 +62,7 @@ static struct xt_match nfacct_mt_reg __read_mostly = { .match = nfacct_mt, .destroy = nfacct_mt_destroy, .matchsize = sizeof(struct xt_nfacct_match_info), + .usersize = offsetof(struct xt_nfacct_match_info, nfacct), .me = THIS_MODULE, }; diff --git a/net/netfilter/xt_qtaguid.c b/net/netfilter/xt_qtaguid.c new file mode 100644 index 000000000000..500b5c944a41 --- /dev/null +++ b/net/netfilter/xt_qtaguid.c @@ -0,0 +1,3011 @@ +/* + * Kernel iptables module to track stats for packets based on user tags. + * + * (C) 2011 Google, Inc + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 as + * published by the Free Software Foundation. + */ + +/* + * There are run-time debug flags enabled via the debug_mask module param, or + * via the DEFAULT_DEBUG_MASK. See xt_qtaguid_internal.h. + */ +#define DEBUG + +#include <linux/file.h> +#include <linux/inetdevice.h> +#include <linux/module.h> +#include <linux/miscdevice.h> +#include <linux/netfilter/x_tables.h> +#include <linux/netfilter/xt_qtaguid.h> +#include <linux/ratelimit.h> +#include <linux/seq_file.h> +#include <linux/skbuff.h> +#include <linux/workqueue.h> +#include <net/addrconf.h> +#include <net/sock.h> +#include <net/tcp.h> +#include <net/udp.h> + +#if defined(CONFIG_IP6_NF_IPTABLES) || defined(CONFIG_IP6_NF_IPTABLES_MODULE) +#include <linux/netfilter_ipv6/ip6_tables.h> +#endif + +#include <linux/netfilter/xt_socket.h> +#include "xt_qtaguid_internal.h" +#include "xt_qtaguid_print.h" +#include "../../fs/proc/internal.h" + +/* + * We only use the xt_socket funcs within a similar context to avoid unexpected + * return values. + */ +#define XT_SOCKET_SUPPORTED_HOOKS \ + ((1 << NF_INET_PRE_ROUTING) | (1 << NF_INET_LOCAL_IN)) + + +static const char *module_procdirname = "xt_qtaguid"; +static struct proc_dir_entry *xt_qtaguid_procdir; + +static unsigned int proc_iface_perms = S_IRUGO; +module_param_named(iface_perms, proc_iface_perms, uint, S_IRUGO | S_IWUSR); + +static struct proc_dir_entry *xt_qtaguid_stats_file; +static unsigned int proc_stats_perms = S_IRUGO; +module_param_named(stats_perms, proc_stats_perms, uint, S_IRUGO | S_IWUSR); + +static struct proc_dir_entry *xt_qtaguid_ctrl_file; + +/* Everybody can write. But proc_ctrl_write_limited is true by default which + * limits what can be controlled. See the can_*() functions. + */ +static unsigned int proc_ctrl_perms = S_IRUGO | S_IWUGO; +module_param_named(ctrl_perms, proc_ctrl_perms, uint, S_IRUGO | S_IWUSR); + +/* Limited by default, so the gid of the ctrl and stats proc entries + * will limit what can be done. See the can_*() functions. + */ +static bool proc_stats_readall_limited = true; +static bool proc_ctrl_write_limited = true; + +module_param_named(stats_readall_limited, proc_stats_readall_limited, bool, + S_IRUGO | S_IWUSR); +module_param_named(ctrl_write_limited, proc_ctrl_write_limited, bool, + S_IRUGO | S_IWUSR); + +/* + * Limit the number of active tags (via socket tags) for a given UID. + * Multiple processes could share the UID. + */ +static int max_sock_tags = DEFAULT_MAX_SOCK_TAGS; +module_param(max_sock_tags, int, S_IRUGO | S_IWUSR); + +/* + * After the kernel has initiallized this module, it is still possible + * to make it passive. + * Setting passive to Y: + * - the iface stats handling will not act on notifications. + * - iptables matches will never match. + * - ctrl commands silently succeed. + * - stats are always empty. + * This is mostly usefull when a bug is suspected. + */ +static bool module_passive; +module_param_named(passive, module_passive, bool, S_IRUGO | S_IWUSR); + +/* + * Control how qtaguid data is tracked per proc/uid. + * Setting tag_tracking_passive to Y: + * - don't create proc specific structs to track tags + * - don't check that active tag stats exceed some limits. + * - don't clean up socket tags on process exits. + * This is mostly usefull when a bug is suspected. + */ +static bool qtu_proc_handling_passive; +module_param_named(tag_tracking_passive, qtu_proc_handling_passive, bool, + S_IRUGO | S_IWUSR); + +#define QTU_DEV_NAME "xt_qtaguid" + +uint qtaguid_debug_mask = DEFAULT_DEBUG_MASK; +module_param_named(debug_mask, qtaguid_debug_mask, uint, S_IRUGO | S_IWUSR); + +/*---------------------------------------------------------------------------*/ +static const char *iface_stat_procdirname = "iface_stat"; +static struct proc_dir_entry *iface_stat_procdir; +/* + * The iface_stat_all* will go away once userspace gets use to the new fields + * that have a format line. + */ +static const char *iface_stat_all_procfilename = "iface_stat_all"; +static struct proc_dir_entry *iface_stat_all_procfile; +static const char *iface_stat_fmt_procfilename = "iface_stat_fmt"; +static struct proc_dir_entry *iface_stat_fmt_procfile; + + +static LIST_HEAD(iface_stat_list); +static DEFINE_SPINLOCK(iface_stat_list_lock); + +static struct rb_root sock_tag_tree = RB_ROOT; +static DEFINE_SPINLOCK(sock_tag_list_lock); + +static struct rb_root tag_counter_set_tree = RB_ROOT; +static DEFINE_SPINLOCK(tag_counter_set_list_lock); + +static struct rb_root uid_tag_data_tree = RB_ROOT; +static DEFINE_SPINLOCK(uid_tag_data_tree_lock); + +static struct rb_root proc_qtu_data_tree = RB_ROOT; +/* No proc_qtu_data_tree_lock; use uid_tag_data_tree_lock */ + +static struct qtaguid_event_counts qtu_events; +/*----------------------------------------------*/ +static bool can_manipulate_uids(void) +{ + /* root pwnd */ + return in_egroup_p(xt_qtaguid_ctrl_file->gid) + || unlikely(!from_kuid(&init_user_ns, current_fsuid())) || unlikely(!proc_ctrl_write_limited) + || unlikely(uid_eq(current_fsuid(), xt_qtaguid_ctrl_file->uid)); +} + +static bool can_impersonate_uid(kuid_t uid) +{ + return uid_eq(uid, current_fsuid()) || can_manipulate_uids(); +} + +static bool can_read_other_uid_stats(kuid_t uid) +{ + /* root pwnd */ + return in_egroup_p(xt_qtaguid_stats_file->gid) + || unlikely(!from_kuid(&init_user_ns, current_fsuid())) || uid_eq(uid, current_fsuid()) + || unlikely(!proc_stats_readall_limited) + || unlikely(uid_eq(current_fsuid(), xt_qtaguid_ctrl_file->uid)); +} + +static inline void dc_add_byte_packets(struct data_counters *counters, int set, + enum ifs_tx_rx direction, + enum ifs_proto ifs_proto, + int bytes, + int packets) +{ + counters->bpc[set][direction][ifs_proto].bytes += bytes; + counters->bpc[set][direction][ifs_proto].packets += packets; +} + +static struct tag_node *tag_node_tree_search(struct rb_root *root, tag_t tag) +{ + struct rb_node *node = root->rb_node; + + while (node) { + struct tag_node *data = rb_entry(node, struct tag_node, node); + int result; + RB_DEBUG("qtaguid: tag_node_tree_search(0x%llx): " + " node=%p data=%p\n", tag, node, data); + result = tag_compare(tag, data->tag); + RB_DEBUG("qtaguid: tag_node_tree_search(0x%llx): " + " data.tag=0x%llx (uid=%u) res=%d\n", + tag, data->tag, get_uid_from_tag(data->tag), result); + if (result < 0) + node = node->rb_left; + else if (result > 0) + node = node->rb_right; + else + return data; + } + return NULL; +} + +static void tag_node_tree_insert(struct tag_node *data, struct rb_root *root) +{ + struct rb_node **new = &(root->rb_node), *parent = NULL; + + /* Figure out where to put new node */ + while (*new) { + struct tag_node *this = rb_entry(*new, struct tag_node, + node); + int result = tag_compare(data->tag, this->tag); + RB_DEBUG("qtaguid: %s(): tag=0x%llx" + " (uid=%u)\n", __func__, + this->tag, + get_uid_from_tag(this->tag)); + parent = *new; + if (result < 0) + new = &((*new)->rb_left); + else if (result > 0) + new = &((*new)->rb_right); + else + BUG(); + } + + /* Add new node and rebalance tree. */ + rb_link_node(&data->node, parent, new); + rb_insert_color(&data->node, root); +} + +static void tag_stat_tree_insert(struct tag_stat *data, struct rb_root *root) +{ + tag_node_tree_insert(&data->tn, root); +} + +static struct tag_stat *tag_stat_tree_search(struct rb_root *root, tag_t tag) +{ + struct tag_node *node = tag_node_tree_search(root, tag); + if (!node) + return NULL; + return rb_entry(&node->node, struct tag_stat, tn.node); +} + +static void tag_counter_set_tree_insert(struct tag_counter_set *data, + struct rb_root *root) +{ + tag_node_tree_insert(&data->tn, root); +} + +static struct tag_counter_set *tag_counter_set_tree_search(struct rb_root *root, + tag_t tag) +{ + struct tag_node *node = tag_node_tree_search(root, tag); + if (!node) + return NULL; + return rb_entry(&node->node, struct tag_counter_set, tn.node); + +} + +static void tag_ref_tree_insert(struct tag_ref *data, struct rb_root *root) +{ + tag_node_tree_insert(&data->tn, root); +} + +static struct tag_ref *tag_ref_tree_search(struct rb_root *root, tag_t tag) +{ + struct tag_node *node = tag_node_tree_search(root, tag); + if (!node) + return NULL; + return rb_entry(&node->node, struct tag_ref, tn.node); +} + +static struct sock_tag *sock_tag_tree_search(struct rb_root *root, + const struct sock *sk) +{ + struct rb_node *node = root->rb_node; + + while (node) { + struct sock_tag *data = rb_entry(node, struct sock_tag, + sock_node); + if (sk < data->sk) + node = node->rb_left; + else if (sk > data->sk) + node = node->rb_right; + else + return data; + } + return NULL; +} + +static void sock_tag_tree_insert(struct sock_tag *data, struct rb_root *root) +{ + struct rb_node **new = &(root->rb_node), *parent = NULL; + + /* Figure out where to put new node */ + while (*new) { + struct sock_tag *this = rb_entry(*new, struct sock_tag, + sock_node); + parent = *new; + if (data->sk < this->sk) + new = &((*new)->rb_left); + else if (data->sk > this->sk) + new = &((*new)->rb_right); + else + BUG(); + } + + /* Add new node and rebalance tree. */ + rb_link_node(&data->sock_node, parent, new); + rb_insert_color(&data->sock_node, root); +} + +static void sock_tag_tree_erase(struct rb_root *st_to_free_tree) +{ + struct rb_node *node; + struct sock_tag *st_entry; + + node = rb_first(st_to_free_tree); + while (node) { + st_entry = rb_entry(node, struct sock_tag, sock_node); + node = rb_next(node); + CT_DEBUG("qtaguid: %s(): " + "erase st: sk=%p tag=0x%llx (uid=%u)\n", __func__, + st_entry->sk, + st_entry->tag, + get_uid_from_tag(st_entry->tag)); + rb_erase(&st_entry->sock_node, st_to_free_tree); + sock_put(st_entry->sk); + kfree(st_entry); + } +} + +static struct proc_qtu_data *proc_qtu_data_tree_search(struct rb_root *root, + const pid_t pid) +{ + struct rb_node *node = root->rb_node; + + while (node) { + struct proc_qtu_data *data = rb_entry(node, + struct proc_qtu_data, + node); + if (pid < data->pid) + node = node->rb_left; + else if (pid > data->pid) + node = node->rb_right; + else + return data; + } + return NULL; +} + +static void proc_qtu_data_tree_insert(struct proc_qtu_data *data, + struct rb_root *root) +{ + struct rb_node **new = &(root->rb_node), *parent = NULL; + + /* Figure out where to put new node */ + while (*new) { + struct proc_qtu_data *this = rb_entry(*new, + struct proc_qtu_data, + node); + parent = *new; + if (data->pid < this->pid) + new = &((*new)->rb_left); + else if (data->pid > this->pid) + new = &((*new)->rb_right); + else + BUG(); + } + + /* Add new node and rebalance tree. */ + rb_link_node(&data->node, parent, new); + rb_insert_color(&data->node, root); +} + +static void uid_tag_data_tree_insert(struct uid_tag_data *data, + struct rb_root *root) +{ + struct rb_node **new = &(root->rb_node), *parent = NULL; + + /* Figure out where to put new node */ + while (*new) { + struct uid_tag_data *this = rb_entry(*new, + struct uid_tag_data, + node); + parent = *new; + if (data->uid < this->uid) + new = &((*new)->rb_left); + else if (data->uid > this->uid) + new = &((*new)->rb_right); + else + BUG(); + } + + /* Add new node and rebalance tree. */ + rb_link_node(&data->node, parent, new); + rb_insert_color(&data->node, root); +} + +static struct uid_tag_data *uid_tag_data_tree_search(struct rb_root *root, + uid_t uid) +{ + struct rb_node *node = root->rb_node; + + while (node) { + struct uid_tag_data *data = rb_entry(node, + struct uid_tag_data, + node); + if (uid < data->uid) + node = node->rb_left; + else if (uid > data->uid) + node = node->rb_right; + else + return data; + } + return NULL; +} + +/* + * Allocates a new uid_tag_data struct if needed. + * Returns a pointer to the found or allocated uid_tag_data. + * Returns a PTR_ERR on failures, and lock is not held. + * If found is not NULL: + * sets *found to true if not allocated. + * sets *found to false if allocated. + */ +struct uid_tag_data *get_uid_data(uid_t uid, bool *found_res) +{ + struct uid_tag_data *utd_entry; + + /* Look for top level uid_tag_data for the UID */ + utd_entry = uid_tag_data_tree_search(&uid_tag_data_tree, uid); + DR_DEBUG("qtaguid: get_uid_data(%u) utd=%p\n", uid, utd_entry); + + if (found_res) + *found_res = utd_entry; + if (utd_entry) + return utd_entry; + + utd_entry = kzalloc(sizeof(*utd_entry), GFP_ATOMIC); + if (!utd_entry) { + pr_err("qtaguid: get_uid_data(%u): " + "tag data alloc failed\n", uid); + return ERR_PTR(-ENOMEM); + } + + utd_entry->uid = uid; + utd_entry->tag_ref_tree = RB_ROOT; + uid_tag_data_tree_insert(utd_entry, &uid_tag_data_tree); + DR_DEBUG("qtaguid: get_uid_data(%u) new utd=%p\n", uid, utd_entry); + return utd_entry; +} + +/* Never returns NULL. Either PTR_ERR or a valid ptr. */ +static struct tag_ref *new_tag_ref(tag_t new_tag, + struct uid_tag_data *utd_entry) +{ + struct tag_ref *tr_entry; + int res; + + if (utd_entry->num_active_tags + 1 > max_sock_tags) { + pr_info("qtaguid: new_tag_ref(0x%llx): " + "tag ref alloc quota exceeded. max=%d\n", + new_tag, max_sock_tags); + res = -EMFILE; + goto err_res; + + } + + tr_entry = kzalloc(sizeof(*tr_entry), GFP_ATOMIC); + if (!tr_entry) { + pr_err("qtaguid: new_tag_ref(0x%llx): " + "tag ref alloc failed\n", + new_tag); + res = -ENOMEM; + goto err_res; + } + tr_entry->tn.tag = new_tag; + /* tr_entry->num_sock_tags handled by caller */ + utd_entry->num_active_tags++; + tag_ref_tree_insert(tr_entry, &utd_entry->tag_ref_tree); + DR_DEBUG("qtaguid: new_tag_ref(0x%llx): " + " inserted new tag ref %p\n", + new_tag, tr_entry); + return tr_entry; + +err_res: + return ERR_PTR(res); +} + +static struct tag_ref *lookup_tag_ref(tag_t full_tag, + struct uid_tag_data **utd_res) +{ + struct uid_tag_data *utd_entry; + struct tag_ref *tr_entry; + bool found_utd; + uid_t uid = get_uid_from_tag(full_tag); + + DR_DEBUG("qtaguid: lookup_tag_ref(tag=0x%llx (uid=%u))\n", + full_tag, uid); + + utd_entry = get_uid_data(uid, &found_utd); + if (IS_ERR_OR_NULL(utd_entry)) { + if (utd_res) + *utd_res = utd_entry; + return NULL; + } + + tr_entry = tag_ref_tree_search(&utd_entry->tag_ref_tree, full_tag); + if (utd_res) + *utd_res = utd_entry; + DR_DEBUG("qtaguid: lookup_tag_ref(0x%llx) utd_entry=%p tr_entry=%p\n", + full_tag, utd_entry, tr_entry); + return tr_entry; +} + +/* Never returns NULL. Either PTR_ERR or a valid ptr. */ +static struct tag_ref *get_tag_ref(tag_t full_tag, + struct uid_tag_data **utd_res) +{ + struct uid_tag_data *utd_entry; + struct tag_ref *tr_entry; + + DR_DEBUG("qtaguid: get_tag_ref(0x%llx)\n", + full_tag); + tr_entry = lookup_tag_ref(full_tag, &utd_entry); + BUG_ON(IS_ERR_OR_NULL(utd_entry)); + if (!tr_entry) + tr_entry = new_tag_ref(full_tag, utd_entry); + + if (utd_res) + *utd_res = utd_entry; + DR_DEBUG("qtaguid: get_tag_ref(0x%llx) utd=%p tr=%p\n", + full_tag, utd_entry, tr_entry); + return tr_entry; +} + +/* Checks and maybe frees the UID Tag Data entry */ +static void put_utd_entry(struct uid_tag_data *utd_entry) +{ + /* Are we done with the UID tag data entry? */ + if (RB_EMPTY_ROOT(&utd_entry->tag_ref_tree) && + !utd_entry->num_pqd) { + DR_DEBUG("qtaguid: %s(): " + "erase utd_entry=%p uid=%u " + "by pid=%u tgid=%u uid=%u\n", __func__, + utd_entry, utd_entry->uid, + current->pid, current->tgid, from_kuid(&init_user_ns, current_fsuid())); + BUG_ON(utd_entry->num_active_tags); + rb_erase(&utd_entry->node, &uid_tag_data_tree); + kfree(utd_entry); + } else { + DR_DEBUG("qtaguid: %s(): " + "utd_entry=%p still has %d tags %d proc_qtu_data\n", + __func__, utd_entry, utd_entry->num_active_tags, + utd_entry->num_pqd); + BUG_ON(!(utd_entry->num_active_tags || + utd_entry->num_pqd)); + } +} + +/* + * If no sock_tags are using this tag_ref, + * decrements refcount of utd_entry, removes tr_entry + * from utd_entry->tag_ref_tree and frees. + */ +static void free_tag_ref_from_utd_entry(struct tag_ref *tr_entry, + struct uid_tag_data *utd_entry) +{ + DR_DEBUG("qtaguid: %s(): %p tag=0x%llx (uid=%u)\n", __func__, + tr_entry, tr_entry->tn.tag, + get_uid_from_tag(tr_entry->tn.tag)); + if (!tr_entry->num_sock_tags) { + BUG_ON(!utd_entry->num_active_tags); + utd_entry->num_active_tags--; + rb_erase(&tr_entry->tn.node, &utd_entry->tag_ref_tree); + DR_DEBUG("qtaguid: %s(): erased %p\n", __func__, tr_entry); + kfree(tr_entry); + } +} + +static void put_tag_ref_tree(tag_t full_tag, struct uid_tag_data *utd_entry) +{ + struct rb_node *node; + struct tag_ref *tr_entry; + tag_t acct_tag; + + DR_DEBUG("qtaguid: %s(tag=0x%llx (uid=%u))\n", __func__, + full_tag, get_uid_from_tag(full_tag)); + acct_tag = get_atag_from_tag(full_tag); + node = rb_first(&utd_entry->tag_ref_tree); + while (node) { + tr_entry = rb_entry(node, struct tag_ref, tn.node); + node = rb_next(node); + if (!acct_tag || tr_entry->tn.tag == full_tag) + free_tag_ref_from_utd_entry(tr_entry, utd_entry); + } +} + +static ssize_t read_proc_u64(struct file *file, char __user *buf, + size_t size, loff_t *ppos) +{ + uint64_t *valuep = PDE_DATA(file_inode(file)); + char tmp[24]; + size_t tmp_size; + + tmp_size = scnprintf(tmp, sizeof(tmp), "%llu\n", *valuep); + return simple_read_from_buffer(buf, size, ppos, tmp, tmp_size); +} + +static ssize_t read_proc_bool(struct file *file, char __user *buf, + size_t size, loff_t *ppos) +{ + bool *valuep = PDE_DATA(file_inode(file)); + char tmp[24]; + size_t tmp_size; + + tmp_size = scnprintf(tmp, sizeof(tmp), "%u\n", *valuep); + return simple_read_from_buffer(buf, size, ppos, tmp, tmp_size); +} + +static int get_active_counter_set(tag_t tag) +{ + int active_set = 0; + struct tag_counter_set *tcs; + + MT_DEBUG("qtaguid: get_active_counter_set(tag=0x%llx)" + " (uid=%u)\n", + tag, get_uid_from_tag(tag)); + /* For now we only handle UID tags for active sets */ + tag = get_utag_from_tag(tag); + spin_lock_bh(&tag_counter_set_list_lock); + tcs = tag_counter_set_tree_search(&tag_counter_set_tree, tag); + if (tcs) + active_set = tcs->active_set; + spin_unlock_bh(&tag_counter_set_list_lock); + return active_set; +} + +/* + * Find the entry for tracking the specified interface. + * Caller must hold iface_stat_list_lock + */ +static struct iface_stat *get_iface_entry(const char *ifname) +{ + struct iface_stat *iface_entry; + + /* Find the entry for tracking the specified tag within the interface */ + if (ifname == NULL) { + pr_info("qtaguid: iface_stat: get() NULL device name\n"); + return NULL; + } + + /* Iterate over interfaces */ + list_for_each_entry(iface_entry, &iface_stat_list, list) { + if (!strcmp(ifname, iface_entry->ifname)) + goto done; + } + iface_entry = NULL; +done: + return iface_entry; +} + +/* This is for fmt2 only */ +static void pp_iface_stat_header(struct seq_file *m) +{ + seq_puts(m, + "ifname " + "total_skb_rx_bytes total_skb_rx_packets " + "total_skb_tx_bytes total_skb_tx_packets " + "rx_tcp_bytes rx_tcp_packets " + "rx_udp_bytes rx_udp_packets " + "rx_other_bytes rx_other_packets " + "tx_tcp_bytes tx_tcp_packets " + "tx_udp_bytes tx_udp_packets " + "tx_other_bytes tx_other_packets\n" + ); +} + +static void pp_iface_stat_line(struct seq_file *m, + struct iface_stat *iface_entry) +{ + struct data_counters *cnts; + int cnt_set = 0; /* We only use one set for the device */ + cnts = &iface_entry->totals_via_skb; + seq_printf(m, "%s %llu %llu %llu %llu %llu %llu %llu %llu " + "%llu %llu %llu %llu %llu %llu %llu %llu\n", + iface_entry->ifname, + dc_sum_bytes(cnts, cnt_set, IFS_RX), + dc_sum_packets(cnts, cnt_set, IFS_RX), + dc_sum_bytes(cnts, cnt_set, IFS_TX), + dc_sum_packets(cnts, cnt_set, IFS_TX), + cnts->bpc[cnt_set][IFS_RX][IFS_TCP].bytes, + cnts->bpc[cnt_set][IFS_RX][IFS_TCP].packets, + cnts->bpc[cnt_set][IFS_RX][IFS_UDP].bytes, + cnts->bpc[cnt_set][IFS_RX][IFS_UDP].packets, + cnts->bpc[cnt_set][IFS_RX][IFS_PROTO_OTHER].bytes, + cnts->bpc[cnt_set][IFS_RX][IFS_PROTO_OTHER].packets, + cnts->bpc[cnt_set][IFS_TX][IFS_TCP].bytes, + cnts->bpc[cnt_set][IFS_TX][IFS_TCP].packets, + cnts->bpc[cnt_set][IFS_TX][IFS_UDP].bytes, + cnts->bpc[cnt_set][IFS_TX][IFS_UDP].packets, + cnts->bpc[cnt_set][IFS_TX][IFS_PROTO_OTHER].bytes, + cnts->bpc[cnt_set][IFS_TX][IFS_PROTO_OTHER].packets); +} + +struct proc_iface_stat_fmt_info { + int fmt; +}; + +static void *iface_stat_fmt_proc_start(struct seq_file *m, loff_t *pos) +{ + struct proc_iface_stat_fmt_info *p = m->private; + loff_t n = *pos; + + /* + * This lock will prevent iface_stat_update() from changing active, + * and in turn prevent an interface from unregistering itself. + */ + spin_lock_bh(&iface_stat_list_lock); + + if (unlikely(module_passive)) + return NULL; + + if (!n && p->fmt == 2) + pp_iface_stat_header(m); + + return seq_list_start(&iface_stat_list, n); +} + +static void *iface_stat_fmt_proc_next(struct seq_file *m, void *p, loff_t *pos) +{ + return seq_list_next(p, &iface_stat_list, pos); +} + +static void iface_stat_fmt_proc_stop(struct seq_file *m, void *p) +{ + spin_unlock_bh(&iface_stat_list_lock); +} + +static int iface_stat_fmt_proc_show(struct seq_file *m, void *v) +{ + struct proc_iface_stat_fmt_info *p = m->private; + struct iface_stat *iface_entry; + struct rtnl_link_stats64 dev_stats, *stats; + struct rtnl_link_stats64 no_dev_stats = {0}; + + + CT_DEBUG("qtaguid:proc iface_stat_fmt pid=%u tgid=%u uid=%u\n", + current->pid, current->tgid, from_kuid(&init_user_ns, current_fsuid())); + + iface_entry = list_entry(v, struct iface_stat, list); + + if (iface_entry->active) { + stats = dev_get_stats(iface_entry->net_dev, + &dev_stats); + } else { + stats = &no_dev_stats; + } + /* + * If the meaning of the data changes, then update the fmtX + * string. + */ + if (p->fmt == 1) { + seq_printf(m, "%s %d %llu %llu %llu %llu %llu %llu %llu %llu\n", + iface_entry->ifname, + iface_entry->active, + iface_entry->totals_via_dev[IFS_RX].bytes, + iface_entry->totals_via_dev[IFS_RX].packets, + iface_entry->totals_via_dev[IFS_TX].bytes, + iface_entry->totals_via_dev[IFS_TX].packets, + stats->rx_bytes, stats->rx_packets, + stats->tx_bytes, stats->tx_packets + ); + } else { + pp_iface_stat_line(m, iface_entry); + } + return 0; +} + +static const struct file_operations read_u64_fops = { + .read = read_proc_u64, + .llseek = default_llseek, +}; + +static const struct file_operations read_bool_fops = { + .read = read_proc_bool, + .llseek = default_llseek, +}; + +static void iface_create_proc_worker(struct work_struct *work) +{ + struct proc_dir_entry *proc_entry; + struct iface_stat_work *isw = container_of(work, struct iface_stat_work, + iface_work); + struct iface_stat *new_iface = isw->iface_entry; + + /* iface_entries are not deleted, so safe to manipulate. */ + proc_entry = proc_mkdir(new_iface->ifname, iface_stat_procdir); + if (IS_ERR_OR_NULL(proc_entry)) { + pr_err("qtaguid: iface_stat: create_proc(): alloc failed.\n"); + kfree(isw); + return; + } + + new_iface->proc_ptr = proc_entry; + + proc_create_data("tx_bytes", proc_iface_perms, proc_entry, + &read_u64_fops, + &new_iface->totals_via_dev[IFS_TX].bytes); + proc_create_data("rx_bytes", proc_iface_perms, proc_entry, + &read_u64_fops, + &new_iface->totals_via_dev[IFS_RX].bytes); + proc_create_data("tx_packets", proc_iface_perms, proc_entry, + &read_u64_fops, + &new_iface->totals_via_dev[IFS_TX].packets); + proc_create_data("rx_packets", proc_iface_perms, proc_entry, + &read_u64_fops, + &new_iface->totals_via_dev[IFS_RX].packets); + proc_create_data("active", proc_iface_perms, proc_entry, + &read_bool_fops, &new_iface->active); + + IF_DEBUG("qtaguid: iface_stat: create_proc(): done " + "entry=%p dev=%s\n", new_iface, new_iface->ifname); + kfree(isw); +} + +/* + * Will set the entry's active state, and + * update the net_dev accordingly also. + */ +static void _iface_stat_set_active(struct iface_stat *entry, + struct net_device *net_dev, + bool activate) +{ + if (activate) { + entry->net_dev = net_dev; + entry->active = true; + IF_DEBUG("qtaguid: %s(%s): " + "enable tracking. rfcnt=%d\n", __func__, + entry->ifname, + __this_cpu_read(*net_dev->pcpu_refcnt)); + } else { + entry->active = false; + entry->net_dev = NULL; + IF_DEBUG("qtaguid: %s(%s): " + "disable tracking. rfcnt=%d\n", __func__, + entry->ifname, + __this_cpu_read(*net_dev->pcpu_refcnt)); + + } +} + +/* Caller must hold iface_stat_list_lock */ +static struct iface_stat *iface_alloc(struct net_device *net_dev) +{ + struct iface_stat *new_iface; + struct iface_stat_work *isw; + + new_iface = kzalloc(sizeof(*new_iface), GFP_ATOMIC); + if (new_iface == NULL) { + pr_err("qtaguid: iface_stat: create(%s): " + "iface_stat alloc failed\n", net_dev->name); + return NULL; + } + new_iface->ifname = kstrdup(net_dev->name, GFP_ATOMIC); + if (new_iface->ifname == NULL) { + pr_err("qtaguid: iface_stat: create(%s): " + "ifname alloc failed\n", net_dev->name); + kfree(new_iface); + return NULL; + } + spin_lock_init(&new_iface->tag_stat_list_lock); + new_iface->tag_stat_tree = RB_ROOT; + _iface_stat_set_active(new_iface, net_dev, true); + + /* + * ipv6 notifier chains are atomic :( + * No create_proc_read_entry() for you! + */ + isw = kmalloc(sizeof(*isw), GFP_ATOMIC); + if (!isw) { + pr_err("qtaguid: iface_stat: create(%s): " + "work alloc failed\n", new_iface->ifname); + _iface_stat_set_active(new_iface, net_dev, false); + kfree(new_iface->ifname); + kfree(new_iface); + return NULL; + } + isw->iface_entry = new_iface; + INIT_WORK(&isw->iface_work, iface_create_proc_worker); + schedule_work(&isw->iface_work); + list_add(&new_iface->list, &iface_stat_list); + return new_iface; +} + +static void iface_check_stats_reset_and_adjust(struct net_device *net_dev, + struct iface_stat *iface) +{ + struct rtnl_link_stats64 dev_stats, *stats; + bool stats_rewound; + + stats = dev_get_stats(net_dev, &dev_stats); + /* No empty packets */ + stats_rewound = + (stats->rx_bytes < iface->last_known[IFS_RX].bytes) + || (stats->tx_bytes < iface->last_known[IFS_TX].bytes); + + IF_DEBUG("qtaguid: %s(%s): iface=%p netdev=%p " + "bytes rx/tx=%llu/%llu " + "active=%d last_known=%d " + "stats_rewound=%d\n", __func__, + net_dev ? net_dev->name : "?", + iface, net_dev, + stats->rx_bytes, stats->tx_bytes, + iface->active, iface->last_known_valid, stats_rewound); + + if (iface->active && iface->last_known_valid && stats_rewound) { + pr_warn_once("qtaguid: iface_stat: %s(%s): " + "iface reset its stats unexpectedly\n", __func__, + net_dev->name); + + iface->totals_via_dev[IFS_TX].bytes += + iface->last_known[IFS_TX].bytes; + iface->totals_via_dev[IFS_TX].packets += + iface->last_known[IFS_TX].packets; + iface->totals_via_dev[IFS_RX].bytes += + iface->last_known[IFS_RX].bytes; + iface->totals_via_dev[IFS_RX].packets += + iface->last_known[IFS_RX].packets; + iface->last_known_valid = false; + IF_DEBUG("qtaguid: %s(%s): iface=%p " + "used last known bytes rx/tx=%llu/%llu\n", __func__, + iface->ifname, iface, iface->last_known[IFS_RX].bytes, + iface->last_known[IFS_TX].bytes); + } +} + +/* + * Create a new entry for tracking the specified interface. + * Do nothing if the entry already exists. + * Called when an interface is configured with a valid IP address. + */ +static void iface_stat_create(struct net_device *net_dev, + struct in_ifaddr *ifa) +{ + struct in_device *in_dev = NULL; + const char *ifname; + struct iface_stat *entry; + __be32 ipaddr = 0; + struct iface_stat *new_iface; + + IF_DEBUG("qtaguid: iface_stat: create(%s): ifa=%p netdev=%p\n", + net_dev ? net_dev->name : "?", + ifa, net_dev); + if (!net_dev) { + pr_err("qtaguid: iface_stat: create(): no net dev\n"); + return; + } + + ifname = net_dev->name; + if (!ifa) { + in_dev = in_dev_get(net_dev); + if (!in_dev) { + pr_err("qtaguid: iface_stat: create(%s): no inet dev\n", + ifname); + return; + } + IF_DEBUG("qtaguid: iface_stat: create(%s): in_dev=%p\n", + ifname, in_dev); + for (ifa = in_dev->ifa_list; ifa; ifa = ifa->ifa_next) { + IF_DEBUG("qtaguid: iface_stat: create(%s): " + "ifa=%p ifa_label=%s\n", + ifname, ifa, + ifa->ifa_label); + if (!strcmp(ifname, ifa->ifa_label)) + break; + } + } + + if (!ifa) { + IF_DEBUG("qtaguid: iface_stat: create(%s): no matching IP\n", + ifname); + goto done_put; + } + ipaddr = ifa->ifa_local; + + spin_lock_bh(&iface_stat_list_lock); + entry = get_iface_entry(ifname); + if (entry != NULL) { + IF_DEBUG("qtaguid: iface_stat: create(%s): entry=%p\n", + ifname, entry); + iface_check_stats_reset_and_adjust(net_dev, entry); + _iface_stat_set_active(entry, net_dev, true); + IF_DEBUG("qtaguid: %s(%s): " + "tracking now %d on ip=%pI4\n", __func__, + entry->ifname, true, &ipaddr); + goto done_unlock_put; + } + + new_iface = iface_alloc(net_dev); + IF_DEBUG("qtaguid: iface_stat: create(%s): done " + "entry=%p ip=%pI4\n", ifname, new_iface, &ipaddr); +done_unlock_put: + spin_unlock_bh(&iface_stat_list_lock); +done_put: + if (in_dev) + in_dev_put(in_dev); +} + +static void iface_stat_create_ipv6(struct net_device *net_dev, + struct inet6_ifaddr *ifa) +{ + struct in_device *in_dev; + const char *ifname; + struct iface_stat *entry; + struct iface_stat *new_iface; + int addr_type; + + IF_DEBUG("qtaguid: iface_stat: create6(): ifa=%p netdev=%p->name=%s\n", + ifa, net_dev, net_dev ? net_dev->name : ""); + if (!net_dev) { + pr_err("qtaguid: iface_stat: create6(): no net dev!\n"); + return; + } + ifname = net_dev->name; + + in_dev = in_dev_get(net_dev); + if (!in_dev) { + pr_err("qtaguid: iface_stat: create6(%s): no inet dev\n", + ifname); + return; + } + + IF_DEBUG("qtaguid: iface_stat: create6(%s): in_dev=%p\n", + ifname, in_dev); + + if (!ifa) { + IF_DEBUG("qtaguid: iface_stat: create6(%s): no matching IP\n", + ifname); + goto done_put; + } + addr_type = ipv6_addr_type(&ifa->addr); + + spin_lock_bh(&iface_stat_list_lock); + entry = get_iface_entry(ifname); + if (entry != NULL) { + IF_DEBUG("qtaguid: %s(%s): entry=%p\n", __func__, + ifname, entry); + iface_check_stats_reset_and_adjust(net_dev, entry); + _iface_stat_set_active(entry, net_dev, true); + IF_DEBUG("qtaguid: %s(%s): " + "tracking now %d on ip=%pI6c\n", __func__, + entry->ifname, true, &ifa->addr); + goto done_unlock_put; + } + + new_iface = iface_alloc(net_dev); + IF_DEBUG("qtaguid: iface_stat: create6(%s): done " + "entry=%p ip=%pI6c\n", ifname, new_iface, &ifa->addr); + +done_unlock_put: + spin_unlock_bh(&iface_stat_list_lock); +done_put: + in_dev_put(in_dev); +} + +static struct sock_tag *get_sock_stat_nl(const struct sock *sk) +{ + MT_DEBUG("qtaguid: get_sock_stat_nl(sk=%p)\n", sk); + return sock_tag_tree_search(&sock_tag_tree, sk); +} + +static int ipx_proto(const struct sk_buff *skb, + struct xt_action_param *par) +{ + int thoff = 0, tproto; + + switch (par->family) { + case NFPROTO_IPV6: + tproto = ipv6_find_hdr(skb, &thoff, -1, NULL, NULL); + if (tproto < 0) + MT_DEBUG("%s(): transport header not found in ipv6" + " skb=%p\n", __func__, skb); + break; + case NFPROTO_IPV4: + tproto = ip_hdr(skb)->protocol; + break; + default: + tproto = IPPROTO_RAW; + } + return tproto; +} + +static void +data_counters_update(struct data_counters *dc, int set, + enum ifs_tx_rx direction, int proto, int bytes) +{ + switch (proto) { + case IPPROTO_TCP: + dc_add_byte_packets(dc, set, direction, IFS_TCP, bytes, 1); + break; + case IPPROTO_UDP: + dc_add_byte_packets(dc, set, direction, IFS_UDP, bytes, 1); + break; + case IPPROTO_IP: + default: + dc_add_byte_packets(dc, set, direction, IFS_PROTO_OTHER, bytes, + 1); + break; + } +} + +/* + * Update stats for the specified interface. Do nothing if the entry + * does not exist (when a device was never configured with an IP address). + * Called when an device is being unregistered. + */ +static void iface_stat_update(struct net_device *net_dev, bool stash_only) +{ + struct rtnl_link_stats64 dev_stats, *stats; + struct iface_stat *entry; + + stats = dev_get_stats(net_dev, &dev_stats); + spin_lock_bh(&iface_stat_list_lock); + entry = get_iface_entry(net_dev->name); + if (entry == NULL) { + IF_DEBUG("qtaguid: iface_stat: update(%s): not tracked\n", + net_dev->name); + spin_unlock_bh(&iface_stat_list_lock); + return; + } + + IF_DEBUG("qtaguid: %s(%s): entry=%p\n", __func__, + net_dev->name, entry); + if (!entry->active) { + IF_DEBUG("qtaguid: %s(%s): already disabled\n", __func__, + net_dev->name); + spin_unlock_bh(&iface_stat_list_lock); + return; + } + + if (stash_only) { + entry->last_known[IFS_TX].bytes = stats->tx_bytes; + entry->last_known[IFS_TX].packets = stats->tx_packets; + entry->last_known[IFS_RX].bytes = stats->rx_bytes; + entry->last_known[IFS_RX].packets = stats->rx_packets; + entry->last_known_valid = true; + IF_DEBUG("qtaguid: %s(%s): " + "dev stats stashed rx/tx=%llu/%llu\n", __func__, + net_dev->name, stats->rx_bytes, stats->tx_bytes); + spin_unlock_bh(&iface_stat_list_lock); + return; + } + entry->totals_via_dev[IFS_TX].bytes += stats->tx_bytes; + entry->totals_via_dev[IFS_TX].packets += stats->tx_packets; + entry->totals_via_dev[IFS_RX].bytes += stats->rx_bytes; + entry->totals_via_dev[IFS_RX].packets += stats->rx_packets; + /* We don't need the last_known[] anymore */ + entry->last_known_valid = false; + _iface_stat_set_active(entry, net_dev, false); + IF_DEBUG("qtaguid: %s(%s): " + "disable tracking. rx/tx=%llu/%llu\n", __func__, + net_dev->name, stats->rx_bytes, stats->tx_bytes); + spin_unlock_bh(&iface_stat_list_lock); +} + +/* Guarantied to return a net_device that has a name */ +static void get_dev_and_dir(const struct sk_buff *skb, + struct xt_action_param *par, + enum ifs_tx_rx *direction, + const struct net_device **el_dev) +{ + BUG_ON(!direction || !el_dev); + + if (par->in) { + *el_dev = par->in; + *direction = IFS_RX; + } else if (par->out) { + *el_dev = par->out; + *direction = IFS_TX; + } else { + pr_err("qtaguid[%d]: %s(): no par->in/out?!!\n", + par->hooknum, __func__); + BUG(); + } + if (skb->dev && *el_dev != skb->dev) { + MT_DEBUG("qtaguid[%d]: skb->dev=%p %s vs par->%s=%p %s\n", + par->hooknum, skb->dev, skb->dev->name, + *direction == IFS_RX ? "in" : "out", *el_dev, + (*el_dev)->name); + } +} + +/* + * Update stats for the specified interface from the skb. + * Do nothing if the entry + * does not exist (when a device was never configured with an IP address). + * Called on each sk. + */ +static void iface_stat_update_from_skb(const struct sk_buff *skb, + struct xt_action_param *par) +{ + struct iface_stat *entry; + const struct net_device *el_dev; + enum ifs_tx_rx direction; + int bytes = skb->len; + int proto; + + get_dev_and_dir(skb, par, &direction, &el_dev); + proto = ipx_proto(skb, par); + MT_DEBUG("qtaguid[%d]: iface_stat: %s(%s): " + "type=%d fam=%d proto=%d dir=%d\n", + par->hooknum, __func__, el_dev->name, el_dev->type, + par->family, proto, direction); + + spin_lock_bh(&iface_stat_list_lock); + entry = get_iface_entry(el_dev->name); + if (entry == NULL) { + IF_DEBUG("qtaguid[%d]: iface_stat: %s(%s): not tracked\n", + par->hooknum, __func__, el_dev->name); + spin_unlock_bh(&iface_stat_list_lock); + return; + } + + IF_DEBUG("qtaguid[%d]: %s(%s): entry=%p\n", par->hooknum, __func__, + el_dev->name, entry); + + data_counters_update(&entry->totals_via_skb, 0, direction, proto, + bytes); + spin_unlock_bh(&iface_stat_list_lock); +} + +static void tag_stat_update(struct tag_stat *tag_entry, + enum ifs_tx_rx direction, int proto, int bytes) +{ + int active_set; + active_set = get_active_counter_set(tag_entry->tn.tag); + MT_DEBUG("qtaguid: tag_stat_update(tag=0x%llx (uid=%u) set=%d " + "dir=%d proto=%d bytes=%d)\n", + tag_entry->tn.tag, get_uid_from_tag(tag_entry->tn.tag), + active_set, direction, proto, bytes); + data_counters_update(&tag_entry->counters, active_set, direction, + proto, bytes); + if (tag_entry->parent_counters) + data_counters_update(tag_entry->parent_counters, active_set, + direction, proto, bytes); +} + +/* + * Create a new entry for tracking the specified {acct_tag,uid_tag} within + * the interface. + * iface_entry->tag_stat_list_lock should be held. + */ +static struct tag_stat *create_if_tag_stat(struct iface_stat *iface_entry, + tag_t tag) +{ + struct tag_stat *new_tag_stat_entry = NULL; + IF_DEBUG("qtaguid: iface_stat: %s(): ife=%p tag=0x%llx" + " (uid=%u)\n", __func__, + iface_entry, tag, get_uid_from_tag(tag)); + new_tag_stat_entry = kzalloc(sizeof(*new_tag_stat_entry), GFP_ATOMIC); + if (!new_tag_stat_entry) { + pr_err("qtaguid: iface_stat: tag stat alloc failed\n"); + goto done; + } + new_tag_stat_entry->tn.tag = tag; + tag_stat_tree_insert(new_tag_stat_entry, &iface_entry->tag_stat_tree); +done: + return new_tag_stat_entry; +} + +static void if_tag_stat_update(const char *ifname, uid_t uid, + const struct sock *sk, enum ifs_tx_rx direction, + int proto, int bytes) +{ + struct tag_stat *tag_stat_entry; + tag_t tag, acct_tag; + tag_t uid_tag; + struct data_counters *uid_tag_counters; + struct sock_tag *sock_tag_entry; + struct iface_stat *iface_entry; + struct tag_stat *new_tag_stat = NULL; + MT_DEBUG("qtaguid: if_tag_stat_update(ifname=%s " + "uid=%u sk=%p dir=%d proto=%d bytes=%d)\n", + ifname, uid, sk, direction, proto, bytes); + + spin_lock_bh(&iface_stat_list_lock); + iface_entry = get_iface_entry(ifname); + if (!iface_entry) { + pr_err_ratelimited("qtaguid: tag_stat: stat_update() " + "%s not found\n", ifname); + spin_unlock_bh(&iface_stat_list_lock); + return; + } + /* It is ok to process data when an iface_entry is inactive */ + + MT_DEBUG("qtaguid: tag_stat: stat_update() dev=%s entry=%p\n", + ifname, iface_entry); + + /* + * Look for a tagged sock. + * It will have an acct_uid. + */ + spin_lock_bh(&sock_tag_list_lock); + sock_tag_entry = sk ? get_sock_stat_nl(sk) : NULL; + if (sock_tag_entry) { + tag = sock_tag_entry->tag; + acct_tag = get_atag_from_tag(tag); + uid_tag = get_utag_from_tag(tag); + } + spin_unlock_bh(&sock_tag_list_lock); + if (!sock_tag_entry) { + acct_tag = make_atag_from_value(0); + tag = combine_atag_with_uid(acct_tag, uid); + uid_tag = make_tag_from_uid(uid); + } + MT_DEBUG("qtaguid: tag_stat: stat_update(): " + " looking for tag=0x%llx (uid=%u) in ife=%p\n", + tag, get_uid_from_tag(tag), iface_entry); + /* Loop over tag list under this interface for {acct_tag,uid_tag} */ + spin_lock_bh(&iface_entry->tag_stat_list_lock); + + tag_stat_entry = tag_stat_tree_search(&iface_entry->tag_stat_tree, + tag); + if (tag_stat_entry) { + /* + * Updating the {acct_tag, uid_tag} entry handles both stats: + * {0, uid_tag} will also get updated. + */ + tag_stat_update(tag_stat_entry, direction, proto, bytes); + goto unlock; + } + + /* Loop over tag list under this interface for {0,uid_tag} */ + tag_stat_entry = tag_stat_tree_search(&iface_entry->tag_stat_tree, + uid_tag); + if (!tag_stat_entry) { + /* Here: the base uid_tag did not exist */ + /* + * No parent counters. So + * - No {0, uid_tag} stats and no {acc_tag, uid_tag} stats. + */ + new_tag_stat = create_if_tag_stat(iface_entry, uid_tag); + if (!new_tag_stat) + goto unlock; + uid_tag_counters = &new_tag_stat->counters; + } else { + uid_tag_counters = &tag_stat_entry->counters; + } + + if (acct_tag) { + /* Create the child {acct_tag, uid_tag} and hook up parent. */ + new_tag_stat = create_if_tag_stat(iface_entry, tag); + if (!new_tag_stat) + goto unlock; + new_tag_stat->parent_counters = uid_tag_counters; + } else { + /* + * For new_tag_stat to be still NULL here would require: + * {0, uid_tag} exists + * and {acct_tag, uid_tag} doesn't exist + * AND acct_tag == 0. + * Impossible. This reassures us that new_tag_stat + * below will always be assigned. + */ + BUG_ON(!new_tag_stat); + } + tag_stat_update(new_tag_stat, direction, proto, bytes); +unlock: + spin_unlock_bh(&iface_entry->tag_stat_list_lock); + spin_unlock_bh(&iface_stat_list_lock); +} + +static int iface_netdev_event_handler(struct notifier_block *nb, + unsigned long event, void *ptr) { + struct net_device *dev = netdev_notifier_info_to_dev(ptr); + + if (unlikely(module_passive)) + return NOTIFY_DONE; + + IF_DEBUG("qtaguid: iface_stat: netdev_event(): " + "ev=0x%lx/%s netdev=%p->name=%s\n", + event, netdev_evt_str(event), dev, dev ? dev->name : ""); + + switch (event) { + case NETDEV_UP: + iface_stat_create(dev, NULL); + atomic64_inc(&qtu_events.iface_events); + break; + case NETDEV_DOWN: + case NETDEV_UNREGISTER: + iface_stat_update(dev, event == NETDEV_DOWN); + atomic64_inc(&qtu_events.iface_events); + break; + } + return NOTIFY_DONE; +} + +static int iface_inet6addr_event_handler(struct notifier_block *nb, + unsigned long event, void *ptr) +{ + struct inet6_ifaddr *ifa = ptr; + struct net_device *dev; + + if (unlikely(module_passive)) + return NOTIFY_DONE; + + IF_DEBUG("qtaguid: iface_stat: inet6addr_event(): " + "ev=0x%lx/%s ifa=%p\n", + event, netdev_evt_str(event), ifa); + + switch (event) { + case NETDEV_UP: + BUG_ON(!ifa || !ifa->idev); + dev = (struct net_device *)ifa->idev->dev; + iface_stat_create_ipv6(dev, ifa); + atomic64_inc(&qtu_events.iface_events); + break; + case NETDEV_DOWN: + case NETDEV_UNREGISTER: + BUG_ON(!ifa || !ifa->idev); + dev = (struct net_device *)ifa->idev->dev; + iface_stat_update(dev, event == NETDEV_DOWN); + atomic64_inc(&qtu_events.iface_events); + break; + } + return NOTIFY_DONE; +} + +static int iface_inetaddr_event_handler(struct notifier_block *nb, + unsigned long event, void *ptr) +{ + struct in_ifaddr *ifa = ptr; + struct net_device *dev; + + if (unlikely(module_passive)) + return NOTIFY_DONE; + + IF_DEBUG("qtaguid: iface_stat: inetaddr_event(): " + "ev=0x%lx/%s ifa=%p\n", + event, netdev_evt_str(event), ifa); + + switch (event) { + case NETDEV_UP: + BUG_ON(!ifa || !ifa->ifa_dev); + dev = ifa->ifa_dev->dev; + iface_stat_create(dev, ifa); + atomic64_inc(&qtu_events.iface_events); + break; + case NETDEV_DOWN: + case NETDEV_UNREGISTER: + BUG_ON(!ifa || !ifa->ifa_dev); + dev = ifa->ifa_dev->dev; + iface_stat_update(dev, event == NETDEV_DOWN); + atomic64_inc(&qtu_events.iface_events); + break; + } + return NOTIFY_DONE; +} + +static struct notifier_block iface_netdev_notifier_blk = { + .notifier_call = iface_netdev_event_handler, +}; + +static struct notifier_block iface_inetaddr_notifier_blk = { + .notifier_call = iface_inetaddr_event_handler, +}; + +static struct notifier_block iface_inet6addr_notifier_blk = { + .notifier_call = iface_inet6addr_event_handler, +}; + +static const struct seq_operations iface_stat_fmt_proc_seq_ops = { + .start = iface_stat_fmt_proc_start, + .next = iface_stat_fmt_proc_next, + .stop = iface_stat_fmt_proc_stop, + .show = iface_stat_fmt_proc_show, +}; + +static int proc_iface_stat_fmt_open(struct inode *inode, struct file *file) +{ + struct proc_iface_stat_fmt_info *s; + + s = __seq_open_private(file, &iface_stat_fmt_proc_seq_ops, + sizeof(struct proc_iface_stat_fmt_info)); + if (!s) + return -ENOMEM; + + s->fmt = (uintptr_t)PDE_DATA(inode); + return 0; +} + +static const struct file_operations proc_iface_stat_fmt_fops = { + .open = proc_iface_stat_fmt_open, + .read = seq_read, + .llseek = seq_lseek, + .release = seq_release_private, +}; + +static int __init iface_stat_init(struct proc_dir_entry *parent_procdir) +{ + int err; + + iface_stat_procdir = proc_mkdir(iface_stat_procdirname, parent_procdir); + if (!iface_stat_procdir) { + pr_err("qtaguid: iface_stat: init failed to create proc entry\n"); + err = -1; + goto err; + } + + iface_stat_all_procfile = proc_create_data(iface_stat_all_procfilename, + proc_iface_perms, + parent_procdir, + &proc_iface_stat_fmt_fops, + (void *)1 /* fmt1 */); + if (!iface_stat_all_procfile) { + pr_err("qtaguid: iface_stat: init " + " failed to create stat_old proc entry\n"); + err = -1; + goto err_zap_entry; + } + + iface_stat_fmt_procfile = proc_create_data(iface_stat_fmt_procfilename, + proc_iface_perms, + parent_procdir, + &proc_iface_stat_fmt_fops, + (void *)2 /* fmt2 */); + if (!iface_stat_fmt_procfile) { + pr_err("qtaguid: iface_stat: init " + " failed to create stat_all proc entry\n"); + err = -1; + goto err_zap_all_stats_entry; + } + + + err = register_netdevice_notifier(&iface_netdev_notifier_blk); + if (err) { + pr_err("qtaguid: iface_stat: init " + "failed to register dev event handler\n"); + goto err_zap_all_stats_entries; + } + err = register_inetaddr_notifier(&iface_inetaddr_notifier_blk); + if (err) { + pr_err("qtaguid: iface_stat: init " + "failed to register ipv4 dev event handler\n"); + goto err_unreg_nd; + } + + err = register_inet6addr_notifier(&iface_inet6addr_notifier_blk); + if (err) { + pr_err("qtaguid: iface_stat: init " + "failed to register ipv6 dev event handler\n"); + goto err_unreg_ip4_addr; + } + return 0; + +err_unreg_ip4_addr: + unregister_inetaddr_notifier(&iface_inetaddr_notifier_blk); +err_unreg_nd: + unregister_netdevice_notifier(&iface_netdev_notifier_blk); +err_zap_all_stats_entries: + remove_proc_entry(iface_stat_fmt_procfilename, parent_procdir); +err_zap_all_stats_entry: + remove_proc_entry(iface_stat_all_procfilename, parent_procdir); +err_zap_entry: + remove_proc_entry(iface_stat_procdirname, parent_procdir); +err: + return err; +} + +static struct sock *qtaguid_find_sk(const struct sk_buff *skb, + struct xt_action_param *par) +{ + struct sock *sk; + unsigned int hook_mask = (1 << par->hooknum); + + MT_DEBUG("qtaguid[%d]: find_sk(skb=%p) family=%d\n", + par->hooknum, skb, par->family); + + /* + * Let's not abuse the the xt_socket_get*_sk(), or else it will + * return garbage SKs. + */ + if (!(hook_mask & XT_SOCKET_SUPPORTED_HOOKS)) + return NULL; + + switch (par->family) { + case NFPROTO_IPV6: + sk = xt_socket_lookup_slow_v6(dev_net(skb->dev), skb, par->in); + break; + case NFPROTO_IPV4: + sk = xt_socket_lookup_slow_v4(dev_net(skb->dev), skb, par->in); + break; + default: + return NULL; + } + + if (sk) { + MT_DEBUG("qtaguid[%d]: %p->sk_proto=%u->sk_state=%d\n", + par->hooknum, sk, sk->sk_protocol, sk->sk_state); + /* + * When in TCP_TIME_WAIT the sk is not a "struct sock" but + * "struct inet_timewait_sock" which is missing fields. + */ + if (!sk_fullsock(sk) || sk->sk_state == TCP_TIME_WAIT) { + sock_gen_put(sk); + sk = NULL; + } + } + return sk; +} + +static void account_for_uid(const struct sk_buff *skb, + const struct sock *alternate_sk, uid_t uid, + struct xt_action_param *par) +{ + const struct net_device *el_dev; + enum ifs_tx_rx direction; + int proto; + + get_dev_and_dir(skb, par, &direction, &el_dev); + proto = ipx_proto(skb, par); + MT_DEBUG("qtaguid[%d]: dev name=%s type=%d fam=%d proto=%d dir=%d\n", + par->hooknum, el_dev->name, el_dev->type, + par->family, proto, direction); + + if_tag_stat_update(el_dev->name, uid, + skb->sk ? skb->sk : alternate_sk, + direction, + proto, skb->len); +} + +static bool qtaguid_mt(const struct sk_buff *skb, struct xt_action_param *par) +{ + const struct xt_qtaguid_match_info *info = par->matchinfo; + const struct file *filp; + bool got_sock = false; + struct sock *sk; + kuid_t sock_uid; + bool res; + bool set_sk_callback_lock = false; + /* + * TODO: unhack how to force just accounting. + * For now we only do tag stats when the uid-owner is not requested + */ + bool do_tag_stat = !(info->match & XT_QTAGUID_UID); + + if (unlikely(module_passive)) + return (info->match ^ info->invert) == 0; + + MT_DEBUG("qtaguid[%d]: entered skb=%p par->in=%p/out=%p fam=%d\n", + par->hooknum, skb, par->in, par->out, par->family); + + atomic64_inc(&qtu_events.match_calls); + if (skb == NULL) { + res = (info->match ^ info->invert) == 0; + goto ret_res; + } + + switch (par->hooknum) { + case NF_INET_PRE_ROUTING: + case NF_INET_POST_ROUTING: + atomic64_inc(&qtu_events.match_calls_prepost); + iface_stat_update_from_skb(skb, par); + /* + * We are done in pre/post. The skb will get processed + * further alter. + */ + res = (info->match ^ info->invert); + goto ret_res; + break; + /* default: Fall through and do UID releated work */ + } + + sk = skb_to_full_sk(skb); + /* + * When in TCP_TIME_WAIT the sk is not a "struct sock" but + * "struct inet_timewait_sock" which is missing fields. + * So we ignore it. + */ + if (sk && sk->sk_state == TCP_TIME_WAIT) + sk = NULL; + if (sk == NULL) { + /* + * A missing sk->sk_socket happens when packets are in-flight + * and the matching socket is already closed and gone. + */ + sk = qtaguid_find_sk(skb, par); + /* + * If we got the socket from the find_sk(), we will need to put + * it back, as nf_tproxy_get_sock_v4() got it. + */ + got_sock = sk; + if (sk) + atomic64_inc(&qtu_events.match_found_sk_in_ct); + else + atomic64_inc(&qtu_events.match_found_no_sk_in_ct); + } else { + atomic64_inc(&qtu_events.match_found_sk); + } + MT_DEBUG("qtaguid[%d]: sk=%p got_sock=%d fam=%d proto=%d\n", + par->hooknum, sk, got_sock, par->family, ipx_proto(skb, par)); + + if (!sk) { + /* + * Here, the qtaguid_find_sk() using connection tracking + * couldn't find the owner, so for now we just count them + * against the system. + */ + if (do_tag_stat) + account_for_uid(skb, sk, 0, par); + MT_DEBUG("qtaguid[%d]: leaving (sk=NULL)\n", par->hooknum); + res = (info->match ^ info->invert) == 0; + atomic64_inc(&qtu_events.match_no_sk); + goto put_sock_ret_res; + } else if (info->match & info->invert & XT_QTAGUID_SOCKET) { + res = false; + goto put_sock_ret_res; + } + sock_uid = sk->sk_uid; + if (do_tag_stat) + account_for_uid(skb, sk, from_kuid(&init_user_ns, sock_uid), + par); + + /* + * The following two tests fail the match when: + * id not in range AND no inverted condition requested + * or id in range AND inverted condition requested + * Thus (!a && b) || (a && !b) == a ^ b + */ + if (info->match & XT_QTAGUID_UID) { + kuid_t uid_min = make_kuid(&init_user_ns, info->uid_min); + kuid_t uid_max = make_kuid(&init_user_ns, info->uid_max); + + if ((uid_gte(sock_uid, uid_min) && + uid_lte(sock_uid, uid_max)) ^ + !(info->invert & XT_QTAGUID_UID)) { + MT_DEBUG("qtaguid[%d]: leaving uid not matching\n", + par->hooknum); + res = false; + goto put_sock_ret_res; + } + } + if (info->match & XT_QTAGUID_GID) { + kgid_t gid_min = make_kgid(&init_user_ns, info->gid_min); + kgid_t gid_max = make_kgid(&init_user_ns, info->gid_max); + set_sk_callback_lock = true; + read_lock_bh(&sk->sk_callback_lock); + MT_DEBUG("qtaguid[%d]: sk=%p->sk_socket=%p->file=%p\n", + par->hooknum, sk, sk->sk_socket, + sk->sk_socket ? sk->sk_socket->file : (void *)-1LL); + filp = sk->sk_socket ? sk->sk_socket->file : NULL; + if (!filp) { + res = ((info->match ^ info->invert) & + XT_QTAGUID_GID) == 0; + atomic64_inc(&qtu_events.match_no_sk_gid); + goto put_sock_ret_res; + } + MT_DEBUG("qtaguid[%d]: filp...uid=%u\n", + par->hooknum, filp ? + from_kuid(&init_user_ns, filp->f_cred->fsuid) : -1); + if ((gid_gte(filp->f_cred->fsgid, gid_min) && + gid_lte(filp->f_cred->fsgid, gid_max)) ^ + !(info->invert & XT_QTAGUID_GID)) { + MT_DEBUG("qtaguid[%d]: leaving gid not matching\n", + par->hooknum); + res = false; + goto put_sock_ret_res; + } + } + MT_DEBUG("qtaguid[%d]: leaving matched\n", par->hooknum); + res = true; + +put_sock_ret_res: + if (got_sock) + sock_gen_put(sk); + if (set_sk_callback_lock) + read_unlock_bh(&sk->sk_callback_lock); +ret_res: + MT_DEBUG("qtaguid[%d]: left %d\n", par->hooknum, res); + return res; +} + +#ifdef DDEBUG +/* + * This function is not in xt_qtaguid_print.c because of locks visibility. + * The lock of sock_tag_list must be aquired before calling this function + */ +static void prdebug_full_state_locked(int indent_level, const char *fmt, ...) +{ + va_list args; + char *fmt_buff; + char *buff; + + if (!unlikely(qtaguid_debug_mask & DDEBUG_MASK)) + return; + + fmt_buff = kasprintf(GFP_ATOMIC, + "qtaguid: %s(): %s {\n", __func__, fmt); + BUG_ON(!fmt_buff); + va_start(args, fmt); + buff = kvasprintf(GFP_ATOMIC, + fmt_buff, args); + BUG_ON(!buff); + pr_debug("%s", buff); + kfree(fmt_buff); + kfree(buff); + va_end(args); + + prdebug_sock_tag_tree(indent_level, &sock_tag_tree); + + spin_lock_bh(&uid_tag_data_tree_lock); + prdebug_uid_tag_data_tree(indent_level, &uid_tag_data_tree); + prdebug_proc_qtu_data_tree(indent_level, &proc_qtu_data_tree); + spin_unlock_bh(&uid_tag_data_tree_lock); + + spin_lock_bh(&iface_stat_list_lock); + prdebug_iface_stat_list(indent_level, &iface_stat_list); + spin_unlock_bh(&iface_stat_list_lock); + + pr_debug("qtaguid: %s(): }\n", __func__); +} +#else +static void prdebug_full_state_locked(int indent_level, const char *fmt, ...) {} +#endif + +struct proc_ctrl_print_info { + struct sock *sk; /* socket found by reading to sk_pos */ + loff_t sk_pos; +}; + +static void *qtaguid_ctrl_proc_next(struct seq_file *m, void *v, loff_t *pos) +{ + struct proc_ctrl_print_info *pcpi = m->private; + struct sock_tag *sock_tag_entry = v; + struct rb_node *node; + + (*pos)++; + + if (!v || v == SEQ_START_TOKEN) + return NULL; + + node = rb_next(&sock_tag_entry->sock_node); + if (!node) { + pcpi->sk = NULL; + sock_tag_entry = SEQ_START_TOKEN; + } else { + sock_tag_entry = rb_entry(node, struct sock_tag, sock_node); + pcpi->sk = sock_tag_entry->sk; + } + pcpi->sk_pos = *pos; + return sock_tag_entry; +} + +static void *qtaguid_ctrl_proc_start(struct seq_file *m, loff_t *pos) +{ + struct proc_ctrl_print_info *pcpi = m->private; + struct sock_tag *sock_tag_entry; + struct rb_node *node; + + spin_lock_bh(&sock_tag_list_lock); + + if (unlikely(module_passive)) + return NULL; + + if (*pos == 0) { + pcpi->sk_pos = 0; + node = rb_first(&sock_tag_tree); + if (!node) { + pcpi->sk = NULL; + return SEQ_START_TOKEN; + } + sock_tag_entry = rb_entry(node, struct sock_tag, sock_node); + pcpi->sk = sock_tag_entry->sk; + } else { + sock_tag_entry = (pcpi->sk ? get_sock_stat_nl(pcpi->sk) : + NULL) ?: SEQ_START_TOKEN; + if (*pos != pcpi->sk_pos) { + /* seq_read skipped a next call */ + *pos = pcpi->sk_pos; + return qtaguid_ctrl_proc_next(m, sock_tag_entry, pos); + } + } + return sock_tag_entry; +} + +static void qtaguid_ctrl_proc_stop(struct seq_file *m, void *v) +{ + spin_unlock_bh(&sock_tag_list_lock); +} + +/* + * Procfs reader to get all active socket tags using style "1)" as described in + * fs/proc/generic.c + */ +static int qtaguid_ctrl_proc_show(struct seq_file *m, void *v) +{ + struct sock_tag *sock_tag_entry = v; + uid_t uid; + + CT_DEBUG("qtaguid: proc ctrl pid=%u tgid=%u uid=%u\n", + current->pid, current->tgid, from_kuid(&init_user_ns, current_fsuid())); + + if (sock_tag_entry != SEQ_START_TOKEN) { + int sk_ref_count; + uid = get_uid_from_tag(sock_tag_entry->tag); + CT_DEBUG("qtaguid: proc_read(): sk=%p tag=0x%llx (uid=%u) " + "pid=%u\n", + sock_tag_entry->sk, + sock_tag_entry->tag, + uid, + sock_tag_entry->pid + ); + sk_ref_count = atomic_read( + &sock_tag_entry->sk->sk_refcnt); + seq_printf(m, "sock=%pK tag=0x%llx (uid=%u) pid=%u " + "f_count=%d\n", + sock_tag_entry->sk, + sock_tag_entry->tag, uid, + sock_tag_entry->pid, sk_ref_count); + } else { + seq_printf(m, "events: sockets_tagged=%llu " + "sockets_untagged=%llu " + "counter_set_changes=%llu " + "delete_cmds=%llu " + "iface_events=%llu " + "match_calls=%llu " + "match_calls_prepost=%llu " + "match_found_sk=%llu " + "match_found_sk_in_ct=%llu " + "match_found_no_sk_in_ct=%llu " + "match_no_sk=%llu " + "match_no_sk_gid=%llu\n", + (u64)atomic64_read(&qtu_events.sockets_tagged), + (u64)atomic64_read(&qtu_events.sockets_untagged), + (u64)atomic64_read(&qtu_events.counter_set_changes), + (u64)atomic64_read(&qtu_events.delete_cmds), + (u64)atomic64_read(&qtu_events.iface_events), + (u64)atomic64_read(&qtu_events.match_calls), + (u64)atomic64_read(&qtu_events.match_calls_prepost), + (u64)atomic64_read(&qtu_events.match_found_sk), + (u64)atomic64_read(&qtu_events.match_found_sk_in_ct), + (u64)atomic64_read(&qtu_events.match_found_no_sk_in_ct), + (u64)atomic64_read(&qtu_events.match_no_sk), + (u64)atomic64_read(&qtu_events.match_no_sk_gid)); + + /* Count the following as part of the last item_index. No need + * to lock the sock_tag_list here since it is already locked when + * starting the seq_file operation + */ + prdebug_full_state_locked(0, "proc ctrl"); + } + + return 0; +} + +/* + * Delete socket tags, and stat tags associated with a given + * accouting tag and uid. + */ +static int ctrl_cmd_delete(const char *input) +{ + char cmd; + int uid_int; + kuid_t uid; + uid_t entry_uid; + tag_t acct_tag; + tag_t tag; + int res, argc; + struct iface_stat *iface_entry; + struct rb_node *node; + struct sock_tag *st_entry; + struct rb_root st_to_free_tree = RB_ROOT; + struct tag_stat *ts_entry; + struct tag_counter_set *tcs_entry; + struct tag_ref *tr_entry; + struct uid_tag_data *utd_entry; + + argc = sscanf(input, "%c %llu %u", &cmd, &acct_tag, &uid_int); + uid = make_kuid(&init_user_ns, uid_int); + CT_DEBUG("qtaguid: ctrl_delete(%s): argc=%d cmd=%c " + "user_tag=0x%llx uid=%u\n", input, argc, cmd, + acct_tag, uid_int); + if (argc < 2) { + res = -EINVAL; + goto err; + } + if (!valid_atag(acct_tag)) { + pr_info("qtaguid: ctrl_delete(%s): invalid tag\n", input); + res = -EINVAL; + goto err; + } + if (argc < 3) { + uid = current_fsuid(); + uid_int = from_kuid(&init_user_ns, uid); + } else if (!can_impersonate_uid(uid)) { + pr_info("qtaguid: ctrl_delete(%s): " + "insufficient priv from pid=%u tgid=%u uid=%u\n", + input, current->pid, current->tgid, from_kuid(&init_user_ns, current_fsuid())); + res = -EPERM; + goto err; + } + + tag = combine_atag_with_uid(acct_tag, uid_int); + CT_DEBUG("qtaguid: ctrl_delete(%s): " + "looking for tag=0x%llx (uid=%u)\n", + input, tag, uid_int); + + /* Delete socket tags */ + spin_lock_bh(&sock_tag_list_lock); + spin_lock_bh(&uid_tag_data_tree_lock); + node = rb_first(&sock_tag_tree); + while (node) { + st_entry = rb_entry(node, struct sock_tag, sock_node); + entry_uid = get_uid_from_tag(st_entry->tag); + node = rb_next(node); + if (entry_uid != uid_int) + continue; + + CT_DEBUG("qtaguid: ctrl_delete(%s): st tag=0x%llx (uid=%u)\n", + input, st_entry->tag, entry_uid); + + if (!acct_tag || st_entry->tag == tag) { + rb_erase(&st_entry->sock_node, &sock_tag_tree); + /* Can't sockfd_put() within spinlock, do it later. */ + sock_tag_tree_insert(st_entry, &st_to_free_tree); + tr_entry = lookup_tag_ref(st_entry->tag, NULL); + BUG_ON(tr_entry->num_sock_tags <= 0); + tr_entry->num_sock_tags--; + /* + * TODO: remove if, and start failing. + * This is a hack to work around the fact that in some + * places we have "if (IS_ERR_OR_NULL(pqd_entry))" + * and are trying to work around apps + * that didn't open the /dev/xt_qtaguid. + */ + if (st_entry->list.next && st_entry->list.prev) + list_del(&st_entry->list); + } + } + spin_unlock_bh(&uid_tag_data_tree_lock); + spin_unlock_bh(&sock_tag_list_lock); + + sock_tag_tree_erase(&st_to_free_tree); + + /* Delete tag counter-sets */ + spin_lock_bh(&tag_counter_set_list_lock); + /* Counter sets are only on the uid tag, not full tag */ + tcs_entry = tag_counter_set_tree_search(&tag_counter_set_tree, tag); + if (tcs_entry) { + CT_DEBUG("qtaguid: ctrl_delete(%s): " + "erase tcs: tag=0x%llx (uid=%u) set=%d\n", + input, + tcs_entry->tn.tag, + get_uid_from_tag(tcs_entry->tn.tag), + tcs_entry->active_set); + rb_erase(&tcs_entry->tn.node, &tag_counter_set_tree); + kfree(tcs_entry); + } + spin_unlock_bh(&tag_counter_set_list_lock); + + /* + * If acct_tag is 0, then all entries belonging to uid are + * erased. + */ + spin_lock_bh(&iface_stat_list_lock); + list_for_each_entry(iface_entry, &iface_stat_list, list) { + spin_lock_bh(&iface_entry->tag_stat_list_lock); + node = rb_first(&iface_entry->tag_stat_tree); + while (node) { + ts_entry = rb_entry(node, struct tag_stat, tn.node); + entry_uid = get_uid_from_tag(ts_entry->tn.tag); + node = rb_next(node); + + CT_DEBUG("qtaguid: ctrl_delete(%s): " + "ts tag=0x%llx (uid=%u)\n", + input, ts_entry->tn.tag, entry_uid); + + if (entry_uid != uid_int) + continue; + if (!acct_tag || ts_entry->tn.tag == tag) { + CT_DEBUG("qtaguid: ctrl_delete(%s): " + "erase ts: %s 0x%llx %u\n", + input, iface_entry->ifname, + get_atag_from_tag(ts_entry->tn.tag), + entry_uid); + rb_erase(&ts_entry->tn.node, + &iface_entry->tag_stat_tree); + kfree(ts_entry); + } + } + spin_unlock_bh(&iface_entry->tag_stat_list_lock); + } + spin_unlock_bh(&iface_stat_list_lock); + + /* Cleanup the uid_tag_data */ + spin_lock_bh(&uid_tag_data_tree_lock); + node = rb_first(&uid_tag_data_tree); + while (node) { + utd_entry = rb_entry(node, struct uid_tag_data, node); + entry_uid = utd_entry->uid; + node = rb_next(node); + + CT_DEBUG("qtaguid: ctrl_delete(%s): " + "utd uid=%u\n", + input, entry_uid); + + if (entry_uid != uid_int) + continue; + /* + * Go over the tag_refs, and those that don't have + * sock_tags using them are freed. + */ + put_tag_ref_tree(tag, utd_entry); + put_utd_entry(utd_entry); + } + spin_unlock_bh(&uid_tag_data_tree_lock); + + atomic64_inc(&qtu_events.delete_cmds); + res = 0; + +err: + return res; +} + +static int ctrl_cmd_counter_set(const char *input) +{ + char cmd; + uid_t uid = 0; + tag_t tag; + int res, argc; + struct tag_counter_set *tcs; + int counter_set; + + argc = sscanf(input, "%c %d %u", &cmd, &counter_set, &uid); + CT_DEBUG("qtaguid: ctrl_counterset(%s): argc=%d cmd=%c " + "set=%d uid=%u\n", input, argc, cmd, + counter_set, uid); + if (argc != 3) { + res = -EINVAL; + goto err; + } + if (counter_set < 0 || counter_set >= IFS_MAX_COUNTER_SETS) { + pr_info("qtaguid: ctrl_counterset(%s): invalid counter_set range\n", + input); + res = -EINVAL; + goto err; + } + if (!can_manipulate_uids()) { + pr_info("qtaguid: ctrl_counterset(%s): " + "insufficient priv from pid=%u tgid=%u uid=%u\n", + input, current->pid, current->tgid, from_kuid(&init_user_ns, current_fsuid())); + res = -EPERM; + goto err; + } + + tag = make_tag_from_uid(uid); + spin_lock_bh(&tag_counter_set_list_lock); + tcs = tag_counter_set_tree_search(&tag_counter_set_tree, tag); + if (!tcs) { + tcs = kzalloc(sizeof(*tcs), GFP_ATOMIC); + if (!tcs) { + spin_unlock_bh(&tag_counter_set_list_lock); + pr_err("qtaguid: ctrl_counterset(%s): " + "failed to alloc counter set\n", + input); + res = -ENOMEM; + goto err; + } + tcs->tn.tag = tag; + tag_counter_set_tree_insert(tcs, &tag_counter_set_tree); + CT_DEBUG("qtaguid: ctrl_counterset(%s): added tcs tag=0x%llx " + "(uid=%u) set=%d\n", + input, tag, get_uid_from_tag(tag), counter_set); + } + tcs->active_set = counter_set; + spin_unlock_bh(&tag_counter_set_list_lock); + atomic64_inc(&qtu_events.counter_set_changes); + res = 0; + +err: + return res; +} + +static int ctrl_cmd_tag(const char *input) +{ + char cmd; + int sock_fd = 0; + kuid_t uid; + unsigned int uid_int = 0; + tag_t acct_tag = make_atag_from_value(0); + tag_t full_tag; + struct socket *el_socket; + int res, argc; + struct sock_tag *sock_tag_entry; + struct tag_ref *tag_ref_entry; + struct uid_tag_data *uid_tag_data_entry; + struct proc_qtu_data *pqd_entry; + + /* Unassigned args will get defaulted later. */ + argc = sscanf(input, "%c %d %llu %u", &cmd, &sock_fd, &acct_tag, &uid_int); + uid = make_kuid(&init_user_ns, uid_int); + CT_DEBUG("qtaguid: ctrl_tag(%s): argc=%d cmd=%c sock_fd=%d " + "acct_tag=0x%llx uid=%u\n", input, argc, cmd, sock_fd, + acct_tag, uid_int); + if (argc < 2) { + res = -EINVAL; + goto err; + } + el_socket = sockfd_lookup(sock_fd, &res); /* This locks the file */ + if (!el_socket) { + pr_info("qtaguid: ctrl_tag(%s): failed to lookup" + " sock_fd=%d err=%d pid=%u tgid=%u uid=%u\n", + input, sock_fd, res, current->pid, current->tgid, + from_kuid(&init_user_ns, current_fsuid())); + goto err; + } + CT_DEBUG("qtaguid: ctrl_tag(%s): socket->...->sk_refcnt=%d ->sk=%p\n", + input, atomic_read(&el_socket->sk->sk_refcnt), + el_socket->sk); + if (argc < 3) { + acct_tag = make_atag_from_value(0); + } else if (!valid_atag(acct_tag)) { + pr_info("qtaguid: ctrl_tag(%s): invalid tag\n", input); + res = -EINVAL; + goto err_put; + } + CT_DEBUG("qtaguid: ctrl_tag(%s): " + "pid=%u tgid=%u uid=%u euid=%u fsuid=%u " + "ctrl.gid=%u in_group()=%d in_egroup()=%d\n", + input, current->pid, current->tgid, + from_kuid(&init_user_ns, current_uid()), + from_kuid(&init_user_ns, current_euid()), + from_kuid(&init_user_ns, current_fsuid()), + from_kgid(&init_user_ns, xt_qtaguid_ctrl_file->gid), + in_group_p(xt_qtaguid_ctrl_file->gid), + in_egroup_p(xt_qtaguid_ctrl_file->gid)); + if (argc < 4) { + uid = current_fsuid(); + uid_int = from_kuid(&init_user_ns, uid); + } else if (!can_impersonate_uid(uid)) { + pr_info("qtaguid: ctrl_tag(%s): " + "insufficient priv from pid=%u tgid=%u uid=%u\n", + input, current->pid, current->tgid, from_kuid(&init_user_ns, current_fsuid())); + res = -EPERM; + goto err_put; + } + full_tag = combine_atag_with_uid(acct_tag, uid_int); + + spin_lock_bh(&sock_tag_list_lock); + spin_lock_bh(&uid_tag_data_tree_lock); + sock_tag_entry = get_sock_stat_nl(el_socket->sk); + tag_ref_entry = get_tag_ref(full_tag, &uid_tag_data_entry); + if (IS_ERR(tag_ref_entry)) { + res = PTR_ERR(tag_ref_entry); + spin_unlock_bh(&uid_tag_data_tree_lock); + spin_unlock_bh(&sock_tag_list_lock); + goto err_put; + } + tag_ref_entry->num_sock_tags++; + if (sock_tag_entry) { + struct tag_ref *prev_tag_ref_entry; + + CT_DEBUG("qtaguid: ctrl_tag(%s): retag for sk=%p " + "st@%p ...->sk_refcnt=%d\n", + input, el_socket->sk, sock_tag_entry, + atomic_read(&el_socket->sk->sk_refcnt)); + prev_tag_ref_entry = lookup_tag_ref(sock_tag_entry->tag, + &uid_tag_data_entry); + BUG_ON(IS_ERR_OR_NULL(prev_tag_ref_entry)); + BUG_ON(prev_tag_ref_entry->num_sock_tags <= 0); + prev_tag_ref_entry->num_sock_tags--; + sock_tag_entry->tag = full_tag; + } else { + CT_DEBUG("qtaguid: ctrl_tag(%s): newtag for sk=%p\n", + input, el_socket->sk); + sock_tag_entry = kzalloc(sizeof(*sock_tag_entry), + GFP_ATOMIC); + if (!sock_tag_entry) { + pr_err("qtaguid: ctrl_tag(%s): " + "socket tag alloc failed\n", + input); + BUG_ON(tag_ref_entry->num_sock_tags <= 0); + tag_ref_entry->num_sock_tags--; + free_tag_ref_from_utd_entry(tag_ref_entry, + uid_tag_data_entry); + spin_unlock_bh(&uid_tag_data_tree_lock); + spin_unlock_bh(&sock_tag_list_lock); + res = -ENOMEM; + goto err_put; + } + /* + * Hold the sk refcount here to make sure the sk pointer cannot + * be freed and reused + */ + sock_hold(el_socket->sk); + sock_tag_entry->sk = el_socket->sk; + sock_tag_entry->pid = current->tgid; + sock_tag_entry->tag = combine_atag_with_uid(acct_tag, uid_int); + pqd_entry = proc_qtu_data_tree_search( + &proc_qtu_data_tree, current->tgid); + /* + * TODO: remove if, and start failing. + * At first, we want to catch user-space code that is not + * opening the /dev/xt_qtaguid. + */ + if (IS_ERR_OR_NULL(pqd_entry)) + pr_warn_once( + "qtaguid: %s(): " + "User space forgot to open /dev/xt_qtaguid? " + "pid=%u tgid=%u uid=%u\n", __func__, + current->pid, current->tgid, + from_kuid(&init_user_ns, current_fsuid())); + else + list_add(&sock_tag_entry->list, + &pqd_entry->sock_tag_list); + + sock_tag_tree_insert(sock_tag_entry, &sock_tag_tree); + atomic64_inc(&qtu_events.sockets_tagged); + } + spin_unlock_bh(&uid_tag_data_tree_lock); + spin_unlock_bh(&sock_tag_list_lock); + /* We keep the ref to the sk until it is untagged */ + CT_DEBUG("qtaguid: ctrl_tag(%s): done st@%p ...->sk_refcnt=%d\n", + input, sock_tag_entry, + atomic_read(&el_socket->sk->sk_refcnt)); + sockfd_put(el_socket); + return 0; + +err_put: + CT_DEBUG("qtaguid: ctrl_tag(%s): done. ...->sk_refcnt=%d\n", + input, atomic_read(&el_socket->sk->sk_refcnt) - 1); + /* Release the sock_fd that was grabbed by sockfd_lookup(). */ + sockfd_put(el_socket); + return res; + +err: + CT_DEBUG("qtaguid: ctrl_tag(%s): done.\n", input); + return res; +} + +static int ctrl_cmd_untag(const char *input) +{ + char cmd; + int sock_fd = 0; + struct socket *el_socket; + int res, argc; + + argc = sscanf(input, "%c %d", &cmd, &sock_fd); + CT_DEBUG("qtaguid: ctrl_untag(%s): argc=%d cmd=%c sock_fd=%d\n", + input, argc, cmd, sock_fd); + if (argc < 2) { + res = -EINVAL; + return res; + } + el_socket = sockfd_lookup(sock_fd, &res); /* This locks the file */ + if (!el_socket) { + pr_info("qtaguid: ctrl_untag(%s): failed to lookup" + " sock_fd=%d err=%d pid=%u tgid=%u uid=%u\n", + input, sock_fd, res, current->pid, current->tgid, + from_kuid(&init_user_ns, current_fsuid())); + return res; + } + CT_DEBUG("qtaguid: ctrl_untag(%s): socket->...->f_count=%ld ->sk=%p\n", + input, atomic_long_read(&el_socket->file->f_count), + el_socket->sk); + res = qtaguid_untag(el_socket, false); + sockfd_put(el_socket); + return res; +} + +int qtaguid_untag(struct socket *el_socket, bool kernel) +{ + int res; + pid_t pid; + struct sock_tag *sock_tag_entry; + struct tag_ref *tag_ref_entry; + struct uid_tag_data *utd_entry; + struct proc_qtu_data *pqd_entry; + + spin_lock_bh(&sock_tag_list_lock); + sock_tag_entry = get_sock_stat_nl(el_socket->sk); + if (!sock_tag_entry) { + spin_unlock_bh(&sock_tag_list_lock); + res = -EINVAL; + return res; + } + /* + * The socket already belongs to the current process + * so it can do whatever it wants to it. + */ + rb_erase(&sock_tag_entry->sock_node, &sock_tag_tree); + + tag_ref_entry = lookup_tag_ref(sock_tag_entry->tag, &utd_entry); + BUG_ON(!tag_ref_entry); + BUG_ON(tag_ref_entry->num_sock_tags <= 0); + spin_lock_bh(&uid_tag_data_tree_lock); + if (kernel) + pid = sock_tag_entry->pid; + else + pid = current->tgid; + pqd_entry = proc_qtu_data_tree_search( + &proc_qtu_data_tree, pid); + /* + * TODO: remove if, and start failing. + * At first, we want to catch user-space code that is not + * opening the /dev/xt_qtaguid. + */ + if (IS_ERR_OR_NULL(pqd_entry)) + pr_warn_once("qtaguid: %s(): " + "User space forgot to open /dev/xt_qtaguid? " + "pid=%u tgid=%u sk_pid=%u, uid=%u\n", __func__, + current->pid, current->tgid, sock_tag_entry->pid, + from_kuid(&init_user_ns, current_fsuid())); + /* + * This check is needed because tagging from a process that + * didn’t open /dev/xt_qtaguid still adds the sock_tag_entry + * to sock_tag_tree. + */ + if (sock_tag_entry->list.next) + list_del(&sock_tag_entry->list); + + spin_unlock_bh(&uid_tag_data_tree_lock); + /* + * We don't free tag_ref from the utd_entry here, + * only during a cmd_delete(). + */ + tag_ref_entry->num_sock_tags--; + spin_unlock_bh(&sock_tag_list_lock); + /* + * Release the sock_fd that was grabbed at tag time. + */ + sock_put(sock_tag_entry->sk); + CT_DEBUG("qtaguid: done. st@%p ...->sk_refcnt=%d\n", + sock_tag_entry, + atomic_read(&el_socket->sk->sk_refcnt)); + + kfree(sock_tag_entry); + atomic64_inc(&qtu_events.sockets_untagged); + + return 0; +} + +static ssize_t qtaguid_ctrl_parse(const char *input, size_t count) +{ + char cmd; + ssize_t res; + + CT_DEBUG("qtaguid: ctrl(%s): pid=%u tgid=%u uid=%u\n", + input, current->pid, current->tgid, from_kuid(&init_user_ns, current_fsuid())); + + cmd = input[0]; + /* Collect params for commands */ + switch (cmd) { + case 'd': + res = ctrl_cmd_delete(input); + break; + + case 's': + res = ctrl_cmd_counter_set(input); + break; + + case 't': + res = ctrl_cmd_tag(input); + break; + + case 'u': + res = ctrl_cmd_untag(input); + break; + + default: + res = -EINVAL; + goto err; + } + if (!res) + res = count; +err: + CT_DEBUG("qtaguid: ctrl(%s): res=%zd\n", input, res); + return res; +} + +#define MAX_QTAGUID_CTRL_INPUT_LEN 255 +static ssize_t qtaguid_ctrl_proc_write(struct file *file, const char __user *buffer, + size_t count, loff_t *offp) +{ + char input_buf[MAX_QTAGUID_CTRL_INPUT_LEN]; + + if (unlikely(module_passive)) + return count; + + if (count >= MAX_QTAGUID_CTRL_INPUT_LEN) + return -EINVAL; + + if (copy_from_user(input_buf, buffer, count)) + return -EFAULT; + + input_buf[count] = '\0'; + return qtaguid_ctrl_parse(input_buf, count); +} + +struct proc_print_info { + struct iface_stat *iface_entry; + int item_index; + tag_t tag; /* tag found by reading to tag_pos */ + off_t tag_pos; + int tag_item_index; +}; + +static void pp_stats_header(struct seq_file *m) +{ + seq_puts(m, + "idx iface acct_tag_hex uid_tag_int cnt_set " + "rx_bytes rx_packets " + "tx_bytes tx_packets " + "rx_tcp_bytes rx_tcp_packets " + "rx_udp_bytes rx_udp_packets " + "rx_other_bytes rx_other_packets " + "tx_tcp_bytes tx_tcp_packets " + "tx_udp_bytes tx_udp_packets " + "tx_other_bytes tx_other_packets\n"); +} + +static int pp_stats_line(struct seq_file *m, struct tag_stat *ts_entry, + int cnt_set) +{ + struct data_counters *cnts; + tag_t tag = ts_entry->tn.tag; + uid_t stat_uid = get_uid_from_tag(tag); + struct proc_print_info *ppi = m->private; + /* Detailed tags are not available to everybody */ + if (!can_read_other_uid_stats(make_kuid(&init_user_ns,stat_uid))) { + CT_DEBUG("qtaguid: stats line: " + "%s 0x%llx %u: insufficient priv " + "from pid=%u tgid=%u uid=%u stats.gid=%u\n", + ppi->iface_entry->ifname, + get_atag_from_tag(tag), stat_uid, + current->pid, current->tgid, from_kuid(&init_user_ns, current_fsuid()), + from_kgid(&init_user_ns,xt_qtaguid_stats_file->gid)); + return 0; + } + ppi->item_index++; + cnts = &ts_entry->counters; + seq_printf(m, "%d %s 0x%llx %u %u " + "%llu %llu " + "%llu %llu " + "%llu %llu " + "%llu %llu " + "%llu %llu " + "%llu %llu " + "%llu %llu " + "%llu %llu\n", + ppi->item_index, + ppi->iface_entry->ifname, + get_atag_from_tag(tag), + stat_uid, + cnt_set, + dc_sum_bytes(cnts, cnt_set, IFS_RX), + dc_sum_packets(cnts, cnt_set, IFS_RX), + dc_sum_bytes(cnts, cnt_set, IFS_TX), + dc_sum_packets(cnts, cnt_set, IFS_TX), + cnts->bpc[cnt_set][IFS_RX][IFS_TCP].bytes, + cnts->bpc[cnt_set][IFS_RX][IFS_TCP].packets, + cnts->bpc[cnt_set][IFS_RX][IFS_UDP].bytes, + cnts->bpc[cnt_set][IFS_RX][IFS_UDP].packets, + cnts->bpc[cnt_set][IFS_RX][IFS_PROTO_OTHER].bytes, + cnts->bpc[cnt_set][IFS_RX][IFS_PROTO_OTHER].packets, + cnts->bpc[cnt_set][IFS_TX][IFS_TCP].bytes, + cnts->bpc[cnt_set][IFS_TX][IFS_TCP].packets, + cnts->bpc[cnt_set][IFS_TX][IFS_UDP].bytes, + cnts->bpc[cnt_set][IFS_TX][IFS_UDP].packets, + cnts->bpc[cnt_set][IFS_TX][IFS_PROTO_OTHER].bytes, + cnts->bpc[cnt_set][IFS_TX][IFS_PROTO_OTHER].packets); + return seq_has_overflowed(m) ? -ENOSPC : 1; +} + +static bool pp_sets(struct seq_file *m, struct tag_stat *ts_entry) +{ + int ret; + int counter_set; + for (counter_set = 0; counter_set < IFS_MAX_COUNTER_SETS; + counter_set++) { + ret = pp_stats_line(m, ts_entry, counter_set); + if (ret < 0) + return false; + } + return true; +} + +static int qtaguid_stats_proc_iface_stat_ptr_valid(struct iface_stat *ptr) +{ + struct iface_stat *iface_entry; + + if (!ptr) + return false; + + list_for_each_entry(iface_entry, &iface_stat_list, list) + if (iface_entry == ptr) + return true; + return false; +} + +static void qtaguid_stats_proc_next_iface_entry(struct proc_print_info *ppi) +{ + spin_unlock_bh(&ppi->iface_entry->tag_stat_list_lock); + list_for_each_entry_continue(ppi->iface_entry, &iface_stat_list, list) { + spin_lock_bh(&ppi->iface_entry->tag_stat_list_lock); + return; + } + ppi->iface_entry = NULL; +} + +static void *qtaguid_stats_proc_next(struct seq_file *m, void *v, loff_t *pos) +{ + struct proc_print_info *ppi = m->private; + struct tag_stat *ts_entry; + struct rb_node *node; + + if (!v) { + pr_err("qtaguid: %s(): unexpected v: NULL\n", __func__); + return NULL; + } + + (*pos)++; + + if (!ppi->iface_entry || unlikely(module_passive)) + return NULL; + + if (v == SEQ_START_TOKEN) + node = rb_first(&ppi->iface_entry->tag_stat_tree); + else + node = rb_next(&((struct tag_stat *)v)->tn.node); + + while (!node) { + qtaguid_stats_proc_next_iface_entry(ppi); + if (!ppi->iface_entry) + return NULL; + node = rb_first(&ppi->iface_entry->tag_stat_tree); + } + + ts_entry = rb_entry(node, struct tag_stat, tn.node); + ppi->tag = ts_entry->tn.tag; + ppi->tag_pos = *pos; + ppi->tag_item_index = ppi->item_index; + return ts_entry; +} + +static void *qtaguid_stats_proc_start(struct seq_file *m, loff_t *pos) +{ + struct proc_print_info *ppi = m->private; + struct tag_stat *ts_entry = NULL; + + spin_lock_bh(&iface_stat_list_lock); + + if (*pos == 0) { + ppi->item_index = 1; + ppi->tag_pos = 0; + if (list_empty(&iface_stat_list)) { + ppi->iface_entry = NULL; + } else { + ppi->iface_entry = list_first_entry(&iface_stat_list, + struct iface_stat, + list); + spin_lock_bh(&ppi->iface_entry->tag_stat_list_lock); + } + return SEQ_START_TOKEN; + } + if (!qtaguid_stats_proc_iface_stat_ptr_valid(ppi->iface_entry)) { + if (ppi->iface_entry) { + pr_err("qtaguid: %s(): iface_entry %p not found\n", + __func__, ppi->iface_entry); + ppi->iface_entry = NULL; + } + return NULL; + } + + spin_lock_bh(&ppi->iface_entry->tag_stat_list_lock); + + if (!ppi->tag_pos) { + /* seq_read skipped first next call */ + ts_entry = SEQ_START_TOKEN; + } else { + ts_entry = tag_stat_tree_search( + &ppi->iface_entry->tag_stat_tree, ppi->tag); + if (!ts_entry) { + pr_info("qtaguid: %s(): tag_stat.tag 0x%llx not found. Abort.\n", + __func__, ppi->tag); + return NULL; + } + } + + if (*pos == ppi->tag_pos) { /* normal resume */ + ppi->item_index = ppi->tag_item_index; + } else { + /* seq_read skipped a next call */ + *pos = ppi->tag_pos; + ts_entry = qtaguid_stats_proc_next(m, ts_entry, pos); + } + + return ts_entry; +} + +static void qtaguid_stats_proc_stop(struct seq_file *m, void *v) +{ + struct proc_print_info *ppi = m->private; + if (ppi->iface_entry) + spin_unlock_bh(&ppi->iface_entry->tag_stat_list_lock); + spin_unlock_bh(&iface_stat_list_lock); +} + +/* + * Procfs reader to get all tag stats using style "1)" as described in + * fs/proc/generic.c + * Groups all protocols tx/rx bytes. + */ +static int qtaguid_stats_proc_show(struct seq_file *m, void *v) +{ + struct tag_stat *ts_entry = v; + + if (v == SEQ_START_TOKEN) + pp_stats_header(m); + else + pp_sets(m, ts_entry); + + return 0; +} + +/*------------------------------------------*/ +static int qtudev_open(struct inode *inode, struct file *file) +{ + struct uid_tag_data *utd_entry; + struct proc_qtu_data *pqd_entry; + struct proc_qtu_data *new_pqd_entry; + int res; + bool utd_entry_found; + + if (unlikely(qtu_proc_handling_passive)) + return 0; + + DR_DEBUG("qtaguid: qtudev_open(): pid=%u tgid=%u uid=%u\n", + current->pid, current->tgid, from_kuid(&init_user_ns, current_fsuid())); + + spin_lock_bh(&uid_tag_data_tree_lock); + + /* Look for existing uid data, or alloc one. */ + utd_entry = get_uid_data(from_kuid(&init_user_ns, current_fsuid()), &utd_entry_found); + if (IS_ERR_OR_NULL(utd_entry)) { + res = PTR_ERR(utd_entry); + goto err_unlock; + } + + /* Look for existing PID based proc_data */ + pqd_entry = proc_qtu_data_tree_search(&proc_qtu_data_tree, + current->tgid); + if (pqd_entry) { + pr_err("qtaguid: qtudev_open(): %u/%u %u " + "%s already opened\n", + current->pid, current->tgid, from_kuid(&init_user_ns, current_fsuid()), + QTU_DEV_NAME); + res = -EBUSY; + goto err_unlock_free_utd; + } + + new_pqd_entry = kzalloc(sizeof(*new_pqd_entry), GFP_ATOMIC); + if (!new_pqd_entry) { + pr_err("qtaguid: qtudev_open(): %u/%u %u: " + "proc data alloc failed\n", + current->pid, current->tgid, from_kuid(&init_user_ns, current_fsuid())); + res = -ENOMEM; + goto err_unlock_free_utd; + } + new_pqd_entry->pid = current->tgid; + INIT_LIST_HEAD(&new_pqd_entry->sock_tag_list); + new_pqd_entry->parent_tag_data = utd_entry; + utd_entry->num_pqd++; + + proc_qtu_data_tree_insert(new_pqd_entry, + &proc_qtu_data_tree); + + spin_unlock_bh(&uid_tag_data_tree_lock); + DR_DEBUG("qtaguid: tracking data for uid=%u in pqd=%p\n", + from_kuid(&init_user_ns, current_fsuid()), new_pqd_entry); + file->private_data = new_pqd_entry; + return 0; + +err_unlock_free_utd: + if (!utd_entry_found) { + rb_erase(&utd_entry->node, &uid_tag_data_tree); + kfree(utd_entry); + } +err_unlock: + spin_unlock_bh(&uid_tag_data_tree_lock); + return res; +} + +static int qtudev_release(struct inode *inode, struct file *file) +{ + struct proc_qtu_data *pqd_entry = file->private_data; + struct uid_tag_data *utd_entry = pqd_entry->parent_tag_data; + struct sock_tag *st_entry; + struct rb_root st_to_free_tree = RB_ROOT; + struct list_head *entry, *next; + struct tag_ref *tr; + + if (unlikely(qtu_proc_handling_passive)) + return 0; + + /* + * Do not trust the current->pid, it might just be a kworker cleaning + * up after a dead proc. + */ + DR_DEBUG("qtaguid: qtudev_release(): " + "pid=%u tgid=%u uid=%u " + "pqd_entry=%p->pid=%u utd_entry=%p->active_tags=%d\n", + current->pid, current->tgid, pqd_entry->parent_tag_data->uid, + pqd_entry, pqd_entry->pid, utd_entry, + utd_entry->num_active_tags); + + spin_lock_bh(&sock_tag_list_lock); + spin_lock_bh(&uid_tag_data_tree_lock); + + list_for_each_safe(entry, next, &pqd_entry->sock_tag_list) { + st_entry = list_entry(entry, struct sock_tag, list); + DR_DEBUG("qtaguid: %s(): " + "erase sock_tag=%p->sk=%p pid=%u tgid=%u uid=%u\n", + __func__, + st_entry, st_entry->sk, + current->pid, current->tgid, + pqd_entry->parent_tag_data->uid); + + utd_entry = uid_tag_data_tree_search( + &uid_tag_data_tree, + get_uid_from_tag(st_entry->tag)); + BUG_ON(IS_ERR_OR_NULL(utd_entry)); + DR_DEBUG("qtaguid: %s(): " + "looking for tag=0x%llx in utd_entry=%p\n", __func__, + st_entry->tag, utd_entry); + tr = tag_ref_tree_search(&utd_entry->tag_ref_tree, + st_entry->tag); + BUG_ON(!tr); + BUG_ON(tr->num_sock_tags <= 0); + tr->num_sock_tags--; + free_tag_ref_from_utd_entry(tr, utd_entry); + + rb_erase(&st_entry->sock_node, &sock_tag_tree); + list_del(&st_entry->list); + /* Can't sockfd_put() within spinlock, do it later. */ + sock_tag_tree_insert(st_entry, &st_to_free_tree); + + /* + * Try to free the utd_entry if no other proc_qtu_data is + * using it (num_pqd is 0) and it doesn't have active tags + * (num_active_tags is 0). + */ + put_utd_entry(utd_entry); + } + + rb_erase(&pqd_entry->node, &proc_qtu_data_tree); + BUG_ON(pqd_entry->parent_tag_data->num_pqd < 1); + pqd_entry->parent_tag_data->num_pqd--; + put_utd_entry(pqd_entry->parent_tag_data); + kfree(pqd_entry); + file->private_data = NULL; + + spin_unlock_bh(&uid_tag_data_tree_lock); + spin_unlock_bh(&sock_tag_list_lock); + + + sock_tag_tree_erase(&st_to_free_tree); + + spin_lock_bh(&sock_tag_list_lock); + prdebug_full_state_locked(0, "%s(): pid=%u tgid=%u", __func__, + current->pid, current->tgid); + spin_unlock_bh(&sock_tag_list_lock); + return 0; +} + +/*------------------------------------------*/ +static const struct file_operations qtudev_fops = { + .owner = THIS_MODULE, + .open = qtudev_open, + .release = qtudev_release, +}; + +static struct miscdevice qtu_device = { + .minor = MISC_DYNAMIC_MINOR, + .name = QTU_DEV_NAME, + .fops = &qtudev_fops, + /* How sad it doesn't allow for defaults: .mode = S_IRUGO | S_IWUSR */ +}; + +static const struct seq_operations proc_qtaguid_ctrl_seqops = { + .start = qtaguid_ctrl_proc_start, + .next = qtaguid_ctrl_proc_next, + .stop = qtaguid_ctrl_proc_stop, + .show = qtaguid_ctrl_proc_show, +}; + +static int proc_qtaguid_ctrl_open(struct inode *inode, struct file *file) +{ + return seq_open_private(file, &proc_qtaguid_ctrl_seqops, + sizeof(struct proc_ctrl_print_info)); +} + +static const struct file_operations proc_qtaguid_ctrl_fops = { + .open = proc_qtaguid_ctrl_open, + .read = seq_read, + .write = qtaguid_ctrl_proc_write, + .llseek = seq_lseek, + .release = seq_release_private, +}; + +static const struct seq_operations proc_qtaguid_stats_seqops = { + .start = qtaguid_stats_proc_start, + .next = qtaguid_stats_proc_next, + .stop = qtaguid_stats_proc_stop, + .show = qtaguid_stats_proc_show, +}; + +static int proc_qtaguid_stats_open(struct inode *inode, struct file *file) +{ + return seq_open_private(file, &proc_qtaguid_stats_seqops, + sizeof(struct proc_print_info)); +} + +static const struct file_operations proc_qtaguid_stats_fops = { + .open = proc_qtaguid_stats_open, + .read = seq_read, + .llseek = seq_lseek, + .release = seq_release_private, +}; + +/*------------------------------------------*/ +static int __init qtaguid_proc_register(struct proc_dir_entry **res_procdir) +{ + int ret; + *res_procdir = proc_mkdir(module_procdirname, init_net.proc_net); + if (!*res_procdir) { + pr_err("qtaguid: failed to create proc/.../xt_qtaguid\n"); + ret = -ENOMEM; + goto no_dir; + } + + xt_qtaguid_ctrl_file = proc_create_data("ctrl", proc_ctrl_perms, + *res_procdir, + &proc_qtaguid_ctrl_fops, + NULL); + if (!xt_qtaguid_ctrl_file) { + pr_err("qtaguid: failed to create xt_qtaguid/ctrl " + " file\n"); + ret = -ENOMEM; + goto no_ctrl_entry; + } + + xt_qtaguid_stats_file = proc_create_data("stats", proc_stats_perms, + *res_procdir, + &proc_qtaguid_stats_fops, + NULL); + if (!xt_qtaguid_stats_file) { + pr_err("qtaguid: failed to create xt_qtaguid/stats " + "file\n"); + ret = -ENOMEM; + goto no_stats_entry; + } + /* + * TODO: add support counter hacking + * xt_qtaguid_stats_file->write_proc = qtaguid_stats_proc_write; + */ + return 0; + +no_stats_entry: + remove_proc_entry("ctrl", *res_procdir); +no_ctrl_entry: + remove_proc_entry("xt_qtaguid", NULL); +no_dir: + return ret; +} + +static struct xt_match qtaguid_mt_reg __read_mostly = { + /* + * This module masquerades as the "owner" module so that iptables + * tools can deal with it. + */ + .name = "owner", + .revision = 1, + .family = NFPROTO_UNSPEC, + .match = qtaguid_mt, + .matchsize = sizeof(struct xt_qtaguid_match_info), + .me = THIS_MODULE, +}; + +static int __init qtaguid_mt_init(void) +{ + if (qtaguid_proc_register(&xt_qtaguid_procdir) + || iface_stat_init(xt_qtaguid_procdir) + || xt_register_match(&qtaguid_mt_reg) + || misc_register(&qtu_device)) + return -1; + return 0; +} + +/* + * TODO: allow unloading of the module. + * For now stats are permanent. + * Kconfig forces'y/n' and never an 'm'. + */ + +module_init(qtaguid_mt_init); +MODULE_AUTHOR("jpa <jpa@google.com>"); +MODULE_DESCRIPTION("Xtables: socket owner+tag matching and associated stats"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_owner"); +MODULE_ALIAS("ip6t_owner"); +MODULE_ALIAS("ipt_qtaguid"); +MODULE_ALIAS("ip6t_qtaguid"); diff --git a/net/netfilter/xt_qtaguid_internal.h b/net/netfilter/xt_qtaguid_internal.h new file mode 100644 index 000000000000..c7052707a6a4 --- /dev/null +++ b/net/netfilter/xt_qtaguid_internal.h @@ -0,0 +1,350 @@ +/* + * Kernel iptables module to track stats for packets based on user tags. + * + * (C) 2011 Google, Inc + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 as + * published by the Free Software Foundation. + */ +#ifndef __XT_QTAGUID_INTERNAL_H__ +#define __XT_QTAGUID_INTERNAL_H__ + +#include <linux/types.h> +#include <linux/rbtree.h> +#include <linux/spinlock_types.h> +#include <linux/workqueue.h> + +/* Iface handling */ +#define IDEBUG_MASK (1<<0) +/* Iptable Matching. Per packet. */ +#define MDEBUG_MASK (1<<1) +/* Red-black tree handling. Per packet. */ +#define RDEBUG_MASK (1<<2) +/* procfs ctrl/stats handling */ +#define CDEBUG_MASK (1<<3) +/* dev and resource tracking */ +#define DDEBUG_MASK (1<<4) + +/* E.g (IDEBUG_MASK | CDEBUG_MASK | DDEBUG_MASK) */ +#define DEFAULT_DEBUG_MASK 0 + +/* + * (Un)Define these *DEBUG to compile out/in the pr_debug calls. + * All undef: text size ~ 0x3030; all def: ~ 0x4404. + */ +#define IDEBUG +#define MDEBUG +#define RDEBUG +#define CDEBUG +#define DDEBUG + +#define MSK_DEBUG(mask, ...) do { \ + if (unlikely(qtaguid_debug_mask & (mask))) \ + pr_debug(__VA_ARGS__); \ + } while (0) +#ifdef IDEBUG +#define IF_DEBUG(...) MSK_DEBUG(IDEBUG_MASK, __VA_ARGS__) +#else +#define IF_DEBUG(...) no_printk(__VA_ARGS__) +#endif +#ifdef MDEBUG +#define MT_DEBUG(...) MSK_DEBUG(MDEBUG_MASK, __VA_ARGS__) +#else +#define MT_DEBUG(...) no_printk(__VA_ARGS__) +#endif +#ifdef RDEBUG +#define RB_DEBUG(...) MSK_DEBUG(RDEBUG_MASK, __VA_ARGS__) +#else +#define RB_DEBUG(...) no_printk(__VA_ARGS__) +#endif +#ifdef CDEBUG +#define CT_DEBUG(...) MSK_DEBUG(CDEBUG_MASK, __VA_ARGS__) +#else +#define CT_DEBUG(...) no_printk(__VA_ARGS__) +#endif +#ifdef DDEBUG +#define DR_DEBUG(...) MSK_DEBUG(DDEBUG_MASK, __VA_ARGS__) +#else +#define DR_DEBUG(...) no_printk(__VA_ARGS__) +#endif + +extern uint qtaguid_debug_mask; + +/*---------------------------------------------------------------------------*/ +/* + * Tags: + * + * They represent what the data usage counters will be tracked against. + * By default a tag is just based on the UID. + * The UID is used as the base for policing, and can not be ignored. + * So a tag will always at least represent a UID (uid_tag). + * + * A tag can be augmented with an "accounting tag" which is associated + * with a UID. + * User space can set the acct_tag portion of the tag which is then used + * with sockets: all data belonging to that socket will be counted against the + * tag. The policing is then based on the tag's uid_tag portion, + * and stats are collected for the acct_tag portion separately. + * + * There could be + * a: {acct_tag=1, uid_tag=10003} + * b: {acct_tag=2, uid_tag=10003} + * c: {acct_tag=3, uid_tag=10003} + * d: {acct_tag=0, uid_tag=10003} + * a, b, and c represent tags associated with specific sockets. + * d is for the totals for that uid, including all untagged traffic. + * Typically d is used with policing/quota rules. + * + * We want tag_t big enough to distinguish uid_t and acct_tag. + * It might become a struct if needed. + * Nothing should be using it as an int. + */ +typedef uint64_t tag_t; /* Only used via accessors */ + +#define TAG_UID_MASK 0xFFFFFFFFULL +#define TAG_ACCT_MASK (~0xFFFFFFFFULL) + +static inline int tag_compare(tag_t t1, tag_t t2) +{ + return t1 < t2 ? -1 : t1 == t2 ? 0 : 1; +} + +static inline tag_t combine_atag_with_uid(tag_t acct_tag, uid_t uid) +{ + return acct_tag | uid; +} +static inline tag_t make_tag_from_uid(uid_t uid) +{ + return uid; +} +static inline uid_t get_uid_from_tag(tag_t tag) +{ + return tag & TAG_UID_MASK; +} +static inline tag_t get_utag_from_tag(tag_t tag) +{ + return tag & TAG_UID_MASK; +} +static inline tag_t get_atag_from_tag(tag_t tag) +{ + return tag & TAG_ACCT_MASK; +} + +static inline bool valid_atag(tag_t tag) +{ + return !(tag & TAG_UID_MASK); +} +static inline tag_t make_atag_from_value(uint32_t value) +{ + return (uint64_t)value << 32; +} +/*---------------------------------------------------------------------------*/ + +/* + * Maximum number of socket tags that a UID is allowed to have active. + * Multiple processes belonging to the same UID contribute towards this limit. + * Special UIDs that can impersonate a UID also contribute (e.g. download + * manager, ...) + */ +#define DEFAULT_MAX_SOCK_TAGS 1024 + +/* + * For now we only track 2 sets of counters. + * The default set is 0. + * Userspace can activate another set for a given uid being tracked. + */ +#define IFS_MAX_COUNTER_SETS 2 + +enum ifs_tx_rx { + IFS_TX, + IFS_RX, + IFS_MAX_DIRECTIONS +}; + +/* For now, TCP, UDP, the rest */ +enum ifs_proto { + IFS_TCP, + IFS_UDP, + IFS_PROTO_OTHER, + IFS_MAX_PROTOS +}; + +struct byte_packet_counters { + uint64_t bytes; + uint64_t packets; +}; + +struct data_counters { + struct byte_packet_counters bpc[IFS_MAX_COUNTER_SETS][IFS_MAX_DIRECTIONS][IFS_MAX_PROTOS]; +}; + +static inline uint64_t dc_sum_bytes(struct data_counters *counters, + int set, + enum ifs_tx_rx direction) +{ + return counters->bpc[set][direction][IFS_TCP].bytes + + counters->bpc[set][direction][IFS_UDP].bytes + + counters->bpc[set][direction][IFS_PROTO_OTHER].bytes; +} + +static inline uint64_t dc_sum_packets(struct data_counters *counters, + int set, + enum ifs_tx_rx direction) +{ + return counters->bpc[set][direction][IFS_TCP].packets + + counters->bpc[set][direction][IFS_UDP].packets + + counters->bpc[set][direction][IFS_PROTO_OTHER].packets; +} + + +/* Generic X based nodes used as a base for rb_tree ops */ +struct tag_node { + struct rb_node node; + tag_t tag; +}; + +struct tag_stat { + struct tag_node tn; + struct data_counters counters; + /* + * If this tag is acct_tag based, we need to count against the + * matching parent uid_tag. + */ + struct data_counters *parent_counters; +}; + +struct iface_stat { + struct list_head list; /* in iface_stat_list */ + char *ifname; + bool active; + /* net_dev is only valid for active iface_stat */ + struct net_device *net_dev; + + struct byte_packet_counters totals_via_dev[IFS_MAX_DIRECTIONS]; + struct data_counters totals_via_skb; + /* + * We keep the last_known, because some devices reset their counters + * just before NETDEV_UP, while some will reset just before + * NETDEV_REGISTER (which is more normal). + * So now, if the device didn't do a NETDEV_UNREGISTER and we see + * its current dev stats smaller that what was previously known, we + * assume an UNREGISTER and just use the last_known. + */ + struct byte_packet_counters last_known[IFS_MAX_DIRECTIONS]; + /* last_known is usable when last_known_valid is true */ + bool last_known_valid; + + struct proc_dir_entry *proc_ptr; + + struct rb_root tag_stat_tree; + spinlock_t tag_stat_list_lock; +}; + +/* This is needed to create proc_dir_entries from atomic context. */ +struct iface_stat_work { + struct work_struct iface_work; + struct iface_stat *iface_entry; +}; + +/* + * Track tag that this socket is transferring data for, and not necessarily + * the uid that owns the socket. + * This is the tag against which tag_stat.counters will be billed. + * These structs need to be looked up by sock and pid. + */ +struct sock_tag { + struct rb_node sock_node; + struct sock *sk; /* Only used as a number, never dereferenced */ + /* Used to associate with a given pid */ + struct list_head list; /* in proc_qtu_data.sock_tag_list */ + pid_t pid; + + tag_t tag; +}; + +struct qtaguid_event_counts { + /* Various successful events */ + atomic64_t sockets_tagged; + atomic64_t sockets_untagged; + atomic64_t counter_set_changes; + atomic64_t delete_cmds; + atomic64_t iface_events; /* Number of NETDEV_* events handled */ + + atomic64_t match_calls; /* Number of times iptables called mt */ + /* Number of times iptables called mt from pre or post routing hooks */ + atomic64_t match_calls_prepost; + /* + * match_found_sk_*: numbers related to the netfilter matching + * function finding a sock for the sk_buff. + * Total skbs processed is sum(match_found*). + */ + atomic64_t match_found_sk; /* An sk was already in the sk_buff. */ + /* The connection tracker had or didn't have the sk. */ + atomic64_t match_found_sk_in_ct; + atomic64_t match_found_no_sk_in_ct; + /* + * No sk could be found. No apparent owner. Could happen with + * unsolicited traffic. + */ + atomic64_t match_no_sk; + /* + * The file ptr in the sk_socket wasn't there and we couldn't get GID. + * This might happen for traffic while the socket is being closed. + */ + atomic64_t match_no_sk_gid; +}; + +/* Track the set active_set for the given tag. */ +struct tag_counter_set { + struct tag_node tn; + int active_set; +}; + +/*----------------------------------------------*/ +/* + * The qtu uid data is used to track resources that are created directly or + * indirectly by processes (uid tracked). + * It is shared by the processes with the same uid. + * Some of the resource will be counted to prevent further rogue allocations, + * some will need freeing once the owner process (uid) exits. + */ +struct uid_tag_data { + struct rb_node node; + uid_t uid; + + /* + * For the uid, how many accounting tags have been set. + */ + int num_active_tags; + /* Track the number of proc_qtu_data that reference it */ + int num_pqd; + struct rb_root tag_ref_tree; + /* No tag_node_tree_lock; use uid_tag_data_tree_lock */ +}; + +struct tag_ref { + struct tag_node tn; + + /* + * This tracks the number of active sockets that have a tag on them + * which matches this tag_ref.tn.tag. + * A tag ref can live on after the sockets are untagged. + * A tag ref can only be removed during a tag delete command. + */ + int num_sock_tags; +}; + +struct proc_qtu_data { + struct rb_node node; + pid_t pid; + + struct uid_tag_data *parent_tag_data; + + /* Tracks the sock_tags that need freeing upon this proc's death */ + struct list_head sock_tag_list; + /* No spinlock_t sock_tag_list_lock; use the global one. */ +}; + +/*----------------------------------------------*/ +#endif /* ifndef __XT_QTAGUID_INTERNAL_H__ */ diff --git a/net/netfilter/xt_qtaguid_print.c b/net/netfilter/xt_qtaguid_print.c new file mode 100644 index 000000000000..2a7190d285e6 --- /dev/null +++ b/net/netfilter/xt_qtaguid_print.c @@ -0,0 +1,566 @@ +/* + * Pretty printing Support for iptables xt_qtaguid module. + * + * (C) 2011 Google, Inc + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 as + * published by the Free Software Foundation. + */ + +/* + * Most of the functions in this file just waste time if DEBUG is not defined. + * The matching xt_qtaguid_print.h will static inline empty funcs if the needed + * debug flags ore not defined. + * Those funcs that fail to allocate memory will panic as there is no need to + * hobble allong just pretending to do the requested work. + */ + +#define DEBUG + +#include <linux/fs.h> +#include <linux/gfp.h> +#include <linux/net.h> +#include <linux/rbtree.h> +#include <linux/slab.h> +#include <linux/spinlock_types.h> +#include <net/sock.h> + +#include "xt_qtaguid_internal.h" +#include "xt_qtaguid_print.h" + +#ifdef DDEBUG + +static void _bug_on_err_or_null(void *ptr) +{ + if (IS_ERR_OR_NULL(ptr)) { + pr_err("qtaguid: kmalloc failed\n"); + BUG(); + } +} + +char *pp_tag_t(tag_t *tag) +{ + char *res; + + if (!tag) + res = kasprintf(GFP_ATOMIC, "tag_t@null{}"); + else + res = kasprintf(GFP_ATOMIC, + "tag_t@%p{tag=0x%llx, uid=%u}", + tag, *tag, get_uid_from_tag(*tag)); + _bug_on_err_or_null(res); + return res; +} + +char *pp_data_counters(struct data_counters *dc, bool showValues) +{ + char *res; + + if (!dc) + res = kasprintf(GFP_ATOMIC, "data_counters@null{}"); + else if (showValues) + res = kasprintf( + GFP_ATOMIC, "data_counters@%p{" + "set0{" + "rx{" + "tcp{b=%llu, p=%llu}, " + "udp{b=%llu, p=%llu}," + "other{b=%llu, p=%llu}}, " + "tx{" + "tcp{b=%llu, p=%llu}, " + "udp{b=%llu, p=%llu}," + "other{b=%llu, p=%llu}}}, " + "set1{" + "rx{" + "tcp{b=%llu, p=%llu}, " + "udp{b=%llu, p=%llu}," + "other{b=%llu, p=%llu}}, " + "tx{" + "tcp{b=%llu, p=%llu}, " + "udp{b=%llu, p=%llu}," + "other{b=%llu, p=%llu}}}}", + dc, + dc->bpc[0][IFS_RX][IFS_TCP].bytes, + dc->bpc[0][IFS_RX][IFS_TCP].packets, + dc->bpc[0][IFS_RX][IFS_UDP].bytes, + dc->bpc[0][IFS_RX][IFS_UDP].packets, + dc->bpc[0][IFS_RX][IFS_PROTO_OTHER].bytes, + dc->bpc[0][IFS_RX][IFS_PROTO_OTHER].packets, + dc->bpc[0][IFS_TX][IFS_TCP].bytes, + dc->bpc[0][IFS_TX][IFS_TCP].packets, + dc->bpc[0][IFS_TX][IFS_UDP].bytes, + dc->bpc[0][IFS_TX][IFS_UDP].packets, + dc->bpc[0][IFS_TX][IFS_PROTO_OTHER].bytes, + dc->bpc[0][IFS_TX][IFS_PROTO_OTHER].packets, + dc->bpc[1][IFS_RX][IFS_TCP].bytes, + dc->bpc[1][IFS_RX][IFS_TCP].packets, + dc->bpc[1][IFS_RX][IFS_UDP].bytes, + dc->bpc[1][IFS_RX][IFS_UDP].packets, + dc->bpc[1][IFS_RX][IFS_PROTO_OTHER].bytes, + dc->bpc[1][IFS_RX][IFS_PROTO_OTHER].packets, + dc->bpc[1][IFS_TX][IFS_TCP].bytes, + dc->bpc[1][IFS_TX][IFS_TCP].packets, + dc->bpc[1][IFS_TX][IFS_UDP].bytes, + dc->bpc[1][IFS_TX][IFS_UDP].packets, + dc->bpc[1][IFS_TX][IFS_PROTO_OTHER].bytes, + dc->bpc[1][IFS_TX][IFS_PROTO_OTHER].packets); + else + res = kasprintf(GFP_ATOMIC, "data_counters@%p{...}", dc); + _bug_on_err_or_null(res); + return res; +} + +char *pp_tag_node(struct tag_node *tn) +{ + char *tag_str; + char *res; + + if (!tn) { + res = kasprintf(GFP_ATOMIC, "tag_node@null{}"); + _bug_on_err_or_null(res); + return res; + } + tag_str = pp_tag_t(&tn->tag); + res = kasprintf(GFP_ATOMIC, + "tag_node@%p{tag=%s}", + tn, tag_str); + _bug_on_err_or_null(res); + kfree(tag_str); + return res; +} + +char *pp_tag_ref(struct tag_ref *tr) +{ + char *tn_str; + char *res; + + if (!tr) { + res = kasprintf(GFP_ATOMIC, "tag_ref@null{}"); + _bug_on_err_or_null(res); + return res; + } + tn_str = pp_tag_node(&tr->tn); + res = kasprintf(GFP_ATOMIC, + "tag_ref@%p{%s, num_sock_tags=%d}", + tr, tn_str, tr->num_sock_tags); + _bug_on_err_or_null(res); + kfree(tn_str); + return res; +} + +char *pp_tag_stat(struct tag_stat *ts) +{ + char *tn_str; + char *counters_str; + char *parent_counters_str; + char *res; + + if (!ts) { + res = kasprintf(GFP_ATOMIC, "tag_stat@null{}"); + _bug_on_err_or_null(res); + return res; + } + tn_str = pp_tag_node(&ts->tn); + counters_str = pp_data_counters(&ts->counters, true); + parent_counters_str = pp_data_counters(ts->parent_counters, false); + res = kasprintf(GFP_ATOMIC, + "tag_stat@%p{%s, counters=%s, parent_counters=%s}", + ts, tn_str, counters_str, parent_counters_str); + _bug_on_err_or_null(res); + kfree(tn_str); + kfree(counters_str); + kfree(parent_counters_str); + return res; +} + +char *pp_iface_stat(struct iface_stat *is) +{ + char *res; + if (!is) { + res = kasprintf(GFP_ATOMIC, "iface_stat@null{}"); + } else { + struct data_counters *cnts = &is->totals_via_skb; + res = kasprintf(GFP_ATOMIC, "iface_stat@%p{" + "list=list_head{...}, " + "ifname=%s, " + "total_dev={rx={bytes=%llu, " + "packets=%llu}, " + "tx={bytes=%llu, " + "packets=%llu}}, " + "total_skb={rx={bytes=%llu, " + "packets=%llu}, " + "tx={bytes=%llu, " + "packets=%llu}}, " + "last_known_valid=%d, " + "last_known={rx={bytes=%llu, " + "packets=%llu}, " + "tx={bytes=%llu, " + "packets=%llu}}, " + "active=%d, " + "net_dev=%p, " + "proc_ptr=%p, " + "tag_stat_tree=rb_root{...}}", + is, + is->ifname, + is->totals_via_dev[IFS_RX].bytes, + is->totals_via_dev[IFS_RX].packets, + is->totals_via_dev[IFS_TX].bytes, + is->totals_via_dev[IFS_TX].packets, + dc_sum_bytes(cnts, 0, IFS_RX), + dc_sum_packets(cnts, 0, IFS_RX), + dc_sum_bytes(cnts, 0, IFS_TX), + dc_sum_packets(cnts, 0, IFS_TX), + is->last_known_valid, + is->last_known[IFS_RX].bytes, + is->last_known[IFS_RX].packets, + is->last_known[IFS_TX].bytes, + is->last_known[IFS_TX].packets, + is->active, + is->net_dev, + is->proc_ptr); + } + _bug_on_err_or_null(res); + return res; +} + +char *pp_sock_tag(struct sock_tag *st) +{ + char *tag_str; + char *res; + + if (!st) { + res = kasprintf(GFP_ATOMIC, "sock_tag@null{}"); + _bug_on_err_or_null(res); + return res; + } + tag_str = pp_tag_t(&st->tag); + res = kasprintf(GFP_ATOMIC, "sock_tag@%p{" + "sock_node=rb_node{...}, " + "sk=%p (f_count=%d), list=list_head{...}, " + "pid=%u, tag=%s}", + st, st->sk, atomic_read( + &st->sk->sk_refcnt), + st->pid, tag_str); + _bug_on_err_or_null(res); + kfree(tag_str); + return res; +} + +char *pp_uid_tag_data(struct uid_tag_data *utd) +{ + char *res; + + if (!utd) + res = kasprintf(GFP_ATOMIC, "uid_tag_data@null{}"); + else + res = kasprintf(GFP_ATOMIC, "uid_tag_data@%p{" + "uid=%u, num_active_acct_tags=%d, " + "num_pqd=%d, " + "tag_node_tree=rb_root{...}, " + "proc_qtu_data_tree=rb_root{...}}", + utd, utd->uid, + utd->num_active_tags, utd->num_pqd); + _bug_on_err_or_null(res); + return res; +} + +char *pp_proc_qtu_data(struct proc_qtu_data *pqd) +{ + char *parent_tag_data_str; + char *res; + + if (!pqd) { + res = kasprintf(GFP_ATOMIC, "proc_qtu_data@null{}"); + _bug_on_err_or_null(res); + return res; + } + parent_tag_data_str = pp_uid_tag_data(pqd->parent_tag_data); + res = kasprintf(GFP_ATOMIC, "proc_qtu_data@%p{" + "node=rb_node{...}, pid=%u, " + "parent_tag_data=%s, " + "sock_tag_list=list_head{...}}", + pqd, pqd->pid, parent_tag_data_str + ); + _bug_on_err_or_null(res); + kfree(parent_tag_data_str); + return res; +} + +/*------------------------------------------*/ +void prdebug_sock_tag_tree(int indent_level, + struct rb_root *sock_tag_tree) +{ + struct rb_node *node; + struct sock_tag *sock_tag_entry; + char *str; + + if (!unlikely(qtaguid_debug_mask & DDEBUG_MASK)) + return; + + if (RB_EMPTY_ROOT(sock_tag_tree)) { + str = "sock_tag_tree=rb_root{}"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); + return; + } + + str = "sock_tag_tree=rb_root{"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); + indent_level++; + for (node = rb_first(sock_tag_tree); + node; + node = rb_next(node)) { + sock_tag_entry = rb_entry(node, struct sock_tag, sock_node); + str = pp_sock_tag(sock_tag_entry); + pr_debug("%*d: %s,\n", indent_level*2, indent_level, str); + kfree(str); + } + indent_level--; + str = "}"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); +} + +void prdebug_sock_tag_list(int indent_level, + struct list_head *sock_tag_list) +{ + struct sock_tag *sock_tag_entry; + char *str; + + if (!unlikely(qtaguid_debug_mask & DDEBUG_MASK)) + return; + + if (list_empty(sock_tag_list)) { + str = "sock_tag_list=list_head{}"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); + return; + } + + str = "sock_tag_list=list_head{"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); + indent_level++; + list_for_each_entry(sock_tag_entry, sock_tag_list, list) { + str = pp_sock_tag(sock_tag_entry); + pr_debug("%*d: %s,\n", indent_level*2, indent_level, str); + kfree(str); + } + indent_level--; + str = "}"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); +} + +void prdebug_proc_qtu_data_tree(int indent_level, + struct rb_root *proc_qtu_data_tree) +{ + char *str; + struct rb_node *node; + struct proc_qtu_data *proc_qtu_data_entry; + + if (!unlikely(qtaguid_debug_mask & DDEBUG_MASK)) + return; + + if (RB_EMPTY_ROOT(proc_qtu_data_tree)) { + str = "proc_qtu_data_tree=rb_root{}"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); + return; + } + + str = "proc_qtu_data_tree=rb_root{"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); + indent_level++; + for (node = rb_first(proc_qtu_data_tree); + node; + node = rb_next(node)) { + proc_qtu_data_entry = rb_entry(node, + struct proc_qtu_data, + node); + str = pp_proc_qtu_data(proc_qtu_data_entry); + pr_debug("%*d: %s,\n", indent_level*2, indent_level, + str); + kfree(str); + indent_level++; + prdebug_sock_tag_list(indent_level, + &proc_qtu_data_entry->sock_tag_list); + indent_level--; + + } + indent_level--; + str = "}"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); +} + +void prdebug_tag_ref_tree(int indent_level, struct rb_root *tag_ref_tree) +{ + char *str; + struct rb_node *node; + struct tag_ref *tag_ref_entry; + + if (!unlikely(qtaguid_debug_mask & DDEBUG_MASK)) + return; + + if (RB_EMPTY_ROOT(tag_ref_tree)) { + str = "tag_ref_tree{}"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); + return; + } + + str = "tag_ref_tree{"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); + indent_level++; + for (node = rb_first(tag_ref_tree); + node; + node = rb_next(node)) { + tag_ref_entry = rb_entry(node, + struct tag_ref, + tn.node); + str = pp_tag_ref(tag_ref_entry); + pr_debug("%*d: %s,\n", indent_level*2, indent_level, + str); + kfree(str); + } + indent_level--; + str = "}"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); +} + +void prdebug_uid_tag_data_tree(int indent_level, + struct rb_root *uid_tag_data_tree) +{ + char *str; + struct rb_node *node; + struct uid_tag_data *uid_tag_data_entry; + + if (!unlikely(qtaguid_debug_mask & DDEBUG_MASK)) + return; + + if (RB_EMPTY_ROOT(uid_tag_data_tree)) { + str = "uid_tag_data_tree=rb_root{}"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); + return; + } + + str = "uid_tag_data_tree=rb_root{"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); + indent_level++; + for (node = rb_first(uid_tag_data_tree); + node; + node = rb_next(node)) { + uid_tag_data_entry = rb_entry(node, struct uid_tag_data, + node); + str = pp_uid_tag_data(uid_tag_data_entry); + pr_debug("%*d: %s,\n", indent_level*2, indent_level, str); + kfree(str); + if (!RB_EMPTY_ROOT(&uid_tag_data_entry->tag_ref_tree)) { + indent_level++; + prdebug_tag_ref_tree(indent_level, + &uid_tag_data_entry->tag_ref_tree); + indent_level--; + } + } + indent_level--; + str = "}"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); +} + +void prdebug_tag_stat_tree(int indent_level, + struct rb_root *tag_stat_tree) +{ + char *str; + struct rb_node *node; + struct tag_stat *ts_entry; + + if (!unlikely(qtaguid_debug_mask & DDEBUG_MASK)) + return; + + if (RB_EMPTY_ROOT(tag_stat_tree)) { + str = "tag_stat_tree{}"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); + return; + } + + str = "tag_stat_tree{"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); + indent_level++; + for (node = rb_first(tag_stat_tree); + node; + node = rb_next(node)) { + ts_entry = rb_entry(node, struct tag_stat, tn.node); + str = pp_tag_stat(ts_entry); + pr_debug("%*d: %s\n", indent_level*2, indent_level, + str); + kfree(str); + } + indent_level--; + str = "}"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); +} + +void prdebug_iface_stat_list(int indent_level, + struct list_head *iface_stat_list) +{ + char *str; + struct iface_stat *iface_entry; + + if (!unlikely(qtaguid_debug_mask & DDEBUG_MASK)) + return; + + if (list_empty(iface_stat_list)) { + str = "iface_stat_list=list_head{}"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); + return; + } + + str = "iface_stat_list=list_head{"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); + indent_level++; + list_for_each_entry(iface_entry, iface_stat_list, list) { + str = pp_iface_stat(iface_entry); + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); + kfree(str); + + spin_lock_bh(&iface_entry->tag_stat_list_lock); + if (!RB_EMPTY_ROOT(&iface_entry->tag_stat_tree)) { + indent_level++; + prdebug_tag_stat_tree(indent_level, + &iface_entry->tag_stat_tree); + indent_level--; + } + spin_unlock_bh(&iface_entry->tag_stat_list_lock); + } + indent_level--; + str = "}"; + pr_debug("%*d: %s\n", indent_level*2, indent_level, str); +} + +#endif /* ifdef DDEBUG */ +/*------------------------------------------*/ +static const char * const netdev_event_strings[] = { + "netdev_unknown", + "NETDEV_UP", + "NETDEV_DOWN", + "NETDEV_REBOOT", + "NETDEV_CHANGE", + "NETDEV_REGISTER", + "NETDEV_UNREGISTER", + "NETDEV_CHANGEMTU", + "NETDEV_CHANGEADDR", + "NETDEV_GOING_DOWN", + "NETDEV_CHANGENAME", + "NETDEV_FEAT_CHANGE", + "NETDEV_BONDING_FAILOVER", + "NETDEV_PRE_UP", + "NETDEV_PRE_TYPE_CHANGE", + "NETDEV_POST_TYPE_CHANGE", + "NETDEV_POST_INIT", + "NETDEV_UNREGISTER_BATCH", + "NETDEV_RELEASE", + "NETDEV_NOTIFY_PEERS", + "NETDEV_JOIN", +}; + +const char *netdev_evt_str(int netdev_event) +{ + if (netdev_event < 0 + || netdev_event >= ARRAY_SIZE(netdev_event_strings)) + return "bad event num"; + return netdev_event_strings[netdev_event]; +} diff --git a/net/netfilter/xt_qtaguid_print.h b/net/netfilter/xt_qtaguid_print.h new file mode 100644 index 000000000000..b63871a0be5a --- /dev/null +++ b/net/netfilter/xt_qtaguid_print.h @@ -0,0 +1,120 @@ +/* + * Pretty printing Support for iptables xt_qtaguid module. + * + * (C) 2011 Google, Inc + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License version 2 as + * published by the Free Software Foundation. + */ +#ifndef __XT_QTAGUID_PRINT_H__ +#define __XT_QTAGUID_PRINT_H__ + +#include "xt_qtaguid_internal.h" + +#ifdef DDEBUG + +char *pp_tag_t(tag_t *tag); +char *pp_data_counters(struct data_counters *dc, bool showValues); +char *pp_tag_node(struct tag_node *tn); +char *pp_tag_ref(struct tag_ref *tr); +char *pp_tag_stat(struct tag_stat *ts); +char *pp_iface_stat(struct iface_stat *is); +char *pp_sock_tag(struct sock_tag *st); +char *pp_uid_tag_data(struct uid_tag_data *qtd); +char *pp_proc_qtu_data(struct proc_qtu_data *pqd); + +/*------------------------------------------*/ +void prdebug_sock_tag_list(int indent_level, + struct list_head *sock_tag_list); +void prdebug_sock_tag_tree(int indent_level, + struct rb_root *sock_tag_tree); +void prdebug_proc_qtu_data_tree(int indent_level, + struct rb_root *proc_qtu_data_tree); +void prdebug_tag_ref_tree(int indent_level, struct rb_root *tag_ref_tree); +void prdebug_uid_tag_data_tree(int indent_level, + struct rb_root *uid_tag_data_tree); +void prdebug_tag_stat_tree(int indent_level, + struct rb_root *tag_stat_tree); +void prdebug_iface_stat_list(int indent_level, + struct list_head *iface_stat_list); + +#else + +/*------------------------------------------*/ +static inline char *pp_tag_t(tag_t *tag) +{ + return NULL; +} +static inline char *pp_data_counters(struct data_counters *dc, bool showValues) +{ + return NULL; +} +static inline char *pp_tag_node(struct tag_node *tn) +{ + return NULL; +} +static inline char *pp_tag_ref(struct tag_ref *tr) +{ + return NULL; +} +static inline char *pp_tag_stat(struct tag_stat *ts) +{ + return NULL; +} +static inline char *pp_iface_stat(struct iface_stat *is) +{ + return NULL; +} +static inline char *pp_sock_tag(struct sock_tag *st) +{ + return NULL; +} +static inline char *pp_uid_tag_data(struct uid_tag_data *qtd) +{ + return NULL; +} +static inline char *pp_proc_qtu_data(struct proc_qtu_data *pqd) +{ + return NULL; +} + +/*------------------------------------------*/ +static inline +void prdebug_sock_tag_list(int indent_level, + struct list_head *sock_tag_list) +{ +} +static inline +void prdebug_sock_tag_tree(int indent_level, + struct rb_root *sock_tag_tree) +{ +} +static inline +void prdebug_proc_qtu_data_tree(int indent_level, + struct rb_root *proc_qtu_data_tree) +{ +} +static inline +void prdebug_tag_ref_tree(int indent_level, struct rb_root *tag_ref_tree) +{ +} +static inline +void prdebug_uid_tag_data_tree(int indent_level, + struct rb_root *uid_tag_data_tree) +{ +} +static inline +void prdebug_tag_stat_tree(int indent_level, + struct rb_root *tag_stat_tree) +{ +} +static inline +void prdebug_iface_stat_list(int indent_level, + struct list_head *iface_stat_list) +{ +} +#endif +/*------------------------------------------*/ +const char *netdev_evt_str(int netdev_event); +#endif /* ifndef __XT_QTAGUID_PRINT_H__ */ diff --git a/net/netfilter/xt_quota.c b/net/netfilter/xt_quota.c index 44c8eb4c9d66..10d61a6eed71 100644 --- a/net/netfilter/xt_quota.c +++ b/net/netfilter/xt_quota.c @@ -73,6 +73,7 @@ static struct xt_match quota_mt_reg __read_mostly = { .checkentry = quota_mt_check, .destroy = quota_mt_destroy, .matchsize = sizeof(struct xt_quota_info), + .usersize = offsetof(struct xt_quota_info, master), .me = THIS_MODULE, }; diff --git a/net/netfilter/xt_quota2.c b/net/netfilter/xt_quota2.c new file mode 100644 index 000000000000..b1123626c714 --- /dev/null +++ b/net/netfilter/xt_quota2.c @@ -0,0 +1,397 @@ +/* + * xt_quota2 - enhanced xt_quota that can count upwards and in packets + * as a minimal accounting match. + * by Jan Engelhardt <jengelh@medozas.de>, 2008 + * + * Originally based on xt_quota.c: + * netfilter module to enforce network quotas + * Sam Johnston <samj@samj.net> + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License; either + * version 2 of the License, as published by the Free Software Foundation. + */ +#include <linux/list.h> +#include <linux/module.h> +#include <linux/proc_fs.h> +#include <linux/skbuff.h> +#include <linux/spinlock.h> +#include <asm/atomic.h> +#include <net/netlink.h> + +#include <linux/netfilter/x_tables.h> +#include <linux/netfilter/xt_quota2.h> + +#ifdef CONFIG_NETFILTER_XT_MATCH_QUOTA2_LOG +/* For compatibility, these definitions are copied from the + * deprecated header file <linux/netfilter_ipv4/ipt_ULOG.h> */ +#define ULOG_MAC_LEN 80 +#define ULOG_PREFIX_LEN 32 + +/* Format of the ULOG packets passed through netlink */ +typedef struct ulog_packet_msg { + unsigned long mark; + long timestamp_sec; + long timestamp_usec; + unsigned int hook; + char indev_name[IFNAMSIZ]; + char outdev_name[IFNAMSIZ]; + size_t data_len; + char prefix[ULOG_PREFIX_LEN]; + unsigned char mac_len; + unsigned char mac[ULOG_MAC_LEN]; + unsigned char payload[0]; +} ulog_packet_msg_t; +#endif + +/** + * @lock: lock to protect quota writers from each other + */ +struct xt_quota_counter { + u_int64_t quota; + spinlock_t lock; + struct list_head list; + atomic_t ref; + char name[sizeof(((struct xt_quota_mtinfo2 *)NULL)->name)]; + struct proc_dir_entry *procfs_entry; +}; + +#ifdef CONFIG_NETFILTER_XT_MATCH_QUOTA2_LOG +/* Harald's favorite number +1 :D From ipt_ULOG.C */ +static int qlog_nl_event = 112; +module_param_named(event_num, qlog_nl_event, uint, S_IRUGO | S_IWUSR); +MODULE_PARM_DESC(event_num, + "Event number for NETLINK_NFLOG message. 0 disables log." + "111 is what ipt_ULOG uses."); +static struct sock *nflognl; +#endif + +static LIST_HEAD(counter_list); +static DEFINE_SPINLOCK(counter_list_lock); + +static struct proc_dir_entry *proc_xt_quota; +static unsigned int quota_list_perms = S_IRUGO | S_IWUSR; +static kuid_t quota_list_uid = KUIDT_INIT(0); +static kgid_t quota_list_gid = KGIDT_INIT(0); +module_param_named(perms, quota_list_perms, uint, S_IRUGO | S_IWUSR); + +#ifdef CONFIG_NETFILTER_XT_MATCH_QUOTA2_LOG +static void quota2_log(unsigned int hooknum, + const struct sk_buff *skb, + const struct net_device *in, + const struct net_device *out, + const char *prefix) +{ + ulog_packet_msg_t *pm; + struct sk_buff *log_skb; + size_t size; + struct nlmsghdr *nlh; + + if (!qlog_nl_event) + return; + + size = NLMSG_SPACE(sizeof(*pm)); + size = max(size, (size_t)NLMSG_GOODSIZE); + log_skb = alloc_skb(size, GFP_ATOMIC); + if (!log_skb) { + pr_err("xt_quota2: cannot alloc skb for logging\n"); + return; + } + + nlh = nlmsg_put(log_skb, /*pid*/0, /*seq*/0, qlog_nl_event, + sizeof(*pm), 0); + if (!nlh) { + pr_err("xt_quota2: nlmsg_put failed\n"); + kfree_skb(log_skb); + return; + } + pm = nlmsg_data(nlh); + memset(pm, 0, sizeof(*pm)); + if (skb->tstamp.tv64 == 0) + __net_timestamp((struct sk_buff *)skb); + pm->hook = hooknum; + if (prefix != NULL) + strlcpy(pm->prefix, prefix, sizeof(pm->prefix)); + if (in) + strlcpy(pm->indev_name, in->name, sizeof(pm->indev_name)); + if (out) + strlcpy(pm->outdev_name, out->name, sizeof(pm->outdev_name)); + + NETLINK_CB(log_skb).dst_group = 1; + pr_debug("throwing 1 packets to netlink group 1\n"); + netlink_broadcast(nflognl, log_skb, 0, 1, GFP_ATOMIC); +} +#else +static void quota2_log(unsigned int hooknum, + const struct sk_buff *skb, + const struct net_device *in, + const struct net_device *out, + const char *prefix) +{ +} +#endif /* if+else CONFIG_NETFILTER_XT_MATCH_QUOTA2_LOG */ + +static ssize_t quota_proc_read(struct file *file, char __user *buf, + size_t size, loff_t *ppos) +{ + struct xt_quota_counter *e = PDE_DATA(file_inode(file)); + char tmp[24]; + size_t tmp_size; + + spin_lock_bh(&e->lock); + tmp_size = scnprintf(tmp, sizeof(tmp), "%llu\n", e->quota); + spin_unlock_bh(&e->lock); + return simple_read_from_buffer(buf, size, ppos, tmp, tmp_size); +} + +static ssize_t quota_proc_write(struct file *file, const char __user *input, + size_t size, loff_t *ppos) +{ + struct xt_quota_counter *e = PDE_DATA(file_inode(file)); + char buf[sizeof("18446744073709551616")]; + + if (size > sizeof(buf)) + size = sizeof(buf); + if (copy_from_user(buf, input, size) != 0) + return -EFAULT; + buf[sizeof(buf)-1] = '\0'; + if (size < sizeof(buf)) + buf[size] = '\0'; + + spin_lock_bh(&e->lock); + e->quota = simple_strtoull(buf, NULL, 0); + spin_unlock_bh(&e->lock); + return size; +} + +static const struct file_operations q2_counter_fops = { + .read = quota_proc_read, + .write = quota_proc_write, + .llseek = default_llseek, +}; + +static struct xt_quota_counter * +q2_new_counter(const struct xt_quota_mtinfo2 *q, bool anon) +{ + struct xt_quota_counter *e; + unsigned int size; + + /* Do not need all the procfs things for anonymous counters. */ + size = anon ? offsetof(typeof(*e), list) : sizeof(*e); + e = kmalloc(size, GFP_KERNEL); + if (e == NULL) + return NULL; + + e->quota = q->quota; + spin_lock_init(&e->lock); + if (!anon) { + INIT_LIST_HEAD(&e->list); + atomic_set(&e->ref, 1); + strlcpy(e->name, q->name, sizeof(e->name)); + } + return e; +} + +/** + * q2_get_counter - get ref to counter or create new + * @name: name of counter + */ +static struct xt_quota_counter * +q2_get_counter(const struct xt_quota_mtinfo2 *q) +{ + struct proc_dir_entry *p; + struct xt_quota_counter *e = NULL; + struct xt_quota_counter *new_e; + + if (*q->name == '\0') + return q2_new_counter(q, true); + + /* No need to hold a lock while getting a new counter */ + new_e = q2_new_counter(q, false); + if (new_e == NULL) + goto out; + + spin_lock_bh(&counter_list_lock); + list_for_each_entry(e, &counter_list, list) + if (strcmp(e->name, q->name) == 0) { + atomic_inc(&e->ref); + spin_unlock_bh(&counter_list_lock); + kfree(new_e); + pr_debug("xt_quota2: old counter name=%s", e->name); + return e; + } + e = new_e; + pr_debug("xt_quota2: new_counter name=%s", e->name); + list_add_tail(&e->list, &counter_list); + /* The entry having a refcount of 1 is not directly destructible. + * This func has not yet returned the new entry, thus iptables + * has not references for destroying this entry. + * For another rule to try to destroy it, it would 1st need for this + * func* to be re-invoked, acquire a new ref for the same named quota. + * Nobody will access the e->procfs_entry either. + * So release the lock. */ + spin_unlock_bh(&counter_list_lock); + + /* create_proc_entry() is not spin_lock happy */ + p = e->procfs_entry = proc_create_data(e->name, quota_list_perms, + proc_xt_quota, &q2_counter_fops, e); + + if (IS_ERR_OR_NULL(p)) { + spin_lock_bh(&counter_list_lock); + list_del(&e->list); + spin_unlock_bh(&counter_list_lock); + goto out; + } + proc_set_user(p, quota_list_uid, quota_list_gid); + return e; + + out: + kfree(e); + return NULL; +} + +static int quota_mt2_check(const struct xt_mtchk_param *par) +{ + struct xt_quota_mtinfo2 *q = par->matchinfo; + + pr_debug("xt_quota2: check() flags=0x%04x", q->flags); + + if (q->flags & ~XT_QUOTA_MASK) + return -EINVAL; + + q->name[sizeof(q->name)-1] = '\0'; + if (*q->name == '.' || strchr(q->name, '/') != NULL) { + printk(KERN_ERR "xt_quota.3: illegal name\n"); + return -EINVAL; + } + + q->master = q2_get_counter(q); + if (q->master == NULL) { + printk(KERN_ERR "xt_quota.3: memory alloc failure\n"); + return -ENOMEM; + } + + return 0; +} + +static void quota_mt2_destroy(const struct xt_mtdtor_param *par) +{ + struct xt_quota_mtinfo2 *q = par->matchinfo; + struct xt_quota_counter *e = q->master; + + if (*q->name == '\0') { + kfree(e); + return; + } + + spin_lock_bh(&counter_list_lock); + if (!atomic_dec_and_test(&e->ref)) { + spin_unlock_bh(&counter_list_lock); + return; + } + + list_del(&e->list); + spin_unlock_bh(&counter_list_lock); + remove_proc_entry(e->name, proc_xt_quota); + kfree(e); +} + +static bool +quota_mt2(const struct sk_buff *skb, struct xt_action_param *par) +{ + struct xt_quota_mtinfo2 *q = (void *)par->matchinfo; + struct xt_quota_counter *e = q->master; + int charge = (q->flags & XT_QUOTA_PACKET) ? 1 : skb->len; + bool no_change = q->flags & XT_QUOTA_NO_CHANGE; + bool ret = q->flags & XT_QUOTA_INVERT; + + spin_lock_bh(&e->lock); + if (q->flags & XT_QUOTA_GROW) { + /* + * While no_change is pointless in "grow" mode, we will + * implement it here simply to have a consistent behavior. + */ + if (!no_change) + e->quota += charge; + ret = true; /* note: does not respect inversion (bug??) */ + } else { + if (e->quota > charge) { + if (!no_change) + e->quota -= charge; + ret = !ret; + } else if (e->quota) { + /* We are transitioning, log that fact. */ + quota2_log(par->hooknum, + skb, + par->in, + par->out, + q->name); + /* we do not allow even small packets from now on */ + e->quota = 0; + } + } + spin_unlock_bh(&e->lock); + return ret; +} + +static struct xt_match quota_mt2_reg[] __read_mostly = { + { + .name = "quota2", + .revision = 3, + .family = NFPROTO_IPV4, + .checkentry = quota_mt2_check, + .match = quota_mt2, + .destroy = quota_mt2_destroy, + .matchsize = sizeof(struct xt_quota_mtinfo2), + .usersize = offsetof(struct xt_quota_mtinfo2, master), + .me = THIS_MODULE, + }, + { + .name = "quota2", + .revision = 3, + .family = NFPROTO_IPV6, + .checkentry = quota_mt2_check, + .match = quota_mt2, + .destroy = quota_mt2_destroy, + .matchsize = sizeof(struct xt_quota_mtinfo2), + .usersize = offsetof(struct xt_quota_mtinfo2, master), + .me = THIS_MODULE, + }, +}; + +static int __init quota_mt2_init(void) +{ + int ret; + pr_debug("xt_quota2: init()"); + +#ifdef CONFIG_NETFILTER_XT_MATCH_QUOTA2_LOG + nflognl = netlink_kernel_create(&init_net, NETLINK_NFLOG, NULL); + if (!nflognl) + return -ENOMEM; +#endif + + proc_xt_quota = proc_mkdir("xt_quota", init_net.proc_net); + if (proc_xt_quota == NULL) + return -EACCES; + + ret = xt_register_matches(quota_mt2_reg, ARRAY_SIZE(quota_mt2_reg)); + if (ret < 0) + remove_proc_entry("xt_quota", init_net.proc_net); + pr_debug("xt_quota2: init() %d", ret); + return ret; +} + +static void __exit quota_mt2_exit(void) +{ + xt_unregister_matches(quota_mt2_reg, ARRAY_SIZE(quota_mt2_reg)); + remove_proc_entry("xt_quota", init_net.proc_net); +} + +module_init(quota_mt2_init); +module_exit(quota_mt2_exit); +MODULE_DESCRIPTION("Xtables: countdown quota match; up counter"); +MODULE_AUTHOR("Sam Johnston <samj@samj.net>"); +MODULE_AUTHOR("Jan Engelhardt <jengelh@medozas.de>"); +MODULE_LICENSE("GPL"); +MODULE_ALIAS("ipt_quota2"); +MODULE_ALIAS("ip6t_quota2"); diff --git a/net/netfilter/xt_rateest.c b/net/netfilter/xt_rateest.c index 7720b036d76a..d46dc9ff591f 100644 --- a/net/netfilter/xt_rateest.c +++ b/net/netfilter/xt_rateest.c @@ -135,6 +135,7 @@ static struct xt_match xt_rateest_mt_reg __read_mostly = { .checkentry = xt_rateest_mt_checkentry, .destroy = xt_rateest_mt_destroy, .matchsize = sizeof(struct xt_rateest_match_info), + .usersize = offsetof(struct xt_rateest_match_info, est1), .me = THIS_MODULE, }; diff --git a/net/netfilter/xt_socket.c b/net/netfilter/xt_socket.c index 2ec08f04b816..8a2a489b2cd3 100644 --- a/net/netfilter/xt_socket.c +++ b/net/netfilter/xt_socket.c @@ -143,11 +143,12 @@ static bool xt_socket_sk_is_transparent(struct sock *sk) } } -static struct sock *xt_socket_lookup_slow_v4(struct net *net, +struct sock *xt_socket_lookup_slow_v4(struct net *net, const struct sk_buff *skb, const struct net_device *indev) { const struct iphdr *iph = ip_hdr(skb); + struct sock *sk = skb->sk; __be32 uninitialized_var(daddr), uninitialized_var(saddr); __be16 uninitialized_var(dport), uninitialized_var(sport); u8 uninitialized_var(protocol); @@ -198,9 +199,16 @@ static struct sock *xt_socket_lookup_slow_v4(struct net *net, } #endif - return xt_socket_get_sock_v4(net, protocol, saddr, daddr, - sport, dport, indev); + if (sk) + atomic_inc(&sk->sk_refcnt); + else + sk = xt_socket_get_sock_v4(dev_net(skb->dev), protocol, + saddr, daddr, sport, dport, + indev); + + return sk; } +EXPORT_SYMBOL(xt_socket_lookup_slow_v4); static bool socket_match(const struct sk_buff *skb, struct xt_action_param *par, @@ -232,8 +240,7 @@ socket_match(const struct sk_buff *skb, struct xt_action_param *par, transparent) pskb->mark = sk->sk_mark; - if (sk != skb->sk) - sock_gen_put(sk); + sock_gen_put(sk); if (wildcard || !transparent) sk = NULL; @@ -336,10 +343,11 @@ xt_socket_get_sock_v6(struct net *net, const u8 protocol, return NULL; } -static struct sock *xt_socket_lookup_slow_v6(struct net *net, +struct sock *xt_socket_lookup_slow_v6(struct net *net, const struct sk_buff *skb, const struct net_device *indev) { + struct sock *sk = skb->sk; __be16 uninitialized_var(dport), uninitialized_var(sport); const struct in6_addr *daddr = NULL, *saddr = NULL; struct ipv6hdr *iph = ipv6_hdr(skb); @@ -373,9 +381,16 @@ static struct sock *xt_socket_lookup_slow_v6(struct net *net, return NULL; } - return xt_socket_get_sock_v6(net, tproto, saddr, daddr, - sport, dport, indev); + if (sk) + atomic_inc(&sk->sk_refcnt); + else + sk = xt_socket_get_sock_v6(dev_net(skb->dev), tproto, + saddr, daddr, sport, dport, + indev); + + return sk; } +EXPORT_SYMBOL(xt_socket_lookup_slow_v6); static bool socket_mt6_v1_v2_v3(const struct sk_buff *skb, struct xt_action_param *par) diff --git a/net/netfilter/xt_statistic.c b/net/netfilter/xt_statistic.c index 11de55e7a868..8710fdba2ae2 100644 --- a/net/netfilter/xt_statistic.c +++ b/net/netfilter/xt_statistic.c @@ -84,6 +84,7 @@ static struct xt_match xt_statistic_mt_reg __read_mostly = { .checkentry = statistic_mt_check, .destroy = statistic_mt_destroy, .matchsize = sizeof(struct xt_statistic_info), + .usersize = offsetof(struct xt_statistic_info, master), .me = THIS_MODULE, }; diff --git a/net/netfilter/xt_string.c b/net/netfilter/xt_string.c index 0bc3460319c8..423293ee57c2 100644 --- a/net/netfilter/xt_string.c +++ b/net/netfilter/xt_string.c @@ -77,6 +77,7 @@ static struct xt_match xt_string_mt_reg __read_mostly = { .match = string_mt, .destroy = string_mt_destroy, .matchsize = sizeof(struct xt_string_info), + .usersize = offsetof(struct xt_string_info, config), .me = THIS_MODULE, }; diff --git a/net/rfkill/Kconfig b/net/rfkill/Kconfig index 598d374f6a35..44c3be9d1021 100644 --- a/net/rfkill/Kconfig +++ b/net/rfkill/Kconfig @@ -10,6 +10,11 @@ menuconfig RFKILL To compile this driver as a module, choose M here: the module will be called rfkill. +config RFKILL_PM + bool "Power off on suspend" + depends on RFKILL && PM + default y + # LED trigger support config RFKILL_LEDS bool diff --git a/net/rfkill/core.c b/net/rfkill/core.c index ad927a6ca2a1..8662a976f666 100644 --- a/net/rfkill/core.c +++ b/net/rfkill/core.c @@ -802,8 +802,7 @@ void rfkill_resume_polling(struct rfkill *rfkill) } EXPORT_SYMBOL(rfkill_resume_polling); -#ifdef CONFIG_PM_SLEEP -static int rfkill_suspend(struct device *dev) +static __maybe_unused int rfkill_suspend(struct device *dev) { struct rfkill *rfkill = to_rfkill(dev); @@ -812,7 +811,7 @@ static int rfkill_suspend(struct device *dev) return 0; } -static int rfkill_resume(struct device *dev) +static __maybe_unused int rfkill_resume(struct device *dev) { struct rfkill *rfkill = to_rfkill(dev); bool cur; @@ -828,17 +827,13 @@ static int rfkill_resume(struct device *dev) } static SIMPLE_DEV_PM_OPS(rfkill_pm_ops, rfkill_suspend, rfkill_resume); -#define RFKILL_PM_OPS (&rfkill_pm_ops) -#else -#define RFKILL_PM_OPS NULL -#endif static struct class rfkill_class = { .name = "rfkill", .dev_release = rfkill_release, .dev_groups = rfkill_dev_groups, .dev_uevent = rfkill_dev_uevent, - .pm = RFKILL_PM_OPS, + .pm = IS_ENABLED(CONFIG_RFKILL_PM) ? &rfkill_pm_ops : NULL, }; bool rfkill_blocked(struct rfkill *rfkill) diff --git a/net/socket.c b/net/socket.c index 1392461d391a..4f122228284c 100644 --- a/net/socket.c +++ b/net/socket.c @@ -509,9 +509,26 @@ static ssize_t sockfs_listxattr(struct dentry *dentry, char *buffer, return used; } +static int sockfs_setattr(struct dentry *dentry, struct iattr *iattr) +{ + int err = simple_setattr(dentry, iattr); + + if (!err && (iattr->ia_valid & ATTR_UID)) { + struct socket *sock = SOCKET_I(d_inode(dentry)); + + if (sock->sk) + sock->sk->sk_uid = iattr->ia_uid; + else + err = -ENOENT; + } + + return err; +} + static const struct inode_operations sockfs_inode_ops = { .getxattr = sockfs_getxattr, .listxattr = sockfs_listxattr, + .setattr = sockfs_setattr, }; /** @@ -553,12 +570,17 @@ static struct socket *sock_alloc(void) * an inode not a file. */ -void sock_release(struct socket *sock) +static void __sock_release(struct socket *sock, struct inode *inode) { if (sock->ops) { struct module *owner = sock->ops->owner; + if (inode) + inode_lock(inode); sock->ops->release(sock); + sock->sk = NULL; + if (inode) + inode_unlock(inode); sock->ops = NULL; module_put(owner); } @@ -573,6 +595,11 @@ void sock_release(struct socket *sock) } sock->file = NULL; } + +void sock_release(struct socket *sock) +{ + __sock_release(sock, NULL); +} EXPORT_SYMBOL(sock_release); void __sock_tx_timestamp(const struct sock *sk, __u8 *tx_flags) @@ -1009,7 +1036,7 @@ static int sock_mmap(struct file *file, struct vm_area_struct *vma) static int sock_close(struct inode *inode, struct file *filp) { - sock_release(SOCKET_I(inode)); + __sock_release(SOCKET_I(inode), inode); return 0; } diff --git a/net/vmw_vsock/Kconfig b/net/vmw_vsock/Kconfig index 14810abedc2e..8831e7c42167 100644 --- a/net/vmw_vsock/Kconfig +++ b/net/vmw_vsock/Kconfig @@ -26,3 +26,23 @@ config VMWARE_VMCI_VSOCKETS To compile this driver as a module, choose M here: the module will be called vmw_vsock_vmci_transport. If unsure, say N. + +config VIRTIO_VSOCKETS + tristate "virtio transport for Virtual Sockets" + depends on VSOCKETS && VIRTIO + select VIRTIO_VSOCKETS_COMMON + help + This module implements a virtio transport for Virtual Sockets. + + Enable this transport if your Virtual Machine host supports Virtual + Sockets over virtio. + + To compile this driver as a module, choose M here: the module will be + called vmw_vsock_virtio_transport. If unsure, say N. + +config VIRTIO_VSOCKETS_COMMON + tristate + help + This option is selected by any driver which needs to access + the virtio_vsock. The module will be called + vmw_vsock_virtio_transport_common. diff --git a/net/vmw_vsock/Makefile b/net/vmw_vsock/Makefile index 2ce52d70f224..bc27c70e0e59 100644 --- a/net/vmw_vsock/Makefile +++ b/net/vmw_vsock/Makefile @@ -1,7 +1,13 @@ obj-$(CONFIG_VSOCKETS) += vsock.o obj-$(CONFIG_VMWARE_VMCI_VSOCKETS) += vmw_vsock_vmci_transport.o +obj-$(CONFIG_VIRTIO_VSOCKETS) += vmw_vsock_virtio_transport.o +obj-$(CONFIG_VIRTIO_VSOCKETS_COMMON) += vmw_vsock_virtio_transport_common.o vsock-y += af_vsock.o vsock_addr.o vmw_vsock_vmci_transport-y += vmci_transport.o vmci_transport_notify.o \ vmci_transport_notify_qstate.o + +vmw_vsock_virtio_transport-y += virtio_transport.o + +vmw_vsock_virtio_transport_common-y += virtio_transport_common.o diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index 537d57558c21..baab5f65fbeb 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -61,6 +61,14 @@ * function will also cleanup rejected sockets, those that reach the connected * state but leave it before they have been accepted. * + * - Lock ordering for pending or accept queue sockets is: + * + * lock_sock(listener); + * lock_sock_nested(pending, SINGLE_DEPTH_NESTING); + * + * Using explicit nested locking keeps lockdep happy since normally only one + * lock of a given class may be taken at a time. + * * - Sockets created by user action will be cleaned up when the user process * calls close(2), causing our release implementation to be called. Our release * implementation will perform some cleanup then drop the last reference so our @@ -337,6 +345,16 @@ static bool vsock_in_connected_table(struct vsock_sock *vsk) return ret; } +void vsock_remove_sock(struct vsock_sock *vsk) +{ + if (vsock_in_bound_table(vsk)) + vsock_remove_bound(vsk); + + if (vsock_in_connected_table(vsk)) + vsock_remove_connected(vsk); +} +EXPORT_SYMBOL_GPL(vsock_remove_sock); + void vsock_for_each_connected_socket(void (*fn)(struct sock *sk)) { int i; @@ -444,10 +462,12 @@ static void vsock_pending_work(struct work_struct *work) cleanup = true; lock_sock(listener); - lock_sock(sk); + lock_sock_nested(sk, SINGLE_DEPTH_NESTING); if (vsock_is_pending(sk)) { vsock_remove_pending(listener, sk); + + listener->sk_ack_backlog--; } else if (!vsk->rejected) { /* We are not on the pending list and accept() did not reject * us, so we must have been accepted by our user process. We @@ -458,8 +478,6 @@ static void vsock_pending_work(struct work_struct *work) goto out; } - listener->sk_ack_backlog--; - /* We need to remove ourself from the global connected sockets list so * incoming packets can't find this socket, and to reduce the reference * count. @@ -661,12 +679,6 @@ static void __vsock_release(struct sock *sk) vsk = vsock_sk(sk); pending = NULL; /* Compiler warning. */ - if (vsock_in_bound_table(vsk)) - vsock_remove_bound(vsk); - - if (vsock_in_connected_table(vsk)) - vsock_remove_connected(vsk); - transport->release(vsk); lock_sock(sk); @@ -1100,10 +1112,19 @@ static const struct proto_ops vsock_dgram_ops = { .sendpage = sock_no_sendpage, }; +static int vsock_transport_cancel_pkt(struct vsock_sock *vsk) +{ + if (!transport->cancel_pkt) + return -EOPNOTSUPP; + + return transport->cancel_pkt(vsk); +} + static void vsock_connect_timeout(struct work_struct *work) { struct sock *sk; struct vsock_sock *vsk; + int cancel = 0; vsk = container_of(work, struct vsock_sock, connect_work.work); sk = sk_vsock(vsk); @@ -1114,8 +1135,11 @@ static void vsock_connect_timeout(struct work_struct *work) sk->sk_state = SS_UNCONNECTED; sk->sk_err = ETIMEDOUT; sk->sk_error_report(sk); + cancel = 1; } release_sock(sk); + if (cancel) + vsock_transport_cancel_pkt(vsk); sock_put(sk); } @@ -1222,11 +1246,13 @@ static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr, err = sock_intr_errno(timeout); sk->sk_state = SS_UNCONNECTED; sock->state = SS_UNCONNECTED; + vsock_transport_cancel_pkt(vsk); goto out_wait; } else if (timeout == 0) { err = -ETIMEDOUT; sk->sk_state = SS_UNCONNECTED; sock->state = SS_UNCONNECTED; + vsock_transport_cancel_pkt(vsk); goto out_wait; } @@ -1303,7 +1329,7 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags) if (connected) { listener->sk_ack_backlog--; - lock_sock(connected); + lock_sock_nested(connected, SINGLE_DEPTH_NESTING); vconnected = vsock_sk(connected); /* If the listener socket has received an error, then we should @@ -1993,7 +2019,16 @@ void vsock_core_exit(void) } EXPORT_SYMBOL_GPL(vsock_core_exit); +const struct vsock_transport *vsock_core_get_transport(void) +{ + /* vsock_register_mutex not taken since only the transport uses this + * function and only while registered. + */ + return transport; +} +EXPORT_SYMBOL_GPL(vsock_core_get_transport); + MODULE_AUTHOR("VMware, Inc."); MODULE_DESCRIPTION("VMware Virtual Socket Family"); -MODULE_VERSION("1.0.1.0-k"); +MODULE_VERSION("1.0.2.0-k"); MODULE_LICENSE("GPL v2"); diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c new file mode 100644 index 000000000000..936d7eee62d0 --- /dev/null +++ b/net/vmw_vsock/virtio_transport.c @@ -0,0 +1,620 @@ +/* + * virtio transport for vsock + * + * Copyright (C) 2013-2015 Red Hat, Inc. + * Author: Asias He <asias@redhat.com> + * Stefan Hajnoczi <stefanha@redhat.com> + * + * Some of the code is take from Gerd Hoffmann <kraxel@redhat.com>'s + * early virtio-vsock proof-of-concept bits. + * + * This work is licensed under the terms of the GNU GPL, version 2. + */ +#include <linux/spinlock.h> +#include <linux/module.h> +#include <linux/list.h> +#include <linux/atomic.h> +#include <linux/virtio.h> +#include <linux/virtio_ids.h> +#include <linux/virtio_config.h> +#include <linux/virtio_vsock.h> +#include <net/sock.h> +#include <linux/mutex.h> +#include <net/af_vsock.h> + +static struct workqueue_struct *virtio_vsock_workqueue; +static struct virtio_vsock *the_virtio_vsock; +static DEFINE_MUTEX(the_virtio_vsock_mutex); /* protects the_virtio_vsock */ + +struct virtio_vsock { + struct virtio_device *vdev; + struct virtqueue *vqs[VSOCK_VQ_MAX]; + + /* Virtqueue processing is deferred to a workqueue */ + struct work_struct tx_work; + struct work_struct rx_work; + struct work_struct event_work; + + /* The following fields are protected by tx_lock. vqs[VSOCK_VQ_TX] + * must be accessed with tx_lock held. + */ + struct mutex tx_lock; + + struct work_struct send_pkt_work; + spinlock_t send_pkt_list_lock; + struct list_head send_pkt_list; + + atomic_t queued_replies; + + /* The following fields are protected by rx_lock. vqs[VSOCK_VQ_RX] + * must be accessed with rx_lock held. + */ + struct mutex rx_lock; + int rx_buf_nr; + int rx_buf_max_nr; + + /* The following fields are protected by event_lock. + * vqs[VSOCK_VQ_EVENT] must be accessed with event_lock held. + */ + struct mutex event_lock; + struct virtio_vsock_event event_list[8]; + + u32 guest_cid; +}; + +static struct virtio_vsock *virtio_vsock_get(void) +{ + return the_virtio_vsock; +} + +static u32 virtio_transport_get_local_cid(void) +{ + struct virtio_vsock *vsock = virtio_vsock_get(); + + return vsock->guest_cid; +} + +static void +virtio_transport_send_pkt_work(struct work_struct *work) +{ + struct virtio_vsock *vsock = + container_of(work, struct virtio_vsock, send_pkt_work); + struct virtqueue *vq; + bool added = false; + bool restart_rx = false; + + mutex_lock(&vsock->tx_lock); + + vq = vsock->vqs[VSOCK_VQ_TX]; + + for (;;) { + struct virtio_vsock_pkt *pkt; + struct scatterlist hdr, buf, *sgs[2]; + int ret, in_sg = 0, out_sg = 0; + bool reply; + + spin_lock_bh(&vsock->send_pkt_list_lock); + if (list_empty(&vsock->send_pkt_list)) { + spin_unlock_bh(&vsock->send_pkt_list_lock); + break; + } + + pkt = list_first_entry(&vsock->send_pkt_list, + struct virtio_vsock_pkt, list); + list_del_init(&pkt->list); + spin_unlock_bh(&vsock->send_pkt_list_lock); + + reply = pkt->reply; + + sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr)); + sgs[out_sg++] = &hdr; + if (pkt->buf) { + sg_init_one(&buf, pkt->buf, pkt->len); + sgs[out_sg++] = &buf; + } + + ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, pkt, GFP_KERNEL); + /* Usually this means that there is no more space available in + * the vq + */ + if (ret < 0) { + spin_lock_bh(&vsock->send_pkt_list_lock); + list_add(&pkt->list, &vsock->send_pkt_list); + spin_unlock_bh(&vsock->send_pkt_list_lock); + break; + } + + if (reply) { + struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX]; + int val; + + val = atomic_dec_return(&vsock->queued_replies); + + /* Do we now have resources to resume rx processing? */ + if (val + 1 == virtqueue_get_vring_size(rx_vq)) + restart_rx = true; + } + + added = true; + } + + if (added) + virtqueue_kick(vq); + + mutex_unlock(&vsock->tx_lock); + + if (restart_rx) + queue_work(virtio_vsock_workqueue, &vsock->rx_work); +} + +static int +virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt) +{ + struct virtio_vsock *vsock; + int len = pkt->len; + + vsock = virtio_vsock_get(); + if (!vsock) { + virtio_transport_free_pkt(pkt); + return -ENODEV; + } + + if (pkt->reply) + atomic_inc(&vsock->queued_replies); + + spin_lock_bh(&vsock->send_pkt_list_lock); + list_add_tail(&pkt->list, &vsock->send_pkt_list); + spin_unlock_bh(&vsock->send_pkt_list_lock); + + queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work); + return len; +} + +static void virtio_vsock_rx_fill(struct virtio_vsock *vsock) +{ + int buf_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; + struct virtio_vsock_pkt *pkt; + struct scatterlist hdr, buf, *sgs[2]; + struct virtqueue *vq; + int ret; + + vq = vsock->vqs[VSOCK_VQ_RX]; + + do { + pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); + if (!pkt) + break; + + pkt->buf = kmalloc(buf_len, GFP_KERNEL); + if (!pkt->buf) { + virtio_transport_free_pkt(pkt); + break; + } + + pkt->len = buf_len; + + sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr)); + sgs[0] = &hdr; + + sg_init_one(&buf, pkt->buf, buf_len); + sgs[1] = &buf; + ret = virtqueue_add_sgs(vq, sgs, 0, 2, pkt, GFP_KERNEL); + if (ret) { + virtio_transport_free_pkt(pkt); + break; + } + vsock->rx_buf_nr++; + } while (vq->num_free); + if (vsock->rx_buf_nr > vsock->rx_buf_max_nr) + vsock->rx_buf_max_nr = vsock->rx_buf_nr; + virtqueue_kick(vq); +} + +static void virtio_transport_tx_work(struct work_struct *work) +{ + struct virtio_vsock *vsock = + container_of(work, struct virtio_vsock, tx_work); + struct virtqueue *vq; + bool added = false; + + vq = vsock->vqs[VSOCK_VQ_TX]; + mutex_lock(&vsock->tx_lock); + do { + struct virtio_vsock_pkt *pkt; + unsigned int len; + + virtqueue_disable_cb(vq); + while ((pkt = virtqueue_get_buf(vq, &len)) != NULL) { + virtio_transport_free_pkt(pkt); + added = true; + } + } while (!virtqueue_enable_cb(vq)); + mutex_unlock(&vsock->tx_lock); + + if (added) + queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work); +} + +/* Is there space left for replies to rx packets? */ +static bool virtio_transport_more_replies(struct virtio_vsock *vsock) +{ + struct virtqueue *vq = vsock->vqs[VSOCK_VQ_RX]; + int val; + + smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */ + val = atomic_read(&vsock->queued_replies); + + return val < virtqueue_get_vring_size(vq); +} + +static void virtio_transport_rx_work(struct work_struct *work) +{ + struct virtio_vsock *vsock = + container_of(work, struct virtio_vsock, rx_work); + struct virtqueue *vq; + + vq = vsock->vqs[VSOCK_VQ_RX]; + + mutex_lock(&vsock->rx_lock); + + do { + virtqueue_disable_cb(vq); + for (;;) { + struct virtio_vsock_pkt *pkt; + unsigned int len; + + if (!virtio_transport_more_replies(vsock)) { + /* Stop rx until the device processes already + * pending replies. Leave rx virtqueue + * callbacks disabled. + */ + goto out; + } + + pkt = virtqueue_get_buf(vq, &len); + if (!pkt) { + break; + } + + vsock->rx_buf_nr--; + + /* Drop short/long packets */ + if (unlikely(len < sizeof(pkt->hdr) || + len > sizeof(pkt->hdr) + pkt->len)) { + virtio_transport_free_pkt(pkt); + continue; + } + + pkt->len = len - sizeof(pkt->hdr); + virtio_transport_recv_pkt(pkt); + } + } while (!virtqueue_enable_cb(vq)); + +out: + if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2) + virtio_vsock_rx_fill(vsock); + mutex_unlock(&vsock->rx_lock); +} + +/* event_lock must be held */ +static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock, + struct virtio_vsock_event *event) +{ + struct scatterlist sg; + struct virtqueue *vq; + + vq = vsock->vqs[VSOCK_VQ_EVENT]; + + sg_init_one(&sg, event, sizeof(*event)); + + return virtqueue_add_inbuf(vq, &sg, 1, event, GFP_KERNEL); +} + +/* event_lock must be held */ +static void virtio_vsock_event_fill(struct virtio_vsock *vsock) +{ + size_t i; + + for (i = 0; i < ARRAY_SIZE(vsock->event_list); i++) { + struct virtio_vsock_event *event = &vsock->event_list[i]; + + virtio_vsock_event_fill_one(vsock, event); + } + + virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]); +} + +static void virtio_vsock_reset_sock(struct sock *sk) +{ + lock_sock(sk); + sk->sk_state = SS_UNCONNECTED; + sk->sk_err = ECONNRESET; + sk->sk_error_report(sk); + release_sock(sk); +} + +static void virtio_vsock_update_guest_cid(struct virtio_vsock *vsock) +{ + struct virtio_device *vdev = vsock->vdev; + u64 guest_cid; + + vdev->config->get(vdev, offsetof(struct virtio_vsock_config, guest_cid), + &guest_cid, sizeof(guest_cid)); + vsock->guest_cid = le64_to_cpu(guest_cid); +} + +/* event_lock must be held */ +static void virtio_vsock_event_handle(struct virtio_vsock *vsock, + struct virtio_vsock_event *event) +{ + switch (le32_to_cpu(event->id)) { + case VIRTIO_VSOCK_EVENT_TRANSPORT_RESET: + virtio_vsock_update_guest_cid(vsock); + vsock_for_each_connected_socket(virtio_vsock_reset_sock); + break; + } +} + +static void virtio_transport_event_work(struct work_struct *work) +{ + struct virtio_vsock *vsock = + container_of(work, struct virtio_vsock, event_work); + struct virtqueue *vq; + + vq = vsock->vqs[VSOCK_VQ_EVENT]; + + mutex_lock(&vsock->event_lock); + + do { + struct virtio_vsock_event *event; + unsigned int len; + + virtqueue_disable_cb(vq); + while ((event = virtqueue_get_buf(vq, &len)) != NULL) { + if (len == sizeof(*event)) + virtio_vsock_event_handle(vsock, event); + + virtio_vsock_event_fill_one(vsock, event); + } + } while (!virtqueue_enable_cb(vq)); + + virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]); + + mutex_unlock(&vsock->event_lock); +} + +static void virtio_vsock_event_done(struct virtqueue *vq) +{ + struct virtio_vsock *vsock = vq->vdev->priv; + + if (!vsock) + return; + queue_work(virtio_vsock_workqueue, &vsock->event_work); +} + +static void virtio_vsock_tx_done(struct virtqueue *vq) +{ + struct virtio_vsock *vsock = vq->vdev->priv; + + if (!vsock) + return; + queue_work(virtio_vsock_workqueue, &vsock->tx_work); +} + +static void virtio_vsock_rx_done(struct virtqueue *vq) +{ + struct virtio_vsock *vsock = vq->vdev->priv; + + if (!vsock) + return; + queue_work(virtio_vsock_workqueue, &vsock->rx_work); +} + +static struct virtio_transport virtio_transport = { + .transport = { + .get_local_cid = virtio_transport_get_local_cid, + + .init = virtio_transport_do_socket_init, + .destruct = virtio_transport_destruct, + .release = virtio_transport_release, + .connect = virtio_transport_connect, + .shutdown = virtio_transport_shutdown, + + .dgram_bind = virtio_transport_dgram_bind, + .dgram_dequeue = virtio_transport_dgram_dequeue, + .dgram_enqueue = virtio_transport_dgram_enqueue, + .dgram_allow = virtio_transport_dgram_allow, + + .stream_dequeue = virtio_transport_stream_dequeue, + .stream_enqueue = virtio_transport_stream_enqueue, + .stream_has_data = virtio_transport_stream_has_data, + .stream_has_space = virtio_transport_stream_has_space, + .stream_rcvhiwat = virtio_transport_stream_rcvhiwat, + .stream_is_active = virtio_transport_stream_is_active, + .stream_allow = virtio_transport_stream_allow, + + .notify_poll_in = virtio_transport_notify_poll_in, + .notify_poll_out = virtio_transport_notify_poll_out, + .notify_recv_init = virtio_transport_notify_recv_init, + .notify_recv_pre_block = virtio_transport_notify_recv_pre_block, + .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue, + .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue, + .notify_send_init = virtio_transport_notify_send_init, + .notify_send_pre_block = virtio_transport_notify_send_pre_block, + .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, + .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, + + .set_buffer_size = virtio_transport_set_buffer_size, + .set_min_buffer_size = virtio_transport_set_min_buffer_size, + .set_max_buffer_size = virtio_transport_set_max_buffer_size, + .get_buffer_size = virtio_transport_get_buffer_size, + .get_min_buffer_size = virtio_transport_get_min_buffer_size, + .get_max_buffer_size = virtio_transport_get_max_buffer_size, + }, + + .send_pkt = virtio_transport_send_pkt, +}; + +static int virtio_vsock_probe(struct virtio_device *vdev) +{ + vq_callback_t *callbacks[] = { + virtio_vsock_rx_done, + virtio_vsock_tx_done, + virtio_vsock_event_done, + }; + static const char * const names[] = { + "rx", + "tx", + "event", + }; + struct virtio_vsock *vsock = NULL; + int ret; + + ret = mutex_lock_interruptible(&the_virtio_vsock_mutex); + if (ret) + return ret; + + /* Only one virtio-vsock device per guest is supported */ + if (the_virtio_vsock) { + ret = -EBUSY; + goto out; + } + + vsock = kzalloc(sizeof(*vsock), GFP_KERNEL); + if (!vsock) { + ret = -ENOMEM; + goto out; + } + + vsock->vdev = vdev; + + ret = vsock->vdev->config->find_vqs(vsock->vdev, VSOCK_VQ_MAX, + vsock->vqs, callbacks, names); + if (ret < 0) + goto out; + + virtio_vsock_update_guest_cid(vsock); + + ret = vsock_core_init(&virtio_transport.transport); + if (ret < 0) + goto out_vqs; + + vsock->rx_buf_nr = 0; + vsock->rx_buf_max_nr = 0; + atomic_set(&vsock->queued_replies, 0); + + vdev->priv = vsock; + the_virtio_vsock = vsock; + mutex_init(&vsock->tx_lock); + mutex_init(&vsock->rx_lock); + mutex_init(&vsock->event_lock); + spin_lock_init(&vsock->send_pkt_list_lock); + INIT_LIST_HEAD(&vsock->send_pkt_list); + INIT_WORK(&vsock->rx_work, virtio_transport_rx_work); + INIT_WORK(&vsock->tx_work, virtio_transport_tx_work); + INIT_WORK(&vsock->event_work, virtio_transport_event_work); + INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work); + + mutex_lock(&vsock->rx_lock); + virtio_vsock_rx_fill(vsock); + mutex_unlock(&vsock->rx_lock); + + mutex_lock(&vsock->event_lock); + virtio_vsock_event_fill(vsock); + mutex_unlock(&vsock->event_lock); + + mutex_unlock(&the_virtio_vsock_mutex); + return 0; + +out_vqs: + vsock->vdev->config->del_vqs(vsock->vdev); +out: + kfree(vsock); + mutex_unlock(&the_virtio_vsock_mutex); + return ret; +} + +static void virtio_vsock_remove(struct virtio_device *vdev) +{ + struct virtio_vsock *vsock = vdev->priv; + struct virtio_vsock_pkt *pkt; + + flush_work(&vsock->rx_work); + flush_work(&vsock->tx_work); + flush_work(&vsock->event_work); + flush_work(&vsock->send_pkt_work); + + vdev->config->reset(vdev); + + mutex_lock(&vsock->rx_lock); + while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_RX]))) + virtio_transport_free_pkt(pkt); + mutex_unlock(&vsock->rx_lock); + + mutex_lock(&vsock->tx_lock); + while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_TX]))) + virtio_transport_free_pkt(pkt); + mutex_unlock(&vsock->tx_lock); + + spin_lock_bh(&vsock->send_pkt_list_lock); + while (!list_empty(&vsock->send_pkt_list)) { + pkt = list_first_entry(&vsock->send_pkt_list, + struct virtio_vsock_pkt, list); + list_del(&pkt->list); + virtio_transport_free_pkt(pkt); + } + spin_unlock_bh(&vsock->send_pkt_list_lock); + + mutex_lock(&the_virtio_vsock_mutex); + the_virtio_vsock = NULL; + vsock_core_exit(); + mutex_unlock(&the_virtio_vsock_mutex); + + vdev->config->del_vqs(vdev); + + kfree(vsock); +} + +static struct virtio_device_id id_table[] = { + { VIRTIO_ID_VSOCK, VIRTIO_DEV_ANY_ID }, + { 0 }, +}; + +static unsigned int features[] = { +}; + +static struct virtio_driver virtio_vsock_driver = { + .feature_table = features, + .feature_table_size = ARRAY_SIZE(features), + .driver.name = KBUILD_MODNAME, + .driver.owner = THIS_MODULE, + .id_table = id_table, + .probe = virtio_vsock_probe, + .remove = virtio_vsock_remove, +}; + +static int __init virtio_vsock_init(void) +{ + int ret; + + virtio_vsock_workqueue = alloc_workqueue("virtio_vsock", 0, 0); + if (!virtio_vsock_workqueue) + return -ENOMEM; + ret = register_virtio_driver(&virtio_vsock_driver); + if (ret) + destroy_workqueue(virtio_vsock_workqueue); + return ret; +} + +static void __exit virtio_vsock_exit(void) +{ + unregister_virtio_driver(&virtio_vsock_driver); + destroy_workqueue(virtio_vsock_workqueue); +} + +module_init(virtio_vsock_init); +module_exit(virtio_vsock_exit); +MODULE_LICENSE("GPL v2"); +MODULE_AUTHOR("Asias He"); +MODULE_DESCRIPTION("virtio transport for vsock"); +MODULE_DEVICE_TABLE(virtio, id_table); diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c new file mode 100644 index 000000000000..ddcae46ae408 --- /dev/null +++ b/net/vmw_vsock/virtio_transport_common.c @@ -0,0 +1,1003 @@ +/* + * common code for virtio vsock + * + * Copyright (C) 2013-2015 Red Hat, Inc. + * Author: Asias He <asias@redhat.com> + * Stefan Hajnoczi <stefanha@redhat.com> + * + * This work is licensed under the terms of the GNU GPL, version 2. + */ +#include <linux/spinlock.h> +#include <linux/module.h> +#include <linux/ctype.h> +#include <linux/list.h> +#include <linux/virtio.h> +#include <linux/virtio_ids.h> +#include <linux/virtio_config.h> +#include <linux/virtio_vsock.h> + +#include <net/sock.h> +#include <net/af_vsock.h> + +#define CREATE_TRACE_POINTS +#include <trace/events/vsock_virtio_transport_common.h> + +/* How long to wait for graceful shutdown of a connection */ +#define VSOCK_CLOSE_TIMEOUT (8 * HZ) + +uint virtio_transport_max_vsock_pkt_buf_size = 64 * 1024; +module_param(virtio_transport_max_vsock_pkt_buf_size, uint, 0444); +EXPORT_SYMBOL_GPL(virtio_transport_max_vsock_pkt_buf_size); + +static const struct virtio_transport *virtio_transport_get_ops(void) +{ + const struct vsock_transport *t = vsock_core_get_transport(); + + return container_of(t, struct virtio_transport, transport); +} + +struct virtio_vsock_pkt * +virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info, + size_t len, + u32 src_cid, + u32 src_port, + u32 dst_cid, + u32 dst_port) +{ + struct virtio_vsock_pkt *pkt; + int err; + + pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); + if (!pkt) + return NULL; + + pkt->hdr.type = cpu_to_le16(info->type); + pkt->hdr.op = cpu_to_le16(info->op); + pkt->hdr.src_cid = cpu_to_le64(src_cid); + pkt->hdr.dst_cid = cpu_to_le64(dst_cid); + pkt->hdr.src_port = cpu_to_le32(src_port); + pkt->hdr.dst_port = cpu_to_le32(dst_port); + pkt->hdr.flags = cpu_to_le32(info->flags); + pkt->len = len; + pkt->hdr.len = cpu_to_le32(len); + pkt->reply = info->reply; + pkt->vsk = info->vsk; + + if (info->msg && len > 0) { + pkt->buf = kmalloc(len, GFP_KERNEL); + if (!pkt->buf) + goto out_pkt; + err = memcpy_from_msg(pkt->buf, info->msg, len); + if (err) + goto out; + } + + trace_virtio_transport_alloc_pkt(src_cid, src_port, + dst_cid, dst_port, + len, + info->type, + info->op, + info->flags); + + return pkt; + +out: + kfree(pkt->buf); +out_pkt: + kfree(pkt); + return NULL; +} +EXPORT_SYMBOL_GPL(virtio_transport_alloc_pkt); + +static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, + struct virtio_vsock_pkt_info *info) +{ + u32 src_cid, src_port, dst_cid, dst_port; + struct virtio_vsock_sock *vvs; + struct virtio_vsock_pkt *pkt; + u32 pkt_len = info->pkt_len; + + src_cid = vm_sockets_get_local_cid(); + src_port = vsk->local_addr.svm_port; + if (!info->remote_cid) { + dst_cid = vsk->remote_addr.svm_cid; + dst_port = vsk->remote_addr.svm_port; + } else { + dst_cid = info->remote_cid; + dst_port = info->remote_port; + } + + vvs = vsk->trans; + + /* we can send less than pkt_len bytes */ + if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE) + pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; + + /* virtio_transport_get_credit might return less than pkt_len credit */ + pkt_len = virtio_transport_get_credit(vvs, pkt_len); + + /* Do not send zero length OP_RW pkt */ + if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW) + return pkt_len; + + pkt = virtio_transport_alloc_pkt(info, pkt_len, + src_cid, src_port, + dst_cid, dst_port); + if (!pkt) { + virtio_transport_put_credit(vvs, pkt_len); + return -ENOMEM; + } + + virtio_transport_inc_tx_pkt(vvs, pkt); + + return virtio_transport_get_ops()->send_pkt(pkt); +} + +static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, + struct virtio_vsock_pkt *pkt) +{ + vvs->rx_bytes += pkt->len; +} + +static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs, + struct virtio_vsock_pkt *pkt) +{ + vvs->rx_bytes -= pkt->len; + vvs->fwd_cnt += pkt->len; +} + +void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt) +{ + spin_lock_bh(&vvs->tx_lock); + pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt); + pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc); + spin_unlock_bh(&vvs->tx_lock); +} +EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt); + +u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit) +{ + u32 ret; + + spin_lock_bh(&vvs->tx_lock); + ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); + if (ret > credit) + ret = credit; + vvs->tx_cnt += ret; + spin_unlock_bh(&vvs->tx_lock); + + return ret; +} +EXPORT_SYMBOL_GPL(virtio_transport_get_credit); + +void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit) +{ + spin_lock_bh(&vvs->tx_lock); + vvs->tx_cnt -= credit; + spin_unlock_bh(&vvs->tx_lock); +} +EXPORT_SYMBOL_GPL(virtio_transport_put_credit); + +static int virtio_transport_send_credit_update(struct vsock_sock *vsk, + int type, + struct virtio_vsock_hdr *hdr) +{ + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE, + .type = type, + .vsk = vsk, + }; + + return virtio_transport_send_pkt_info(vsk, &info); +} + +static ssize_t +virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + struct virtio_vsock_pkt *pkt; + size_t bytes, total = 0; + int err = -EFAULT; + + spin_lock_bh(&vvs->rx_lock); + while (total < len && !list_empty(&vvs->rx_queue)) { + pkt = list_first_entry(&vvs->rx_queue, + struct virtio_vsock_pkt, list); + + bytes = len - total; + if (bytes > pkt->len - pkt->off) + bytes = pkt->len - pkt->off; + + /* sk_lock is held by caller so no one else can dequeue. + * Unlock rx_lock since memcpy_to_msg() may sleep. + */ + spin_unlock_bh(&vvs->rx_lock); + + err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes); + if (err) + goto out; + + spin_lock_bh(&vvs->rx_lock); + + total += bytes; + pkt->off += bytes; + if (pkt->off == pkt->len) { + virtio_transport_dec_rx_pkt(vvs, pkt); + list_del(&pkt->list); + virtio_transport_free_pkt(pkt); + } + } + spin_unlock_bh(&vvs->rx_lock); + + /* Send a credit pkt to peer */ + virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, + NULL); + + return total; + +out: + if (total) + err = total; + return err; +} + +ssize_t +virtio_transport_stream_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len, int flags) +{ + if (flags & MSG_PEEK) + return -EOPNOTSUPP; + + return virtio_transport_stream_do_dequeue(vsk, msg, len); +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); + +int +virtio_transport_dgram_dequeue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len, int flags) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue); + +s64 virtio_transport_stream_has_data(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + s64 bytes; + + spin_lock_bh(&vvs->rx_lock); + bytes = vvs->rx_bytes; + spin_unlock_bh(&vvs->rx_lock); + + return bytes; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data); + +static s64 virtio_transport_has_space(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + s64 bytes; + + bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); + if (bytes < 0) + bytes = 0; + + return bytes; +} + +s64 virtio_transport_stream_has_space(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + s64 bytes; + + spin_lock_bh(&vvs->tx_lock); + bytes = virtio_transport_has_space(vsk); + spin_unlock_bh(&vvs->tx_lock); + + return bytes; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space); + +int virtio_transport_do_socket_init(struct vsock_sock *vsk, + struct vsock_sock *psk) +{ + struct virtio_vsock_sock *vvs; + + vvs = kzalloc(sizeof(*vvs), GFP_KERNEL); + if (!vvs) + return -ENOMEM; + + vsk->trans = vvs; + vvs->vsk = vsk; + if (psk) { + struct virtio_vsock_sock *ptrans = psk->trans; + + vvs->buf_size = ptrans->buf_size; + vvs->buf_size_min = ptrans->buf_size_min; + vvs->buf_size_max = ptrans->buf_size_max; + vvs->peer_buf_alloc = ptrans->peer_buf_alloc; + } else { + vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE; + vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE; + vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE; + } + + vvs->buf_alloc = vvs->buf_size; + + spin_lock_init(&vvs->rx_lock); + spin_lock_init(&vvs->tx_lock); + INIT_LIST_HEAD(&vvs->rx_queue); + + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); + +u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + + return vvs->buf_size; +} +EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size); + +u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + + return vvs->buf_size_min; +} +EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size); + +u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + + return vvs->buf_size_max; +} +EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size); + +void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + + if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) + val = VIRTIO_VSOCK_MAX_BUF_SIZE; + if (val < vvs->buf_size_min) + vvs->buf_size_min = val; + if (val > vvs->buf_size_max) + vvs->buf_size_max = val; + vvs->buf_size = val; + vvs->buf_alloc = val; +} +EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size); + +void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + + if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) + val = VIRTIO_VSOCK_MAX_BUF_SIZE; + if (val > vvs->buf_size) + vvs->buf_size = val; + vvs->buf_size_min = val; +} +EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size); + +void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + + if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) + val = VIRTIO_VSOCK_MAX_BUF_SIZE; + if (val < vvs->buf_size) + vvs->buf_size = val; + vvs->buf_size_max = val; +} +EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size); + +int +virtio_transport_notify_poll_in(struct vsock_sock *vsk, + size_t target, + bool *data_ready_now) +{ + if (vsock_stream_has_data(vsk)) + *data_ready_now = true; + else + *data_ready_now = false; + + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in); + +int +virtio_transport_notify_poll_out(struct vsock_sock *vsk, + size_t target, + bool *space_avail_now) +{ + s64 free_space; + + free_space = vsock_stream_has_space(vsk); + if (free_space > 0) + *space_avail_now = true; + else if (free_space == 0) + *space_avail_now = false; + + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out); + +int virtio_transport_notify_recv_init(struct vsock_sock *vsk, + size_t target, struct vsock_transport_recv_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init); + +int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk, + size_t target, struct vsock_transport_recv_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block); + +int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk, + size_t target, struct vsock_transport_recv_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue); + +int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk, + size_t target, ssize_t copied, bool data_read, + struct vsock_transport_recv_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue); + +int virtio_transport_notify_send_init(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init); + +int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block); + +int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk, + struct vsock_transport_send_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue); + +int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk, + ssize_t written, struct vsock_transport_send_notify_data *data) +{ + return 0; +} +EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue); + +u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + + return vvs->buf_size; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); + +bool virtio_transport_stream_is_active(struct vsock_sock *vsk) +{ + return true; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active); + +bool virtio_transport_stream_allow(u32 cid, u32 port) +{ + return true; +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_allow); + +int virtio_transport_dgram_bind(struct vsock_sock *vsk, + struct sockaddr_vm *addr) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind); + +bool virtio_transport_dgram_allow(u32 cid, u32 port) +{ + return false; +} +EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow); + +int virtio_transport_connect(struct vsock_sock *vsk) +{ + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_REQUEST, + .type = VIRTIO_VSOCK_TYPE_STREAM, + .vsk = vsk, + }; + + return virtio_transport_send_pkt_info(vsk, &info); +} +EXPORT_SYMBOL_GPL(virtio_transport_connect); + +int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) +{ + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_SHUTDOWN, + .type = VIRTIO_VSOCK_TYPE_STREAM, + .flags = (mode & RCV_SHUTDOWN ? + VIRTIO_VSOCK_SHUTDOWN_RCV : 0) | + (mode & SEND_SHUTDOWN ? + VIRTIO_VSOCK_SHUTDOWN_SEND : 0), + .vsk = vsk, + }; + + return virtio_transport_send_pkt_info(vsk, &info); +} +EXPORT_SYMBOL_GPL(virtio_transport_shutdown); + +int +virtio_transport_dgram_enqueue(struct vsock_sock *vsk, + struct sockaddr_vm *remote_addr, + struct msghdr *msg, + size_t dgram_len) +{ + return -EOPNOTSUPP; +} +EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue); + +ssize_t +virtio_transport_stream_enqueue(struct vsock_sock *vsk, + struct msghdr *msg, + size_t len) +{ + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_RW, + .type = VIRTIO_VSOCK_TYPE_STREAM, + .msg = msg, + .pkt_len = len, + .vsk = vsk, + }; + + return virtio_transport_send_pkt_info(vsk, &info); +} +EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue); + +void virtio_transport_destruct(struct vsock_sock *vsk) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + + kfree(vvs); +} +EXPORT_SYMBOL_GPL(virtio_transport_destruct); + +static int virtio_transport_reset(struct vsock_sock *vsk, + struct virtio_vsock_pkt *pkt) +{ + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_RST, + .type = VIRTIO_VSOCK_TYPE_STREAM, + .reply = !!pkt, + .vsk = vsk, + }; + + /* Send RST only if the original pkt is not a RST pkt */ + if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) + return 0; + + return virtio_transport_send_pkt_info(vsk, &info); +} + +/* Normally packets are associated with a socket. There may be no socket if an + * attempt was made to connect to a socket that does not exist. + */ +static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt) +{ + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_RST, + .type = le16_to_cpu(pkt->hdr.type), + .reply = true, + }; + + /* Send RST only if the original pkt is not a RST pkt */ + if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) + return 0; + + pkt = virtio_transport_alloc_pkt(&info, 0, + le64_to_cpu(pkt->hdr.dst_cid), + le32_to_cpu(pkt->hdr.dst_port), + le64_to_cpu(pkt->hdr.src_cid), + le32_to_cpu(pkt->hdr.src_port)); + if (!pkt) + return -ENOMEM; + + return virtio_transport_get_ops()->send_pkt(pkt); +} + +static void virtio_transport_wait_close(struct sock *sk, long timeout) +{ + if (timeout) { + DEFINE_WAIT(wait); + + do { + prepare_to_wait(sk_sleep(sk), &wait, + TASK_INTERRUPTIBLE); + if (sk_wait_event(sk, &timeout, + sock_flag(sk, SOCK_DONE))) + break; + } while (!signal_pending(current) && timeout); + + finish_wait(sk_sleep(sk), &wait); + } +} + +static void virtio_transport_do_close(struct vsock_sock *vsk, + bool cancel_timeout) +{ + struct sock *sk = sk_vsock(vsk); + + sock_set_flag(sk, SOCK_DONE); + vsk->peer_shutdown = SHUTDOWN_MASK; + if (vsock_stream_has_data(vsk) <= 0) + sk->sk_state = SS_DISCONNECTING; + sk->sk_state_change(sk); + + if (vsk->close_work_scheduled && + (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { + vsk->close_work_scheduled = false; + + vsock_remove_sock(vsk); + + /* Release refcnt obtained when we scheduled the timeout */ + sock_put(sk); + } +} + +static void virtio_transport_close_timeout(struct work_struct *work) +{ + struct vsock_sock *vsk = + container_of(work, struct vsock_sock, close_work.work); + struct sock *sk = sk_vsock(vsk); + + sock_hold(sk); + lock_sock(sk); + + if (!sock_flag(sk, SOCK_DONE)) { + (void)virtio_transport_reset(vsk, NULL); + + virtio_transport_do_close(vsk, false); + } + + vsk->close_work_scheduled = false; + + release_sock(sk); + sock_put(sk); +} + +/* User context, vsk->sk is locked */ +static bool virtio_transport_close(struct vsock_sock *vsk) +{ + struct sock *sk = &vsk->sk; + + if (!(sk->sk_state == SS_CONNECTED || + sk->sk_state == SS_DISCONNECTING)) + return true; + + /* Already received SHUTDOWN from peer, reply with RST */ + if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) { + (void)virtio_transport_reset(vsk, NULL); + return true; + } + + if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK) + (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK); + + if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING)) + virtio_transport_wait_close(sk, sk->sk_lingertime); + + if (sock_flag(sk, SOCK_DONE)) { + return true; + } + + sock_hold(sk); + INIT_DELAYED_WORK(&vsk->close_work, + virtio_transport_close_timeout); + vsk->close_work_scheduled = true; + schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT); + return false; +} + +void virtio_transport_release(struct vsock_sock *vsk) +{ + struct sock *sk = &vsk->sk; + bool remove_sock = true; + + lock_sock(sk); + if (sk->sk_type == SOCK_STREAM) + remove_sock = virtio_transport_close(vsk); + release_sock(sk); + + if (remove_sock) + vsock_remove_sock(vsk); +} +EXPORT_SYMBOL_GPL(virtio_transport_release); + +static int +virtio_transport_recv_connecting(struct sock *sk, + struct virtio_vsock_pkt *pkt) +{ + struct vsock_sock *vsk = vsock_sk(sk); + int err; + int skerr; + + switch (le16_to_cpu(pkt->hdr.op)) { + case VIRTIO_VSOCK_OP_RESPONSE: + sk->sk_state = SS_CONNECTED; + sk->sk_socket->state = SS_CONNECTED; + vsock_insert_connected(vsk); + sk->sk_state_change(sk); + break; + case VIRTIO_VSOCK_OP_INVALID: + break; + case VIRTIO_VSOCK_OP_RST: + skerr = ECONNRESET; + err = 0; + goto destroy; + default: + skerr = EPROTO; + err = -EINVAL; + goto destroy; + } + return 0; + +destroy: + virtio_transport_reset(vsk, pkt); + sk->sk_state = SS_UNCONNECTED; + sk->sk_err = skerr; + sk->sk_error_report(sk); + return err; +} + +static int +virtio_transport_recv_connected(struct sock *sk, + struct virtio_vsock_pkt *pkt) +{ + struct vsock_sock *vsk = vsock_sk(sk); + struct virtio_vsock_sock *vvs = vsk->trans; + int err = 0; + + switch (le16_to_cpu(pkt->hdr.op)) { + case VIRTIO_VSOCK_OP_RW: + pkt->len = le32_to_cpu(pkt->hdr.len); + pkt->off = 0; + + spin_lock_bh(&vvs->rx_lock); + virtio_transport_inc_rx_pkt(vvs, pkt); + list_add_tail(&pkt->list, &vvs->rx_queue); + spin_unlock_bh(&vvs->rx_lock); + + sk->sk_data_ready(sk); + return err; + case VIRTIO_VSOCK_OP_CREDIT_UPDATE: + sk->sk_write_space(sk); + break; + case VIRTIO_VSOCK_OP_SHUTDOWN: + if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV) + vsk->peer_shutdown |= RCV_SHUTDOWN; + if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND) + vsk->peer_shutdown |= SEND_SHUTDOWN; + if (vsk->peer_shutdown == SHUTDOWN_MASK && + vsock_stream_has_data(vsk) <= 0) + sk->sk_state = SS_DISCONNECTING; + if (le32_to_cpu(pkt->hdr.flags)) + sk->sk_state_change(sk); + break; + case VIRTIO_VSOCK_OP_RST: + virtio_transport_do_close(vsk, true); + break; + default: + err = -EINVAL; + break; + } + + virtio_transport_free_pkt(pkt); + return err; +} + +static void +virtio_transport_recv_disconnecting(struct sock *sk, + struct virtio_vsock_pkt *pkt) +{ + struct vsock_sock *vsk = vsock_sk(sk); + + if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) + virtio_transport_do_close(vsk, true); +} + +static int +virtio_transport_send_response(struct vsock_sock *vsk, + struct virtio_vsock_pkt *pkt) +{ + struct virtio_vsock_pkt_info info = { + .op = VIRTIO_VSOCK_OP_RESPONSE, + .type = VIRTIO_VSOCK_TYPE_STREAM, + .remote_cid = le64_to_cpu(pkt->hdr.src_cid), + .remote_port = le32_to_cpu(pkt->hdr.src_port), + .reply = true, + .vsk = vsk, + }; + + return virtio_transport_send_pkt_info(vsk, &info); +} + +/* Handle server socket */ +static int +virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) +{ + struct vsock_sock *vsk = vsock_sk(sk); + struct vsock_sock *vchild; + struct sock *child; + + if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) { + virtio_transport_reset(vsk, pkt); + return -EINVAL; + } + + if (sk_acceptq_is_full(sk)) { + virtio_transport_reset(vsk, pkt); + return -ENOMEM; + } + + child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, + sk->sk_type, 0); + if (!child) { + virtio_transport_reset(vsk, pkt); + return -ENOMEM; + } + + sk->sk_ack_backlog++; + + lock_sock_nested(child, SINGLE_DEPTH_NESTING); + + child->sk_state = SS_CONNECTED; + + vchild = vsock_sk(child); + vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid), + le32_to_cpu(pkt->hdr.dst_port)); + vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid), + le32_to_cpu(pkt->hdr.src_port)); + + vsock_insert_connected(vchild); + vsock_enqueue_accept(sk, child); + virtio_transport_send_response(vchild, pkt); + + release_sock(child); + + sk->sk_data_ready(sk); + return 0; +} + +static bool virtio_transport_space_update(struct sock *sk, + struct virtio_vsock_pkt *pkt) +{ + struct vsock_sock *vsk = vsock_sk(sk); + struct virtio_vsock_sock *vvs = vsk->trans; + bool space_available; + + /* buf_alloc and fwd_cnt is always included in the hdr */ + spin_lock_bh(&vvs->tx_lock); + vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); + vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); + space_available = virtio_transport_has_space(vsk); + spin_unlock_bh(&vvs->tx_lock); + return space_available; +} + +/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex + * lock. + */ +void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) +{ + struct sockaddr_vm src, dst; + struct vsock_sock *vsk; + struct sock *sk; + bool space_available; + + vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid), + le32_to_cpu(pkt->hdr.src_port)); + vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid), + le32_to_cpu(pkt->hdr.dst_port)); + + trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port, + dst.svm_cid, dst.svm_port, + le32_to_cpu(pkt->hdr.len), + le16_to_cpu(pkt->hdr.type), + le16_to_cpu(pkt->hdr.op), + le32_to_cpu(pkt->hdr.flags), + le32_to_cpu(pkt->hdr.buf_alloc), + le32_to_cpu(pkt->hdr.fwd_cnt)); + + if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) { + (void)virtio_transport_reset_no_sock(pkt); + goto free_pkt; + } + + /* The socket must be in connected or bound table + * otherwise send reset back + */ + sk = vsock_find_connected_socket(&src, &dst); + if (!sk) { + sk = vsock_find_bound_socket(&dst); + if (!sk) { + (void)virtio_transport_reset_no_sock(pkt); + goto free_pkt; + } + } + + vsk = vsock_sk(sk); + + space_available = virtio_transport_space_update(sk, pkt); + + lock_sock(sk); + + /* Update CID in case it has changed after a transport reset event */ + vsk->local_addr.svm_cid = dst.svm_cid; + + if (space_available) + sk->sk_write_space(sk); + + switch (sk->sk_state) { + case VSOCK_SS_LISTEN: + virtio_transport_recv_listen(sk, pkt); + virtio_transport_free_pkt(pkt); + break; + case SS_CONNECTING: + virtio_transport_recv_connecting(sk, pkt); + virtio_transport_free_pkt(pkt); + break; + case SS_CONNECTED: + virtio_transport_recv_connected(sk, pkt); + break; + case SS_DISCONNECTING: + virtio_transport_recv_disconnecting(sk, pkt); + virtio_transport_free_pkt(pkt); + break; + default: + virtio_transport_free_pkt(pkt); + break; + } + release_sock(sk); + + /* Release refcnt obtained when we fetched this socket out of the + * bound or connected list. + */ + sock_put(sk); + return; + +free_pkt: + virtio_transport_free_pkt(pkt); +} +EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt); + +void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt) +{ + kfree(pkt->buf); + kfree(pkt); +} +EXPORT_SYMBOL_GPL(virtio_transport_free_pkt); + +MODULE_LICENSE("GPL v2"); +MODULE_AUTHOR("Asias He"); +MODULE_DESCRIPTION("common code for virtio vsock"); diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c index 1f3f34b56840..c09efcdf72d2 100644 --- a/net/vmw_vsock/vmci_transport.c +++ b/net/vmw_vsock/vmci_transport.c @@ -1679,6 +1679,8 @@ static void vmci_transport_destruct(struct vsock_sock *vsk) static void vmci_transport_release(struct vsock_sock *vsk) { + vsock_remove_sock(vsk); + if (!vmci_handle_is_invalid(vmci_trans(vsk)->dg_handle)) { vmci_datagram_destroy_handle(vmci_trans(vsk)->dg_handle); vmci_trans(vsk)->dg_handle = VMCI_INVALID_HANDLE; @@ -1770,11 +1772,8 @@ static int vmci_transport_dgram_dequeue(struct vsock_sock *vsk, /* Retrieve the head sk_buff from the socket's receive queue. */ err = 0; skb = skb_recv_datagram(&vsk->sk, flags, noblock, &err); - if (err) - return err; - if (!skb) - return -EAGAIN; + return err; dg = (struct vmci_datagram *)skb->data; if (!dg) @@ -2089,7 +2088,7 @@ static u32 vmci_transport_get_local_cid(void) return vmci_get_context_id(); } -static struct vsock_transport vmci_transport = { +static const struct vsock_transport vmci_transport = { .init = vmci_transport_socket_init, .destruct = vmci_transport_destruct, .release = vmci_transport_release, @@ -2189,7 +2188,7 @@ module_exit(vmci_transport_exit); MODULE_AUTHOR("VMware, Inc."); MODULE_DESCRIPTION("VMCI transport for Virtual Sockets"); -MODULE_VERSION("1.0.3.0-k"); +MODULE_VERSION("1.0.4.0-k"); MODULE_LICENSE("GPL v2"); MODULE_ALIAS("vmware_vsock"); MODULE_ALIAS_NETPROTO(PF_VSOCK); diff --git a/net/wireless/nl80211.c b/net/wireless/nl80211.c index eb25998a0032..858d1e36d8a5 100644 --- a/net/wireless/nl80211.c +++ b/net/wireless/nl80211.c @@ -3650,7 +3650,7 @@ static bool nl80211_put_sta_rate(struct sk_buff *msg, struct rate_info *info, struct nlattr *rate; u32 bitrate; u16 bitrate_compat; - enum nl80211_attrs rate_flg; + enum nl80211_rate_info rate_flg; rate = nla_nest_start(msg, attr); if (!rate) diff --git a/net/wireless/scan.c b/net/wireless/scan.c index 018457e86e60..26ad02f13d38 100644 --- a/net/wireless/scan.c +++ b/net/wireless/scan.c @@ -69,7 +69,7 @@ module_param(bss_entries_limit, int, 0644); MODULE_PARM_DESC(bss_entries_limit, "limit to number of scan BSS entries (per wiphy, default 1000)"); -#define IEEE80211_SCAN_RESULT_EXPIRE (30 * HZ) +#define IEEE80211_SCAN_RESULT_EXPIRE (7 * HZ) static void bss_free(struct cfg80211_internal_bss *bss) { diff --git a/net/xfrm/Kconfig b/net/xfrm/Kconfig index bda1a13628a8..16cdd2c9221f 100644 --- a/net/xfrm/Kconfig +++ b/net/xfrm/Kconfig @@ -20,6 +20,17 @@ config XFRM_USER If unsure, say Y. +config XFRM_USER_COMPAT + tristate "Compatible ABI support" + depends on XFRM_USER && COMPAT_FOR_U64_ALIGNMENT && \ + HAVE_EFFICIENT_UNALIGNED_ACCESS + select WANT_COMPAT_NETLINK_MESSAGES + help + Transformation(XFRM) user configuration interface like IPsec + used by compatible Linux applications. + + If unsure, say N. + config XFRM_SUB_POLICY bool "Transformation sub policy support" depends on XFRM diff --git a/net/xfrm/Makefile b/net/xfrm/Makefile index c0e961983f17..516c78e22e2a 100644 --- a/net/xfrm/Makefile +++ b/net/xfrm/Makefile @@ -8,4 +8,5 @@ obj-$(CONFIG_XFRM) := xfrm_policy.o xfrm_state.o xfrm_hash.o \ obj-$(CONFIG_XFRM_STATISTICS) += xfrm_proc.o obj-$(CONFIG_XFRM_ALGO) += xfrm_algo.o obj-$(CONFIG_XFRM_USER) += xfrm_user.o +obj-$(CONFIG_XFRM_USER_COMPAT) += xfrm_compat.o obj-$(CONFIG_XFRM_IPCOMP) += xfrm_ipcomp.o diff --git a/net/xfrm/xfrm_algo.c b/net/xfrm/xfrm_algo.c index f07224d8b88f..9adfdd711b31 100644 --- a/net/xfrm/xfrm_algo.c +++ b/net/xfrm/xfrm_algo.c @@ -239,7 +239,7 @@ static struct xfrm_algo_desc aalg_list[] = { .uinfo = { .auth = { - .icv_truncbits = 96, + .icv_truncbits = 128, .icv_fullbits = 256, } }, diff --git a/net/xfrm/xfrm_compat.c b/net/xfrm/xfrm_compat.c new file mode 100644 index 000000000000..7158078a71f1 --- /dev/null +++ b/net/xfrm/xfrm_compat.c @@ -0,0 +1,685 @@ +// SPDX-License-Identifier: GPL-2.0 +/* + * XFRM compat layer + * Author: Dmitry Safonov <dima@arista.com> + * Based on code and translator idea by: Florian Westphal <fw@strlen.de> + */ +#include <linux/compat.h> +#include <linux/vmalloc.h> +#include <linux/xfrm.h> +#include <net/xfrm.h> + +struct compat_xfrm_lifetime_cfg { + compat_u64 soft_byte_limit, hard_byte_limit; + compat_u64 soft_packet_limit, hard_packet_limit; + compat_u64 soft_add_expires_seconds, hard_add_expires_seconds; + compat_u64 soft_use_expires_seconds, hard_use_expires_seconds; +}; /* same size on 32bit, but only 4 byte alignment required */ + +struct compat_xfrm_lifetime_cur { + compat_u64 bytes, packets, add_time, use_time; +}; /* same size on 32bit, but only 4 byte alignment required */ + +struct compat_xfrm_userpolicy_info { + struct xfrm_selector sel; + struct compat_xfrm_lifetime_cfg lft; + struct compat_xfrm_lifetime_cur curlft; + __u32 priority, index; + u8 dir, action, flags, share; + /* 4 bytes additional padding on 64bit */ +}; + +struct compat_xfrm_usersa_info { + struct xfrm_selector sel; + struct xfrm_id id; + xfrm_address_t saddr; + struct compat_xfrm_lifetime_cfg lft; + struct compat_xfrm_lifetime_cur curlft; + struct xfrm_stats stats; + __u32 seq, reqid; + u16 family; + u8 mode, replay_window, flags; + /* 4 bytes additional padding on 64bit */ +}; + +struct compat_xfrm_user_acquire { + struct xfrm_id id; + xfrm_address_t saddr; + struct xfrm_selector sel; + struct compat_xfrm_userpolicy_info policy; + /* 4 bytes additional padding on 64bit */ + __u32 aalgos, ealgos, calgos, seq; +}; + +struct compat_xfrm_userspi_info { + struct compat_xfrm_usersa_info info; + /* 4 bytes additional padding on 64bit */ + __u32 min, max; +}; + +struct compat_xfrm_user_expire { + struct compat_xfrm_usersa_info state; + /* 8 bytes additional padding on 64bit */ + u8 hard; +}; + +struct compat_xfrm_user_polexpire { + struct compat_xfrm_userpolicy_info pol; + /* 8 bytes additional padding on 64bit */ + u8 hard; +}; + +#define XMSGSIZE(type) sizeof(struct type) + +static const int compat_msg_min[XFRM_NR_MSGTYPES] = { + [XFRM_MSG_NEWSA - XFRM_MSG_BASE] = XMSGSIZE(compat_xfrm_usersa_info), + [XFRM_MSG_DELSA - XFRM_MSG_BASE] = XMSGSIZE(xfrm_usersa_id), + [XFRM_MSG_GETSA - XFRM_MSG_BASE] = XMSGSIZE(xfrm_usersa_id), + [XFRM_MSG_NEWPOLICY - XFRM_MSG_BASE] = XMSGSIZE(compat_xfrm_userpolicy_info), + [XFRM_MSG_DELPOLICY - XFRM_MSG_BASE] = XMSGSIZE(xfrm_userpolicy_id), + [XFRM_MSG_GETPOLICY - XFRM_MSG_BASE] = XMSGSIZE(xfrm_userpolicy_id), + [XFRM_MSG_ALLOCSPI - XFRM_MSG_BASE] = XMSGSIZE(compat_xfrm_userspi_info), + [XFRM_MSG_ACQUIRE - XFRM_MSG_BASE] = XMSGSIZE(compat_xfrm_user_acquire), + [XFRM_MSG_EXPIRE - XFRM_MSG_BASE] = XMSGSIZE(compat_xfrm_user_expire), + [XFRM_MSG_UPDPOLICY - XFRM_MSG_BASE] = XMSGSIZE(compat_xfrm_userpolicy_info), + [XFRM_MSG_UPDSA - XFRM_MSG_BASE] = XMSGSIZE(compat_xfrm_usersa_info), + [XFRM_MSG_POLEXPIRE - XFRM_MSG_BASE] = XMSGSIZE(compat_xfrm_user_polexpire), + [XFRM_MSG_FLUSHSA - XFRM_MSG_BASE] = XMSGSIZE(xfrm_usersa_flush), + [XFRM_MSG_FLUSHPOLICY - XFRM_MSG_BASE] = 0, + [XFRM_MSG_NEWAE - XFRM_MSG_BASE] = XMSGSIZE(xfrm_aevent_id), + [XFRM_MSG_GETAE - XFRM_MSG_BASE] = XMSGSIZE(xfrm_aevent_id), + [XFRM_MSG_REPORT - XFRM_MSG_BASE] = XMSGSIZE(xfrm_user_report), + [XFRM_MSG_MIGRATE - XFRM_MSG_BASE] = XMSGSIZE(xfrm_userpolicy_id), + [XFRM_MSG_NEWSADINFO - XFRM_MSG_BASE] = sizeof(u32), + [XFRM_MSG_GETSADINFO - XFRM_MSG_BASE] = sizeof(u32), + [XFRM_MSG_NEWSPDINFO - XFRM_MSG_BASE] = sizeof(u32), + [XFRM_MSG_GETSPDINFO - XFRM_MSG_BASE] = sizeof(u32), + [XFRM_MSG_MAPPING - XFRM_MSG_BASE] = XMSGSIZE(xfrm_user_mapping) +}; + +static const struct nla_policy compat_policy[XFRMA_MAX+1] = { + [XFRMA_SA] = { .len = XMSGSIZE(compat_xfrm_usersa_info)}, + [XFRMA_POLICY] = { .len = XMSGSIZE(compat_xfrm_userpolicy_info)}, + [XFRMA_LASTUSED] = { .type = NLA_U64}, + [XFRMA_ALG_AUTH_TRUNC] = { .len = sizeof(struct xfrm_algo_auth)}, + [XFRMA_ALG_AEAD] = { .len = sizeof(struct xfrm_algo_aead) }, + [XFRMA_ALG_AUTH] = { .len = sizeof(struct xfrm_algo) }, + [XFRMA_ALG_CRYPT] = { .len = sizeof(struct xfrm_algo) }, + [XFRMA_ALG_COMP] = { .len = sizeof(struct xfrm_algo) }, + [XFRMA_ENCAP] = { .len = sizeof(struct xfrm_encap_tmpl) }, + [XFRMA_TMPL] = { .len = sizeof(struct xfrm_user_tmpl) }, + [XFRMA_SEC_CTX] = { .len = sizeof(struct xfrm_sec_ctx) }, + [XFRMA_LTIME_VAL] = { .len = sizeof(struct xfrm_lifetime_cur) }, + [XFRMA_REPLAY_VAL] = { .len = sizeof(struct xfrm_replay_state) }, + [XFRMA_REPLAY_THRESH] = { .type = NLA_U32 }, + [XFRMA_ETIMER_THRESH] = { .type = NLA_U32 }, + [XFRMA_SRCADDR] = { .len = sizeof(xfrm_address_t) }, + [XFRMA_COADDR] = { .len = sizeof(xfrm_address_t) }, + [XFRMA_POLICY_TYPE] = { .len = sizeof(struct xfrm_userpolicy_type)}, + [XFRMA_MIGRATE] = { .len = sizeof(struct xfrm_user_migrate) }, + [XFRMA_KMADDRESS] = { .len = sizeof(struct xfrm_user_kmaddress) }, + [XFRMA_MARK] = { .len = sizeof(struct xfrm_mark) }, + [XFRMA_TFCPAD] = { .type = NLA_U32 }, + [XFRMA_REPLAY_ESN_VAL] = { .len = sizeof(struct xfrm_replay_state_esn) }, + [XFRMA_SA_EXTRA_FLAGS] = { .type = NLA_U32 }, + [XFRMA_PROTO] = { .type = NLA_U8 }, + [XFRMA_ADDRESS_FILTER] = { .len = sizeof(struct xfrm_address_filter) }, + [XFRMA_OUTPUT_MARK] = { .type = NLA_U32 }, +}; + +static inline void *kvmalloc(size_t size, gfp_t flags) +{ + void *ret; + + ret = kmalloc(size, flags | __GFP_NOWARN); + if (!ret) + ret = __vmalloc(size, flags, PAGE_KERNEL); + return ret; +} + +static inline bool nla_need_padding_for_64bit(struct sk_buff *skb) +{ +#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS + if (IS_ALIGNED((unsigned long)skb_tail_pointer(skb), 8)) + return true; +#endif + return false; +} + +static inline int nla_total_size_64bit(int payload) +{ + return NLA_ALIGN(nla_attr_size(payload)) +#ifndef CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS + + NLA_ALIGN(nla_attr_size(0)) +#endif + ; +} + +static inline int nla_align_64bit(struct sk_buff *skb, int padattr) +{ + if (nla_need_padding_for_64bit(skb) && + !nla_reserve(skb, padattr, 0)) + return -EMSGSIZE; + + return 0; +} + +struct nlattr *__nla_reserve_64bit(struct sk_buff *skb, int attrtype, + int attrlen, int padattr) +{ + nla_align_64bit(skb, padattr); + + return __nla_reserve(skb, attrtype, attrlen); +} + +static inline void __nla_put_64bit(struct sk_buff *skb, int attrtype, + int attrlen, const void *data, int padattr) +{ + struct nlattr *nla; + + nla = __nla_reserve_64bit(skb, attrtype, attrlen, padattr); + memcpy(nla_data(nla), data, attrlen); +} + +static inline int nla_put_64bit(struct sk_buff *skb, int attrtype, + int attrlen, const void *data, int padattr) +{ + size_t len; + + if (nla_need_padding_for_64bit(skb)) + len = nla_total_size_64bit(attrlen); + else + len = nla_total_size(attrlen); + if (unlikely(skb_tailroom(skb) < len)) + return -EMSGSIZE; + + __nla_put_64bit(skb, attrtype, attrlen, data, padattr); + return 0; +} + +static struct nlmsghdr *xfrm_nlmsg_put_compat(struct sk_buff *skb, + const struct nlmsghdr *nlh_src, u16 type) +{ + int payload = compat_msg_min[type]; + int src_len = xfrm_msg_min[type]; + struct nlmsghdr *nlh_dst; + + /* Compat messages are shorter or equal to native (+padding) */ + if (WARN_ON_ONCE(src_len < payload)) + return ERR_PTR(-EMSGSIZE); + + nlh_dst = nlmsg_put(skb, nlh_src->nlmsg_pid, nlh_src->nlmsg_seq, + nlh_src->nlmsg_type, payload, nlh_src->nlmsg_flags); + if (!nlh_dst) + return ERR_PTR(-EMSGSIZE); + + memset(nlmsg_data(nlh_dst), 0, payload); + + switch (nlh_src->nlmsg_type) { + /* Compat message has the same layout as native */ + case XFRM_MSG_DELSA: + case XFRM_MSG_DELPOLICY: + case XFRM_MSG_FLUSHSA: + case XFRM_MSG_FLUSHPOLICY: + case XFRM_MSG_NEWAE: + case XFRM_MSG_REPORT: + case XFRM_MSG_MIGRATE: + case XFRM_MSG_NEWSADINFO: + case XFRM_MSG_NEWSPDINFO: + case XFRM_MSG_MAPPING: + WARN_ON_ONCE(src_len != payload); + memcpy(nlmsg_data(nlh_dst), nlmsg_data(nlh_src), src_len); + break; + /* 4 byte alignment for trailing u64 on native, but not on compat */ + case XFRM_MSG_NEWSA: + case XFRM_MSG_NEWPOLICY: + case XFRM_MSG_UPDSA: + case XFRM_MSG_UPDPOLICY: + WARN_ON_ONCE(src_len != payload + 4); + memcpy(nlmsg_data(nlh_dst), nlmsg_data(nlh_src), payload); + break; + case XFRM_MSG_EXPIRE: { + const struct xfrm_user_expire *src_ue = nlmsg_data(nlh_src); + struct compat_xfrm_user_expire *dst_ue = nlmsg_data(nlh_dst); + + /* compat_xfrm_user_expire has 4-byte smaller state */ + memcpy(dst_ue, src_ue, sizeof(dst_ue->state)); + dst_ue->hard = src_ue->hard; + break; + } + case XFRM_MSG_ACQUIRE: { + const struct xfrm_user_acquire *src_ua = nlmsg_data(nlh_src); + struct compat_xfrm_user_acquire *dst_ua = nlmsg_data(nlh_dst); + + memcpy(dst_ua, src_ua, offsetof(struct compat_xfrm_user_acquire, aalgos)); + dst_ua->aalgos = src_ua->aalgos; + dst_ua->ealgos = src_ua->ealgos; + dst_ua->calgos = src_ua->calgos; + dst_ua->seq = src_ua->seq; + break; + } + case XFRM_MSG_POLEXPIRE: { + const struct xfrm_user_polexpire *src_upe = nlmsg_data(nlh_src); + struct compat_xfrm_user_polexpire *dst_upe = nlmsg_data(nlh_dst); + + /* compat_xfrm_user_polexpire has 4-byte smaller state */ + memcpy(dst_upe, src_upe, sizeof(dst_upe->pol)); + dst_upe->hard = src_upe->hard; + break; + } + case XFRM_MSG_ALLOCSPI: { + const struct xfrm_userspi_info *src_usi = nlmsg_data(nlh_src); + struct compat_xfrm_userspi_info *dst_usi = nlmsg_data(nlh_dst); + + /* compat_xfrm_user_polexpire has 4-byte smaller state */ + memcpy(dst_usi, src_usi, sizeof(src_usi->info)); + dst_usi->min = src_usi->min; + dst_usi->max = src_usi->max; + break; + } + /* Not being sent by kernel */ + case XFRM_MSG_GETSA: + case XFRM_MSG_GETPOLICY: + case XFRM_MSG_GETAE: + case XFRM_MSG_GETSADINFO: + case XFRM_MSG_GETSPDINFO: + default: + WARN_ONCE(1, "unsupported nlmsg_type %d", nlh_src->nlmsg_type); + return ERR_PTR(-EOPNOTSUPP); + } + + return nlh_dst; +} + +static int xfrm_nla_cpy(struct sk_buff *dst, const struct nlattr *src, int len) +{ + return nla_put(dst, src->nla_type, len, nla_data(src)); +} + +static int xfrm_xlate64_attr(struct sk_buff *dst, const struct nlattr *src) +{ + switch (src->nla_type) { + case XFRMA_PAD: + case XFRMA_OFFLOAD_DEV: + /* Ignore */ + return 0; + case XFRMA_ALG_AUTH: + case XFRMA_ALG_CRYPT: + case XFRMA_ALG_COMP: + case XFRMA_ENCAP: + case XFRMA_TMPL: + return xfrm_nla_cpy(dst, src, nla_len(src)); + case XFRMA_SA: + return xfrm_nla_cpy(dst, src, XMSGSIZE(compat_xfrm_usersa_info)); + case XFRMA_POLICY: + return xfrm_nla_cpy(dst, src, XMSGSIZE(compat_xfrm_userpolicy_info)); + case XFRMA_SEC_CTX: + return xfrm_nla_cpy(dst, src, nla_len(src)); + case XFRMA_LTIME_VAL: + return nla_put_64bit(dst, src->nla_type, nla_len(src), + nla_data(src), XFRMA_PAD); + case XFRMA_REPLAY_VAL: + case XFRMA_REPLAY_THRESH: + case XFRMA_ETIMER_THRESH: + case XFRMA_SRCADDR: + case XFRMA_COADDR: + return xfrm_nla_cpy(dst, src, nla_len(src)); + case XFRMA_LASTUSED: + return nla_put_64bit(dst, src->nla_type, nla_len(src), + nla_data(src), XFRMA_PAD); + case XFRMA_POLICY_TYPE: + case XFRMA_MIGRATE: + case XFRMA_ALG_AEAD: + case XFRMA_KMADDRESS: + case XFRMA_ALG_AUTH_TRUNC: + case XFRMA_MARK: + case XFRMA_TFCPAD: + case XFRMA_REPLAY_ESN_VAL: + case XFRMA_SA_EXTRA_FLAGS: + case XFRMA_PROTO: + case XFRMA_ADDRESS_FILTER: + case XFRMA_OUTPUT_MARK: + return xfrm_nla_cpy(dst, src, nla_len(src)); + default: + BUILD_BUG_ON(XFRMA_MAX != XFRMA_OUTPUT_MARK); + WARN_ONCE(1, "unsupported nla_type %d", src->nla_type); + return -EOPNOTSUPP; + } +} + +/* Take kernel-built (64bit layout) and create 32bit layout for userspace */ +static int xfrm_xlate64(struct sk_buff *dst, const struct nlmsghdr *nlh_src) +{ + u16 type = nlh_src->nlmsg_type - XFRM_MSG_BASE; + const struct nlattr *nla, *attrs; + struct nlmsghdr *nlh_dst; + int len, remaining; + + nlh_dst = xfrm_nlmsg_put_compat(dst, nlh_src, type); + if (IS_ERR(nlh_dst)) + return PTR_ERR(nlh_dst); + + attrs = nlmsg_attrdata(nlh_src, xfrm_msg_min[type]); + len = nlmsg_attrlen(nlh_src, xfrm_msg_min[type]); + + nla_for_each_attr(nla, attrs, len, remaining) { + int err = xfrm_xlate64_attr(dst, nla); + + if (err) + return err; + } + + nlmsg_end(dst, nlh_dst); + + return 0; +} + +static int xfrm_alloc_compat(struct sk_buff *skb, const struct nlmsghdr *nlh_src) +{ + u16 type = nlh_src->nlmsg_type - XFRM_MSG_BASE; + struct sk_buff *new = NULL; + int err; + + if (WARN_ON_ONCE(type >= ARRAY_SIZE(xfrm_msg_min))) + return -EOPNOTSUPP; + + if (skb_shinfo(skb)->frag_list == NULL) { + new = alloc_skb(skb->len + skb_tailroom(skb), GFP_ATOMIC); + if (!new) + return -ENOMEM; + skb_shinfo(skb)->frag_list = new; + } + + err = xfrm_xlate64(skb_shinfo(skb)->frag_list, nlh_src); + if (err) { + if (new) { + kfree_skb(new); + skb_shinfo(skb)->frag_list = NULL; + } + return err; + } + + return 0; +} + +/* Calculates len of translated 64-bit message. */ +static size_t xfrm_user_rcv_calculate_len64(const struct nlmsghdr *src, + struct nlattr *attrs[XFRMA_MAX+1]) +{ + size_t len = nlmsg_len(src); + + switch (src->nlmsg_type) { + case XFRM_MSG_NEWSA: + case XFRM_MSG_NEWPOLICY: + case XFRM_MSG_ALLOCSPI: + case XFRM_MSG_ACQUIRE: + case XFRM_MSG_UPDPOLICY: + case XFRM_MSG_UPDSA: + len += 4; + break; + case XFRM_MSG_EXPIRE: + case XFRM_MSG_POLEXPIRE: + len += 8; + break; + default: + break; + } + + if (attrs[XFRMA_SA]) + len += 4; + if (attrs[XFRMA_POLICY]) + len += 4; + + /* XXX: some attrs may need to be realigned + * if !CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS + */ + + return len; +} + +static int xfrm_attr_cpy32(void *dst, size_t *pos, const struct nlattr *src, + size_t size, int copy_len, int payload) +{ + struct nlmsghdr *nlmsg = dst; + struct nlattr *nla; + + if (WARN_ON_ONCE(copy_len > payload)) + copy_len = payload; + + if (size - *pos < nla_attr_size(payload)) + return -ENOBUFS; + + nla = dst + *pos; + + memcpy(nla, src, nla_attr_size(copy_len)); + nla->nla_len = nla_attr_size(payload); + *pos += nla_attr_size(payload); + nlmsg->nlmsg_len += nla->nla_len; + + memset(dst + *pos, 0, payload - copy_len); + *pos += payload - copy_len; + + return 0; +} + +static int xfrm_xlate32_attr(void *dst, const struct nlattr *nla, + size_t *pos, size_t size) +{ + int type = nla_type(nla); + u16 pol_len32, pol_len64; + int err; + + if (type > XFRMA_MAX) { + BUILD_BUG_ON(XFRMA_MAX != XFRMA_OUTPUT_MARK); + return -EOPNOTSUPP; + } + if (nla_len(nla) < compat_policy[type].len) { + return -EOPNOTSUPP; + } + + pol_len32 = compat_policy[type].len; + pol_len64 = xfrma_policy[type].len; + + /* XFRMA_SA and XFRMA_POLICY - need to know how-to translate */ + if (pol_len32 != pol_len64) { + if (nla_len(nla) != compat_policy[type].len) { + return -EOPNOTSUPP; + } + err = xfrm_attr_cpy32(dst, pos, nla, size, pol_len32, pol_len64); + if (err) + return err; + } + + return xfrm_attr_cpy32(dst, pos, nla, size, nla_len(nla), nla_len(nla)); +} + +static int xfrm_xlate32(struct nlmsghdr *dst, const struct nlmsghdr *src, + struct nlattr *attrs[XFRMA_MAX+1], + size_t size, u8 type) +{ + size_t pos; + int i; + + memcpy(dst, src, NLMSG_HDRLEN); + dst->nlmsg_len = NLMSG_HDRLEN + xfrm_msg_min[type]; + memset(nlmsg_data(dst), 0, xfrm_msg_min[type]); + + switch (src->nlmsg_type) { + /* Compat message has the same layout as native */ + case XFRM_MSG_DELSA: + case XFRM_MSG_GETSA: + case XFRM_MSG_DELPOLICY: + case XFRM_MSG_GETPOLICY: + case XFRM_MSG_FLUSHSA: + case XFRM_MSG_FLUSHPOLICY: + case XFRM_MSG_NEWAE: + case XFRM_MSG_GETAE: + case XFRM_MSG_REPORT: + case XFRM_MSG_MIGRATE: + case XFRM_MSG_NEWSADINFO: + case XFRM_MSG_GETSADINFO: + case XFRM_MSG_NEWSPDINFO: + case XFRM_MSG_GETSPDINFO: + case XFRM_MSG_MAPPING: + memcpy(nlmsg_data(dst), nlmsg_data(src), compat_msg_min[type]); + break; + /* 4 byte alignment for trailing u64 on native, but not on compat */ + case XFRM_MSG_NEWSA: + case XFRM_MSG_NEWPOLICY: + case XFRM_MSG_UPDSA: + case XFRM_MSG_UPDPOLICY: + memcpy(nlmsg_data(dst), nlmsg_data(src), compat_msg_min[type]); + break; + case XFRM_MSG_EXPIRE: { + const struct compat_xfrm_user_expire *src_ue = nlmsg_data(src); + struct xfrm_user_expire *dst_ue = nlmsg_data(dst); + + /* compat_xfrm_user_expire has 4-byte smaller state */ + memcpy(dst_ue, src_ue, sizeof(src_ue->state)); + dst_ue->hard = src_ue->hard; + break; + } + case XFRM_MSG_ACQUIRE: { + const struct compat_xfrm_user_acquire *src_ua = nlmsg_data(src); + struct xfrm_user_acquire *dst_ua = nlmsg_data(dst); + + memcpy(dst_ua, src_ua, offsetof(struct compat_xfrm_user_acquire, aalgos)); + dst_ua->aalgos = src_ua->aalgos; + dst_ua->ealgos = src_ua->ealgos; + dst_ua->calgos = src_ua->calgos; + dst_ua->seq = src_ua->seq; + break; + } + case XFRM_MSG_POLEXPIRE: { + const struct compat_xfrm_user_polexpire *src_upe = nlmsg_data(src); + struct xfrm_user_polexpire *dst_upe = nlmsg_data(dst); + + /* compat_xfrm_user_polexpire has 4-byte smaller state */ + memcpy(dst_upe, src_upe, sizeof(src_upe->pol)); + dst_upe->hard = src_upe->hard; + break; + } + case XFRM_MSG_ALLOCSPI: { + const struct compat_xfrm_userspi_info *src_usi = nlmsg_data(src); + struct xfrm_userspi_info *dst_usi = nlmsg_data(dst); + + /* compat_xfrm_user_polexpire has 4-byte smaller state */ + memcpy(dst_usi, src_usi, sizeof(src_usi->info)); + dst_usi->min = src_usi->min; + dst_usi->max = src_usi->max; + break; + } + default: + return -EOPNOTSUPP; + } + pos = dst->nlmsg_len; + + for (i = 1; i < XFRMA_MAX + 1; i++) { + int err; + + if (i == XFRMA_PAD || i == XFRMA_OFFLOAD_DEV) + continue; + + if (!attrs[i]) + continue; + + err = xfrm_xlate32_attr(dst, attrs[i], &pos, size); + if (err) + return err; + } + + return 0; +} + +static struct nlmsghdr *xfrm_user_rcv_msg_compat(const struct nlmsghdr *h32, + int maxtype, const struct nla_policy *policy) +{ + /* netlink_rcv_skb() checks if a message has full (struct nlmsghdr) */ + u16 type = h32->nlmsg_type - XFRM_MSG_BASE; + struct nlattr *attrs[XFRMA_MAX+1]; + struct nlmsghdr *h64; + size_t len; + int err; + + BUILD_BUG_ON(ARRAY_SIZE(xfrm_msg_min) != ARRAY_SIZE(compat_msg_min)); + + if (type >= ARRAY_SIZE(xfrm_msg_min)) + return ERR_PTR(-EINVAL); + + /* Don't call parse: the message might have only nlmsg header */ + if ((h32->nlmsg_type == XFRM_MSG_GETSA || + h32->nlmsg_type == XFRM_MSG_GETPOLICY) && + (h32->nlmsg_flags & NLM_F_DUMP)) + return NULL; + + err = nlmsg_parse(h32, compat_msg_min[type], attrs, + maxtype ? : XFRMA_MAX, policy ? : compat_policy); + if (err < 0) + return ERR_PTR(err); + + len = xfrm_user_rcv_calculate_len64(h32, attrs); + /* The message doesn't need translation */ + if (len == nlmsg_len(h32)) + return NULL; + + len += NLMSG_HDRLEN; + h64 = kvmalloc(len, GFP_KERNEL | __GFP_ZERO); + if (!h64) + return ERR_PTR(-ENOMEM); + + err = xfrm_xlate32(h64, h32, attrs, len, type); + if (err < 0) { + kvfree(h64); + return ERR_PTR(err); + } + + return h64; +} + +static int xfrm_user_policy_compat(u8 **pdata32, int optlen) +{ + struct compat_xfrm_userpolicy_info *p = (void *)*pdata32; + u8 *src_templates, *dst_templates; + u8 *data64; + + if (optlen < sizeof(*p)) + return -EINVAL; + + data64 = kmalloc_track_caller(optlen + 4, GFP_USER | __GFP_NOWARN); + if (!data64) + return -ENOMEM; + + memcpy(data64, *pdata32, sizeof(*p)); + memset(data64 + sizeof(*p), 0, 4); + + src_templates = *pdata32 + sizeof(*p); + dst_templates = data64 + sizeof(*p) + 4; + memcpy(dst_templates, src_templates, optlen - sizeof(*p)); + + kfree(*pdata32); + *pdata32 = data64; + return 0; +} + +static struct xfrm_translator xfrm_translator = { + .owner = THIS_MODULE, + .alloc_compat = xfrm_alloc_compat, + .rcv_msg_compat = xfrm_user_rcv_msg_compat, + .xlate_user_policy_sockptr = xfrm_user_policy_compat, +}; + +static int __init xfrm_compat_init(void) +{ + return xfrm_register_translator(&xfrm_translator); +} + +static void __exit xfrm_compat_exit(void) +{ + xfrm_unregister_translator(&xfrm_translator); +} + +module_init(xfrm_compat_init); +module_exit(xfrm_compat_exit); +MODULE_LICENSE("GPL"); +MODULE_AUTHOR("Dmitry Safonov"); +MODULE_DESCRIPTION("XFRM 32-bit compatibility layer"); diff --git a/net/xfrm/xfrm_output.c b/net/xfrm/xfrm_output.c index 3e45e778404d..6986da7e4ce8 100644 --- a/net/xfrm/xfrm_output.c +++ b/net/xfrm/xfrm_output.c @@ -66,6 +66,9 @@ static int xfrm_output_one(struct sk_buff *skb, int err) goto error_nolock; } + if (x->props.output_mark) + skb->mark = x->props.output_mark; + err = x->outer_mode->output(x, skb); if (err) { XFRM_INC_STATS(net, LINUX_MIB_XFRMOUTSTATEMODEERROR); diff --git a/net/xfrm/xfrm_policy.c b/net/xfrm/xfrm_policy.c index e62d4819089c..2c34636218c0 100644 --- a/net/xfrm/xfrm_policy.c +++ b/net/xfrm/xfrm_policy.c @@ -119,7 +119,7 @@ static inline struct dst_entry *__xfrm_dst_lookup(struct net *net, int tos, int oif, const xfrm_address_t *saddr, const xfrm_address_t *daddr, - int family) + int family, u32 mark) { struct xfrm_policy_afinfo *afinfo; struct dst_entry *dst; @@ -128,7 +128,7 @@ static inline struct dst_entry *__xfrm_dst_lookup(struct net *net, if (unlikely(afinfo == NULL)) return ERR_PTR(-EAFNOSUPPORT); - dst = afinfo->dst_lookup(net, tos, oif, saddr, daddr); + dst = afinfo->dst_lookup(net, tos, oif, saddr, daddr, mark); xfrm_policy_put_afinfo(afinfo); @@ -139,7 +139,7 @@ static inline struct dst_entry *xfrm_dst_lookup(struct xfrm_state *x, int tos, int oif, xfrm_address_t *prev_saddr, xfrm_address_t *prev_daddr, - int family) + int family, u32 mark) { struct net *net = xs_net(x); xfrm_address_t *saddr = &x->props.saddr; @@ -155,7 +155,7 @@ static inline struct dst_entry *xfrm_dst_lookup(struct xfrm_state *x, daddr = x->coaddr; } - dst = __xfrm_dst_lookup(net, tos, oif, saddr, daddr, family); + dst = __xfrm_dst_lookup(net, tos, oif, saddr, daddr, family, mark); if (!IS_ERR(dst)) { if (prev_saddr != saddr) @@ -1404,14 +1404,14 @@ int __xfrm_sk_clone_policy(struct sock *sk, const struct sock *osk) static int xfrm_get_saddr(struct net *net, int oif, xfrm_address_t *local, - xfrm_address_t *remote, unsigned short family) + xfrm_address_t *remote, unsigned short family, u32 mark) { int err; struct xfrm_policy_afinfo *afinfo = xfrm_policy_get_afinfo(family); if (unlikely(afinfo == NULL)) return -EINVAL; - err = afinfo->get_saddr(net, oif, local, remote); + err = afinfo->get_saddr(net, oif, local, remote, mark); xfrm_policy_put_afinfo(afinfo); return err; } @@ -1442,7 +1442,7 @@ xfrm_tmpl_resolve_one(struct xfrm_policy *policy, const struct flowi *fl, if (xfrm_addr_any(local, tmpl->encap_family)) { error = xfrm_get_saddr(net, fl->flowi_oif, &tmp, remote, - tmpl->encap_family); + tmpl->encap_family, 0); if (error) goto fail; local = &tmp; @@ -1721,7 +1721,8 @@ static struct dst_entry *xfrm_bundle_create(struct xfrm_policy *policy, if (xfrm[i]->props.mode != XFRM_MODE_TRANSPORT) { family = xfrm[i]->props.family; dst = xfrm_dst_lookup(xfrm[i], tos, fl->flowi_oif, - &saddr, &daddr, family); + &saddr, &daddr, family, + xfrm[i]->props.output_mark); err = PTR_ERR(dst); if (IS_ERR(dst)) goto put_states; diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c index ed05cd7a4ef2..4875b5167858 100644 --- a/net/xfrm/xfrm_state.c +++ b/net/xfrm/xfrm_state.c @@ -1361,6 +1361,15 @@ out: if (x1->curlft.use_time) xfrm_state_check_expire(x1); + if (x->props.output_mark) { + spin_lock_bh(&net->xfrm.xfrm_state_lock); + + x1->props.output_mark = x->props.output_mark; + + __xfrm_state_bump_genids(x1); + spin_unlock_bh(&net->xfrm.xfrm_state_lock); + } + err = 0; x->km.state = XFRM_STATE_DEAD; __xfrm_state_put(x); @@ -1847,6 +1856,66 @@ bool km_is_alive(const struct km_event *c) } EXPORT_SYMBOL(km_is_alive); +#if IS_ENABLED(CONFIG_XFRM_USER_COMPAT) +static DEFINE_SPINLOCK(xfrm_translator_lock); +static struct xfrm_translator __rcu *xfrm_translator; + +struct xfrm_translator *xfrm_get_translator(void) +{ + struct xfrm_translator *xtr; + + rcu_read_lock(); + xtr = rcu_dereference(xfrm_translator); + if (unlikely(!xtr)) + goto out; + if (!try_module_get(xtr->owner)) + xtr = NULL; +out: + rcu_read_unlock(); + return xtr; +} +EXPORT_SYMBOL_GPL(xfrm_get_translator); + +void xfrm_put_translator(struct xfrm_translator *xtr) +{ + module_put(xtr->owner); +} +EXPORT_SYMBOL_GPL(xfrm_put_translator); + +int xfrm_register_translator(struct xfrm_translator *xtr) +{ + int err = 0; + + spin_lock_bh(&xfrm_translator_lock); + if (unlikely(xfrm_translator != NULL)) + err = -EEXIST; + else + rcu_assign_pointer(xfrm_translator, xtr); + spin_unlock_bh(&xfrm_translator_lock); + + return err; +} +EXPORT_SYMBOL_GPL(xfrm_register_translator); + +int xfrm_unregister_translator(struct xfrm_translator *xtr) +{ + int err = 0; + + spin_lock_bh(&xfrm_translator_lock); + if (likely(xfrm_translator != NULL)) { + if (rcu_access_pointer(xfrm_translator) != xtr) + err = -EINVAL; + else + RCU_INIT_POINTER(xfrm_translator, NULL); + } + spin_unlock_bh(&xfrm_translator_lock); + synchronize_rcu(); + + return err; +} +EXPORT_SYMBOL_GPL(xfrm_unregister_translator); +#endif + int xfrm_user_policy(struct sock *sk, int optname, u8 __user *optval, int optlen) { int err; @@ -1854,11 +1923,6 @@ int xfrm_user_policy(struct sock *sk, int optname, u8 __user *optval, int optlen struct xfrm_mgr *km; struct xfrm_policy *pol = NULL; -#ifdef CONFIG_COMPAT - if (is_compat_task()) - return -EOPNOTSUPP; -#endif - if (!optval && !optlen) { xfrm_sk_policy_insert(sk, XFRM_POLICY_IN, NULL); xfrm_sk_policy_insert(sk, XFRM_POLICY_OUT, NULL); @@ -1877,6 +1941,23 @@ int xfrm_user_policy(struct sock *sk, int optname, u8 __user *optval, int optlen if (copy_from_user(data, optval, optlen)) goto out; + /* Use the 64-bit / untranslated format on Android, even for compat */ + if (!IS_ENABLED(CONFIG_ANDROID) || IS_ENABLED(CONFIG_XFRM_USER_COMPAT)) { + if (is_compat_task()) { + struct xfrm_translator *xtr = xfrm_get_translator(); + + if (!xtr) + return -EOPNOTSUPP; + + err = xtr->xlate_user_policy_sockptr(&data, optlen); + xfrm_put_translator(xtr); + if (err) { + kfree(data); + return err; + } + } + } + err = -EINVAL; rcu_read_lock(); list_for_each_entry_rcu(km, &xfrm_km_list, list) { diff --git a/net/xfrm/xfrm_user.c b/net/xfrm/xfrm_user.c index 98ea6ebc7317..3fd866867ce9 100644 --- a/net/xfrm/xfrm_user.c +++ b/net/xfrm/xfrm_user.c @@ -605,6 +605,9 @@ static struct xfrm_state *xfrm_state_construct(struct net *net, xfrm_mark_get(attrs, &x->mark); + if (attrs[XFRMA_OUTPUT_MARK]) + x->props.output_mark = nla_get_u32(attrs[XFRMA_OUTPUT_MARK]); + err = __xfrm_init_state(x, false); if (err) goto error; @@ -889,6 +892,11 @@ static int copy_to_user_state_extra(struct xfrm_state *x, &x->replay); if (ret) goto out; + if (x->props.output_mark) { + ret = nla_put_u32(skb, XFRMA_OUTPUT_MARK, x->props.output_mark); + if (ret) + goto out; + } if (x->security) ret = copy_sec_ctx(x->security, skb); out: @@ -900,6 +908,7 @@ static int dump_one_state(struct xfrm_state *x, int count, void *ptr) struct xfrm_dump_info *sp = ptr; struct sk_buff *in_skb = sp->in_skb; struct sk_buff *skb = sp->out_skb; + struct xfrm_translator *xtr; struct xfrm_usersa_info *p; struct nlmsghdr *nlh; int err; @@ -917,6 +926,18 @@ static int dump_one_state(struct xfrm_state *x, int count, void *ptr) return err; } nlmsg_end(skb, nlh); + + xtr = xfrm_get_translator(); + if (xtr) { + err = xtr->alloc_compat(skb, nlh); + + xfrm_put_translator(xtr); + if (err) { + nlmsg_cancel(skb, nlh); + return err; + } + } + return 0; } @@ -931,7 +952,6 @@ static int xfrm_dump_sa_done(struct netlink_callback *cb) return 0; } -static const struct nla_policy xfrma_policy[XFRMA_MAX+1]; static int xfrm_dump_sa(struct sk_buff *skb, struct netlink_callback *cb) { struct net *net = sock_net(skb->sk); @@ -1008,12 +1028,24 @@ static inline int xfrm_nlmsg_multicast(struct net *net, struct sk_buff *skb, u32 pid, unsigned int group) { struct sock *nlsk = rcu_dereference(net->xfrm.nlsk); + struct xfrm_translator *xtr; if (!nlsk) { kfree_skb(skb); return -EPIPE; } + xtr = xfrm_get_translator(); + if (xtr) { + int err = xtr->alloc_compat(skb, nlmsg_hdr(skb)); + + xfrm_put_translator(xtr); + if (err) { + kfree_skb(skb); + return err; + } + } + return nlmsg_multicast(nlsk, skb, pid, group, GFP_ATOMIC); } @@ -1231,6 +1263,7 @@ static int xfrm_alloc_userspi(struct sk_buff *skb, struct nlmsghdr *nlh, struct net *net = sock_net(skb->sk); struct xfrm_state *x; struct xfrm_userspi_info *p; + struct xfrm_translator *xtr; struct sk_buff *resp_skb; xfrm_address_t *daddr; int family; @@ -1276,6 +1309,17 @@ static int xfrm_alloc_userspi(struct sk_buff *skb, struct nlmsghdr *nlh, goto out; } + xtr = xfrm_get_translator(); + if (xtr) { + err = xtr->alloc_compat(skb, nlmsg_hdr(skb)); + + xfrm_put_translator(xtr); + if (err) { + kfree_skb(resp_skb); + goto out; + } + } + err = nlmsg_unicast(net->xfrm.nlsk, resp_skb, NETLINK_CB(skb).portid); out: @@ -1679,6 +1723,7 @@ static int dump_one_policy(struct xfrm_policy *xp, int dir, int count, void *ptr struct xfrm_userpolicy_info *p; struct sk_buff *in_skb = sp->in_skb; struct sk_buff *skb = sp->out_skb; + struct xfrm_translator *xtr; struct nlmsghdr *nlh; int err; @@ -1701,6 +1746,18 @@ static int dump_one_policy(struct xfrm_policy *xp, int dir, int count, void *ptr return err; } nlmsg_end(skb, nlh); + + xtr = xfrm_get_translator(); + if (xtr) { + err = xtr->alloc_compat(skb, nlh); + + xfrm_put_translator(xtr); + if (err) { + nlmsg_cancel(skb, nlh); + return err; + } + } + return 0; } @@ -1747,6 +1804,10 @@ static struct sk_buff *xfrm_policy_netlink(struct sk_buff *in_skb, struct sk_buff *skb; int err; + err = verify_policy_dir(dir); + if (err) + return ERR_PTR(err); + skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); if (!skb) return ERR_PTR(-ENOMEM); @@ -2275,6 +2336,10 @@ static int xfrm_do_migrate(struct sk_buff *skb, struct nlmsghdr *nlh, int n = 0; struct net *net = sock_net(skb->sk); + err = verify_policy_dir(pi->dir); + if (err) + return err; + if (attrs[XFRMA_MIGRATE] == NULL) return -EINVAL; @@ -2390,6 +2455,11 @@ static int xfrm_send_migrate(const struct xfrm_selector *sel, u8 dir, u8 type, { struct net *net = &init_net; struct sk_buff *skb; + int err; + + err = verify_policy_dir(dir); + if (err) + return err; skb = nlmsg_new(xfrm_migrate_msgsize(num_migrate, !!k), GFP_ATOMIC); if (skb == NULL) @@ -2412,7 +2482,7 @@ static int xfrm_send_migrate(const struct xfrm_selector *sel, u8 dir, u8 type, #define XMSGSIZE(type) sizeof(struct type) -static const int xfrm_msg_min[XFRM_NR_MSGTYPES] = { +const int xfrm_msg_min[XFRM_NR_MSGTYPES] = { [XFRM_MSG_NEWSA - XFRM_MSG_BASE] = XMSGSIZE(xfrm_usersa_info), [XFRM_MSG_DELSA - XFRM_MSG_BASE] = XMSGSIZE(xfrm_usersa_id), [XFRM_MSG_GETSA - XFRM_MSG_BASE] = XMSGSIZE(xfrm_usersa_id), @@ -2435,10 +2505,11 @@ static const int xfrm_msg_min[XFRM_NR_MSGTYPES] = { [XFRM_MSG_NEWSPDINFO - XFRM_MSG_BASE] = sizeof(u32), [XFRM_MSG_GETSPDINFO - XFRM_MSG_BASE] = sizeof(u32), }; +EXPORT_SYMBOL_GPL(xfrm_msg_min); #undef XMSGSIZE -static const struct nla_policy xfrma_policy[XFRMA_MAX+1] = { +const struct nla_policy xfrma_policy[XFRMA_MAX+1] = { [XFRMA_SA] = { .len = sizeof(struct xfrm_usersa_info)}, [XFRMA_POLICY] = { .len = sizeof(struct xfrm_userpolicy_info)}, [XFRMA_LASTUSED] = { .type = NLA_U64}, @@ -2465,7 +2536,9 @@ static const struct nla_policy xfrma_policy[XFRMA_MAX+1] = { [XFRMA_SA_EXTRA_FLAGS] = { .type = NLA_U32 }, [XFRMA_PROTO] = { .type = NLA_U8 }, [XFRMA_ADDRESS_FILTER] = { .len = sizeof(struct xfrm_address_filter) }, + [XFRMA_OUTPUT_MARK] = { .len = NLA_U32 }, }; +EXPORT_SYMBOL_GPL(xfrma_policy); static const struct nla_policy xfrma_spd_policy[XFRMA_SPD_MAX+1] = { [XFRMA_SPD_IPV4_HTHRESH] = { .len = sizeof(struct xfrmu_spdhthresh) }, @@ -2514,13 +2587,9 @@ static int xfrm_user_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh) struct net *net = sock_net(skb->sk); struct nlattr *attrs[XFRMA_MAX+1]; const struct xfrm_link *link; + struct nlmsghdr *nlh64 = NULL; int type, err; -#ifdef CONFIG_COMPAT - if (is_compat_task()) - return -EOPNOTSUPP; -#endif - type = nlh->nlmsg_type; if (type > XFRM_MSG_MAX) return -EINVAL; @@ -2532,32 +2601,58 @@ static int xfrm_user_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh) if (!netlink_net_capable(skb, CAP_NET_ADMIN)) return -EPERM; + /* Use the 64-bit / untranslated format on Android, even for compat */ + if (!IS_ENABLED(CONFIG_ANDROID) || IS_ENABLED(CONFIG_XFRM_USER_COMPAT)) { + if (is_compat_task()) { + struct xfrm_translator *xtr = xfrm_get_translator(); + + if (!xtr) + return -EOPNOTSUPP; + + nlh64 = xtr->rcv_msg_compat(nlh, link->nla_max, + link->nla_pol); + xfrm_put_translator(xtr); + if (IS_ERR(nlh64)) + return PTR_ERR(nlh64); + if (nlh64) + nlh = nlh64; + } + } + if ((type == (XFRM_MSG_GETSA - XFRM_MSG_BASE) || type == (XFRM_MSG_GETPOLICY - XFRM_MSG_BASE)) && (nlh->nlmsg_flags & NLM_F_DUMP)) { - if (link->dump == NULL) - return -EINVAL; + struct netlink_dump_control c = { + .start = link->start, + .dump = link->dump, + .done = link->done, + }; - { - struct netlink_dump_control c = { - .start = link->start, - .dump = link->dump, - .done = link->done, - }; - return netlink_dump_start(net->xfrm.nlsk, skb, nlh, &c); + if (link->dump == NULL) { + err = -EINVAL; + goto err; } + + err = netlink_dump_start(net->xfrm.nlsk, skb, nlh, &c); + goto err; } err = nlmsg_parse(nlh, xfrm_msg_min[type], attrs, link->nla_max ? : XFRMA_MAX, link->nla_pol ? : xfrma_policy); if (err < 0) - return err; + goto err; - if (link->doit == NULL) - return -EINVAL; + if (link->doit == NULL) { + err = -EINVAL; + goto err; + } + + err = link->doit(skb, nlh, attrs); - return link->doit(skb, nlh, attrs); +err: + kvfree(nlh64); + return err; } static void xfrm_netlink_rcv(struct sk_buff *skb) @@ -2684,6 +2779,8 @@ static inline size_t xfrm_sa_len(struct xfrm_state *x) l += nla_total_size(sizeof(*x->coaddr)); if (x->props.extra_flags) l += nla_total_size(sizeof(x->props.extra_flags)); + if (x->props.output_mark) + l += nla_total_size(sizeof(x->props.output_mark)); /* Must count x->lastused as it may become non-zero behind our back. */ l += nla_total_size(sizeof(u64)); @@ -3047,6 +3144,11 @@ out_free_skb: static int xfrm_send_policy_notify(struct xfrm_policy *xp, int dir, const struct km_event *c) { + int err; + + err = verify_policy_dir(dir); + if (err) + return err; switch (c->event) { case XFRM_MSG_NEWPOLICY: |