summaryrefslogtreecommitdiffstats
path: root/net/tls/tls_sw.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls/tls_sw.c')
-rw-r--r--net/tls/tls_sw.c144
1 files changed, 95 insertions, 49 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 7453f5ae0819..83d67df33f0c 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -53,7 +53,6 @@ static int tls_do_decryption(struct sock *sk,
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
- struct strp_msg *rxm = strp_msg(skb);
struct aead_request *aead_req;
int ret;
@@ -71,18 +70,6 @@ static int tls_do_decryption(struct sock *sk,
ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
- if (ret < 0)
- goto out;
-
- rxm->offset += tls_ctx->rx.prepend_size;
- rxm->full_len -= tls_ctx->rx.overhead_size;
- tls_advance_record_sn(sk, &tls_ctx->rx);
-
- ctx->decrypted = true;
-
- ctx->saved_data_ready(sk);
-
-out:
aead_request_free(aead_req);
return ret;
}
@@ -324,7 +311,12 @@ static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
}
}
+ /* Mark the end in the last sg entry if newly added */
+ if (num_elem > *pages_used)
+ sg_mark_end(&to[num_elem - 1]);
out:
+ if (rc)
+ iov_iter_revert(from, size - *size_used);
*size_used = size;
*pages_used = num_elem;
@@ -373,6 +365,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
int record_room;
bool full_record;
int orig_size;
+ bool is_kvec = msg->msg_iter.type & ITER_KVEC;
if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
return -ENOTSUPP;
@@ -421,8 +414,7 @@ alloc_encrypted:
try_to_copy -= required_size - ctx->sg_encrypted_size;
full_record = true;
}
-
- if (full_record || eor) {
+ if (!is_kvec && (full_record || eor)) {
ret = zerocopy_from_iter(sk, &msg->msg_iter,
try_to_copy, &ctx->sg_plaintext_num_elem,
&ctx->sg_plaintext_size,
@@ -434,15 +426,11 @@ alloc_encrypted:
copied += try_to_copy;
ret = tls_push_record(sk, msg->msg_flags, record_type);
- if (!ret)
- continue;
- if (ret == -EAGAIN)
+ if (ret)
goto send_end;
+ continue;
- copied -= try_to_copy;
fallback_to_reg_send:
- iov_iter_revert(&msg->msg_iter,
- ctx->sg_plaintext_size - orig_size);
trim_sg(sk, ctx->sg_plaintext_data,
&ctx->sg_plaintext_num_elem,
&ctx->sg_plaintext_size,
@@ -642,6 +630,9 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
return NULL;
}
+ if (sk->sk_shutdown & RCV_SHUTDOWN)
+ return NULL;
+
if (sock_flag(sk, SOCK_DONE))
return NULL;
@@ -666,8 +657,38 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
return skb;
}
-static int decrypt_skb(struct sock *sk, struct sk_buff *skb,
- struct scatterlist *sgout)
+static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
+ struct scatterlist *sgout, bool *zc)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+ struct strp_msg *rxm = strp_msg(skb);
+ int err = 0;
+
+#ifdef CONFIG_TLS_DEVICE
+ err = tls_device_decrypted(sk, skb);
+ if (err < 0)
+ return err;
+#endif
+ if (!ctx->decrypted) {
+ err = decrypt_skb(sk, skb, sgout);
+ if (err < 0)
+ return err;
+ } else {
+ *zc = false;
+ }
+
+ rxm->offset += tls_ctx->rx.prepend_size;
+ rxm->full_len -= tls_ctx->rx.overhead_size;
+ tls_advance_record_sn(sk, &tls_ctx->rx);
+ ctx->decrypted = true;
+ ctx->saved_data_ready(sk);
+
+ return err;
+}
+
+int decrypt_skb(struct sock *sk, struct sk_buff *skb,
+ struct scatterlist *sgout)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -697,6 +718,10 @@ static int decrypt_skb(struct sock *sk, struct sk_buff *skb,
nsg = skb_to_sgvec(skb, &sgin[1],
rxm->offset + tls_ctx->rx.prepend_size,
rxm->full_len - tls_ctx->rx.prepend_size);
+ if (nsg < 0) {
+ ret = nsg;
+ goto out;
+ }
tls_make_aad(ctx->rx_aad_ciphertext,
rxm->full_len - tls_ctx->rx.overhead_size,
@@ -708,6 +733,7 @@ static int decrypt_skb(struct sock *sk, struct sk_buff *skb,
rxm->full_len - tls_ctx->rx.overhead_size,
skb, sk->sk_allocation);
+out:
if (sgin != &sgin_arr[0])
kfree(sgin);
@@ -752,6 +778,7 @@ int tls_sw_recvmsg(struct sock *sk,
bool cmsg = false;
int target, err = 0;
long timeo;
+ bool is_kvec = msg->msg_iter.type & ITER_KVEC;
flags |= nonblock;
@@ -795,7 +822,7 @@ int tls_sw_recvmsg(struct sock *sk,
page_count = iov_iter_npages(&msg->msg_iter,
MAX_SKB_FRAGS);
to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
- if (to_copy <= len && page_count < MAX_SKB_FRAGS &&
+ if (!is_kvec && to_copy <= len && page_count < MAX_SKB_FRAGS &&
likely(!(flags & MSG_PEEK))) {
struct scatterlist sgin[MAX_SKB_FRAGS + 1];
int pages = 0;
@@ -812,7 +839,7 @@ int tls_sw_recvmsg(struct sock *sk,
if (err < 0)
goto fallback_to_reg_recv;
- err = decrypt_skb(sk, skb, sgin);
+ err = decrypt_skb_update(sk, skb, sgin, &zc);
for (; pages > 0; pages--)
put_page(sg_page(&sgin[pages]));
if (err < 0) {
@@ -821,7 +848,7 @@ int tls_sw_recvmsg(struct sock *sk,
}
} else {
fallback_to_reg_recv:
- err = decrypt_skb(sk, skb, NULL);
+ err = decrypt_skb_update(sk, skb, NULL, &zc);
if (err < 0) {
tls_err_abort(sk, EBADMSG);
goto recv_end;
@@ -876,6 +903,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
int err = 0;
long timeo;
int chunk;
+ bool zc;
lock_sock(sk);
@@ -892,7 +920,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
}
if (!ctx->decrypted) {
- err = decrypt_skb(sk, skb, NULL);
+ err = decrypt_skb_update(sk, skb, NULL, &zc);
if (err < 0) {
tls_err_abort(sk, EBADMSG);
@@ -981,6 +1009,10 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
goto read_failure;
}
+#ifdef CONFIG_TLS_DEVICE
+ handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
+ *(u64*)tls_ctx->rx.rec_seq);
+#endif
return data_len + TLS_HEADER_SIZE;
read_failure:
@@ -999,7 +1031,7 @@ static void tls_queue(struct strparser *strp, struct sk_buff *skb)
ctx->recv_pkt = skb;
strp_pause(strp);
- strp->sk->sk_state_change(strp->sk);
+ ctx->saved_data_ready(strp->sk);
}
static void tls_data_ready(struct sock *sk)
@@ -1015,23 +1047,20 @@ void tls_sw_free_resources_tx(struct sock *sk)
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
- if (ctx->aead_send)
- crypto_free_aead(ctx->aead_send);
+ crypto_free_aead(ctx->aead_send);
tls_free_both_sg(sk);
kfree(ctx);
}
-void tls_sw_free_resources_rx(struct sock *sk)
+void tls_sw_release_resources_rx(struct sock *sk)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
if (ctx->aead_recv) {
- if (ctx->recv_pkt) {
- kfree_skb(ctx->recv_pkt);
- ctx->recv_pkt = NULL;
- }
+ kfree_skb(ctx->recv_pkt);
+ ctx->recv_pkt = NULL;
crypto_free_aead(ctx->aead_recv);
strp_stop(&ctx->strp);
write_lock_bh(&sk->sk_callback_lock);
@@ -1041,6 +1070,14 @@ void tls_sw_free_resources_rx(struct sock *sk)
strp_done(&ctx->strp);
lock_sock(sk);
}
+}
+
+void tls_sw_free_resources_rx(struct sock *sk)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+
+ tls_sw_release_resources_rx(sk);
kfree(ctx);
}
@@ -1065,28 +1102,38 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
}
if (tx) {
- sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
- if (!sw_ctx_tx) {
- rc = -ENOMEM;
- goto out;
+ if (!ctx->priv_ctx_tx) {
+ sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
+ if (!sw_ctx_tx) {
+ rc = -ENOMEM;
+ goto out;
+ }
+ ctx->priv_ctx_tx = sw_ctx_tx;
+ } else {
+ sw_ctx_tx =
+ (struct tls_sw_context_tx *)ctx->priv_ctx_tx;
}
- crypto_init_wait(&sw_ctx_tx->async_wait);
- ctx->priv_ctx_tx = sw_ctx_tx;
} else {
- sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
- if (!sw_ctx_rx) {
- rc = -ENOMEM;
- goto out;
+ if (!ctx->priv_ctx_rx) {
+ sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
+ if (!sw_ctx_rx) {
+ rc = -ENOMEM;
+ goto out;
+ }
+ ctx->priv_ctx_rx = sw_ctx_rx;
+ } else {
+ sw_ctx_rx =
+ (struct tls_sw_context_rx *)ctx->priv_ctx_rx;
}
- crypto_init_wait(&sw_ctx_rx->async_wait);
- ctx->priv_ctx_rx = sw_ctx_rx;
}
if (tx) {
+ crypto_init_wait(&sw_ctx_tx->async_wait);
crypto_info = &ctx->crypto_send;
cctx = &ctx->tx;
aead = &sw_ctx_tx->aead_send;
} else {
+ crypto_init_wait(&sw_ctx_rx->async_wait);
crypto_info = &ctx->crypto_recv;
cctx = &ctx->rx;
aead = &sw_ctx_rx->aead_recv;
@@ -1129,12 +1176,11 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
cctx->rec_seq_size = rec_seq_size;
- cctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
+ cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
if (!cctx->rec_seq) {
rc = -ENOMEM;
goto free_iv;
}
- memcpy(cctx->rec_seq, rec_seq, rec_seq_size);
if (sw_ctx_tx) {
sg_init_table(sw_ctx_tx->sg_encrypted_data,