diff options
Diffstat (limited to 'net/ipv4')
-rw-r--r-- | net/ipv4/tcp.c | 5 | ||||
-rw-r--r-- | net/ipv4/tcp_bpf.c | 27 | ||||
-rw-r--r-- | net/ipv4/tcp_ipv4.c | 45 | ||||
-rw-r--r-- | net/ipv4/udp.c | 3 | ||||
-rw-r--r-- | net/ipv4/udp_bpf.c | 1 |
5 files changed, 50 insertions, 31 deletions
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c index e8b48df73c85..f5c336f8b0c8 100644 --- a/net/ipv4/tcp.c +++ b/net/ipv4/tcp.c @@ -486,10 +486,7 @@ static bool tcp_stream_is_readable(struct sock *sk, int target) { if (tcp_epollin_ready(sk, target)) return true; - - if (sk->sk_prot->stream_memory_read) - return sk->sk_prot->stream_memory_read(sk); - return false; + return sk_is_readable(sk); } /* diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c index d3e9386b493e..5f4d6f45d87f 100644 --- a/net/ipv4/tcp_bpf.c +++ b/net/ipv4/tcp_bpf.c @@ -150,19 +150,6 @@ int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); #ifdef CONFIG_BPF_SYSCALL -static bool tcp_bpf_stream_read(const struct sock *sk) -{ - struct sk_psock *psock; - bool empty = true; - - rcu_read_lock(); - psock = sk_psock(sk); - if (likely(psock)) - empty = list_empty(&psock->ingress_msg); - rcu_read_unlock(); - return !empty; -} - static int tcp_msg_wait_data(struct sock *sk, struct sk_psock *psock, long timeo) { @@ -232,6 +219,7 @@ static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, bool cork = false, enospc = sk_msg_full(msg); struct sock *sk_redir; u32 tosend, delta = 0; + u32 eval = __SK_NONE; int ret; more_data: @@ -275,13 +263,24 @@ more_data: case __SK_REDIRECT: sk_redir = psock->sk_redir; sk_msg_apply_bytes(psock, tosend); + if (!psock->apply_bytes) { + /* Clean up before releasing the sock lock. */ + eval = psock->eval; + psock->eval = __SK_NONE; + psock->sk_redir = NULL; + } if (psock->cork) { cork = true; psock->cork = NULL; } sk_msg_return(sk, msg, tosend); release_sock(sk); + ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags); + + if (eval == __SK_REDIRECT) + sock_put(sk_redir); + lock_sock(sk); if (unlikely(ret < 0)) { int free = sk_msg_free_nocharge(sk, msg); @@ -479,7 +478,7 @@ static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], prot[TCP_BPF_BASE].unhash = sock_map_unhash; prot[TCP_BPF_BASE].close = sock_map_close; prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; - prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read; + prot[TCP_BPF_BASE].sock_is_readable = sk_msg_is_readable; prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index 2e62e0d6373a..5b8ce65dfc06 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -1037,6 +1037,20 @@ static void tcp_v4_reqsk_destructor(struct request_sock *req) DEFINE_STATIC_KEY_FALSE(tcp_md5_needed); EXPORT_SYMBOL(tcp_md5_needed); +static bool better_md5_match(struct tcp_md5sig_key *old, struct tcp_md5sig_key *new) +{ + if (!old) + return true; + + /* l3index always overrides non-l3index */ + if (old->l3index && new->l3index == 0) + return false; + if (old->l3index == 0 && new->l3index) + return true; + + return old->prefixlen < new->prefixlen; +} + /* Find the Key structure for an address. */ struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index, const union tcp_md5_addr *addr, @@ -1059,7 +1073,7 @@ struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index, lockdep_sock_is_held(sk)) { if (key->family != family) continue; - if (key->l3index && key->l3index != l3index) + if (key->flags & TCP_MD5SIG_FLAG_IFINDEX && key->l3index != l3index) continue; if (family == AF_INET) { mask = inet_make_mask(key->prefixlen); @@ -1074,8 +1088,7 @@ struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index, match = false; } - if (match && (!best_match || - key->prefixlen > best_match->prefixlen)) + if (match && better_md5_match(best_match, key)) best_match = key; } return best_match; @@ -1085,7 +1098,7 @@ EXPORT_SYMBOL(__tcp_md5_do_lookup); static struct tcp_md5sig_key *tcp_md5_do_lookup_exact(const struct sock *sk, const union tcp_md5_addr *addr, int family, u8 prefixlen, - int l3index) + int l3index, u8 flags) { const struct tcp_sock *tp = tcp_sk(sk); struct tcp_md5sig_key *key; @@ -1105,7 +1118,9 @@ static struct tcp_md5sig_key *tcp_md5_do_lookup_exact(const struct sock *sk, lockdep_sock_is_held(sk)) { if (key->family != family) continue; - if (key->l3index && key->l3index != l3index) + if ((key->flags & TCP_MD5SIG_FLAG_IFINDEX) != (flags & TCP_MD5SIG_FLAG_IFINDEX)) + continue; + if (key->l3index != l3index) continue; if (!memcmp(&key->addr, addr, size) && key->prefixlen == prefixlen) @@ -1129,7 +1144,7 @@ EXPORT_SYMBOL(tcp_v4_md5_lookup); /* This can be called on a newly created socket, from other files */ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, - int family, u8 prefixlen, int l3index, + int family, u8 prefixlen, int l3index, u8 flags, const u8 *newkey, u8 newkeylen, gfp_t gfp) { /* Add Key to the list */ @@ -1137,7 +1152,7 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, struct tcp_sock *tp = tcp_sk(sk); struct tcp_md5sig_info *md5sig; - key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen, l3index); + key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen, l3index, flags); if (key) { /* Pre-existing entry - just update that one. * Note that the key might be used concurrently. @@ -1182,6 +1197,7 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, key->family = family; key->prefixlen = prefixlen; key->l3index = l3index; + key->flags = flags; memcpy(&key->addr, addr, (family == AF_INET6) ? sizeof(struct in6_addr) : sizeof(struct in_addr)); @@ -1191,11 +1207,11 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, EXPORT_SYMBOL(tcp_md5_do_add); int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr, int family, - u8 prefixlen, int l3index) + u8 prefixlen, int l3index, u8 flags) { struct tcp_md5sig_key *key; - key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen, l3index); + key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen, l3index, flags); if (!key) return -ENOENT; hlist_del_rcu(&key->node); @@ -1229,6 +1245,7 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname, const union tcp_md5_addr *addr; u8 prefixlen = 32; int l3index = 0; + u8 flags; if (optlen < sizeof(cmd)) return -EINVAL; @@ -1239,6 +1256,8 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname, if (sin->sin_family != AF_INET) return -EINVAL; + flags = cmd.tcpm_flags & TCP_MD5SIG_FLAG_IFINDEX; + if (optname == TCP_MD5SIG_EXT && cmd.tcpm_flags & TCP_MD5SIG_FLAG_PREFIX) { prefixlen = cmd.tcpm_prefixlen; @@ -1246,7 +1265,7 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname, return -EINVAL; } - if (optname == TCP_MD5SIG_EXT && + if (optname == TCP_MD5SIG_EXT && cmd.tcpm_ifindex && cmd.tcpm_flags & TCP_MD5SIG_FLAG_IFINDEX) { struct net_device *dev; @@ -1267,12 +1286,12 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname, addr = (union tcp_md5_addr *)&sin->sin_addr.s_addr; if (!cmd.tcpm_keylen) - return tcp_md5_do_del(sk, addr, AF_INET, prefixlen, l3index); + return tcp_md5_do_del(sk, addr, AF_INET, prefixlen, l3index, flags); if (cmd.tcpm_keylen > TCP_MD5SIG_MAXKEYLEN) return -EINVAL; - return tcp_md5_do_add(sk, addr, AF_INET, prefixlen, l3index, + return tcp_md5_do_add(sk, addr, AF_INET, prefixlen, l3index, flags, cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL); } @@ -1596,7 +1615,7 @@ struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb, * memory, then we end up not copying the key * across. Shucks. */ - tcp_md5_do_add(newsk, addr, AF_INET, 32, l3index, + tcp_md5_do_add(newsk, addr, AF_INET, 32, l3index, key->flags, key->key, key->keylen, GFP_ATOMIC); sk_nocaps_add(newsk, NETIF_F_GSO_MASK); } diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 8536b2a7210b..2fffcf2b54f3 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -2867,6 +2867,9 @@ __poll_t udp_poll(struct file *file, struct socket *sock, poll_table *wait) !(sk->sk_shutdown & RCV_SHUTDOWN) && first_packet_length(sk) == -1) mask &= ~(EPOLLIN | EPOLLRDNORM); + /* psock ingress_msg queue should not contain any bad checksum frames */ + if (sk_is_readable(sk)) + mask |= EPOLLIN | EPOLLRDNORM; return mask; } diff --git a/net/ipv4/udp_bpf.c b/net/ipv4/udp_bpf.c index 7a1d5f473878..bbe6569c9ad3 100644 --- a/net/ipv4/udp_bpf.c +++ b/net/ipv4/udp_bpf.c @@ -114,6 +114,7 @@ static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base) *prot = *base; prot->close = sock_map_close; prot->recvmsg = udp_bpf_recvmsg; + prot->sock_is_readable = sk_msg_is_readable; } static void udp_bpf_check_v6_needs_rebuild(struct proto *ops) |