summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--net/ipv6/mcast.c48
-rw-r--r--net/tipc/udp_media.c2
2 files changed, 25 insertions, 25 deletions
diff --git a/net/ipv6/mcast.c b/net/ipv6/mcast.c
index 49b0cebfdcdc..ff536a158b85 100644
--- a/net/ipv6/mcast.c
+++ b/net/ipv6/mcast.c
@@ -114,10 +114,13 @@ int sysctl_mld_qrv __read_mostly = MLD_QRV_DEFAULT;
#define mc_dereference(e, idev) \
rcu_dereference_protected(e, lockdep_is_held(&(idev)->mc_lock))
-#define for_each_pmc_rtnl(np, pmc) \
- for (pmc = rtnl_dereference((np)->ipv6_mc_list); \
+#define sock_dereference(e, sk) \
+ rcu_dereference_protected(e, lockdep_sock_is_held(sk))
+
+#define for_each_pmc_socklock(np, sk, pmc) \
+ for (pmc = sock_dereference((np)->ipv6_mc_list, sk); \
pmc; \
- pmc = rtnl_dereference(pmc->next))
+ pmc = sock_dereference(pmc->next, sk))
#define for_each_pmc_rcu(np, pmc) \
for (pmc = rcu_dereference((np)->ipv6_mc_list); \
@@ -180,7 +183,7 @@ static int __ipv6_sock_mc_join(struct sock *sk, int ifindex,
if (!ipv6_addr_is_multicast(addr))
return -EINVAL;
- for_each_pmc_rtnl(np, mc_lst) {
+ for_each_pmc_socklock(np, sk, mc_lst) {
if ((ifindex == 0 || mc_lst->ifindex == ifindex) &&
ipv6_addr_equal(&mc_lst->addr, addr))
return -EADDRINUSE;
@@ -258,7 +261,7 @@ int ipv6_sock_mc_drop(struct sock *sk, int ifindex, const struct in6_addr *addr)
return -EINVAL;
for (lnk = &np->ipv6_mc_list;
- (mc_lst = rtnl_dereference(*lnk)) != NULL;
+ (mc_lst = sock_dereference(*lnk, sk)) != NULL;
lnk = &mc_lst->next) {
if ((ifindex == 0 || mc_lst->ifindex == ifindex) &&
ipv6_addr_equal(&mc_lst->addr, addr)) {
@@ -323,7 +326,7 @@ void __ipv6_sock_mc_close(struct sock *sk)
ASSERT_RTNL();
- while ((mc_lst = rtnl_dereference(np->ipv6_mc_list)) != NULL) {
+ while ((mc_lst = sock_dereference(np->ipv6_mc_list, sk)) != NULL) {
struct net_device *dev;
np->ipv6_mc_list = mc_lst->next;
@@ -350,8 +353,11 @@ void ipv6_sock_mc_close(struct sock *sk)
if (!rcu_access_pointer(np->ipv6_mc_list))
return;
+
rtnl_lock();
+ lock_sock(sk);
__ipv6_sock_mc_close(sk);
+ release_sock(sk);
rtnl_unlock();
}
@@ -381,7 +387,7 @@ int ip6_mc_source(int add, int omode, struct sock *sk,
err = -EADDRNOTAVAIL;
mutex_lock(&idev->mc_lock);
- for_each_pmc_rtnl(inet6, pmc) {
+ for_each_pmc_socklock(inet6, sk, pmc) {
if (pgsr->gsr_interface && pmc->ifindex != pgsr->gsr_interface)
continue;
if (ipv6_addr_equal(&pmc->addr, group))
@@ -404,7 +410,7 @@ int ip6_mc_source(int add, int omode, struct sock *sk,
pmc->sfmode = omode;
}
- psl = rtnl_dereference(pmc->sflist);
+ psl = sock_dereference(pmc->sflist, sk);
if (!add) {
if (!psl)
goto done; /* err = -EADDRNOTAVAIL */
@@ -511,7 +517,7 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
goto done;
}
- for_each_pmc_rtnl(inet6, pmc) {
+ for_each_pmc_socklock(inet6, sk, pmc) {
if (pmc->ifindex != gsf->gf_interface)
continue;
if (ipv6_addr_equal(&pmc->addr, group))
@@ -552,7 +558,7 @@ int ip6_mc_msfilter(struct sock *sk, struct group_filter *gsf,
}
mutex_lock(&idev->mc_lock);
- psl = rtnl_dereference(pmc->sflist);
+ psl = sock_dereference(pmc->sflist, sk);
if (psl) {
ip6_mc_del_src(idev, group, pmc->sfmode,
psl->sl_count, psl->sl_addr, 0);
@@ -574,40 +580,32 @@ done:
int ip6_mc_msfget(struct sock *sk, struct group_filter *gsf,
struct sockaddr_storage __user *p)
{
- int err, i, count, copycount;
+ struct ipv6_pinfo *inet6 = inet6_sk(sk);
const struct in6_addr *group;
struct ipv6_mc_socklist *pmc;
- struct inet6_dev *idev;
- struct ipv6_pinfo *inet6 = inet6_sk(sk);
struct ip6_sf_socklist *psl;
- struct net *net = sock_net(sk);
+ int i, count, copycount;
group = &((struct sockaddr_in6 *)&gsf->gf_group)->sin6_addr;
if (!ipv6_addr_is_multicast(group))
return -EINVAL;
- idev = ip6_mc_find_dev_rtnl(net, group, gsf->gf_interface);
- if (!idev)
- return -ENODEV;
-
- err = -EADDRNOTAVAIL;
/* changes to the ipv6_mc_list require the socket lock and
- * rtnl lock. We have the socket lock and rcu read lock,
- * so reading the list is safe.
+ * rtnl lock. We have the socket lock, so reading the list is safe.
*/
- for_each_pmc_rtnl(inet6, pmc) {
+ for_each_pmc_socklock(inet6, sk, pmc) {
if (pmc->ifindex != gsf->gf_interface)
continue;
if (ipv6_addr_equal(group, &pmc->addr))
break;
}
if (!pmc) /* must have a prior join */
- return err;
+ return -EADDRNOTAVAIL;
gsf->gf_fmode = pmc->sfmode;
- psl = rtnl_dereference(pmc->sflist);
+ psl = sock_dereference(pmc->sflist, sk);
count = psl ? psl->sl_count : 0;
copycount = count < gsf->gf_numsrc ? count : gsf->gf_numsrc;
@@ -2600,7 +2598,7 @@ static int ip6_mc_leave_src(struct sock *sk, struct ipv6_mc_socklist *iml,
struct ip6_sf_socklist *psl;
int err;
- psl = rtnl_dereference(iml->sflist);
+ psl = sock_dereference(iml->sflist, sk);
if (idev)
mutex_lock(&idev->mc_lock);
diff --git a/net/tipc/udp_media.c b/net/tipc/udp_media.c
index 21e75e28e86a..e556d2cdc064 100644
--- a/net/tipc/udp_media.c
+++ b/net/tipc/udp_media.c
@@ -414,8 +414,10 @@ static int enable_mcast(struct udp_bearer *ub, struct udp_media_addr *remote)
err = ip_mc_join_group(sk, &mreqn);
#if IS_ENABLED(CONFIG_IPV6)
} else {
+ lock_sock(sk);
err = ipv6_stub->ipv6_sock_mc_join(sk, ub->ifindex,
&remote->ipv6);
+ release_sock(sk);
#endif
}
return err;