summaryrefslogtreecommitdiffstats
path: root/drivers/vfio
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/vfio')
-rw-r--r--drivers/vfio/vfio_iommu_type1.c17
1 files changed, 11 insertions, 6 deletions
diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c
index 372e4f626138..8549cb111627 100644
--- a/drivers/vfio/vfio_iommu_type1.c
+++ b/drivers/vfio/vfio_iommu_type1.c
@@ -380,10 +380,10 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
* first page and all consecutive pages with the same locking.
*/
static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
- long npage, unsigned long *pfn_base)
+ long npage, unsigned long *pfn_base,
+ bool lock_cap, unsigned long limit)
{
- unsigned long pfn = 0, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
- bool lock_cap = capable(CAP_IPC_LOCK);
+ unsigned long pfn = 0;
long ret, pinned = 0, lock_acct = 0;
bool rsvd;
dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
@@ -924,13 +924,15 @@ static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
unsigned long vaddr = dma->vaddr;
size_t size = map_size;
long npage;
- unsigned long pfn;
+ unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
+ bool lock_cap = capable(CAP_IPC_LOCK);
int ret = 0;
while (size) {
/* Pin a contiguous chunk of memory */
npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
- size >> PAGE_SHIFT, &pfn);
+ size >> PAGE_SHIFT, &pfn,
+ lock_cap, limit);
if (npage <= 0) {
WARN_ON(!npage);
ret = (int)npage;
@@ -1040,6 +1042,8 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
{
struct vfio_domain *d;
struct rb_node *n;
+ unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
+ bool lock_cap = capable(CAP_IPC_LOCK);
int ret;
/* Arbitrarily pick the first domain in the list for lookups */
@@ -1086,7 +1090,8 @@ static int vfio_iommu_replay(struct vfio_iommu *iommu,
npage = vfio_pin_pages_remote(dma, vaddr,
n >> PAGE_SHIFT,
- &pfn);
+ &pfn, lock_cap,
+ limit);
if (npage <= 0) {
WARN_ON(!npage);
ret = (int)npage;