From 8a05661b2b266b6dc45af255b3037b00ef31d85d Mon Sep 17 00:00:00 2001 From: Paolo Abeni Date: Mon, 29 Jun 2020 22:26:25 +0200 Subject: mptcp: close poll() races mptcp_poll always return POLLOUT for unblocking connect(), ensure that the socket is a suitable state. The MPTCP_DATA_READY bit is never cleared on accept: ensure we don't leave mptcp_accept() with an empty accept queue and such bit set. Signed-off-by: Paolo Abeni Signed-off-by: Davide Caratti Reviewed-by: Mat Martineau Signed-off-by: David S. Miller --- net/mptcp/protocol.c | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) (limited to 'net/mptcp') diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c index f2b2bd37e371..28ec26d97f96 100644 --- a/net/mptcp/protocol.c +++ b/net/mptcp/protocol.c @@ -1841,6 +1841,7 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, if (!ssock) goto unlock_fail; + clear_bit(MPTCP_DATA_READY, &msk->flags); sock_hold(ssock->sk); release_sock(sock->sk); @@ -1861,6 +1862,8 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, } } + if (inet_csk_listen_poll(ssock->sk)) + set_bit(MPTCP_DATA_READY, &msk->flags); sock_put(ssock->sk); return err; @@ -1869,21 +1872,33 @@ unlock_fail: return -EINVAL; } +static __poll_t mptcp_check_readable(struct mptcp_sock *msk) +{ + return test_bit(MPTCP_DATA_READY, &msk->flags) ? EPOLLIN | EPOLLRDNORM : + 0; +} + static __poll_t mptcp_poll(struct file *file, struct socket *sock, struct poll_table_struct *wait) { struct sock *sk = sock->sk; struct mptcp_sock *msk; __poll_t mask = 0; + int state; msk = mptcp_sk(sk); sock_poll_wait(file, sock, wait); - if (test_bit(MPTCP_DATA_READY, &msk->flags)) - mask = EPOLLIN | EPOLLRDNORM; - if (sk_stream_is_writeable(sk) && - test_bit(MPTCP_SEND_SPACE, &msk->flags)) - mask |= EPOLLOUT | EPOLLWRNORM; + state = inet_sk_state_load(sk); + if (state == TCP_LISTEN) + return mptcp_check_readable(msk); + + if (state != TCP_SYN_SENT && state != TCP_SYN_RECV) { + mask |= mptcp_check_readable(msk); + if (sk_stream_is_writeable(sk) && + test_bit(MPTCP_SEND_SPACE, &msk->flags)) + mask |= EPOLLOUT | EPOLLWRNORM; + } if (sk->sk_shutdown & RCV_SHUTDOWN) mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP; -- cgit v1.2.3