summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--include/net/sock.h6
-rw-r--r--mm/memcontrol.c19
-rw-r--r--net/core/sock.c2
3 files changed, 22 insertions, 5 deletions
diff --git a/include/net/sock.h b/include/net/sock.h
index bb972d254dff..0ed65e3a0bea 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -1103,6 +1103,12 @@ sk_sockets_allocated_read_positive(struct sock *sk)
return percpu_counter_sum_positive(prot->sockets_allocated);
}
+static inline void sk_update_clone(const struct sock *sk, struct sock *newsk)
+{
+ if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
+ sock_update_memcg(newsk);
+}
+
static inline int
proto_sockets_allocated_sum_positive(struct proto *prot)
{
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 94da8ee9e2c2..9c72d5d5372a 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -381,16 +381,25 @@ static void mem_cgroup_put(struct mem_cgroup *memcg);
static bool mem_cgroup_is_root(struct mem_cgroup *memcg);
void sock_update_memcg(struct sock *sk)
{
- /* A socket spends its whole life in the same cgroup */
- if (sk->sk_cgrp) {
- WARN_ON(1);
- return;
- }
if (static_branch(&memcg_socket_limit_enabled)) {
struct mem_cgroup *memcg;
BUG_ON(!sk->sk_prot->proto_cgroup);
+ /* Socket cloning can throw us here with sk_cgrp already
+ * filled. It won't however, necessarily happen from
+ * process context. So the test for root memcg given
+ * the current task's memcg won't help us in this case.
+ *
+ * Respecting the original socket's memcg is a better
+ * decision in this case.
+ */
+ if (sk->sk_cgrp) {
+ BUG_ON(mem_cgroup_is_root(sk->sk_cgrp->memcg));
+ mem_cgroup_get(sk->sk_cgrp->memcg);
+ return;
+ }
+
rcu_read_lock();
memcg = mem_cgroup_from_task(current);
if (!mem_cgroup_is_root(memcg)) {
diff --git a/net/core/sock.c b/net/core/sock.c
index 002939cfc069..e80b64fbd663 100644
--- a/net/core/sock.c
+++ b/net/core/sock.c
@@ -1362,6 +1362,8 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority)
sk_set_socket(newsk, NULL);
newsk->sk_wq = NULL;
+ sk_update_clone(sk, newsk);
+
if (newsk->sk_prot->sockets_allocated)
sk_sockets_allocated_inc(newsk);