summaryrefslogtreecommitdiffstats
path: root/drivers/infiniband/core/cache.c
diff options
context:
space:
mode:
Diffstat (limited to 'drivers/infiniband/core/cache.c')
-rw-r--r--drivers/infiniband/core/cache.c73
1 files changed, 61 insertions, 12 deletions
diff --git a/drivers/infiniband/core/cache.c b/drivers/infiniband/core/cache.c
index a53c7713d77a..099d922ae7bd 100644
--- a/drivers/infiniband/core/cache.c
+++ b/drivers/infiniband/core/cache.c
@@ -78,11 +78,22 @@ enum gid_table_entry_state {
GID_TABLE_ENTRY_PENDING_DEL = 3,
};
+struct roce_gid_ndev_storage {
+ struct rcu_head rcu_head;
+ struct net_device *ndev;
+};
+
struct ib_gid_table_entry {
struct kref kref;
struct work_struct del_work;
struct ib_gid_attr attr;
void *context;
+ /* Store the ndev pointer to release reference later on in
+ * call_rcu context because by that time gid_table_entry
+ * and attr might be already freed. So keep a copy of it.
+ * ndev_storage is freed by rcu callback.
+ */
+ struct roce_gid_ndev_storage *ndev_storage;
enum gid_table_entry_state state;
};
@@ -206,6 +217,20 @@ static void schedule_free_gid(struct kref *kref)
queue_work(ib_wq, &entry->del_work);
}
+static void put_gid_ndev(struct rcu_head *head)
+{
+ struct roce_gid_ndev_storage *storage =
+ container_of(head, struct roce_gid_ndev_storage, rcu_head);
+
+ WARN_ON(!storage->ndev);
+ /* At this point its safe to release netdev reference,
+ * as all callers working on gid_attr->ndev are done
+ * using this netdev.
+ */
+ dev_put(storage->ndev);
+ kfree(storage);
+}
+
static void free_gid_entry_locked(struct ib_gid_table_entry *entry)
{
struct ib_device *device = entry->attr.device;
@@ -228,8 +253,8 @@ static void free_gid_entry_locked(struct ib_gid_table_entry *entry)
/* Now this index is ready to be allocated */
write_unlock_irq(&table->rwlock);
- if (entry->attr.ndev)
- dev_put(entry->attr.ndev);
+ if (entry->ndev_storage)
+ call_rcu(&entry->ndev_storage->rcu_head, put_gid_ndev);
kfree(entry);
}
@@ -266,14 +291,25 @@ static struct ib_gid_table_entry *
alloc_gid_entry(const struct ib_gid_attr *attr)
{
struct ib_gid_table_entry *entry;
+ struct net_device *ndev;
entry = kzalloc(sizeof(*entry), GFP_KERNEL);
if (!entry)
return NULL;
+
+ ndev = rcu_dereference_protected(attr->ndev, 1);
+ if (ndev) {
+ entry->ndev_storage = kzalloc(sizeof(*entry->ndev_storage),
+ GFP_KERNEL);
+ if (!entry->ndev_storage) {
+ kfree(entry);
+ return NULL;
+ }
+ dev_hold(ndev);
+ entry->ndev_storage->ndev = ndev;
+ }
kref_init(&entry->kref);
memcpy(&entry->attr, attr, sizeof(*attr));
- if (entry->attr.ndev)
- dev_hold(entry->attr.ndev);
INIT_WORK(&entry->del_work, free_gid_work);
entry->state = GID_TABLE_ENTRY_INVALID;
return entry;
@@ -343,6 +379,7 @@ static int add_roce_gid(struct ib_gid_table_entry *entry)
static void del_gid(struct ib_device *ib_dev, u8 port,
struct ib_gid_table *table, int ix)
{
+ struct roce_gid_ndev_storage *ndev_storage;
struct ib_gid_table_entry *entry;
lockdep_assert_held(&table->lock);
@@ -360,6 +397,13 @@ static void del_gid(struct ib_device *ib_dev, u8 port,
table->data_vec[ix] = NULL;
write_unlock_irq(&table->rwlock);
+ ndev_storage = entry->ndev_storage;
+ if (ndev_storage) {
+ entry->ndev_storage = NULL;
+ rcu_assign_pointer(entry->attr.ndev, NULL);
+ call_rcu(&ndev_storage->rcu_head, put_gid_ndev);
+ }
+
if (rdma_cap_roce_gid_table(ib_dev, port))
ib_dev->ops.del_gid(&entry->attr, &entry->context);
@@ -1244,8 +1288,12 @@ struct net_device *rdma_read_gid_attr_ndev_rcu(const struct ib_gid_attr *attr)
read_lock_irqsave(&table->rwlock, flags);
valid = is_gid_entry_valid(table->data_vec[attr->index]);
- if (valid && attr->ndev && (READ_ONCE(attr->ndev->flags) & IFF_UP))
- ndev = attr->ndev;
+ if (valid) {
+ ndev = rcu_dereference(attr->ndev);
+ if (!ndev ||
+ (ndev && ((READ_ONCE(ndev->flags) & IFF_UP) == 0)))
+ ndev = ERR_PTR(-ENODEV);
+ }
read_unlock_irqrestore(&table->rwlock, flags);
return ndev;
}
@@ -1281,10 +1329,12 @@ int rdma_read_gid_l2_fields(const struct ib_gid_attr *attr,
{
struct net_device *ndev;
- ndev = attr->ndev;
- if (!ndev)
- return -EINVAL;
-
+ rcu_read_lock();
+ ndev = rcu_dereference(attr->ndev);
+ if (!ndev) {
+ rcu_read_unlock();
+ return -ENODEV;
+ }
if (smac)
ether_addr_copy(smac, ndev->dev_addr);
if (vlan_id) {
@@ -1296,12 +1346,11 @@ int rdma_read_gid_l2_fields(const struct ib_gid_attr *attr,
* device is vlan device, consider vlan id of the
* the lower vlan device for this gid entry.
*/
- rcu_read_lock();
netdev_walk_all_lower_dev_rcu(attr->ndev,
get_lower_dev_vlan, vlan_id);
- rcu_read_unlock();
}
}
+ rcu_read_unlock();
return 0;
}
EXPORT_SYMBOL(rdma_read_gid_l2_fields);