diff options
-rw-r--r-- | drivers/vfio/vfio_iommu_type1.c | 190 |
1 files changed, 96 insertions, 94 deletions
diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c index 6f3fbc48a6c7..0e863b3ddcab 100644 --- a/drivers/vfio/vfio_iommu_type1.c +++ b/drivers/vfio/vfio_iommu_type1.c @@ -31,6 +31,7 @@ #include <linux/module.h> #include <linux/mm.h> #include <linux/pci.h> /* pci_bus_type */ +#include <linux/rbtree.h> #include <linux/sched.h> #include <linux/slab.h> #include <linux/uaccess.h> @@ -50,13 +51,13 @@ MODULE_PARM_DESC(allow_unsafe_interrupts, struct vfio_iommu { struct iommu_domain *domain; struct mutex lock; - struct list_head dma_list; + struct rb_root dma_list; struct list_head group_list; bool cache; }; struct vfio_dma { - struct list_head next; + struct rb_node node; dma_addr_t iova; /* Device address */ unsigned long vaddr; /* Process virtual addr */ long npage; /* Number of pages */ @@ -75,6 +76,49 @@ struct vfio_group { #define NPAGE_TO_SIZE(npage) ((size_t)(npage) << PAGE_SHIFT) +static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu, + dma_addr_t start, size_t size) +{ + struct rb_node *node = iommu->dma_list.rb_node; + + while (node) { + struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node); + + if (start + size <= dma->iova) + node = node->rb_left; + else if (start >= dma->iova + NPAGE_TO_SIZE(dma->npage)) + node = node->rb_right; + else + return dma; + } + + return NULL; +} + +static void vfio_insert_dma(struct vfio_iommu *iommu, struct vfio_dma *new) +{ + struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL; + struct vfio_dma *dma; + + while (*link) { + parent = *link; + dma = rb_entry(parent, struct vfio_dma, node); + + if (new->iova + NPAGE_TO_SIZE(new->npage) <= dma->iova) + link = &(*link)->rb_left; + else + link = &(*link)->rb_right; + } + + rb_link_node(&new->node, parent, link); + rb_insert_color(&new->node, &iommu->dma_list); +} + +static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *old) +{ + rb_erase(&old->node, &iommu->dma_list); +} + struct vwork { struct mm_struct *mm; long npage; @@ -289,31 +333,8 @@ static int __vfio_dma_map(struct vfio_iommu *iommu, dma_addr_t iova, return 0; } -static inline bool ranges_overlap(dma_addr_t start1, size_t size1, - dma_addr_t start2, size_t size2) -{ - if (start1 < start2) - return (start2 - start1 < size1); - else if (start2 < start1) - return (start1 - start2 < size2); - return (size1 > 0 && size2 > 0); -} - -static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu, - dma_addr_t start, size_t size) -{ - struct vfio_dma *dma; - - list_for_each_entry(dma, &iommu->dma_list, next) { - if (ranges_overlap(dma->iova, NPAGE_TO_SIZE(dma->npage), - start, size)) - return dma; - } - return NULL; -} - -static long vfio_remove_dma_overlap(struct vfio_iommu *iommu, dma_addr_t start, - size_t size, struct vfio_dma *dma) +static int vfio_remove_dma_overlap(struct vfio_iommu *iommu, dma_addr_t start, + size_t size, struct vfio_dma *dma) { struct vfio_dma *split; long npage_lo, npage_hi; @@ -322,10 +343,9 @@ static long vfio_remove_dma_overlap(struct vfio_iommu *iommu, dma_addr_t start, if (start <= dma->iova && start + size >= dma->iova + NPAGE_TO_SIZE(dma->npage)) { vfio_dma_unmap(iommu, dma->iova, dma->npage, dma->prot); - list_del(&dma->next); - npage_lo = dma->npage; + vfio_remove_dma(iommu, dma); kfree(dma); - return npage_lo; + return 0; } /* Overlap low address of existing range */ @@ -339,7 +359,7 @@ static long vfio_remove_dma_overlap(struct vfio_iommu *iommu, dma_addr_t start, dma->iova += overlap; dma->vaddr += overlap; dma->npage -= npage_lo; - return npage_lo; + return 0; } /* Overlap high address of existing range */ @@ -351,7 +371,7 @@ static long vfio_remove_dma_overlap(struct vfio_iommu *iommu, dma_addr_t start, vfio_dma_unmap(iommu, start, npage_hi, dma->prot); dma->npage -= npage_hi; - return npage_hi; + return 0; } /* Split existing */ @@ -370,16 +390,16 @@ static long vfio_remove_dma_overlap(struct vfio_iommu *iommu, dma_addr_t start, split->iova = start + size; split->vaddr = dma->vaddr + NPAGE_TO_SIZE(npage_lo) + size; split->prot = dma->prot; - list_add(&split->next, &iommu->dma_list); - return size >> PAGE_SHIFT; + vfio_insert_dma(iommu, split); + return 0; } static int vfio_dma_do_unmap(struct vfio_iommu *iommu, struct vfio_iommu_type1_dma_unmap *unmap) { - long ret = 0, npage = unmap->size >> PAGE_SHIFT; - struct vfio_dma *dma, *tmp; uint64_t mask; + struct vfio_dma *dma; + int ret = 0; mask = ((uint64_t)1 << __ffs(iommu->domain->ops->pgsize_bitmap)) - 1; @@ -393,25 +413,19 @@ static int vfio_dma_do_unmap(struct vfio_iommu *iommu, mutex_lock(&iommu->lock); - list_for_each_entry_safe(dma, tmp, &iommu->dma_list, next) { - if (ranges_overlap(dma->iova, NPAGE_TO_SIZE(dma->npage), - unmap->iova, unmap->size)) { - ret = vfio_remove_dma_overlap(iommu, unmap->iova, - unmap->size, dma); - if (ret > 0) - npage -= ret; - if (ret < 0 || npage == 0) - break; - } - } + while (!ret && (dma = vfio_find_dma(iommu, + unmap->iova, unmap->size))) + ret = vfio_remove_dma_overlap(iommu, unmap->iova, + unmap->size, dma); + mutex_unlock(&iommu->lock); - return ret > 0 ? 0 : (int)ret; + return ret; } static int vfio_dma_do_map(struct vfio_iommu *iommu, struct vfio_iommu_type1_dma_map *map) { - struct vfio_dma *dma, *pdma = NULL; + struct vfio_dma *dma; dma_addr_t iova = map->iova; unsigned long locked, lock_limit, vaddr = map->vaddr; size_t size = map->size; @@ -452,6 +466,10 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu, if (!npage) return -EINVAL; + dma = kzalloc(sizeof *dma, GFP_KERNEL); + if (!dma) + return -ENOMEM; + mutex_lock(&iommu->lock); if (vfio_find_dma(iommu, iova, size)) { @@ -473,62 +491,45 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu, if (ret) goto out_lock; + dma->npage = npage; + dma->iova = iova; + dma->vaddr = vaddr; + dma->prot = prot; + /* Check if we abut a region below - nothing below 0 */ if (iova) { - dma = vfio_find_dma(iommu, iova - 1, 1); - if (dma && dma->prot == prot && - dma->vaddr + NPAGE_TO_SIZE(dma->npage) == vaddr) { - - dma->npage += npage; - iova = dma->iova; - vaddr = dma->vaddr; + struct vfio_dma *tmp = vfio_find_dma(iommu, iova - 1, 1); + if (tmp && tmp->prot == prot && + tmp->vaddr + NPAGE_TO_SIZE(tmp->npage) == vaddr) { + vfio_remove_dma(iommu, tmp); + dma->npage += tmp->npage; + dma->iova = iova = tmp->iova; + dma->vaddr = vaddr = tmp->vaddr; + kfree(tmp); npage = dma->npage; size = NPAGE_TO_SIZE(npage); - - pdma = dma; } } /* Check if we abut a region above - nothing above ~0 + 1 */ if (iova + size) { - dma = vfio_find_dma(iommu, iova + size, 1); - if (dma && dma->prot == prot && - dma->vaddr == vaddr + size) { - - dma->npage += npage; - dma->iova = iova; - dma->vaddr = vaddr; - - /* - * If merged above and below, remove previously - * merged entry. New entry covers it. - */ - if (pdma) { - list_del(&pdma->next); - kfree(pdma); - } - pdma = dma; + struct vfio_dma *tmp = vfio_find_dma(iommu, iova + size, 1); + if (tmp && tmp->prot == prot && + tmp->vaddr == vaddr + size) { + vfio_remove_dma(iommu, tmp); + dma->npage += tmp->npage; + kfree(tmp); + npage = dma->npage; + size = NPAGE_TO_SIZE(npage); } } - /* Isolated, new region */ - if (!pdma) { - dma = kzalloc(sizeof *dma, GFP_KERNEL); - if (!dma) { - ret = -ENOMEM; - vfio_dma_unmap(iommu, iova, npage, prot); - goto out_lock; - } - - dma->npage = npage; - dma->iova = iova; - dma->vaddr = vaddr; - dma->prot = prot; - list_add(&dma->next, &iommu->dma_list); - } + vfio_insert_dma(iommu, dma); out_lock: mutex_unlock(&iommu->lock); + if (ret) + kfree(dma); return ret; } @@ -606,7 +607,7 @@ static void *vfio_iommu_type1_open(unsigned long arg) return ERR_PTR(-ENOMEM); INIT_LIST_HEAD(&iommu->group_list); - INIT_LIST_HEAD(&iommu->dma_list); + iommu->dma_list = RB_ROOT; mutex_init(&iommu->lock); /* @@ -640,7 +641,7 @@ static void vfio_iommu_type1_release(void *iommu_data) { struct vfio_iommu *iommu = iommu_data; struct vfio_group *group, *group_tmp; - struct vfio_dma *dma, *dma_tmp; + struct rb_node *node; list_for_each_entry_safe(group, group_tmp, &iommu->group_list, next) { iommu_detach_group(iommu->domain, group->iommu_group); @@ -648,9 +649,10 @@ static void vfio_iommu_type1_release(void *iommu_data) kfree(group); } - list_for_each_entry_safe(dma, dma_tmp, &iommu->dma_list, next) { + while ((node = rb_first(&iommu->dma_list))) { + struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node); vfio_dma_unmap(iommu, dma->iova, dma->npage, dma->prot); - list_del(&dma->next); + vfio_remove_dma(iommu, dma); kfree(dma); } |