diff options
Diffstat (limited to 'fs/cifs/smb2ops.c')
-rw-r--r-- | fs/cifs/smb2ops.c | 378 |
1 files changed, 181 insertions, 197 deletions
diff --git a/fs/cifs/smb2ops.c b/fs/cifs/smb2ops.c index 3fea94212b73..f79b075f2992 100644 --- a/fs/cifs/smb2ops.c +++ b/fs/cifs/smb2ops.c @@ -4238,8 +4238,8 @@ fill_transform_hdr(struct smb2_transform_hdr *tr_hdr, unsigned int orig_len, static void *smb2_aead_req_alloc(struct crypto_aead *tfm, const struct smb_rqst *rqst, int num_rqst, const u8 *sig, u8 **iv, - struct aead_request **req, struct scatterlist **sgl, - unsigned int *num_sgs) + struct aead_request **req, struct sg_table *sgt, + unsigned int *num_sgs, size_t *sensitive_size) { unsigned int req_size = sizeof(**req) + crypto_aead_reqsize(tfm); unsigned int iv_size = crypto_aead_ivsize(tfm); @@ -4247,43 +4247,45 @@ static void *smb2_aead_req_alloc(struct crypto_aead *tfm, const struct smb_rqst u8 *p; *num_sgs = cifs_get_num_sgs(rqst, num_rqst, sig); + if (IS_ERR_VALUE((long)(int)*num_sgs)) + return ERR_PTR(*num_sgs); len = iv_size; len += crypto_aead_alignmask(tfm) & ~(crypto_tfm_ctx_alignment() - 1); len = ALIGN(len, crypto_tfm_ctx_alignment()); len += req_size; len = ALIGN(len, __alignof__(struct scatterlist)); - len += *num_sgs * sizeof(**sgl); + len += array_size(*num_sgs, sizeof(struct scatterlist)); + *sensitive_size = len; - p = kmalloc(len, GFP_ATOMIC); + p = kvzalloc(len, GFP_NOFS); if (!p) - return NULL; + return ERR_PTR(-ENOMEM); *iv = (u8 *)PTR_ALIGN(p, crypto_aead_alignmask(tfm) + 1); *req = (struct aead_request *)PTR_ALIGN(*iv + iv_size, crypto_tfm_ctx_alignment()); - *sgl = (struct scatterlist *)PTR_ALIGN((u8 *)*req + req_size, - __alignof__(struct scatterlist)); + sgt->sgl = (struct scatterlist *)PTR_ALIGN((u8 *)*req + req_size, + __alignof__(struct scatterlist)); return p; } -static void *smb2_get_aead_req(struct crypto_aead *tfm, const struct smb_rqst *rqst, +static void *smb2_get_aead_req(struct crypto_aead *tfm, struct smb_rqst *rqst, int num_rqst, const u8 *sig, u8 **iv, - struct aead_request **req, struct scatterlist **sgl) + struct aead_request **req, struct scatterlist **sgl, + size_t *sensitive_size) { - unsigned int off, len, skip; - struct scatterlist *sg; - unsigned int num_sgs; - unsigned long addr; - int i, j; + struct sg_table sgtable = {}; + unsigned int skip, num_sgs, i, j; + ssize_t rc; void *p; - p = smb2_aead_req_alloc(tfm, rqst, num_rqst, sig, iv, req, sgl, &num_sgs); - if (!p) - return NULL; + p = smb2_aead_req_alloc(tfm, rqst, num_rqst, sig, iv, req, &sgtable, + &num_sgs, sensitive_size); + if (IS_ERR(p)) + return ERR_CAST(p); - sg_init_table(*sgl, num_sgs); - sg = *sgl; + sg_init_marker(sgtable.sgl, num_sgs); /* * The first rqst has a transform header where the @@ -4291,30 +4293,29 @@ static void *smb2_get_aead_req(struct crypto_aead *tfm, const struct smb_rqst *r */ skip = 20; - /* Assumes the first rqst has a transform header as the first iov. - * I.e. - * rqst[0].rq_iov[0] is transform header - * rqst[0].rq_iov[1+] data to be encrypted/decrypted - * rqst[1+].rq_iov[0+] data to be encrypted/decrypted - */ for (i = 0; i < num_rqst; i++) { - for (j = 0; j < rqst[i].rq_nvec; j++) { - struct kvec *iov = &rqst[i].rq_iov[j]; + struct iov_iter *iter = &rqst[i].rq_iter; + size_t count = iov_iter_count(iter); - addr = (unsigned long)iov->iov_base + skip; - len = iov->iov_len - skip; - sg = cifs_sg_set_buf(sg, (void *)addr, len); + for (j = 0; j < rqst[i].rq_nvec; j++) { + cifs_sg_set_buf(&sgtable, + rqst[i].rq_iov[j].iov_base + skip, + rqst[i].rq_iov[j].iov_len - skip); /* See the above comment on the 'skip' assignment */ skip = 0; } - for (j = 0; j < rqst[i].rq_npages; j++) { - rqst_page_get_length(&rqst[i], j, &len, &off); - sg_set_page(sg++, rqst[i].rq_pages[j], len, off); - } + sgtable.orig_nents = sgtable.nents; + + rc = netfs_extract_iter_to_sg(iter, count, &sgtable, + num_sgs - sgtable.nents, 0); + iov_iter_revert(iter, rc); + sgtable.orig_nents = sgtable.nents; } - cifs_sg_set_buf(sg, sig, SMB2_SIGNATURE_SIZE); + cifs_sg_set_buf(&sgtable, sig, SMB2_SIGNATURE_SIZE); + sg_mark_end(&sgtable.sgl[sgtable.nents - 1]); + *sgl = sgtable.sgl; return p; } @@ -4368,6 +4369,7 @@ crypt_message(struct TCP_Server_Info *server, int num_rqst, struct crypto_aead *tfm; unsigned int crypt_len = le32_to_cpu(tr_hdr->OriginalMessageSize); void *creq; + size_t sensitive_size; rc = smb2_get_enc_key(server, le64_to_cpu(tr_hdr->SessionId), enc, key); if (rc) { @@ -4401,9 +4403,10 @@ crypt_message(struct TCP_Server_Info *server, int num_rqst, return rc; } - creq = smb2_get_aead_req(tfm, rqst, num_rqst, sign, &iv, &req, &sg); - if (unlikely(!creq)) - return -ENOMEM; + creq = smb2_get_aead_req(tfm, rqst, num_rqst, sign, &iv, &req, &sg, + &sensitive_size); + if (IS_ERR(creq)) + return PTR_ERR(creq); if (!enc) { memcpy(sign, &tr_hdr->Signature, SMB2_SIGNATURE_SIZE); @@ -4431,22 +4434,35 @@ crypt_message(struct TCP_Server_Info *server, int num_rqst, if (!rc && enc) memcpy(&tr_hdr->Signature, sign, SMB2_SIGNATURE_SIZE); - kfree_sensitive(creq); + kvfree_sensitive(creq, sensitive_size); return rc; } +/* + * Clear a read buffer, discarding the folios which have XA_MARK_0 set. + */ +static void cifs_clear_xarray_buffer(struct xarray *buffer) +{ + struct folio *folio; + + XA_STATE(xas, buffer, 0); + + rcu_read_lock(); + xas_for_each_marked(&xas, folio, ULONG_MAX, XA_MARK_0) { + folio_put(folio); + } + rcu_read_unlock(); + xa_destroy(buffer); +} + void smb3_free_compound_rqst(int num_rqst, struct smb_rqst *rqst) { - int i, j; + int i; - for (i = 0; i < num_rqst; i++) { - if (rqst[i].rq_pages) { - for (j = rqst[i].rq_npages - 1; j >= 0; j--) - put_page(rqst[i].rq_pages[j]); - kfree(rqst[i].rq_pages); - } - } + for (i = 0; i < num_rqst; i++) + if (!xa_empty(&rqst[i].rq_buffer)) + cifs_clear_xarray_buffer(&rqst[i].rq_buffer); } /* @@ -4466,9 +4482,8 @@ static int smb3_init_transform_rq(struct TCP_Server_Info *server, int num_rqst, struct smb_rqst *new_rq, struct smb_rqst *old_rq) { - struct page **pages; struct smb2_transform_hdr *tr_hdr = new_rq[0].rq_iov[0].iov_base; - unsigned int npages; + struct page *page; unsigned int orig_len = 0; int i, j; int rc = -ENOMEM; @@ -4476,40 +4491,45 @@ smb3_init_transform_rq(struct TCP_Server_Info *server, int num_rqst, for (i = 1; i < num_rqst; i++) { struct smb_rqst *old = &old_rq[i - 1]; struct smb_rqst *new = &new_rq[i]; + struct xarray *buffer = &new->rq_buffer; + size_t size = iov_iter_count(&old->rq_iter), seg, copied = 0; orig_len += smb_rqst_len(server, old); new->rq_iov = old->rq_iov; new->rq_nvec = old->rq_nvec; - npages = old->rq_npages; - if (!npages) - continue; - - pages = kmalloc_array(npages, sizeof(struct page *), - GFP_KERNEL); - if (!pages) - goto err_free; - - new->rq_pages = pages; - new->rq_npages = npages; - new->rq_offset = old->rq_offset; - new->rq_pagesz = old->rq_pagesz; - new->rq_tailsz = old->rq_tailsz; - - for (j = 0; j < npages; j++) { - pages[j] = alloc_page(GFP_KERNEL|__GFP_HIGHMEM); - if (!pages[j]) - goto err_free; - } + xa_init(buffer); - /* copy pages form the old */ - for (j = 0; j < npages; j++) { - unsigned int offset, len; + if (size > 0) { + unsigned int npages = DIV_ROUND_UP(size, PAGE_SIZE); + + for (j = 0; j < npages; j++) { + void *o; + + rc = -ENOMEM; + page = alloc_page(GFP_KERNEL|__GFP_HIGHMEM); + if (!page) + goto err_free; + page->index = j; + o = xa_store(buffer, j, page, GFP_KERNEL); + if (xa_is_err(o)) { + rc = xa_err(o); + put_page(page); + goto err_free; + } - rqst_page_get_length(new, j, &len, &offset); + xa_set_mark(buffer, j, XA_MARK_0); - memcpy_page(new->rq_pages[j], offset, - old->rq_pages[j], offset, len); + seg = min_t(size_t, size - copied, PAGE_SIZE); + if (copy_page_from_iter(page, 0, seg, &old->rq_iter) != seg) { + rc = -EFAULT; + goto err_free; + } + copied += seg; + } + iov_iter_xarray(&new->rq_iter, ITER_SOURCE, + buffer, 0, size); + new->rq_iter_size = size; } } @@ -4538,12 +4558,12 @@ smb3_is_transform_hdr(void *buf) static int decrypt_raw_data(struct TCP_Server_Info *server, char *buf, - unsigned int buf_data_size, struct page **pages, - unsigned int npages, unsigned int page_data_size, + unsigned int buf_data_size, struct iov_iter *iter, bool is_offloaded) { struct kvec iov[2]; struct smb_rqst rqst = {NULL}; + size_t iter_size = 0; int rc; iov[0].iov_base = buf; @@ -4553,10 +4573,11 @@ decrypt_raw_data(struct TCP_Server_Info *server, char *buf, rqst.rq_iov = iov; rqst.rq_nvec = 2; - rqst.rq_pages = pages; - rqst.rq_npages = npages; - rqst.rq_pagesz = PAGE_SIZE; - rqst.rq_tailsz = (page_data_size % PAGE_SIZE) ? : PAGE_SIZE; + if (iter) { + rqst.rq_iter = *iter; + rqst.rq_iter_size = iov_iter_count(iter); + iter_size = iov_iter_count(iter); + } rc = crypt_message(server, 1, &rqst, 0); cifs_dbg(FYI, "Decrypt message returned %d\n", rc); @@ -4567,73 +4588,37 @@ decrypt_raw_data(struct TCP_Server_Info *server, char *buf, memmove(buf, iov[1].iov_base, buf_data_size); if (!is_offloaded) - server->total_read = buf_data_size + page_data_size; + server->total_read = buf_data_size + iter_size; return rc; } static int -read_data_into_pages(struct TCP_Server_Info *server, struct page **pages, - unsigned int npages, unsigned int len) +cifs_copy_pages_to_iter(struct xarray *pages, unsigned int data_size, + unsigned int skip, struct iov_iter *iter) { - int i; - int length; + struct page *page; + unsigned long index; - for (i = 0; i < npages; i++) { - struct page *page = pages[i]; - size_t n; + xa_for_each(pages, index, page) { + size_t n, len = min_t(unsigned int, PAGE_SIZE - skip, data_size); - n = len; - if (len >= PAGE_SIZE) { - /* enough data to fill the page */ - n = PAGE_SIZE; - len -= n; - } else { - zero_user(page, len, PAGE_SIZE - len); - len = 0; + n = copy_page_to_iter(page, skip, len, iter); + if (n != len) { + cifs_dbg(VFS, "%s: something went wrong\n", __func__); + return -EIO; } - length = cifs_read_page_from_socket(server, page, 0, n); - if (length < 0) - return length; - server->total_read += length; + data_size -= n; + skip = 0; } return 0; } static int -init_read_bvec(struct page **pages, unsigned int npages, unsigned int data_size, - unsigned int cur_off, struct bio_vec **page_vec) -{ - struct bio_vec *bvec; - int i; - - bvec = kcalloc(npages, sizeof(struct bio_vec), GFP_KERNEL); - if (!bvec) - return -ENOMEM; - - for (i = 0; i < npages; i++) { - bvec[i].bv_page = pages[i]; - bvec[i].bv_offset = (i == 0) ? cur_off : 0; - bvec[i].bv_len = min_t(unsigned int, PAGE_SIZE, data_size); - data_size -= bvec[i].bv_len; - } - - if (data_size != 0) { - cifs_dbg(VFS, "%s: something went wrong\n", __func__); - kfree(bvec); - return -EIO; - } - - *page_vec = bvec; - return 0; -} - -static int handle_read_data(struct TCP_Server_Info *server, struct mid_q_entry *mid, - char *buf, unsigned int buf_len, struct page **pages, - unsigned int npages, unsigned int page_data_size, - bool is_offloaded) + char *buf, unsigned int buf_len, struct xarray *pages, + unsigned int pages_len, bool is_offloaded) { unsigned int data_offset; unsigned int data_len; @@ -4642,9 +4627,6 @@ handle_read_data(struct TCP_Server_Info *server, struct mid_q_entry *mid, unsigned int pad_len; struct cifs_readdata *rdata = mid->callback_data; struct smb2_hdr *shdr = (struct smb2_hdr *)buf; - struct bio_vec *bvec = NULL; - struct iov_iter iter; - struct kvec iov; int length; bool use_rdma_mr = false; @@ -4733,7 +4715,7 @@ handle_read_data(struct TCP_Server_Info *server, struct mid_q_entry *mid, return 0; } - if (data_len > page_data_size - pad_len) { + if (data_len > pages_len - pad_len) { /* data_len is corrupt -- discard frame */ rdata->result = -EIO; if (is_offloaded) @@ -4743,8 +4725,9 @@ handle_read_data(struct TCP_Server_Info *server, struct mid_q_entry *mid, return 0; } - rdata->result = init_read_bvec(pages, npages, page_data_size, - cur_off, &bvec); + /* Copy the data to the output I/O iterator. */ + rdata->result = cifs_copy_pages_to_iter(pages, pages_len, + cur_off, &rdata->iter); if (rdata->result != 0) { if (is_offloaded) mid->mid_state = MID_RESPONSE_MALFORMED; @@ -4752,14 +4735,16 @@ handle_read_data(struct TCP_Server_Info *server, struct mid_q_entry *mid, dequeue_mid(mid, rdata->result); return 0; } + rdata->got_bytes = pages_len; - iov_iter_bvec(&iter, ITER_SOURCE, bvec, npages, data_len); } else if (buf_len >= data_offset + data_len) { /* read response payload is in buf */ - WARN_ONCE(npages > 0, "read data can be either in buf or in pages"); - iov.iov_base = buf + data_offset; - iov.iov_len = data_len; - iov_iter_kvec(&iter, ITER_SOURCE, &iov, 1, data_len); + WARN_ONCE(pages && !xa_empty(pages), + "read data can be either in buf or in pages"); + length = copy_to_iter(buf + data_offset, data_len, &rdata->iter); + if (length < 0) + return length; + rdata->got_bytes = data_len; } else { /* read response payload cannot be in both buf and pages */ WARN_ONCE(1, "buf can not contain only a part of read data"); @@ -4771,26 +4756,18 @@ handle_read_data(struct TCP_Server_Info *server, struct mid_q_entry *mid, return 0; } - length = rdata->copy_into_pages(server, rdata, &iter); - - kfree(bvec); - - if (length < 0) - return length; - if (is_offloaded) mid->mid_state = MID_RESPONSE_RECEIVED; else dequeue_mid(mid, false); - return length; + return 0; } struct smb2_decrypt_work { struct work_struct decrypt; struct TCP_Server_Info *server; - struct page **ppages; + struct xarray buffer; char *buf; - unsigned int npages; unsigned int len; }; @@ -4799,11 +4776,13 @@ static void smb2_decrypt_offload(struct work_struct *work) { struct smb2_decrypt_work *dw = container_of(work, struct smb2_decrypt_work, decrypt); - int i, rc; + int rc; struct mid_q_entry *mid; + struct iov_iter iter; + iov_iter_xarray(&iter, ITER_DEST, &dw->buffer, 0, dw->len); rc = decrypt_raw_data(dw->server, dw->buf, dw->server->vals->read_rsp_size, - dw->ppages, dw->npages, dw->len, true); + &iter, true); if (rc) { cifs_dbg(VFS, "error decrypting rc=%d\n", rc); goto free_pages; @@ -4817,7 +4796,7 @@ static void smb2_decrypt_offload(struct work_struct *work) mid->decrypted = true; rc = handle_read_data(dw->server, mid, dw->buf, dw->server->vals->read_rsp_size, - dw->ppages, dw->npages, dw->len, + &dw->buffer, dw->len, true); if (rc >= 0) { #ifdef CONFIG_CIFS_STATS2 @@ -4850,10 +4829,7 @@ static void smb2_decrypt_offload(struct work_struct *work) } free_pages: - for (i = dw->npages-1; i >= 0; i--) - put_page(dw->ppages[i]); - - kfree(dw->ppages); + cifs_clear_xarray_buffer(&dw->buffer); cifs_small_buf_release(dw->buf); kfree(dw); } @@ -4863,47 +4839,66 @@ static int receive_encrypted_read(struct TCP_Server_Info *server, struct mid_q_entry **mid, int *num_mids) { + struct page *page; char *buf = server->smallbuf; struct smb2_transform_hdr *tr_hdr = (struct smb2_transform_hdr *)buf; - unsigned int npages; - struct page **pages; - unsigned int len; + struct iov_iter iter; + unsigned int len, npages; unsigned int buflen = server->pdu_size; int rc; int i = 0; struct smb2_decrypt_work *dw; + dw = kzalloc(sizeof(struct smb2_decrypt_work), GFP_KERNEL); + if (!dw) + return -ENOMEM; + xa_init(&dw->buffer); + INIT_WORK(&dw->decrypt, smb2_decrypt_offload); + dw->server = server; + *num_mids = 1; len = min_t(unsigned int, buflen, server->vals->read_rsp_size + sizeof(struct smb2_transform_hdr)) - HEADER_SIZE(server) + 1; rc = cifs_read_from_socket(server, buf + HEADER_SIZE(server) - 1, len); if (rc < 0) - return rc; + goto free_dw; server->total_read += rc; len = le32_to_cpu(tr_hdr->OriginalMessageSize) - server->vals->read_rsp_size; + dw->len = len; npages = DIV_ROUND_UP(len, PAGE_SIZE); - pages = kmalloc_array(npages, sizeof(struct page *), GFP_KERNEL); - if (!pages) { - rc = -ENOMEM; - goto discard_data; - } - + rc = -ENOMEM; for (; i < npages; i++) { - pages[i] = alloc_page(GFP_KERNEL|__GFP_HIGHMEM); - if (!pages[i]) { - rc = -ENOMEM; + void *old; + + page = alloc_page(GFP_KERNEL|__GFP_HIGHMEM); + if (!page) + goto discard_data; + page->index = i; + old = xa_store(&dw->buffer, i, page, GFP_KERNEL); + if (xa_is_err(old)) { + rc = xa_err(old); + put_page(page); goto discard_data; } + xa_set_mark(&dw->buffer, i, XA_MARK_0); } - /* read read data into pages */ - rc = read_data_into_pages(server, pages, npages, len); - if (rc) - goto free_pages; + iov_iter_xarray(&iter, ITER_DEST, &dw->buffer, 0, npages * PAGE_SIZE); + + /* Read the data into the buffer and clear excess bufferage. */ + rc = cifs_read_iter_from_socket(server, &iter, dw->len); + if (rc < 0) + goto discard_data; + + server->total_read += rc; + if (rc < npages * PAGE_SIZE) + iov_iter_zero(npages * PAGE_SIZE - rc, &iter); + iov_iter_revert(&iter, npages * PAGE_SIZE); + iov_iter_truncate(&iter, dw->len); rc = cifs_discard_remaining_data(server); if (rc) @@ -4916,39 +4911,28 @@ receive_encrypted_read(struct TCP_Server_Info *server, struct mid_q_entry **mid, if ((server->min_offload) && (server->in_flight > 1) && (server->pdu_size >= server->min_offload)) { - dw = kmalloc(sizeof(struct smb2_decrypt_work), GFP_KERNEL); - if (dw == NULL) - goto non_offloaded_decrypt; - dw->buf = server->smallbuf; server->smallbuf = (char *)cifs_small_buf_get(); - INIT_WORK(&dw->decrypt, smb2_decrypt_offload); - - dw->npages = npages; - dw->server = server; - dw->ppages = pages; - dw->len = len; queue_work(decrypt_wq, &dw->decrypt); *num_mids = 0; /* worker thread takes care of finding mid */ return -1; } -non_offloaded_decrypt: rc = decrypt_raw_data(server, buf, server->vals->read_rsp_size, - pages, npages, len, false); + &iter, false); if (rc) goto free_pages; *mid = smb2_find_mid(server, buf); - if (*mid == NULL) + if (*mid == NULL) { cifs_dbg(FYI, "mid not found\n"); - else { + } else { cifs_dbg(FYI, "mid found\n"); (*mid)->decrypted = true; rc = handle_read_data(server, *mid, buf, server->vals->read_rsp_size, - pages, npages, len, false); + &dw->buffer, dw->len, false); if (rc >= 0) { if (server->ops->is_network_name_deleted) { server->ops->is_network_name_deleted(buf, @@ -4958,9 +4942,9 @@ non_offloaded_decrypt: } free_pages: - for (i = i - 1; i >= 0; i--) - put_page(pages[i]); - kfree(pages); + cifs_clear_xarray_buffer(&dw->buffer); +free_dw: + kfree(dw); return rc; discard_data: cifs_discard_remaining_data(server); @@ -4998,7 +4982,7 @@ receive_encrypted_standard(struct TCP_Server_Info *server, server->total_read += length; buf_size = pdu_length - sizeof(struct smb2_transform_hdr); - length = decrypt_raw_data(server, buf, buf_size, NULL, 0, 0, false); + length = decrypt_raw_data(server, buf, buf_size, NULL, false); if (length) return length; @@ -5097,7 +5081,7 @@ smb3_handle_read_data(struct TCP_Server_Info *server, struct mid_q_entry *mid) char *buf = server->large_buf ? server->bigbuf : server->smallbuf; return handle_read_data(server, mid, buf, server->pdu_size, - NULL, 0, 0, false); + NULL, 0, false); } static int |