summaryrefslogtreecommitdiffstats
path: root/net/tls/tls_sw.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls/tls_sw.c')
-rw-r--r--net/tls/tls_sw.c80
1 files changed, 35 insertions, 45 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index bd4486819e64..0fc24a5ce208 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -1283,7 +1283,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
static int
tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
- long timeo)
+ bool released, long timeo)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -1297,7 +1297,7 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
return sock_error(sk);
if (!skb_queue_empty(&sk->sk_receive_queue)) {
- __strp_unpause(&ctx->strp);
+ tls_strp_check_rcv(&ctx->strp);
if (tls_strp_msg_ready(ctx))
break;
}
@@ -1311,6 +1311,7 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
if (nonblock || !timeo)
return -EAGAIN;
+ released = true;
add_wait_queue(sk_sleep(sk), &wait);
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
sk_wait_event(sk, &timeo,
@@ -1325,6 +1326,8 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
return sock_intr_errno(timeo);
}
+ tls_strp_msg_load(&ctx->strp, released);
+
return 1;
}
@@ -1570,7 +1573,7 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
clear_skb = NULL;
if (unlikely(darg->async)) {
- err = tls_strp_msg_hold(sk, skb, &ctx->async_hold);
+ err = tls_strp_msg_hold(&ctx->strp, &ctx->async_hold);
if (err)
__skb_queue_tail(&ctx->async_hold, darg->skb);
return err;
@@ -1734,9 +1737,7 @@ static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
static void tls_rx_rec_done(struct tls_sw_context_rx *ctx)
{
- consume_skb(ctx->recv_pkt);
- ctx->recv_pkt = NULL;
- __strp_unpause(&ctx->strp);
+ tls_strp_msg_done(&ctx->strp);
}
/* This function traverses the rx_list in tls receive context to copies the
@@ -1823,7 +1824,7 @@ out:
return copied ? : err;
}
-static void
+static bool
tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot,
size_t len_left, size_t decrypted, ssize_t done,
size_t *flushed_at)
@@ -1831,14 +1832,14 @@ tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot,
size_t max_rec;
if (len_left <= decrypted)
- return;
+ return false;
max_rec = prot->overhead_size - prot->tail_size + TLS_MAX_PAYLOAD_SIZE;
if (done - *flushed_at < SZ_128K && tcp_inq(sk) > max_rec)
- return;
+ return false;
*flushed_at = done;
- sk_flush_backlog(sk);
+ return sk_flush_backlog(sk);
}
static long tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx,
@@ -1916,6 +1917,7 @@ int tls_sw_recvmsg(struct sock *sk,
long timeo;
bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
bool is_peek = flags & MSG_PEEK;
+ bool released = true;
bool bpf_strp_enabled;
bool zc_capable;
@@ -1952,7 +1954,8 @@ int tls_sw_recvmsg(struct sock *sk,
struct tls_decrypt_arg darg;
int to_decrypt, chunk;
- err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, timeo);
+ err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, released,
+ timeo);
if (err <= 0) {
if (psock) {
chunk = sk_msg_recvmsg(sk, psock, msg, len,
@@ -1968,8 +1971,8 @@ int tls_sw_recvmsg(struct sock *sk,
memset(&darg.inargs, 0, sizeof(darg.inargs));
- rxm = strp_msg(ctx->recv_pkt);
- tlm = tls_msg(ctx->recv_pkt);
+ rxm = strp_msg(tls_strp_msg(ctx));
+ tlm = tls_msg(tls_strp_msg(ctx));
to_decrypt = rxm->full_len - prot->overhead_size;
@@ -2008,8 +2011,9 @@ put_on_rx_list_err:
}
/* periodically flush backlog, and feed strparser */
- tls_read_flush_backlog(sk, prot, len, to_decrypt,
- decrypted + copied, &flushed_at);
+ released = tls_read_flush_backlog(sk, prot, len, to_decrypt,
+ decrypted + copied,
+ &flushed_at);
/* TLS 1.3 may have updated the length by more than overhead */
rxm = strp_msg(darg.skb);
@@ -2020,7 +2024,7 @@ put_on_rx_list_err:
bool partially_consumed = chunk > len;
struct sk_buff *skb = darg.skb;
- DEBUG_NET_WARN_ON_ONCE(darg.skb == ctx->recv_pkt);
+ DEBUG_NET_WARN_ON_ONCE(darg.skb == tls_strp_msg(ctx));
if (async) {
/* TLS 1.2-only, to_decrypt must be text len */
@@ -2034,6 +2038,7 @@ put_on_rx_list:
}
if (bpf_strp_enabled) {
+ released = true;
err = sk_psock_tls_strp_read(psock, skb);
if (err != __SK_PASS) {
rxm->offset = rxm->offset + rxm->full_len;
@@ -2140,7 +2145,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
struct tls_decrypt_arg darg;
err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK,
- timeo);
+ true, timeo);
if (err <= 0)
goto splice_read_end;
@@ -2204,19 +2209,17 @@ bool tls_sw_sock_is_readable(struct sock *sk)
!skb_queue_empty(&ctx->rx_list);
}
-static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
+int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb)
{
struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
- struct strp_msg *rxm = strp_msg(skb);
- struct tls_msg *tlm = tls_msg(skb);
size_t cipher_overhead;
size_t data_len = 0;
int ret;
/* Verify that we have a full TLS header, or wait for more data */
- if (rxm->offset + prot->prepend_size > skb->len)
+ if (strp->stm.offset + prot->prepend_size > skb->len)
return 0;
/* Sanity-check size of on-stack buffer. */
@@ -2226,11 +2229,11 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
}
/* Linearize header to local buffer */
- ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size);
+ ret = skb_copy_bits(skb, strp->stm.offset, header, prot->prepend_size);
if (ret < 0)
goto read_failure;
- tlm->control = header[0];
+ strp->mark = header[0];
data_len = ((header[4] & 0xFF) | (header[3] << 8));
@@ -2257,7 +2260,7 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
}
tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE,
- TCP_SKB_CB(skb)->seq + rxm->offset);
+ TCP_SKB_CB(skb)->seq + strp->stm.offset);
return data_len + TLS_HEADER_SIZE;
read_failure:
@@ -2266,14 +2269,11 @@ read_failure:
return ret;
}
-static void tls_queue(struct strparser *strp, struct sk_buff *skb)
+void tls_rx_msg_ready(struct tls_strparser *strp)
{
- struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
- struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
-
- ctx->recv_pkt = skb;
- strp_pause(strp);
+ struct tls_sw_context_rx *ctx;
+ ctx = container_of(strp, struct tls_sw_context_rx, strp);
ctx->saved_data_ready(strp->sk);
}
@@ -2283,7 +2283,7 @@ static void tls_data_ready(struct sock *sk)
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct sk_psock *psock;
- strp_data_ready(&ctx->strp);
+ tls_strp_data_ready(&ctx->strp);
psock = sk_psock_get(sk);
if (psock) {
@@ -2359,13 +2359,11 @@ void tls_sw_release_resources_rx(struct sock *sk)
kfree(tls_ctx->rx.iv);
if (ctx->aead_recv) {
- kfree_skb(ctx->recv_pkt);
- ctx->recv_pkt = NULL;
__skb_queue_purge(&ctx->rx_list);
crypto_free_aead(ctx->aead_recv);
- strp_stop(&ctx->strp);
+ tls_strp_stop(&ctx->strp);
/* If tls_sw_strparser_arm() was not called (cleanup paths)
- * we still want to strp_stop(), but sk->sk_data_ready was
+ * we still want to tls_strp_stop(), but sk->sk_data_ready was
* never swapped.
*/
if (ctx->saved_data_ready) {
@@ -2380,7 +2378,7 @@ void tls_sw_strparser_done(struct tls_context *tls_ctx)
{
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
- strp_done(&ctx->strp);
+ tls_strp_done(&ctx->strp);
}
void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
@@ -2453,8 +2451,6 @@ void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
rx_ctx->saved_data_ready = sk->sk_data_ready;
sk->sk_data_ready = tls_data_ready;
write_unlock_bh(&sk->sk_callback_lock);
-
- strp_check_rcv(&rx_ctx->strp);
}
void tls_update_rx_zc_capable(struct tls_context *tls_ctx)
@@ -2474,7 +2470,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
struct tls_sw_context_rx *sw_ctx_rx = NULL;
struct cipher_context *cctx;
struct crypto_aead **aead;
- struct strp_callbacks cb;
u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
struct crypto_tfm *tfm;
char *iv, *rec_seq, *key, *salt, *cipher_name;
@@ -2708,12 +2703,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
crypto_info->version != TLS_1_3_VERSION &&
!!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC);
- /* Set up strparser */
- memset(&cb, 0, sizeof(cb));
- cb.rcv_msg = tls_queue;
- cb.parse_msg = tls_read_size;
-
- strp_init(&sw_ctx_rx->strp, sk, &cb);
+ tls_strp_init(&sw_ctx_rx->strp, sk);
}
goto out;