summaryrefslogtreecommitdiffstats
path: root/drivers/vhost
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vhost')
-rw-r--r--drivers/vhost/scsi.c70
-rw-r--r--drivers/vhost/vdpa.c6
-rw-r--r--drivers/vhost/vhost.c126
-rw-r--r--drivers/vhost/vhost.h3
4 files changed, 140 insertions, 65 deletions
diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c
index 282aac45c690..006ffacf1c56 100644
--- a/drivers/vhost/scsi.c
+++ b/drivers/vhost/scsi.c
@@ -210,6 +210,7 @@ struct vhost_scsi {
struct vhost_scsi_tmf {
struct vhost_work vwork;
+ struct work_struct flush_work;
struct vhost_scsi *vhost;
struct vhost_scsi_virtqueue *svq;
@@ -358,14 +359,23 @@ static void vhost_scsi_release_tmf_res(struct vhost_scsi_tmf *tmf)
vhost_scsi_put_inflight(inflight);
}
+static void vhost_scsi_drop_cmds(struct vhost_scsi_virtqueue *svq)
+{
+ struct vhost_scsi_cmd *cmd, *t;
+ struct llist_node *llnode;
+
+ llnode = llist_del_all(&svq->completion_list);
+ llist_for_each_entry_safe(cmd, t, llnode, tvc_completion_list)
+ vhost_scsi_release_cmd_res(&cmd->tvc_se_cmd);
+}
+
static void vhost_scsi_release_cmd(struct se_cmd *se_cmd)
{
if (se_cmd->se_cmd_flags & SCF_SCSI_TMR_CDB) {
struct vhost_scsi_tmf *tmf = container_of(se_cmd,
struct vhost_scsi_tmf, se_cmd);
- struct vhost_virtqueue *vq = &tmf->svq->vq;
- vhost_vq_work_queue(vq, &tmf->vwork);
+ schedule_work(&tmf->flush_work);
} else {
struct vhost_scsi_cmd *cmd = container_of(se_cmd,
struct vhost_scsi_cmd, tvc_se_cmd);
@@ -373,7 +383,8 @@ static void vhost_scsi_release_cmd(struct se_cmd *se_cmd)
struct vhost_scsi_virtqueue, vq);
llist_add(&cmd->tvc_completion_list, &svq->completion_list);
- vhost_vq_work_queue(&svq->vq, &svq->completion_work);
+ if (!vhost_vq_work_queue(&svq->vq, &svq->completion_work))
+ vhost_scsi_drop_cmds(svq);
}
}
@@ -497,10 +508,8 @@ again:
vq_err(vq, "Faulted on vhost_scsi_send_event\n");
}
-static void vhost_scsi_evt_work(struct vhost_work *work)
+static void vhost_scsi_complete_events(struct vhost_scsi *vs, bool drop)
{
- struct vhost_scsi *vs = container_of(work, struct vhost_scsi,
- vs_event_work);
struct vhost_virtqueue *vq = &vs->vqs[VHOST_SCSI_VQ_EVT].vq;
struct vhost_scsi_evt *evt, *t;
struct llist_node *llnode;
@@ -508,12 +517,20 @@ static void vhost_scsi_evt_work(struct vhost_work *work)
mutex_lock(&vq->mutex);
llnode = llist_del_all(&vs->vs_event_list);
llist_for_each_entry_safe(evt, t, llnode, list) {
- vhost_scsi_do_evt_work(vs, evt);
+ if (!drop)
+ vhost_scsi_do_evt_work(vs, evt);
vhost_scsi_free_evt(vs, evt);
}
mutex_unlock(&vq->mutex);
}
+static void vhost_scsi_evt_work(struct vhost_work *work)
+{
+ struct vhost_scsi *vs = container_of(work, struct vhost_scsi,
+ vs_event_work);
+ vhost_scsi_complete_events(vs, false);
+}
+
static int vhost_scsi_copy_sgl_to_iov(struct vhost_scsi_cmd *cmd)
{
struct iov_iter *iter = &cmd->saved_iter;
@@ -1270,33 +1287,32 @@ static void vhost_scsi_tmf_resp_work(struct vhost_work *work)
{
struct vhost_scsi_tmf *tmf = container_of(work, struct vhost_scsi_tmf,
vwork);
- struct vhost_virtqueue *ctl_vq, *vq;
- int resp_code, i;
-
- if (tmf->scsi_resp == TMR_FUNCTION_COMPLETE) {
- /*
- * Flush IO vqs that don't share a worker with the ctl to make
- * sure they have sent their responses before us.
- */
- ctl_vq = &tmf->vhost->vqs[VHOST_SCSI_VQ_CTL].vq;
- for (i = VHOST_SCSI_VQ_IO; i < tmf->vhost->dev.nvqs; i++) {
- vq = &tmf->vhost->vqs[i].vq;
-
- if (vhost_vq_is_setup(vq) &&
- vq->worker != ctl_vq->worker)
- vhost_vq_flush(vq);
- }
+ int resp_code;
+ if (tmf->scsi_resp == TMR_FUNCTION_COMPLETE)
resp_code = VIRTIO_SCSI_S_FUNCTION_SUCCEEDED;
- } else {
+ else
resp_code = VIRTIO_SCSI_S_FUNCTION_REJECTED;
- }
vhost_scsi_send_tmf_resp(tmf->vhost, &tmf->svq->vq, tmf->in_iovs,
tmf->vq_desc, &tmf->resp_iov, resp_code);
vhost_scsi_release_tmf_res(tmf);
}
+static void vhost_scsi_tmf_flush_work(struct work_struct *work)
+{
+ struct vhost_scsi_tmf *tmf = container_of(work, struct vhost_scsi_tmf,
+ flush_work);
+ struct vhost_virtqueue *vq = &tmf->svq->vq;
+ /*
+ * Make sure we have sent responses for other commands before we
+ * send our response.
+ */
+ vhost_dev_flush(vq->dev);
+ if (!vhost_vq_work_queue(vq, &tmf->vwork))
+ vhost_scsi_release_tmf_res(tmf);
+}
+
static void
vhost_scsi_handle_tmf(struct vhost_scsi *vs, struct vhost_scsi_tpg *tpg,
struct vhost_virtqueue *vq,
@@ -1320,6 +1336,7 @@ vhost_scsi_handle_tmf(struct vhost_scsi *vs, struct vhost_scsi_tpg *tpg,
if (!tmf)
goto send_reject;
+ INIT_WORK(&tmf->flush_work, vhost_scsi_tmf_flush_work);
vhost_work_init(&tmf->vwork, vhost_scsi_tmf_resp_work);
tmf->vhost = vs;
tmf->svq = svq;
@@ -1509,7 +1526,8 @@ vhost_scsi_send_evt(struct vhost_scsi *vs, struct vhost_virtqueue *vq,
}
llist_add(&evt->list, &vs->vs_event_list);
- vhost_vq_work_queue(vq, &vs->vs_event_work);
+ if (!vhost_vq_work_queue(vq, &vs->vs_event_work))
+ vhost_scsi_complete_events(vs, true);
}
static void vhost_scsi_evt_handle_kick(struct vhost_work *work)
diff --git a/drivers/vhost/vdpa.c b/drivers/vhost/vdpa.c
index ba52d128aeb7..63a53680a85c 100644
--- a/drivers/vhost/vdpa.c
+++ b/drivers/vhost/vdpa.c
@@ -1548,7 +1548,7 @@ static void vhost_vdpa_release_dev(struct device *device)
struct vhost_vdpa *v =
container_of(device, struct vhost_vdpa, dev);
- ida_simple_remove(&vhost_vdpa_ida, v->minor);
+ ida_free(&vhost_vdpa_ida, v->minor);
kfree(v->vqs);
kfree(v);
}
@@ -1571,8 +1571,8 @@ static int vhost_vdpa_probe(struct vdpa_device *vdpa)
if (!v)
return -ENOMEM;
- minor = ida_simple_get(&vhost_vdpa_ida, 0,
- VHOST_VDPA_DEV_MAX, GFP_KERNEL);
+ minor = ida_alloc_max(&vhost_vdpa_ida, VHOST_VDPA_DEV_MAX - 1,
+ GFP_KERNEL);
if (minor < 0) {
kfree(v);
return minor;
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index 8995730ce0bf..b60955682474 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -263,34 +263,37 @@ bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work)
}
EXPORT_SYMBOL_GPL(vhost_vq_work_queue);
-void vhost_vq_flush(struct vhost_virtqueue *vq)
-{
- struct vhost_flush_struct flush;
-
- init_completion(&flush.wait_event);
- vhost_work_init(&flush.work, vhost_flush_work);
-
- if (vhost_vq_work_queue(vq, &flush.work))
- wait_for_completion(&flush.wait_event);
-}
-EXPORT_SYMBOL_GPL(vhost_vq_flush);
-
/**
- * vhost_worker_flush - flush a worker
+ * __vhost_worker_flush - flush a worker
* @worker: worker to flush
*
- * This does not use RCU to protect the worker, so the device or worker
- * mutex must be held.
+ * The worker's flush_mutex must be held.
*/
-static void vhost_worker_flush(struct vhost_worker *worker)
+static void __vhost_worker_flush(struct vhost_worker *worker)
{
struct vhost_flush_struct flush;
+ if (!worker->attachment_cnt || worker->killed)
+ return;
+
init_completion(&flush.wait_event);
vhost_work_init(&flush.work, vhost_flush_work);
vhost_worker_queue(worker, &flush.work);
+ /*
+ * Drop mutex in case our worker is killed and it needs to take the
+ * mutex to force cleanup.
+ */
+ mutex_unlock(&worker->mutex);
wait_for_completion(&flush.wait_event);
+ mutex_lock(&worker->mutex);
+}
+
+static void vhost_worker_flush(struct vhost_worker *worker)
+{
+ mutex_lock(&worker->mutex);
+ __vhost_worker_flush(worker);
+ mutex_unlock(&worker->mutex);
}
void vhost_dev_flush(struct vhost_dev *dev)
@@ -298,15 +301,8 @@ void vhost_dev_flush(struct vhost_dev *dev)
struct vhost_worker *worker;
unsigned long i;
- xa_for_each(&dev->worker_xa, i, worker) {
- mutex_lock(&worker->mutex);
- if (!worker->attachment_cnt) {
- mutex_unlock(&worker->mutex);
- continue;
- }
+ xa_for_each(&dev->worker_xa, i, worker)
vhost_worker_flush(worker);
- mutex_unlock(&worker->mutex);
- }
}
EXPORT_SYMBOL_GPL(vhost_dev_flush);
@@ -392,7 +388,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
__vhost_vq_meta_reset(vq);
}
-static bool vhost_worker(void *data)
+static bool vhost_run_work_list(void *data)
{
struct vhost_worker *worker = data;
struct vhost_work *work, *work_next;
@@ -417,6 +413,40 @@ static bool vhost_worker(void *data)
return !!node;
}
+static void vhost_worker_killed(void *data)
+{
+ struct vhost_worker *worker = data;
+ struct vhost_dev *dev = worker->dev;
+ struct vhost_virtqueue *vq;
+ int i, attach_cnt = 0;
+
+ mutex_lock(&worker->mutex);
+ worker->killed = true;
+
+ for (i = 0; i < dev->nvqs; i++) {
+ vq = dev->vqs[i];
+
+ mutex_lock(&vq->mutex);
+ if (worker ==
+ rcu_dereference_check(vq->worker,
+ lockdep_is_held(&vq->mutex))) {
+ rcu_assign_pointer(vq->worker, NULL);
+ attach_cnt++;
+ }
+ mutex_unlock(&vq->mutex);
+ }
+
+ worker->attachment_cnt -= attach_cnt;
+ if (attach_cnt)
+ synchronize_rcu();
+ /*
+ * Finish vhost_worker_flush calls and any other works that snuck in
+ * before the synchronize_rcu.
+ */
+ vhost_run_work_list(worker);
+ mutex_unlock(&worker->mutex);
+}
+
static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
{
kfree(vq->indirect);
@@ -631,9 +661,11 @@ static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
if (!worker)
return NULL;
+ worker->dev = dev;
snprintf(name, sizeof(name), "vhost-%d", current->pid);
- vtsk = vhost_task_create(vhost_worker, worker, name);
+ vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed,
+ worker, name);
if (!vtsk)
goto free_worker;
@@ -664,22 +696,37 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
{
struct vhost_worker *old_worker;
- old_worker = rcu_dereference_check(vq->worker,
- lockdep_is_held(&vq->dev->mutex));
-
mutex_lock(&worker->mutex);
- worker->attachment_cnt++;
- mutex_unlock(&worker->mutex);
+ if (worker->killed) {
+ mutex_unlock(&worker->mutex);
+ return;
+ }
+
+ mutex_lock(&vq->mutex);
+
+ old_worker = rcu_dereference_check(vq->worker,
+ lockdep_is_held(&vq->mutex));
rcu_assign_pointer(vq->worker, worker);
+ worker->attachment_cnt++;
- if (!old_worker)
+ if (!old_worker) {
+ mutex_unlock(&vq->mutex);
+ mutex_unlock(&worker->mutex);
return;
+ }
+ mutex_unlock(&vq->mutex);
+ mutex_unlock(&worker->mutex);
+
/*
* Take the worker mutex to make sure we see the work queued from
* device wide flushes which doesn't use RCU for execution.
*/
mutex_lock(&old_worker->mutex);
- old_worker->attachment_cnt--;
+ if (old_worker->killed) {
+ mutex_unlock(&old_worker->mutex);
+ return;
+ }
+
/*
* We don't want to call synchronize_rcu for every vq during setup
* because it will slow down VM startup. If we haven't done
@@ -690,6 +737,8 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
mutex_lock(&vq->mutex);
if (!vhost_vq_get_backend(vq) && !vq->kick) {
mutex_unlock(&vq->mutex);
+
+ old_worker->attachment_cnt--;
mutex_unlock(&old_worker->mutex);
/*
* vsock can queue anytime after VHOST_VSOCK_SET_GUEST_CID.
@@ -705,7 +754,8 @@ static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
/* Make sure new vq queue/flush/poll calls see the new worker */
synchronize_rcu();
/* Make sure whatever was queued gets run */
- vhost_worker_flush(old_worker);
+ __vhost_worker_flush(old_worker);
+ old_worker->attachment_cnt--;
mutex_unlock(&old_worker->mutex);
}
@@ -754,10 +804,16 @@ static int vhost_free_worker(struct vhost_dev *dev,
return -ENODEV;
mutex_lock(&worker->mutex);
- if (worker->attachment_cnt) {
+ if (worker->attachment_cnt || worker->killed) {
mutex_unlock(&worker->mutex);
return -EBUSY;
}
+ /*
+ * A flush might have raced and snuck in before attachment_cnt was set
+ * to zero. Make sure flushes are flushed from the queue before
+ * freeing.
+ */
+ __vhost_worker_flush(worker);
mutex_unlock(&worker->mutex);
vhost_worker_destroy(dev, worker);
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 9e942fcda5c3..bb75a292d50c 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -28,12 +28,14 @@ struct vhost_work {
struct vhost_worker {
struct vhost_task *vtsk;
+ struct vhost_dev *dev;
/* Used to serialize device wide flushing with worker swapping. */
struct mutex mutex;
struct llist_head work_list;
u64 kcov_handle;
u32 id;
int attachment_cnt;
+ bool killed;
};
/* Poll a file (eventfd or socket) */
@@ -205,7 +207,6 @@ int vhost_get_vq_desc(struct vhost_virtqueue *,
struct vhost_log *log, unsigned int *log_num);
void vhost_discard_vq_desc(struct vhost_virtqueue *, int n);
-void vhost_vq_flush(struct vhost_virtqueue *vq);
bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work);
bool vhost_vq_has_work(struct vhost_virtqueue *vq);
bool vhost_vq_is_setup(struct vhost_virtqueue *vq);