diff options
Diffstat (limited to 'drivers/infiniband/core/cache.c')
-rw-r--r-- | drivers/infiniband/core/cache.c | 73 |
1 files changed, 61 insertions, 12 deletions
diff --git a/drivers/infiniband/core/cache.c b/drivers/infiniband/core/cache.c index a53c7713d77a..099d922ae7bd 100644 --- a/drivers/infiniband/core/cache.c +++ b/drivers/infiniband/core/cache.c @@ -78,11 +78,22 @@ enum gid_table_entry_state { GID_TABLE_ENTRY_PENDING_DEL = 3, }; +struct roce_gid_ndev_storage { + struct rcu_head rcu_head; + struct net_device *ndev; +}; + struct ib_gid_table_entry { struct kref kref; struct work_struct del_work; struct ib_gid_attr attr; void *context; + /* Store the ndev pointer to release reference later on in + * call_rcu context because by that time gid_table_entry + * and attr might be already freed. So keep a copy of it. + * ndev_storage is freed by rcu callback. + */ + struct roce_gid_ndev_storage *ndev_storage; enum gid_table_entry_state state; }; @@ -206,6 +217,20 @@ static void schedule_free_gid(struct kref *kref) queue_work(ib_wq, &entry->del_work); } +static void put_gid_ndev(struct rcu_head *head) +{ + struct roce_gid_ndev_storage *storage = + container_of(head, struct roce_gid_ndev_storage, rcu_head); + + WARN_ON(!storage->ndev); + /* At this point its safe to release netdev reference, + * as all callers working on gid_attr->ndev are done + * using this netdev. + */ + dev_put(storage->ndev); + kfree(storage); +} + static void free_gid_entry_locked(struct ib_gid_table_entry *entry) { struct ib_device *device = entry->attr.device; @@ -228,8 +253,8 @@ static void free_gid_entry_locked(struct ib_gid_table_entry *entry) /* Now this index is ready to be allocated */ write_unlock_irq(&table->rwlock); - if (entry->attr.ndev) - dev_put(entry->attr.ndev); + if (entry->ndev_storage) + call_rcu(&entry->ndev_storage->rcu_head, put_gid_ndev); kfree(entry); } @@ -266,14 +291,25 @@ static struct ib_gid_table_entry * alloc_gid_entry(const struct ib_gid_attr *attr) { struct ib_gid_table_entry *entry; + struct net_device *ndev; entry = kzalloc(sizeof(*entry), GFP_KERNEL); if (!entry) return NULL; + + ndev = rcu_dereference_protected(attr->ndev, 1); + if (ndev) { + entry->ndev_storage = kzalloc(sizeof(*entry->ndev_storage), + GFP_KERNEL); + if (!entry->ndev_storage) { + kfree(entry); + return NULL; + } + dev_hold(ndev); + entry->ndev_storage->ndev = ndev; + } kref_init(&entry->kref); memcpy(&entry->attr, attr, sizeof(*attr)); - if (entry->attr.ndev) - dev_hold(entry->attr.ndev); INIT_WORK(&entry->del_work, free_gid_work); entry->state = GID_TABLE_ENTRY_INVALID; return entry; @@ -343,6 +379,7 @@ static int add_roce_gid(struct ib_gid_table_entry *entry) static void del_gid(struct ib_device *ib_dev, u8 port, struct ib_gid_table *table, int ix) { + struct roce_gid_ndev_storage *ndev_storage; struct ib_gid_table_entry *entry; lockdep_assert_held(&table->lock); @@ -360,6 +397,13 @@ static void del_gid(struct ib_device *ib_dev, u8 port, table->data_vec[ix] = NULL; write_unlock_irq(&table->rwlock); + ndev_storage = entry->ndev_storage; + if (ndev_storage) { + entry->ndev_storage = NULL; + rcu_assign_pointer(entry->attr.ndev, NULL); + call_rcu(&ndev_storage->rcu_head, put_gid_ndev); + } + if (rdma_cap_roce_gid_table(ib_dev, port)) ib_dev->ops.del_gid(&entry->attr, &entry->context); @@ -1244,8 +1288,12 @@ struct net_device *rdma_read_gid_attr_ndev_rcu(const struct ib_gid_attr *attr) read_lock_irqsave(&table->rwlock, flags); valid = is_gid_entry_valid(table->data_vec[attr->index]); - if (valid && attr->ndev && (READ_ONCE(attr->ndev->flags) & IFF_UP)) - ndev = attr->ndev; + if (valid) { + ndev = rcu_dereference(attr->ndev); + if (!ndev || + (ndev && ((READ_ONCE(ndev->flags) & IFF_UP) == 0))) + ndev = ERR_PTR(-ENODEV); + } read_unlock_irqrestore(&table->rwlock, flags); return ndev; } @@ -1281,10 +1329,12 @@ int rdma_read_gid_l2_fields(const struct ib_gid_attr *attr, { struct net_device *ndev; - ndev = attr->ndev; - if (!ndev) - return -EINVAL; - + rcu_read_lock(); + ndev = rcu_dereference(attr->ndev); + if (!ndev) { + rcu_read_unlock(); + return -ENODEV; + } if (smac) ether_addr_copy(smac, ndev->dev_addr); if (vlan_id) { @@ -1296,12 +1346,11 @@ int rdma_read_gid_l2_fields(const struct ib_gid_attr *attr, * device is vlan device, consider vlan id of the * the lower vlan device for this gid entry. */ - rcu_read_lock(); netdev_walk_all_lower_dev_rcu(attr->ndev, get_lower_dev_vlan, vlan_id); - rcu_read_unlock(); } } + rcu_read_unlock(); return 0; } EXPORT_SYMBOL(rdma_read_gid_l2_fields); |