diff options
-rw-r--r-- | drivers/vfio/pci/vfio_pci_core.c | 1 | ||||
-rw-r--r-- | drivers/vfio/pci/vfio_pci_intrs.c | 91 | ||||
-rw-r--r-- | include/linux/vfio_pci_core.h | 2 |
3 files changed, 48 insertions, 46 deletions
diff --git a/drivers/vfio/pci/vfio_pci_core.c b/drivers/vfio/pci/vfio_pci_core.c index a5ab416cf476..ae0e161c7fc9 100644 --- a/drivers/vfio/pci/vfio_pci_core.c +++ b/drivers/vfio/pci/vfio_pci_core.c @@ -2102,6 +2102,7 @@ int vfio_pci_core_init_dev(struct vfio_device *core_vdev) INIT_LIST_HEAD(&vdev->vma_list); INIT_LIST_HEAD(&vdev->sriov_pfs_item); init_rwsem(&vdev->memory_lock); + xa_init(&vdev->ctx); return 0; } diff --git a/drivers/vfio/pci/vfio_pci_intrs.c b/drivers/vfio/pci/vfio_pci_intrs.c index 96396e1ad085..77957274027c 100644 --- a/drivers/vfio/pci/vfio_pci_intrs.c +++ b/drivers/vfio/pci/vfio_pci_intrs.c @@ -52,25 +52,33 @@ static struct vfio_pci_irq_ctx *vfio_irq_ctx_get(struct vfio_pci_core_device *vdev, unsigned long index) { - if (index >= vdev->num_ctx) - return NULL; - return &vdev->ctx[index]; + return xa_load(&vdev->ctx, index); } -static void vfio_irq_ctx_free_all(struct vfio_pci_core_device *vdev) +static void vfio_irq_ctx_free(struct vfio_pci_core_device *vdev, + struct vfio_pci_irq_ctx *ctx, unsigned long index) { - kfree(vdev->ctx); + xa_erase(&vdev->ctx, index); + kfree(ctx); } -static int vfio_irq_ctx_alloc_num(struct vfio_pci_core_device *vdev, - unsigned long num) +static struct vfio_pci_irq_ctx * +vfio_irq_ctx_alloc(struct vfio_pci_core_device *vdev, unsigned long index) { - vdev->ctx = kcalloc(num, sizeof(struct vfio_pci_irq_ctx), - GFP_KERNEL_ACCOUNT); - if (!vdev->ctx) - return -ENOMEM; + struct vfio_pci_irq_ctx *ctx; + int ret; - return 0; + ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT); + if (!ctx) + return NULL; + + ret = xa_insert(&vdev->ctx, index, ctx, GFP_KERNEL_ACCOUNT); + if (ret) { + kfree(ctx); + return NULL; + } + + return ctx; } /* @@ -226,7 +234,6 @@ static irqreturn_t vfio_intx_handler(int irq, void *dev_id) static int vfio_intx_enable(struct vfio_pci_core_device *vdev) { struct vfio_pci_irq_ctx *ctx; - int ret; if (!is_irq_none(vdev)) return -EINVAL; @@ -234,15 +241,9 @@ static int vfio_intx_enable(struct vfio_pci_core_device *vdev) if (!vdev->pdev->irq) return -ENODEV; - ret = vfio_irq_ctx_alloc_num(vdev, 1); - if (ret) - return ret; - - ctx = vfio_irq_ctx_get(vdev, 0); - if (!ctx) { - vfio_irq_ctx_free_all(vdev); - return -EINVAL; - } + ctx = vfio_irq_ctx_alloc(vdev, 0); + if (!ctx) + return -ENOMEM; vdev->num_ctx = 1; @@ -334,7 +335,7 @@ static void vfio_intx_disable(struct vfio_pci_core_device *vdev) vfio_intx_set_signal(vdev, -1); vdev->irq_type = VFIO_PCI_NUM_IRQS; vdev->num_ctx = 0; - vfio_irq_ctx_free_all(vdev); + vfio_irq_ctx_free(vdev, ctx, 0); } /* @@ -358,10 +359,6 @@ static int vfio_msi_enable(struct vfio_pci_core_device *vdev, int nvec, bool msi if (!is_irq_none(vdev)) return -EINVAL; - ret = vfio_irq_ctx_alloc_num(vdev, nvec); - if (ret) - return ret; - /* return the number of supported vectors if we can't get all: */ cmd = vfio_pci_memory_lock_and_enable(vdev); ret = pci_alloc_irq_vectors(pdev, 1, nvec, flag); @@ -369,7 +366,6 @@ static int vfio_msi_enable(struct vfio_pci_core_device *vdev, int nvec, bool msi if (ret > 0) pci_free_irq_vectors(pdev); vfio_pci_memory_unlock_and_restore(vdev, cmd); - vfio_irq_ctx_free_all(vdev); return ret; } vfio_pci_memory_unlock_and_restore(vdev, cmd); @@ -401,12 +397,13 @@ static int vfio_msi_set_vector_signal(struct vfio_pci_core_device *vdev, if (vector >= vdev->num_ctx) return -EINVAL; - ctx = vfio_irq_ctx_get(vdev, vector); - if (!ctx) - return -EINVAL; irq = pci_irq_vector(pdev, vector); + if (irq < 0) + return -EINVAL; - if (ctx->trigger) { + ctx = vfio_irq_ctx_get(vdev, vector); + + if (ctx) { irq_bypass_unregister_producer(&ctx->producer); cmd = vfio_pci_memory_lock_and_enable(vdev); @@ -414,16 +411,22 @@ static int vfio_msi_set_vector_signal(struct vfio_pci_core_device *vdev, vfio_pci_memory_unlock_and_restore(vdev, cmd); kfree(ctx->name); eventfd_ctx_put(ctx->trigger); - ctx->trigger = NULL; + vfio_irq_ctx_free(vdev, ctx, vector); } if (fd < 0) return 0; + ctx = vfio_irq_ctx_alloc(vdev, vector); + if (!ctx) + return -ENOMEM; + ctx->name = kasprintf(GFP_KERNEL_ACCOUNT, "vfio-msi%s[%d](%s)", msix ? "x" : "", vector, pci_name(pdev)); - if (!ctx->name) - return -ENOMEM; + if (!ctx->name) { + ret = -ENOMEM; + goto out_free_ctx; + } trigger = eventfd_ctx_fdget(fd); if (IS_ERR(trigger)) { @@ -469,6 +472,8 @@ out_put_eventfd_ctx: eventfd_ctx_put(trigger); out_free_name: kfree(ctx->name); +out_free_ctx: + vfio_irq_ctx_free(vdev, ctx, vector); return ret; } @@ -498,16 +503,13 @@ static void vfio_msi_disable(struct vfio_pci_core_device *vdev, bool msix) { struct pci_dev *pdev = vdev->pdev; struct vfio_pci_irq_ctx *ctx; - unsigned int i; + unsigned long i; u16 cmd; - for (i = 0; i < vdev->num_ctx; i++) { - ctx = vfio_irq_ctx_get(vdev, i); - if (ctx) { - vfio_virqfd_disable(&ctx->unmask); - vfio_virqfd_disable(&ctx->mask); - vfio_msi_set_vector_signal(vdev, i, -1, msix); - } + xa_for_each(&vdev->ctx, i, ctx) { + vfio_virqfd_disable(&ctx->unmask); + vfio_virqfd_disable(&ctx->mask); + vfio_msi_set_vector_signal(vdev, i, -1, msix); } cmd = vfio_pci_memory_lock_and_enable(vdev); @@ -523,7 +525,6 @@ static void vfio_msi_disable(struct vfio_pci_core_device *vdev, bool msix) vdev->irq_type = VFIO_PCI_NUM_IRQS; vdev->num_ctx = 0; - vfio_irq_ctx_free_all(vdev); } /* @@ -663,7 +664,7 @@ static int vfio_pci_set_msi_trigger(struct vfio_pci_core_device *vdev, for (i = start; i < start + count; i++) { ctx = vfio_irq_ctx_get(vdev, i); - if (!ctx || !ctx->trigger) + if (!ctx) continue; if (flags & VFIO_IRQ_SET_DATA_NONE) { eventfd_signal(ctx->trigger, 1); diff --git a/include/linux/vfio_pci_core.h b/include/linux/vfio_pci_core.h index 367fd79226a3..61d7873a3973 100644 --- a/include/linux/vfio_pci_core.h +++ b/include/linux/vfio_pci_core.h @@ -59,7 +59,7 @@ struct vfio_pci_core_device { struct perm_bits *msi_perm; spinlock_t irqlock; struct mutex igate; - struct vfio_pci_irq_ctx *ctx; + struct xarray ctx; int num_ctx; int irq_type; int num_regions; |