diff options
-rw-r--r-- | include/net/sock.h | 6 | ||||
-rw-r--r-- | mm/memcontrol.c | 19 | ||||
-rw-r--r-- | net/core/sock.c | 2 |
3 files changed, 22 insertions, 5 deletions
diff --git a/include/net/sock.h b/include/net/sock.h index bb972d254dff..0ed65e3a0bea 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -1103,6 +1103,12 @@ sk_sockets_allocated_read_positive(struct sock *sk) return percpu_counter_sum_positive(prot->sockets_allocated); } +static inline void sk_update_clone(const struct sock *sk, struct sock *newsk) +{ + if (mem_cgroup_sockets_enabled && sk->sk_cgrp) + sock_update_memcg(newsk); +} + static inline int proto_sockets_allocated_sum_positive(struct proto *prot) { diff --git a/mm/memcontrol.c b/mm/memcontrol.c index 94da8ee9e2c2..9c72d5d5372a 100644 --- a/mm/memcontrol.c +++ b/mm/memcontrol.c @@ -381,16 +381,25 @@ static void mem_cgroup_put(struct mem_cgroup *memcg); static bool mem_cgroup_is_root(struct mem_cgroup *memcg); void sock_update_memcg(struct sock *sk) { - /* A socket spends its whole life in the same cgroup */ - if (sk->sk_cgrp) { - WARN_ON(1); - return; - } if (static_branch(&memcg_socket_limit_enabled)) { struct mem_cgroup *memcg; BUG_ON(!sk->sk_prot->proto_cgroup); + /* Socket cloning can throw us here with sk_cgrp already + * filled. It won't however, necessarily happen from + * process context. So the test for root memcg given + * the current task's memcg won't help us in this case. + * + * Respecting the original socket's memcg is a better + * decision in this case. + */ + if (sk->sk_cgrp) { + BUG_ON(mem_cgroup_is_root(sk->sk_cgrp->memcg)); + mem_cgroup_get(sk->sk_cgrp->memcg); + return; + } + rcu_read_lock(); memcg = mem_cgroup_from_task(current); if (!mem_cgroup_is_root(memcg)) { diff --git a/net/core/sock.c b/net/core/sock.c index 002939cfc069..e80b64fbd663 100644 --- a/net/core/sock.c +++ b/net/core/sock.c @@ -1362,6 +1362,8 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority) sk_set_socket(newsk, NULL); newsk->sk_wq = NULL; + sk_update_clone(sk, newsk); + if (newsk->sk_prot->sockets_allocated) sk_sockets_allocated_inc(newsk); |