diff options
Diffstat (limited to 'kernel/bpf/verifier.c')
-rw-r--r-- | kernel/bpf/verifier.c | 1676 |
1 files changed, 673 insertions, 1003 deletions
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index af2819d5c8ee..405da1f9e724 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -342,27 +342,6 @@ struct btf *btf_vmlinux; static DEFINE_MUTEX(bpf_verifier_lock); static DEFINE_MUTEX(bpf_percpu_ma_lock); -static const struct bpf_line_info * -find_linfo(const struct bpf_verifier_env *env, u32 insn_off) -{ - const struct bpf_line_info *linfo; - const struct bpf_prog *prog; - u32 i, nr_linfo; - - prog = env->prog; - nr_linfo = prog->aux->nr_linfo; - - if (!nr_linfo || insn_off >= prog->len) - return NULL; - - linfo = prog->aux->linfo; - for (i = 1; i < nr_linfo; i++) - if (insn_off < linfo[i].insn_off) - break; - - return &linfo[i - 1]; -} - __printf(2, 3) static void verbose(void *private_data, const char *fmt, ...) { struct bpf_verifier_env *env = private_data; @@ -376,42 +355,6 @@ __printf(2, 3) static void verbose(void *private_data, const char *fmt, ...) va_end(args); } -static const char *ltrim(const char *s) -{ - while (isspace(*s)) - s++; - - return s; -} - -__printf(3, 4) static void verbose_linfo(struct bpf_verifier_env *env, - u32 insn_off, - const char *prefix_fmt, ...) -{ - const struct bpf_line_info *linfo; - - if (!bpf_verifier_log_needed(&env->log)) - return; - - linfo = find_linfo(env, insn_off); - if (!linfo || linfo == env->prev_linfo) - return; - - if (prefix_fmt) { - va_list args; - - va_start(args, prefix_fmt); - bpf_verifier_vlog(&env->log, prefix_fmt, args); - va_end(args); - } - - verbose(env, "%s\n", - ltrim(btf_name_by_offset(env->prog->aux->btf, - linfo->line_off))); - - env->prev_linfo = linfo; -} - static void verbose_invalid_scalar(struct bpf_verifier_env *env, struct bpf_reg_state *reg, struct tnum *range, const char *ctx, @@ -430,21 +373,6 @@ static void verbose_invalid_scalar(struct bpf_verifier_env *env, verbose(env, " should have been in %s\n", tn_buf); } -static bool type_is_pkt_pointer(enum bpf_reg_type type) -{ - type = base_type(type); - return type == PTR_TO_PACKET || - type == PTR_TO_PACKET_META; -} - -static bool type_is_sk_pointer(enum bpf_reg_type type) -{ - return type == PTR_TO_SOCKET || - type == PTR_TO_SOCK_COMMON || - type == PTR_TO_TCP_SOCK || - type == PTR_TO_XDP_SOCK; -} - static bool type_may_be_null(u32 type) { return type & PTR_MAYBE_NULL; @@ -468,16 +396,6 @@ static bool reg_not_null(const struct bpf_reg_state *reg) type == PTR_TO_MEM; } -static bool type_is_ptr_alloc_obj(u32 type) -{ - return base_type(type) == PTR_TO_BTF_ID && type_flag(type) & MEM_ALLOC; -} - -static bool type_is_non_owning_ref(u32 type) -{ - return type_is_ptr_alloc_obj(type) && type_flag(type) & NON_OWN_REF; -} - static struct btf_record *reg_btf_record(const struct bpf_reg_state *reg) { struct btf_record *rec = NULL; @@ -605,83 +523,6 @@ static bool is_cmpxchg_insn(const struct bpf_insn *insn) insn->imm == BPF_CMPXCHG; } -/* string representation of 'enum bpf_reg_type' - * - * Note that reg_type_str() can not appear more than once in a single verbose() - * statement. - */ -static const char *reg_type_str(struct bpf_verifier_env *env, - enum bpf_reg_type type) -{ - char postfix[16] = {0}, prefix[64] = {0}; - static const char * const str[] = { - [NOT_INIT] = "?", - [SCALAR_VALUE] = "scalar", - [PTR_TO_CTX] = "ctx", - [CONST_PTR_TO_MAP] = "map_ptr", - [PTR_TO_MAP_VALUE] = "map_value", - [PTR_TO_STACK] = "fp", - [PTR_TO_PACKET] = "pkt", - [PTR_TO_PACKET_META] = "pkt_meta", - [PTR_TO_PACKET_END] = "pkt_end", - [PTR_TO_FLOW_KEYS] = "flow_keys", - [PTR_TO_SOCKET] = "sock", - [PTR_TO_SOCK_COMMON] = "sock_common", - [PTR_TO_TCP_SOCK] = "tcp_sock", - [PTR_TO_TP_BUFFER] = "tp_buffer", - [PTR_TO_XDP_SOCK] = "xdp_sock", - [PTR_TO_BTF_ID] = "ptr_", - [PTR_TO_MEM] = "mem", - [PTR_TO_BUF] = "buf", - [PTR_TO_FUNC] = "func", - [PTR_TO_MAP_KEY] = "map_key", - [CONST_PTR_TO_DYNPTR] = "dynptr_ptr", - }; - - if (type & PTR_MAYBE_NULL) { - if (base_type(type) == PTR_TO_BTF_ID) - strncpy(postfix, "or_null_", 16); - else - strncpy(postfix, "_or_null", 16); - } - - snprintf(prefix, sizeof(prefix), "%s%s%s%s%s%s%s", - type & MEM_RDONLY ? "rdonly_" : "", - type & MEM_RINGBUF ? "ringbuf_" : "", - type & MEM_USER ? "user_" : "", - type & MEM_PERCPU ? "percpu_" : "", - type & MEM_RCU ? "rcu_" : "", - type & PTR_UNTRUSTED ? "untrusted_" : "", - type & PTR_TRUSTED ? "trusted_" : "" - ); - - snprintf(env->tmp_str_buf, TMP_STR_BUF_LEN, "%s%s%s", - prefix, str[base_type(type)], postfix); - return env->tmp_str_buf; -} - -static char slot_type_char[] = { - [STACK_INVALID] = '?', - [STACK_SPILL] = 'r', - [STACK_MISC] = 'm', - [STACK_ZERO] = '0', - [STACK_DYNPTR] = 'd', - [STACK_ITER] = 'i', -}; - -static void print_liveness(struct bpf_verifier_env *env, - enum bpf_reg_liveness live) -{ - if (live & (REG_LIVE_READ | REG_LIVE_WRITTEN | REG_LIVE_DONE)) - verbose(env, "_"); - if (live & REG_LIVE_READ) - verbose(env, "r"); - if (live & REG_LIVE_WRITTEN) - verbose(env, "w"); - if (live & REG_LIVE_DONE) - verbose(env, "D"); -} - static int __get_spi(s32 off) { return (-off - 1) / BPF_REG_SIZE; @@ -751,87 +592,6 @@ static const char *btf_type_name(const struct btf *btf, u32 id) return btf_name_by_offset(btf, btf_type_by_id(btf, id)->name_off); } -static const char *dynptr_type_str(enum bpf_dynptr_type type) -{ - switch (type) { - case BPF_DYNPTR_TYPE_LOCAL: - return "local"; - case BPF_DYNPTR_TYPE_RINGBUF: - return "ringbuf"; - case BPF_DYNPTR_TYPE_SKB: - return "skb"; - case BPF_DYNPTR_TYPE_XDP: - return "xdp"; - case BPF_DYNPTR_TYPE_INVALID: - return "<invalid>"; - default: - WARN_ONCE(1, "unknown dynptr type %d\n", type); - return "<unknown>"; - } -} - -static const char *iter_type_str(const struct btf *btf, u32 btf_id) -{ - if (!btf || btf_id == 0) - return "<invalid>"; - - /* we already validated that type is valid and has conforming name */ - return btf_type_name(btf, btf_id) + sizeof(ITER_PREFIX) - 1; -} - -static const char *iter_state_str(enum bpf_iter_state state) -{ - switch (state) { - case BPF_ITER_STATE_ACTIVE: - return "active"; - case BPF_ITER_STATE_DRAINED: - return "drained"; - case BPF_ITER_STATE_INVALID: - return "<invalid>"; - default: - WARN_ONCE(1, "unknown iter state %d\n", state); - return "<unknown>"; - } -} - -static void mark_reg_scratched(struct bpf_verifier_env *env, u32 regno) -{ - env->scratched_regs |= 1U << regno; -} - -static void mark_stack_slot_scratched(struct bpf_verifier_env *env, u32 spi) -{ - env->scratched_stack_slots |= 1ULL << spi; -} - -static bool reg_scratched(const struct bpf_verifier_env *env, u32 regno) -{ - return (env->scratched_regs >> regno) & 1; -} - -static bool stack_slot_scratched(const struct bpf_verifier_env *env, u64 regno) -{ - return (env->scratched_stack_slots >> regno) & 1; -} - -static bool verifier_state_scratched(const struct bpf_verifier_env *env) -{ - return env->scratched_regs || env->scratched_stack_slots; -} - -static void mark_verifier_state_clean(struct bpf_verifier_env *env) -{ - env->scratched_regs = 0U; - env->scratched_stack_slots = 0ULL; -} - -/* Used for printing the entire verifier state. */ -static void mark_verifier_state_scratched(struct bpf_verifier_env *env) -{ - env->scratched_regs = ~0U; - env->scratched_stack_slots = ~0ULL; -} - static enum bpf_dynptr_type arg_to_dynptr_type(enum bpf_arg_type arg_type) { switch (arg_type & DYNPTR_TYPE_FLAG_MASK) { @@ -1371,226 +1131,6 @@ static void scrub_spilled_slot(u8 *stype) *stype = STACK_MISC; } -static void print_scalar_ranges(struct bpf_verifier_env *env, - const struct bpf_reg_state *reg, - const char **sep) -{ - struct { - const char *name; - u64 val; - bool omit; - } minmaxs[] = { - {"smin", reg->smin_value, reg->smin_value == S64_MIN}, - {"smax", reg->smax_value, reg->smax_value == S64_MAX}, - {"umin", reg->umin_value, reg->umin_value == 0}, - {"umax", reg->umax_value, reg->umax_value == U64_MAX}, - {"smin32", (s64)reg->s32_min_value, reg->s32_min_value == S32_MIN}, - {"smax32", (s64)reg->s32_max_value, reg->s32_max_value == S32_MAX}, - {"umin32", reg->u32_min_value, reg->u32_min_value == 0}, - {"umax32", reg->u32_max_value, reg->u32_max_value == U32_MAX}, - }, *m1, *m2, *mend = &minmaxs[ARRAY_SIZE(minmaxs)]; - bool neg1, neg2; - - for (m1 = &minmaxs[0]; m1 < mend; m1++) { - if (m1->omit) - continue; - - neg1 = m1->name[0] == 's' && (s64)m1->val < 0; - - verbose(env, "%s%s=", *sep, m1->name); - *sep = ","; - - for (m2 = m1 + 2; m2 < mend; m2 += 2) { - if (m2->omit || m2->val != m1->val) - continue; - /* don't mix negatives with positives */ - neg2 = m2->name[0] == 's' && (s64)m2->val < 0; - if (neg2 != neg1) - continue; - m2->omit = true; - verbose(env, "%s=", m2->name); - } - - verbose(env, m1->name[0] == 's' ? "%lld" : "%llu", m1->val); - } -} - -static void print_verifier_state(struct bpf_verifier_env *env, - const struct bpf_func_state *state, - bool print_all) -{ - const struct bpf_reg_state *reg; - enum bpf_reg_type t; - int i; - - if (state->frameno) - verbose(env, " frame%d:", state->frameno); - for (i = 0; i < MAX_BPF_REG; i++) { - reg = &state->regs[i]; - t = reg->type; - if (t == NOT_INIT) - continue; - if (!print_all && !reg_scratched(env, i)) - continue; - verbose(env, " R%d", i); - print_liveness(env, reg->live); - verbose(env, "="); - if (t == SCALAR_VALUE && reg->precise) - verbose(env, "P"); - if ((t == SCALAR_VALUE || t == PTR_TO_STACK) && - tnum_is_const(reg->var_off)) { - /* reg->off should be 0 for SCALAR_VALUE */ - verbose(env, "%s", t == SCALAR_VALUE ? "" : reg_type_str(env, t)); - verbose(env, "%lld", reg->var_off.value + reg->off); - } else { - const char *sep = ""; - - verbose(env, "%s", reg_type_str(env, t)); - if (base_type(t) == PTR_TO_BTF_ID) - verbose(env, "%s", btf_type_name(reg->btf, reg->btf_id)); - verbose(env, "("); -/* - * _a stands for append, was shortened to avoid multiline statements below. - * This macro is used to output a comma separated list of attributes. - */ -#define verbose_a(fmt, ...) ({ verbose(env, "%s" fmt, sep, __VA_ARGS__); sep = ","; }) - - if (reg->id) - verbose_a("id=%d", reg->id); - if (reg->ref_obj_id) - verbose_a("ref_obj_id=%d", reg->ref_obj_id); - if (type_is_non_owning_ref(reg->type)) - verbose_a("%s", "non_own_ref"); - if (t != SCALAR_VALUE) - verbose_a("off=%d", reg->off); - if (type_is_pkt_pointer(t)) - verbose_a("r=%d", reg->range); - else if (base_type(t) == CONST_PTR_TO_MAP || - base_type(t) == PTR_TO_MAP_KEY || - base_type(t) == PTR_TO_MAP_VALUE) - verbose_a("ks=%d,vs=%d", - reg->map_ptr->key_size, - reg->map_ptr->value_size); - if (tnum_is_const(reg->var_off)) { - /* Typically an immediate SCALAR_VALUE, but - * could be a pointer whose offset is too big - * for reg->off - */ - verbose_a("imm=%llx", reg->var_off.value); - } else { - print_scalar_ranges(env, reg, &sep); - if (!tnum_is_unknown(reg->var_off)) { - char tn_buf[48]; - - tnum_strn(tn_buf, sizeof(tn_buf), reg->var_off); - verbose_a("var_off=%s", tn_buf); - } - } -#undef verbose_a - - verbose(env, ")"); - } - } - for (i = 0; i < state->allocated_stack / BPF_REG_SIZE; i++) { - char types_buf[BPF_REG_SIZE + 1]; - bool valid = false; - int j; - - for (j = 0; j < BPF_REG_SIZE; j++) { - if (state->stack[i].slot_type[j] != STACK_INVALID) - valid = true; - types_buf[j] = slot_type_char[state->stack[i].slot_type[j]]; - } - types_buf[BPF_REG_SIZE] = 0; - if (!valid) - continue; - if (!print_all && !stack_slot_scratched(env, i)) - continue; - switch (state->stack[i].slot_type[BPF_REG_SIZE - 1]) { - case STACK_SPILL: - reg = &state->stack[i].spilled_ptr; - t = reg->type; - - verbose(env, " fp%d", (-i - 1) * BPF_REG_SIZE); - print_liveness(env, reg->live); - verbose(env, "=%s", t == SCALAR_VALUE ? "" : reg_type_str(env, t)); - if (t == SCALAR_VALUE && reg->precise) - verbose(env, "P"); - if (t == SCALAR_VALUE && tnum_is_const(reg->var_off)) - verbose(env, "%lld", reg->var_off.value + reg->off); - break; - case STACK_DYNPTR: - i += BPF_DYNPTR_NR_SLOTS - 1; - reg = &state->stack[i].spilled_ptr; - - verbose(env, " fp%d", (-i - 1) * BPF_REG_SIZE); - print_liveness(env, reg->live); - verbose(env, "=dynptr_%s", dynptr_type_str(reg->dynptr.type)); - if (reg->ref_obj_id) - verbose(env, "(ref_id=%d)", reg->ref_obj_id); - break; - case STACK_ITER: - /* only main slot has ref_obj_id set; skip others */ - reg = &state->stack[i].spilled_ptr; - if (!reg->ref_obj_id) - continue; - - verbose(env, " fp%d", (-i - 1) * BPF_REG_SIZE); - print_liveness(env, reg->live); - verbose(env, "=iter_%s(ref_id=%d,state=%s,depth=%u)", - iter_type_str(reg->iter.btf, reg->iter.btf_id), - reg->ref_obj_id, iter_state_str(reg->iter.state), - reg->iter.depth); - break; - case STACK_MISC: - case STACK_ZERO: - default: - reg = &state->stack[i].spilled_ptr; - - for (j = 0; j < BPF_REG_SIZE; j++) - types_buf[j] = slot_type_char[state->stack[i].slot_type[j]]; - types_buf[BPF_REG_SIZE] = 0; - - verbose(env, " fp%d", (-i - 1) * BPF_REG_SIZE); - print_liveness(env, reg->live); - verbose(env, "=%s", types_buf); - break; - } - } - if (state->acquired_refs && state->refs[0].id) { - verbose(env, " refs=%d", state->refs[0].id); - for (i = 1; i < state->acquired_refs; i++) - if (state->refs[i].id) - verbose(env, ",%d", state->refs[i].id); - } - if (state->in_callback_fn) - verbose(env, " cb"); - if (state->in_async_callback_fn) - verbose(env, " async_cb"); - verbose(env, "\n"); - if (!print_all) - mark_verifier_state_clean(env); -} - -static inline u32 vlog_alignment(u32 pos) -{ - return round_up(max(pos + BPF_LOG_MIN_ALIGNMENT / 2, BPF_LOG_ALIGNMENT), - BPF_LOG_MIN_ALIGNMENT) - pos - 1; -} - -static void print_insn_state(struct bpf_verifier_env *env, - const struct bpf_func_state *state) -{ - if (env->prev_log_pos && env->prev_log_pos == env->log.end_pos) { - /* remove new line character */ - bpf_vlog_reset(&env->log, env->prev_log_pos - 1); - verbose(env, "%*c;", vlog_alignment(env->prev_insn_print_pos), ' '); - } else { - verbose(env, "%d:", env->insn_idx); - } - print_verifier_state(env, state, false); -} - /* copy array src of length n * size bytes to dst. dst is reallocated if it's too * small to hold src. This is different from krealloc since we don't want to preserve * the contents of dst. @@ -2341,69 +1881,214 @@ static void __update_reg_bounds(struct bpf_reg_state *reg) /* Uses signed min/max values to inform unsigned, and vice-versa */ static void __reg32_deduce_bounds(struct bpf_reg_state *reg) { - /* Learn sign from signed bounds. - * If we cannot cross the sign boundary, then signed and unsigned bounds - * are the same, so combine. This works even in the negative case, e.g. - * -3 s<= x s<= -1 implies 0xf...fd u<= x u<= 0xf...ff. + /* If upper 32 bits of u64/s64 range don't change, we can use lower 32 + * bits to improve our u32/s32 boundaries. + * + * E.g., the case where we have upper 32 bits as zero ([10, 20] in + * u64) is pretty trivial, it's obvious that in u32 we'll also have + * [10, 20] range. But this property holds for any 64-bit range as + * long as upper 32 bits in that entire range of values stay the same. + * + * E.g., u64 range [0x10000000A, 0x10000000F] ([4294967306, 4294967311] + * in decimal) has the same upper 32 bits throughout all the values in + * that range. As such, lower 32 bits form a valid [0xA, 0xF] ([10, 15]) + * range. + * + * Note also, that [0xA, 0xF] is a valid range both in u32 and in s32, + * following the rules outlined below about u64/s64 correspondence + * (which equally applies to u32 vs s32 correspondence). In general it + * depends on actual hexadecimal values of 32-bit range. They can form + * only valid u32, or only valid s32 ranges in some cases. + * + * So we use all these insights to derive bounds for subregisters here. */ - if (reg->s32_min_value >= 0 || reg->s32_max_value < 0) { - reg->s32_min_value = reg->u32_min_value = - max_t(u32, reg->s32_min_value, reg->u32_min_value); - reg->s32_max_value = reg->u32_max_value = - min_t(u32, reg->s32_max_value, reg->u32_max_value); - return; + if ((reg->umin_value >> 32) == (reg->umax_value >> 32)) { + /* u64 to u32 casting preserves validity of low 32 bits as + * a range, if upper 32 bits are the same + */ + reg->u32_min_value = max_t(u32, reg->u32_min_value, (u32)reg->umin_value); + reg->u32_max_value = min_t(u32, reg->u32_max_value, (u32)reg->umax_value); + + if ((s32)reg->umin_value <= (s32)reg->umax_value) { + reg->s32_min_value = max_t(s32, reg->s32_min_value, (s32)reg->umin_value); + reg->s32_max_value = min_t(s32, reg->s32_max_value, (s32)reg->umax_value); + } + } + if ((reg->smin_value >> 32) == (reg->smax_value >> 32)) { + /* low 32 bits should form a proper u32 range */ + if ((u32)reg->smin_value <= (u32)reg->smax_value) { + reg->u32_min_value = max_t(u32, reg->u32_min_value, (u32)reg->smin_value); + reg->u32_max_value = min_t(u32, reg->u32_max_value, (u32)reg->smax_value); + } + /* low 32 bits should form a proper s32 range */ + if ((s32)reg->smin_value <= (s32)reg->smax_value) { + reg->s32_min_value = max_t(s32, reg->s32_min_value, (s32)reg->smin_value); + reg->s32_max_value = min_t(s32, reg->s32_max_value, (s32)reg->smax_value); + } + } + /* Special case where upper bits form a small sequence of two + * sequential numbers (in 32-bit unsigned space, so 0xffffffff to + * 0x00000000 is also valid), while lower bits form a proper s32 range + * going from negative numbers to positive numbers. E.g., let's say we + * have s64 range [-1, 1] ([0xffffffffffffffff, 0x0000000000000001]). + * Possible s64 values are {-1, 0, 1} ({0xffffffffffffffff, + * 0x0000000000000000, 0x00000000000001}). Ignoring upper 32 bits, + * we still get a valid s32 range [-1, 1] ([0xffffffff, 0x00000001]). + * Note that it doesn't have to be 0xffffffff going to 0x00000000 in + * upper 32 bits. As a random example, s64 range + * [0xfffffff0fffffff0; 0xfffffff100000010], forms a valid s32 range + * [-16, 16] ([0xfffffff0; 0x00000010]) in its 32 bit subregister. + */ + if ((u32)(reg->umin_value >> 32) + 1 == (u32)(reg->umax_value >> 32) && + (s32)reg->umin_value < 0 && (s32)reg->umax_value >= 0) { + reg->s32_min_value = max_t(s32, reg->s32_min_value, (s32)reg->umin_value); + reg->s32_max_value = min_t(s32, reg->s32_max_value, (s32)reg->umax_value); + } + if ((u32)(reg->smin_value >> 32) + 1 == (u32)(reg->smax_value >> 32) && + (s32)reg->smin_value < 0 && (s32)reg->smax_value >= 0) { + reg->s32_min_value = max_t(s32, reg->s32_min_value, (s32)reg->smin_value); + reg->s32_max_value = min_t(s32, reg->s32_max_value, (s32)reg->smax_value); + } + /* if u32 range forms a valid s32 range (due to matching sign bit), + * try to learn from that + */ + if ((s32)reg->u32_min_value <= (s32)reg->u32_max_value) { + reg->s32_min_value = max_t(s32, reg->s32_min_value, reg->u32_min_value); + reg->s32_max_value = min_t(s32, reg->s32_max_value, reg->u32_max_value); } - /* Learn sign from unsigned bounds. Signed bounds cross the sign - * boundary, so we must be careful. + /* If we cannot cross the sign boundary, then signed and unsigned bounds + * are the same, so combine. This works even in the negative case, e.g. + * -3 s<= x s<= -1 implies 0xf...fd u<= x u<= 0xf...ff. */ - if ((s32)reg->u32_max_value >= 0) { - /* Positive. We can't learn anything from the smin, but smax - * is positive, hence safe. - */ - reg->s32_min_value = reg->u32_min_value; - reg->s32_max_value = reg->u32_max_value = - min_t(u32, reg->s32_max_value, reg->u32_max_value); - } else if ((s32)reg->u32_min_value < 0) { - /* Negative. We can't learn anything from the smax, but smin - * is negative, hence safe. - */ - reg->s32_min_value = reg->u32_min_value = - max_t(u32, reg->s32_min_value, reg->u32_min_value); - reg->s32_max_value = reg->u32_max_value; + if ((u32)reg->s32_min_value <= (u32)reg->s32_max_value) { + reg->u32_min_value = max_t(u32, reg->s32_min_value, reg->u32_min_value); + reg->u32_max_value = min_t(u32, reg->s32_max_value, reg->u32_max_value); } } static void __reg64_deduce_bounds(struct bpf_reg_state *reg) { - /* Learn sign from signed bounds. - * If we cannot cross the sign boundary, then signed and unsigned bounds + /* If u64 range forms a valid s64 range (due to matching sign bit), + * try to learn from that. Let's do a bit of ASCII art to see when + * this is happening. Let's take u64 range first: + * + * 0 0x7fffffffffffffff 0x8000000000000000 U64_MAX + * |-------------------------------|--------------------------------| + * + * Valid u64 range is formed when umin and umax are anywhere in the + * range [0, U64_MAX], and umin <= umax. u64 case is simple and + * straightforward. Let's see how s64 range maps onto the same range + * of values, annotated below the line for comparison: + * + * 0 0x7fffffffffffffff 0x8000000000000000 U64_MAX + * |-------------------------------|--------------------------------| + * 0 S64_MAX S64_MIN -1 + * + * So s64 values basically start in the middle and they are logically + * contiguous to the right of it, wrapping around from -1 to 0, and + * then finishing as S64_MAX (0x7fffffffffffffff) right before + * S64_MIN. We can try drawing the continuity of u64 vs s64 values + * more visually as mapped to sign-agnostic range of hex values. + * + * u64 start u64 end + * _______________________________________________________________ + * / \ + * 0 0x7fffffffffffffff 0x8000000000000000 U64_MAX + * |-------------------------------|--------------------------------| + * 0 S64_MAX S64_MIN -1 + * / \ + * >------------------------------ -------------------------------> + * s64 continues... s64 end s64 start s64 "midpoint" + * + * What this means is that, in general, we can't always derive + * something new about u64 from any random s64 range, and vice versa. + * + * But we can do that in two particular cases. One is when entire + * u64/s64 range is *entirely* contained within left half of the above + * diagram or when it is *entirely* contained in the right half. I.e.: + * + * |-------------------------------|--------------------------------| + * ^ ^ ^ ^ + * A B C D + * + * [A, B] and [C, D] are contained entirely in their respective halves + * and form valid contiguous ranges as both u64 and s64 values. [A, B] + * will be non-negative both as u64 and s64 (and in fact it will be + * identical ranges no matter the signedness). [C, D] treated as s64 + * will be a range of negative values, while in u64 it will be + * non-negative range of values larger than 0x8000000000000000. + * + * Now, any other range here can't be represented in both u64 and s64 + * simultaneously. E.g., [A, C], [A, D], [B, C], [B, D] are valid + * contiguous u64 ranges, but they are discontinuous in s64. [B, C] + * in s64 would be properly presented as [S64_MIN, C] and [B, S64_MAX], + * for example. Similarly, valid s64 range [D, A] (going from negative + * to positive values), would be two separate [D, U64_MAX] and [0, A] + * ranges as u64. Currently reg_state can't represent two segments per + * numeric domain, so in such situations we can only derive maximal + * possible range ([0, U64_MAX] for u64, and [S64_MIN, S64_MAX] for s64). + * + * So we use these facts to derive umin/umax from smin/smax and vice + * versa only if they stay within the same "half". This is equivalent + * to checking sign bit: lower half will have sign bit as zero, upper + * half have sign bit 1. Below in code we simplify this by just + * casting umin/umax as smin/smax and checking if they form valid + * range, and vice versa. Those are equivalent checks. + */ + if ((s64)reg->umin_value <= (s64)reg->umax_value) { + reg->smin_value = max_t(s64, reg->smin_value, reg->umin_value); + reg->smax_value = min_t(s64, reg->smax_value, reg->umax_value); + } + /* If we cannot cross the sign boundary, then signed and unsigned bounds * are the same, so combine. This works even in the negative case, e.g. * -3 s<= x s<= -1 implies 0xf...fd u<= x u<= 0xf...ff. */ - if (reg->smin_value >= 0 || reg->smax_value < 0) { - reg->smin_value = reg->umin_value = max_t(u64, reg->smin_value, - reg->umin_value); - reg->smax_value = reg->umax_value = min_t(u64, reg->smax_value, - reg->umax_value); - return; + if ((u64)reg->smin_value <= (u64)reg->smax_value) { + reg->umin_value = max_t(u64, reg->smin_value, reg->umin_value); + reg->umax_value = min_t(u64, reg->smax_value, reg->umax_value); } - /* Learn sign from unsigned bounds. Signed bounds cross the sign - * boundary, so we must be careful. +} + +static void __reg_deduce_mixed_bounds(struct bpf_reg_state *reg) +{ + /* Try to tighten 64-bit bounds from 32-bit knowledge, using 32-bit + * values on both sides of 64-bit range in hope to have tigher range. + * E.g., if r1 is [0x1'00000000, 0x3'80000000], and we learn from + * 32-bit signed > 0 operation that s32 bounds are now [1; 0x7fffffff]. + * With this, we can substitute 1 as low 32-bits of _low_ 64-bit bound + * (0x100000000 -> 0x100000001) and 0x7fffffff as low 32-bits of + * _high_ 64-bit bound (0x380000000 -> 0x37fffffff) and arrive at a + * better overall bounds for r1 as [0x1'000000001; 0x3'7fffffff]. + * We just need to make sure that derived bounds we are intersecting + * with are well-formed ranges in respecitve s64 or u64 domain, just + * like we do with similar kinds of 32-to-64 or 64-to-32 adjustments. */ - if ((s64)reg->umax_value >= 0) { - /* Positive. We can't learn anything from the smin, but smax - * is positive, hence safe. - */ - reg->smin_value = reg->umin_value; - reg->smax_value = reg->umax_value = min_t(u64, reg->smax_value, - reg->umax_value); - } else if ((s64)reg->umin_value < 0) { - /* Negative. We can't learn anything from the smax, but smin - * is negative, hence safe. - */ - reg->smin_value = reg->umin_value = max_t(u64, reg->smin_value, - reg->umin_value); - reg->smax_value = reg->umax_value; + __u64 new_umin, new_umax; + __s64 new_smin, new_smax; + + /* u32 -> u64 tightening, it's always well-formed */ + new_umin = (reg->umin_value & ~0xffffffffULL) | reg->u32_min_value; + new_umax = (reg->umax_value & ~0xffffffffULL) | reg->u32_max_value; + reg->umin_value = max_t(u64, reg->umin_value, new_umin); + reg->umax_value = min_t(u64, reg->umax_value, new_umax); + /* u32 -> s64 tightening, u32 range embedded into s64 preserves range validity */ + new_smin = (reg->smin_value & ~0xffffffffULL) | reg->u32_min_value; + new_smax = (reg->smax_value & ~0xffffffffULL) | reg->u32_max_value; + reg->smin_value = max_t(s64, reg->smin_value, new_smin); + reg->smax_value = min_t(s64, reg->smax_value, new_smax); + + /* if s32 can be treated as valid u32 range, we can use it as well */ + if ((u32)reg->s32_min_value <= (u32)reg->s32_max_value) { + /* s32 -> u64 tightening */ + new_umin = (reg->umin_value & ~0xffffffffULL) | (u32)reg->s32_min_value; + new_umax = (reg->umax_value & ~0xffffffffULL) | (u32)reg->s32_max_value; + reg->umin_value = max_t(u64, reg->umin_value, new_umin); + reg->umax_value = min_t(u64, reg->umax_value, new_umax); + /* s32 -> s64 tightening */ + new_smin = (reg->smin_value & ~0xffffffffULL) | (u32)reg->s32_min_value; + new_smax = (reg->smax_value & ~0xffffffffULL) | (u32)reg->s32_max_value; + reg->smin_value = max_t(s64, reg->smin_value, new_smin); + reg->smax_value = min_t(s64, reg->smax_value, new_smax); } } @@ -2411,6 +2096,7 @@ static void __reg_deduce_bounds(struct bpf_reg_state *reg) { __reg32_deduce_bounds(reg); __reg64_deduce_bounds(reg); + __reg_deduce_mixed_bounds(reg); } /* Attempts to improve var_off based on unsigned min/max information */ @@ -2432,6 +2118,7 @@ static void reg_bounds_sync(struct bpf_reg_state *reg) __update_reg_bounds(reg); /* We might have learned something about the sign bit. */ __reg_deduce_bounds(reg); + __reg_deduce_bounds(reg); /* We might have learned some bits from the bounds. */ __reg_bound_offset(reg); /* Intersecting with the old var_off might have improved our bounds @@ -2441,6 +2128,56 @@ static void reg_bounds_sync(struct bpf_reg_state *reg) __update_reg_bounds(reg); } +static int reg_bounds_sanity_check(struct bpf_verifier_env *env, + struct bpf_reg_state *reg, const char *ctx) +{ + const char *msg; + + if (reg->umin_value > reg->umax_value || + reg->smin_value > reg->smax_value || + reg->u32_min_value > reg->u32_max_value || + reg->s32_min_value > reg->s32_max_value) { + msg = "range bounds violation"; + goto out; + } + + if (tnum_is_const(reg->var_off)) { + u64 uval = reg->var_off.value; + s64 sval = (s64)uval; + + if (reg->umin_value != uval || reg->umax_value != uval || + reg->smin_value != sval || reg->smax_value != sval) { + msg = "const tnum out of sync with range bounds"; + goto out; + } + } + + if (tnum_subreg_is_const(reg->var_off)) { + u32 uval32 = tnum_subreg(reg->var_off).value; + s32 sval32 = (s32)uval32; + + if (reg->u32_min_value != uval32 || reg->u32_max_value != uval32 || + reg->s32_min_value != sval32 || reg->s32_max_value != sval32) { + msg = "const subreg tnum out of sync with range bounds"; + goto out; + } + } + + return 0; +out: + verbose(env, "REG INVARIANTS VIOLATION (%s): %s u64=[%#llx, %#llx] " + "s64=[%#llx, %#llx] u32=[%#x, %#x] s32=[%#x, %#x] var_off=(%#llx, %#llx)\n", + ctx, msg, reg->umin_value, reg->umax_value, + reg->smin_value, reg->smax_value, + reg->u32_min_value, reg->u32_max_value, + reg->s32_min_value, reg->s32_max_value, + reg->var_off.value, reg->var_off.mask); + if (env->test_reg_invariants) + return -EFAULT; + __mark_reg_unbounded(reg); + return 0; +} + static bool __reg32_bound_s64(s32 a) { return a >= 0 && a <= S32_MAX; @@ -2465,51 +2202,6 @@ static void __reg_assign_32_into_64(struct bpf_reg_state *reg) } } -static void __reg_combine_32_into_64(struct bpf_reg_state *reg) -{ - /* special case when 64-bit register has upper 32-bit register - * zeroed. Typically happens after zext or <<32, >>32 sequence - * allowing us to use 32-bit bounds directly, - */ - if (tnum_equals_const(tnum_clear_subreg(reg->var_off), 0)) { - __reg_assign_32_into_64(reg); - } else { - /* Otherwise the best we can do is push lower 32bit known and - * unknown bits into register (var_off set from jmp logic) - * then learn as much as possible from the 64-bit tnum - * known and unknown bits. The previous smin/smax bounds are - * invalid here because of jmp32 compare so mark them unknown - * so they do not impact tnum bounds calculation. - */ - __mark_reg64_unbounded(reg); - } - reg_bounds_sync(reg); -} - -static bool __reg64_bound_s32(s64 a) -{ - return a >= S32_MIN && a <= S32_MAX; -} - -static bool __reg64_bound_u32(u64 a) -{ - return a >= U32_MIN && a <= U32_MAX; -} - -static void __reg_combine_64_into_32(struct bpf_reg_state *reg) -{ - __mark_reg32_unbounded(reg); - if (__reg64_bound_s32(reg->smin_value) && __reg64_bound_s32(reg->smax_value)) { - reg->s32_min_value = (s32)reg->smin_value; - reg->s32_max_value = (s32)reg->smax_value; - } - if (__reg64_bound_u32(reg->umin_value) && __reg64_bound_u32(reg->umax_value)) { - reg->u32_min_value = (u32)reg->umin_value; - reg->u32_max_value = (u32)reg->umax_value; - } - reg_bounds_sync(reg); -} - /* Mark a register as having a completely unknown (scalar) value. */ static void __mark_reg_unknown(const struct bpf_verifier_env *env, struct bpf_reg_state *reg) @@ -4592,9 +4284,17 @@ static bool register_is_null(struct bpf_reg_state *reg) return reg->type == SCALAR_VALUE && tnum_equals_const(reg->var_off, 0); } -static bool register_is_const(struct bpf_reg_state *reg) +/* check if register is a constant scalar value */ +static bool is_reg_const(struct bpf_reg_state *reg, bool subreg32) { - return reg->type == SCALAR_VALUE && tnum_is_const(reg->var_off); + return reg->type == SCALAR_VALUE && + tnum_is_const(subreg32 ? tnum_subreg(reg->var_off) : reg->var_off); +} + +/* assuming is_reg_const() is true, return constant value of a register */ +static u64 reg_const_value(struct bpf_reg_state *reg, bool subreg32) +{ + return subreg32 ? tnum_subreg(reg->var_off).value : reg->var_off.value; } static bool __is_scalar_unbounded(struct bpf_reg_state *reg) @@ -5451,10 +5151,23 @@ BTF_SET_END(rcu_protected_types) static bool rcu_protected_object(const struct btf *btf, u32 btf_id) { if (!btf_is_kernel(btf)) - return false; + return true; return btf_id_set_contains(&rcu_protected_types, btf_id); } +static struct btf_record *kptr_pointee_btf_record(struct btf_field *kptr_field) +{ + struct btf_struct_meta *meta; + + if (btf_is_kernel(kptr_field->kptr.btf)) + return NULL; + + meta = btf_find_struct_meta(kptr_field->kptr.btf, + kptr_field->kptr.btf_id); + + return meta ? meta->record : NULL; +} + static bool rcu_safe_kptr(const struct btf_field *field) { const struct btf_field_kptr *kptr = &field->kptr; @@ -5465,12 +5178,25 @@ static bool rcu_safe_kptr(const struct btf_field *field) static u32 btf_ld_kptr_type(struct bpf_verifier_env *env, struct btf_field *kptr_field) { + struct btf_record *rec; + u32 ret; + + ret = PTR_MAYBE_NULL; if (rcu_safe_kptr(kptr_field) && in_rcu_cs(env)) { - if (kptr_field->type != BPF_KPTR_PERCPU) - return PTR_MAYBE_NULL | MEM_RCU; - return PTR_MAYBE_NULL | MEM_RCU | MEM_PERCPU; + ret |= MEM_RCU; + if (kptr_field->type == BPF_KPTR_PERCPU) + ret |= MEM_PERCPU; + else if (!btf_is_kernel(kptr_field->kptr.btf)) + ret |= MEM_ALLOC; + + rec = kptr_pointee_btf_record(kptr_field); + if (rec && btf_record_has_field(rec, BPF_GRAPH_NODE)) + ret |= NON_OWN_REF; + } else { + ret |= PTR_UNTRUSTED; } - return PTR_MAYBE_NULL | PTR_UNTRUSTED; + + return ret; } static int check_map_kptr_access(struct bpf_verifier_env *env, u32 regno, @@ -6244,9 +5970,10 @@ static void coerce_reg_to_size(struct bpf_reg_state *reg, int size) * values are also truncated so we push 64-bit bounds into * 32-bit bounds. Above were truncated < 32-bits already. */ - if (size >= 4) - return; - __reg_combine_64_into_32(reg); + if (size < 4) { + __mark_reg32_unbounded(reg); + reg_bounds_sync(reg); + } } static void set_sext64_default_val(struct bpf_reg_state *reg, int size) @@ -8626,6 +8353,54 @@ static enum bpf_dynptr_type dynptr_get_type(struct bpf_verifier_env *env, return state->stack[spi].spilled_ptr.dynptr.type; } +static int check_reg_const_str(struct bpf_verifier_env *env, + struct bpf_reg_state *reg, u32 regno) +{ + struct bpf_map *map = reg->map_ptr; + int err; + int map_off; + u64 map_addr; + char *str_ptr; + + if (reg->type != PTR_TO_MAP_VALUE) + return -EINVAL; + + if (!bpf_map_is_rdonly(map)) { + verbose(env, "R%d does not point to a readonly map'\n", regno); + return -EACCES; + } + + if (!tnum_is_const(reg->var_off)) { + verbose(env, "R%d is not a constant address'\n", regno); + return -EACCES; + } + + if (!map->ops->map_direct_value_addr) { + verbose(env, "no direct value access support for this map type\n"); + return -EACCES; + } + + err = check_map_access(env, regno, reg->off, + map->value_size - reg->off, false, + ACCESS_HELPER); + if (err) + return err; + + map_off = reg->off + reg->var_off.value; + err = map->ops->map_direct_value_addr(map, &map_addr, map_off); + if (err) { + verbose(env, "direct value access on string failed\n"); + return err; + } + + str_ptr = (char *)(long)(map_addr); + if (!strnchr(str_ptr + map_off, map->value_size - map_off, 0)) { + verbose(env, "string is not zero-terminated\n"); + return -EINVAL; + } + return 0; +} + static int check_func_arg(struct bpf_verifier_env *env, u32 arg, struct bpf_call_arg_meta *meta, const struct bpf_func_proto *fn, @@ -8870,44 +8645,9 @@ skip_type_check: } case ARG_PTR_TO_CONST_STR: { - struct bpf_map *map = reg->map_ptr; - int map_off; - u64 map_addr; - char *str_ptr; - - if (!bpf_map_is_rdonly(map)) { - verbose(env, "R%d does not point to a readonly map'\n", regno); - return -EACCES; - } - - if (!tnum_is_const(reg->var_off)) { - verbose(env, "R%d is not a constant address'\n", regno); - return -EACCES; - } - - if (!map->ops->map_direct_value_addr) { - verbose(env, "no direct value access support for this map type\n"); - return -EACCES; - } - - err = check_map_access(env, regno, reg->off, - map->value_size - reg->off, false, - ACCESS_HELPER); + err = check_reg_const_str(env, reg, regno); if (err) return err; - - map_off = reg->off + reg->var_off.value; - err = map->ops->map_direct_value_addr(map, &map_addr, map_off); - if (err) { - verbose(env, "direct value access on string failed\n"); - return err; - } - - str_ptr = (char *)(long)(map_addr); - if (!strnchr(str_ptr + map_off, map->value_size - map_off, 0)) { - verbose(env, "string is not zero-terminated\n"); - return -EINVAL; - } break; } case ARG_PTR_TO_KPTR: @@ -9896,14 +9636,15 @@ static int prepare_func_exit(struct bpf_verifier_env *env, int *insn_idx) return 0; } -static void do_refine_retval_range(struct bpf_reg_state *regs, int ret_type, - int func_id, - struct bpf_call_arg_meta *meta) +static int do_refine_retval_range(struct bpf_verifier_env *env, + struct bpf_reg_state *regs, int ret_type, + int func_id, + struct bpf_call_arg_meta *meta) { struct bpf_reg_state *ret_reg = ®s[BPF_REG_0]; if (ret_type != RET_INTEGER) - return; + return 0; switch (func_id) { case BPF_FUNC_get_stack: @@ -9929,6 +9670,8 @@ static void do_refine_retval_range(struct bpf_reg_state *regs, int ret_type, reg_bounds_sync(ret_reg); break; } + + return reg_bounds_sanity_check(env, ret_reg, "retval"); } static int @@ -9998,7 +9741,7 @@ record_func_key(struct bpf_verifier_env *env, struct bpf_call_arg_meta *meta, val = reg->var_off.value; max = map->max_entries; - if (!(register_is_const(reg) && val < max)) { + if (!(is_reg_const(reg, false) && val < max)) { bpf_map_key_store(aux, BPF_MAP_KEY_POISON); return 0; } @@ -10593,7 +10336,9 @@ static int check_helper_call(struct bpf_verifier_env *env, struct bpf_insn *insn regs[BPF_REG_0].ref_obj_id = id; } - do_refine_retval_range(regs, fn->ret_type, func_id, &meta); + err = do_refine_retval_range(env, regs, fn->ret_type, func_id, &meta); + if (err) + return err; err = check_map_func_compatibility(env, meta.map_ptr, func_id); if (err) @@ -10771,6 +10516,11 @@ static bool is_kfunc_arg_nullable(const struct btf *btf, const struct btf_param return __kfunc_param_match_suffix(btf, arg, "__nullable"); } +static bool is_kfunc_arg_const_str(const struct btf *btf, const struct btf_param *arg) +{ + return __kfunc_param_match_suffix(btf, arg, "__str"); +} + static bool is_kfunc_arg_scalar_with_name(const struct btf *btf, const struct btf_param *arg, const char *name) @@ -10914,6 +10664,7 @@ enum kfunc_ptr_arg_type { KF_ARG_PTR_TO_RB_ROOT, KF_ARG_PTR_TO_RB_NODE, KF_ARG_PTR_TO_NULL, + KF_ARG_PTR_TO_CONST_STR, }; enum special_kfunc_type { @@ -11064,6 +10815,9 @@ get_kfunc_ptr_arg_type(struct bpf_verifier_env *env, if (is_kfunc_arg_rbtree_node(meta->btf, &args[argno])) return KF_ARG_PTR_TO_RB_NODE; + if (is_kfunc_arg_const_str(meta->btf, &args[argno])) + return KF_ARG_PTR_TO_CONST_STR; + if ((base_type(reg->type) == PTR_TO_BTF_ID || reg2btf_ids[base_type(reg->type)])) { if (!btf_type_is_struct(ref_t)) { verbose(env, "kernel function %s args#%d pointer type %s %s is not supported\n", @@ -11695,6 +11449,7 @@ static int check_kfunc_args(struct bpf_verifier_env *env, struct bpf_kfunc_call_ case KF_ARG_PTR_TO_MEM_SIZE: case KF_ARG_PTR_TO_CALLBACK: case KF_ARG_PTR_TO_REFCOUNTED_KPTR: + case KF_ARG_PTR_TO_CONST_STR: /* Trusted by default */ break; default: @@ -11966,6 +11721,15 @@ static int check_kfunc_args(struct bpf_verifier_env *env, struct bpf_kfunc_call_ meta->arg_btf = reg->btf; meta->arg_btf_id = reg->btf_id; break; + case KF_ARG_PTR_TO_CONST_STR: + if (reg->type != PTR_TO_MAP_VALUE) { + verbose(env, "arg#%d doesn't point to a const string\n", i); + return -EINVAL; + } + ret = check_reg_const_str(env, reg, regno); + if (ret) + return ret; + break; } } @@ -14086,13 +13850,12 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn) /* check dest operand */ err = check_reg_arg(env, insn->dst_reg, DST_OP_NO_MARK); + err = err ?: adjust_reg_min_max_vals(env, insn); if (err) return err; - - return adjust_reg_min_max_vals(env, insn); } - return 0; + return reg_bounds_sanity_check(env, ®s[insn->dst_reg], "alu"); } static void find_good_pkt_pointers(struct bpf_verifier_state *vstate, @@ -14174,161 +13937,130 @@ static void find_good_pkt_pointers(struct bpf_verifier_state *vstate, })); } -static int is_branch32_taken(struct bpf_reg_state *reg, u32 val, u8 opcode) -{ - struct tnum subreg = tnum_subreg(reg->var_off); - s32 sval = (s32)val; - - switch (opcode) { - case BPF_JEQ: - if (tnum_is_const(subreg)) - return !!tnum_equals_const(subreg, val); - else if (val < reg->u32_min_value || val > reg->u32_max_value) - return 0; - else if (sval < reg->s32_min_value || sval > reg->s32_max_value) - return 0; - break; - case BPF_JNE: - if (tnum_is_const(subreg)) - return !tnum_equals_const(subreg, val); - else if (val < reg->u32_min_value || val > reg->u32_max_value) - return 1; - else if (sval < reg->s32_min_value || sval > reg->s32_max_value) - return 1; - break; - case BPF_JSET: - if ((~subreg.mask & subreg.value) & val) - return 1; - if (!((subreg.mask | subreg.value) & val)) - return 0; - break; - case BPF_JGT: - if (reg->u32_min_value > val) - return 1; - else if (reg->u32_max_value <= val) - return 0; - break; - case BPF_JSGT: - if (reg->s32_min_value > sval) - return 1; - else if (reg->s32_max_value <= sval) - return 0; - break; - case BPF_JLT: - if (reg->u32_max_value < val) - return 1; - else if (reg->u32_min_value >= val) - return 0; - break; - case BPF_JSLT: - if (reg->s32_max_value < sval) - return 1; - else if (reg->s32_min_value >= sval) - return 0; - break; - case BPF_JGE: - if (reg->u32_min_value >= val) - return 1; - else if (reg->u32_max_value < val) - return 0; - break; - case BPF_JSGE: - if (reg->s32_min_value >= sval) - return 1; - else if (reg->s32_max_value < sval) - return 0; - break; - case BPF_JLE: - if (reg->u32_max_value <= val) - return 1; - else if (reg->u32_min_value > val) - return 0; - break; - case BPF_JSLE: - if (reg->s32_max_value <= sval) - return 1; - else if (reg->s32_min_value > sval) - return 0; - break; - } - - return -1; -} - - -static int is_branch64_taken(struct bpf_reg_state *reg, u64 val, u8 opcode) -{ - s64 sval = (s64)val; +/* + * <reg1> <op> <reg2>, currently assuming reg2 is a constant + */ +static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, + u8 opcode, bool is_jmp32) +{ + struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off; + struct tnum t2 = is_jmp32 ? tnum_subreg(reg2->var_off) : reg2->var_off; + u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value; + u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value; + s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value; + s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value; + u64 umin2 = is_jmp32 ? (u64)reg2->u32_min_value : reg2->umin_value; + u64 umax2 = is_jmp32 ? (u64)reg2->u32_max_value : reg2->umax_value; + s64 smin2 = is_jmp32 ? (s64)reg2->s32_min_value : reg2->smin_value; + s64 smax2 = is_jmp32 ? (s64)reg2->s32_max_value : reg2->smax_value; switch (opcode) { case BPF_JEQ: - if (tnum_is_const(reg->var_off)) - return !!tnum_equals_const(reg->var_off, val); - else if (val < reg->umin_value || val > reg->umax_value) + /* constants, umin/umax and smin/smax checks would be + * redundant in this case because they all should match + */ + if (tnum_is_const(t1) && tnum_is_const(t2)) + return t1.value == t2.value; + /* non-overlapping ranges */ + if (umin1 > umax2 || umax1 < umin2) return 0; - else if (sval < reg->smin_value || sval > reg->smax_value) + if (smin1 > smax2 || smax1 < smin2) return 0; + if (!is_jmp32) { + /* if 64-bit ranges are inconclusive, see if we can + * utilize 32-bit subrange knowledge to eliminate + * branches that can't be taken a priori + */ + if (reg1->u32_min_value > reg2->u32_max_value || + reg1->u32_max_value < reg2->u32_min_value) + return 0; + if (reg1->s32_min_value > reg2->s32_max_value || + reg1->s32_max_value < reg2->s32_min_value) + return 0; + } break; case BPF_JNE: - if (tnum_is_const(reg->var_off)) - return !tnum_equals_const(reg->var_off, val); - else if (val < reg->umin_value || val > reg->umax_value) + /* constants, umin/umax and smin/smax checks would be + * redundant in this case because they all should match + */ + if (tnum_is_const(t1) && tnum_is_const(t2)) + return t1.value != t2.value; + /* non-overlapping ranges */ + if (umin1 > umax2 || umax1 < umin2) return 1; - else if (sval < reg->smin_value || sval > reg->smax_value) + if (smin1 > smax2 || smax1 < smin2) return 1; + if (!is_jmp32) { + /* if 64-bit ranges are inconclusive, see if we can + * utilize 32-bit subrange knowledge to eliminate + * branches that can't be taken a priori + */ + if (reg1->u32_min_value > reg2->u32_max_value || + reg1->u32_max_value < reg2->u32_min_value) + return 1; + if (reg1->s32_min_value > reg2->s32_max_value || + reg1->s32_max_value < reg2->s32_min_value) + return 1; + } break; case BPF_JSET: - if ((~reg->var_off.mask & reg->var_off.value) & val) + if (!is_reg_const(reg2, is_jmp32)) { + swap(reg1, reg2); + swap(t1, t2); + } + if (!is_reg_const(reg2, is_jmp32)) + return -1; + if ((~t1.mask & t1.value) & t2.value) return 1; - if (!((reg->var_off.mask | reg->var_off.value) & val)) + if (!((t1.mask | t1.value) & t2.value)) return 0; break; case BPF_JGT: - if (reg->umin_value > val) + if (umin1 > umax2) return 1; - else if (reg->umax_value <= val) + else if (umax1 <= umin2) return 0; break; case BPF_JSGT: - if (reg->smin_value > sval) + if (smin1 > smax2) return 1; - else if (reg->smax_value <= sval) + else if (smax1 <= smin2) return 0; break; case BPF_JLT: - if (reg->umax_value < val) + if (umax1 < umin2) return 1; - else if (reg->umin_value >= val) + else if (umin1 >= umax2) return 0; break; case BPF_JSLT: - if (reg->smax_value < sval) + if (smax1 < smin2) return 1; - else if (reg->smin_value >= sval) + else if (smin1 >= smax2) return 0; break; case BPF_JGE: - if (reg->umin_value >= val) + if (umin1 >= umax2) return 1; - else if (reg->umax_value < val) + else if (umax1 < umin2) return 0; break; case BPF_JSGE: - if (reg->smin_value >= sval) + if (smin1 >= smax2) return 1; - else if (reg->smax_value < sval) + else if (smax1 < smin2) return 0; break; case BPF_JLE: - if (reg->umax_value <= val) + if (umax1 <= umin2) return 1; - else if (reg->umin_value > val) + else if (umin1 > umax2) return 0; break; case BPF_JSLE: - if (reg->smax_value <= sval) + if (smax1 <= smin2) return 1; - else if (reg->smin_value > sval) + else if (smin1 > smax2) return 0; break; } @@ -14336,41 +14068,6 @@ static int is_branch64_taken(struct bpf_reg_state *reg, u64 val, u8 opcode) return -1; } -/* compute branch direction of the expression "if (reg opcode val) goto target;" - * and return: - * 1 - branch will be taken and "goto target" will be executed - * 0 - branch will not be taken and fall-through to next insn - * -1 - unknown. Example: "if (reg < 5)" is unknown when register value - * range [0,10] - */ -static int is_branch_taken(struct bpf_reg_state *reg, u64 val, u8 opcode, - bool is_jmp32) -{ - if (__is_pointer_value(false, reg)) { - if (!reg_not_null(reg)) - return -1; - - /* If pointer is valid tests against zero will fail so we can - * use this to direct branch taken. - */ - if (val != 0) - return -1; - - switch (opcode) { - case BPF_JEQ: - return 0; - case BPF_JNE: - return 1; - default: - return -1; - } - } - - if (is_jmp32) - return is_branch32_taken(reg, val, opcode); - return is_branch64_taken(reg, val, opcode); -} - static int flip_opcode(u32 opcode) { /* How can we transform "a <op> b" into "b <op> a"? */ @@ -14432,216 +14129,244 @@ static int is_pkt_ptr_branch_taken(struct bpf_reg_state *dst_reg, return -1; } -/* Adjusts the register min/max values in the case that the dst_reg is the - * variable register that we are working on, and src_reg is a constant or we're - * simply doing a BPF_K check. - * In JEQ/JNE cases we also adjust the var_off values. +/* compute branch direction of the expression "if (<reg1> opcode <reg2>) goto target;" + * and return: + * 1 - branch will be taken and "goto target" will be executed + * 0 - branch will not be taken and fall-through to next insn + * -1 - unknown. Example: "if (reg1 < 5)" is unknown when register value + * range [0,10] */ -static void reg_set_min_max(struct bpf_reg_state *true_reg, - struct bpf_reg_state *false_reg, - u64 val, u32 val32, - u8 opcode, bool is_jmp32) -{ - struct tnum false_32off = tnum_subreg(false_reg->var_off); - struct tnum false_64off = false_reg->var_off; - struct tnum true_32off = tnum_subreg(true_reg->var_off); - struct tnum true_64off = true_reg->var_off; - s64 sval = (s64)val; - s32 sval32 = (s32)val32; - - /* If the dst_reg is a pointer, we can't learn anything about its - * variable offset from the compare (unless src_reg were a pointer into - * the same object, but we don't bother with that. - * Since false_reg and true_reg have the same type by construction, we - * only need to check one of them for pointerness. - */ - if (__is_pointer_value(false, false_reg)) - return; +static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, + u8 opcode, bool is_jmp32) +{ + if (reg_is_pkt_pointer_any(reg1) && reg_is_pkt_pointer_any(reg2) && !is_jmp32) + return is_pkt_ptr_branch_taken(reg1, reg2, opcode); + + if (__is_pointer_value(false, reg1) || __is_pointer_value(false, reg2)) { + u64 val; + + /* arrange that reg2 is a scalar, and reg1 is a pointer */ + if (!is_reg_const(reg2, is_jmp32)) { + opcode = flip_opcode(opcode); + swap(reg1, reg2); + } + /* and ensure that reg2 is a constant */ + if (!is_reg_const(reg2, is_jmp32)) + return -1; + + if (!reg_not_null(reg1)) + return -1; + + /* If pointer is valid tests against zero will fail so we can + * use this to direct branch taken. + */ + val = reg_const_value(reg2, is_jmp32); + if (val != 0) + return -1; + + switch (opcode) { + case BPF_JEQ: + return 0; + case BPF_JNE: + return 1; + default: + return -1; + } + } + /* now deal with two scalars, but not necessarily constants */ + return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32); +} + +/* Opcode that corresponds to a *false* branch condition. + * E.g., if r1 < r2, then reverse (false) condition is r1 >= r2 + */ +static u8 rev_opcode(u8 opcode) +{ switch (opcode) { - /* JEQ/JNE comparison doesn't change the register equivalence. - * - * r1 = r2; - * if (r1 == 42) goto label; - * ... - * label: // here both r1 and r2 are known to be 42. - * - * Hence when marking register as known preserve it's ID. + case BPF_JEQ: return BPF_JNE; + case BPF_JNE: return BPF_JEQ; + /* JSET doesn't have it's reverse opcode in BPF, so add + * BPF_X flag to denote the reverse of that operation */ + case BPF_JSET: return BPF_JSET | BPF_X; + case BPF_JSET | BPF_X: return BPF_JSET; + case BPF_JGE: return BPF_JLT; + case BPF_JGT: return BPF_JLE; + case BPF_JLE: return BPF_JGT; + case BPF_JLT: return BPF_JGE; + case BPF_JSGE: return BPF_JSLT; + case BPF_JSGT: return BPF_JSLE; + case BPF_JSLE: return BPF_JSGT; + case BPF_JSLT: return BPF_JSGE; + default: return 0; + } +} + +/* Refine range knowledge for <reg1> <op> <reg>2 conditional operation. */ +static void regs_refine_cond_op(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2, + u8 opcode, bool is_jmp32) +{ + struct tnum t; + u64 val; + +again: + switch (opcode) { case BPF_JEQ: if (is_jmp32) { - __mark_reg32_known(true_reg, val32); - true_32off = tnum_subreg(true_reg->var_off); + reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value); + reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value); + reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value); + reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value); + reg2->u32_min_value = reg1->u32_min_value; + reg2->u32_max_value = reg1->u32_max_value; + reg2->s32_min_value = reg1->s32_min_value; + reg2->s32_max_value = reg1->s32_max_value; + + t = tnum_intersect(tnum_subreg(reg1->var_off), tnum_subreg(reg2->var_off)); + reg1->var_off = tnum_with_subreg(reg1->var_off, t); + reg2->var_off = tnum_with_subreg(reg2->var_off, t); } else { - ___mark_reg_known(true_reg, val); - true_64off = true_reg->var_off; + reg1->umin_value = max(reg1->umin_value, reg2->umin_value); + reg1->umax_value = min(reg1->umax_value, reg2->umax_value); + reg1->smin_value = max(reg1->smin_value, reg2->smin_value); + reg1->smax_value = min(reg1->smax_value, reg2->smax_value); + reg2->umin_value = reg1->umin_value; + reg2->umax_value = reg1->umax_value; + reg2->smin_value = reg1->smin_value; + reg2->smax_value = reg1->smax_value; + + reg1->var_off = tnum_intersect(reg1->var_off, reg2->var_off); + reg2->var_off = reg1->var_off; } break; case BPF_JNE: - if (is_jmp32) { - __mark_reg32_known(false_reg, val32); - false_32off = tnum_subreg(false_reg->var_off); - } else { - ___mark_reg_known(false_reg, val); - false_64off = false_reg->var_off; - } + /* we don't derive any new information for inequality yet */ break; case BPF_JSET: + if (!is_reg_const(reg2, is_jmp32)) + swap(reg1, reg2); + if (!is_reg_const(reg2, is_jmp32)) + break; + val = reg_const_value(reg2, is_jmp32); + /* BPF_JSET (i.e., TRUE branch, *not* BPF_JSET | BPF_X) + * requires single bit to learn something useful. E.g., if we + * know that `r1 & 0x3` is true, then which bits (0, 1, or both) + * are actually set? We can learn something definite only if + * it's a single-bit value to begin with. + * + * BPF_JSET | BPF_X (i.e., negation of BPF_JSET) doesn't have + * this restriction. I.e., !(r1 & 0x3) means neither bit 0 nor + * bit 1 is set, which we can readily use in adjustments. + */ + if (!is_power_of_2(val)) + break; if (is_jmp32) { - false_32off = tnum_and(false_32off, tnum_const(~val32)); - if (is_power_of_2(val32)) - true_32off = tnum_or(true_32off, - tnum_const(val32)); + t = tnum_or(tnum_subreg(reg1->var_off), tnum_const(val)); + reg1->var_off = tnum_with_subreg(reg1->var_off, t); } else { - false_64off = tnum_and(false_64off, tnum_const(~val)); - if (is_power_of_2(val)) - true_64off = tnum_or(true_64off, - tnum_const(val)); + reg1->var_off = tnum_or(reg1->var_off, tnum_const(val)); } break; - case BPF_JGE: - case BPF_JGT: - { + case BPF_JSET | BPF_X: /* reverse of BPF_JSET, see rev_opcode() */ + if (!is_reg_const(reg2, is_jmp32)) + swap(reg1, reg2); + if (!is_reg_const(reg2, is_jmp32)) + break; + val = reg_const_value(reg2, is_jmp32); if (is_jmp32) { - u32 false_umax = opcode == BPF_JGT ? val32 : val32 - 1; - u32 true_umin = opcode == BPF_JGT ? val32 + 1 : val32; - - false_reg->u32_max_value = min(false_reg->u32_max_value, - false_umax); - true_reg->u32_min_value = max(true_reg->u32_min_value, - true_umin); + t = tnum_and(tnum_subreg(reg1->var_off), tnum_const(~val)); + reg1->var_off = tnum_with_subreg(reg1->var_off, t); } else { - u64 false_umax = opcode == BPF_JGT ? val : val - 1; - u64 true_umin = opcode == BPF_JGT ? val + 1 : val; - - false_reg->umax_value = min(false_reg->umax_value, false_umax); - true_reg->umin_value = max(true_reg->umin_value, true_umin); + reg1->var_off = tnum_and(reg1->var_off, tnum_const(~val)); } break; - } - case BPF_JSGE: - case BPF_JSGT: - { + case BPF_JLE: if (is_jmp32) { - s32 false_smax = opcode == BPF_JSGT ? sval32 : sval32 - 1; - s32 true_smin = opcode == BPF_JSGT ? sval32 + 1 : sval32; - - false_reg->s32_max_value = min(false_reg->s32_max_value, false_smax); - true_reg->s32_min_value = max(true_reg->s32_min_value, true_smin); + reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value); + reg2->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value); } else { - s64 false_smax = opcode == BPF_JSGT ? sval : sval - 1; - s64 true_smin = opcode == BPF_JSGT ? sval + 1 : sval; - - false_reg->smax_value = min(false_reg->smax_value, false_smax); - true_reg->smin_value = max(true_reg->smin_value, true_smin); + reg1->umax_value = min(reg1->umax_value, reg2->umax_value); + reg2->umin_value = max(reg1->umin_value, reg2->umin_value); } break; - } - case BPF_JLE: case BPF_JLT: - { if (is_jmp32) { - u32 false_umin = opcode == BPF_JLT ? val32 : val32 + 1; - u32 true_umax = opcode == BPF_JLT ? val32 - 1 : val32; - - false_reg->u32_min_value = max(false_reg->u32_min_value, - false_umin); - true_reg->u32_max_value = min(true_reg->u32_max_value, - true_umax); + reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value - 1); + reg2->u32_min_value = max(reg1->u32_min_value + 1, reg2->u32_min_value); } else { - u64 false_umin = opcode == BPF_JLT ? val : val + 1; - u64 true_umax = opcode == BPF_JLT ? val - 1 : val; - - false_reg->umin_value = max(false_reg->umin_value, false_umin); - true_reg->umax_value = min(true_reg->umax_value, true_umax); + reg1->umax_value = min(reg1->umax_value, reg2->umax_value - 1); + reg2->umin_value = max(reg1->umin_value + 1, reg2->umin_value); } break; - } case BPF_JSLE: + if (is_jmp32) { + reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value); + reg2->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value); + } else { + reg1->smax_value = min(reg1->smax_value, reg2->smax_value); + reg2->smin_value = max(reg1->smin_value, reg2->smin_value); + } + break; case BPF_JSLT: - { if (is_jmp32) { - s32 false_smin = opcode == BPF_JSLT ? sval32 : sval32 + 1; - s32 true_smax = opcode == BPF_JSLT ? sval32 - 1 : sval32; - - false_reg->s32_min_value = max(false_reg->s32_min_value, false_smin); - true_reg->s32_max_value = min(true_reg->s32_max_value, true_smax); + reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value - 1); + reg2->s32_min_value = max(reg1->s32_min_value + 1, reg2->s32_min_value); } else { - s64 false_smin = opcode == BPF_JSLT ? sval : sval + 1; - s64 true_smax = opcode == BPF_JSLT ? sval - 1 : sval; - - false_reg->smin_value = max(false_reg->smin_value, false_smin); - true_reg->smax_value = min(true_reg->smax_value, true_smax); + reg1->smax_value = min(reg1->smax_value, reg2->smax_value - 1); + reg2->smin_value = max(reg1->smin_value + 1, reg2->smin_value); } break; - } + case BPF_JGE: + case BPF_JGT: + case BPF_JSGE: + case BPF_JSGT: + /* just reuse LE/LT logic above */ + opcode = flip_opcode(opcode); + swap(reg1, reg2); + goto again; default: return; } - - if (is_jmp32) { - false_reg->var_off = tnum_or(tnum_clear_subreg(false_64off), - tnum_subreg(false_32off)); - true_reg->var_off = tnum_or(tnum_clear_subreg(true_64off), - tnum_subreg(true_32off)); - __reg_combine_32_into_64(false_reg); - __reg_combine_32_into_64(true_reg); - } else { - false_reg->var_off = false_64off; - true_reg->var_off = true_64off; - __reg_combine_64_into_32(false_reg); - __reg_combine_64_into_32(true_reg); - } } -/* Same as above, but for the case that dst_reg holds a constant and src_reg is - * the variable reg. +/* Adjusts the register min/max values in the case that the dst_reg and + * src_reg are both SCALAR_VALUE registers (or we are simply doing a BPF_K + * check, in which case we havea fake SCALAR_VALUE representing insn->imm). + * Technically we can do similar adjustments for pointers to the same object, + * but we don't support that right now. */ -static void reg_set_min_max_inv(struct bpf_reg_state *true_reg, - struct bpf_reg_state *false_reg, - u64 val, u32 val32, - u8 opcode, bool is_jmp32) +static int reg_set_min_max(struct bpf_verifier_env *env, + struct bpf_reg_state *true_reg1, + struct bpf_reg_state *true_reg2, + struct bpf_reg_state *false_reg1, + struct bpf_reg_state *false_reg2, + u8 opcode, bool is_jmp32) { - opcode = flip_opcode(opcode); - /* This uses zero as "not present in table"; luckily the zero opcode, - * BPF_JA, can't get here. + int err; + + /* If either register is a pointer, we can't learn anything about its + * variable offset from the compare (unless they were a pointer into + * the same object, but we don't bother with that). */ - if (opcode) - reg_set_min_max(true_reg, false_reg, val, val32, opcode, is_jmp32); -} - -/* Regs are known to be equal, so intersect their min/max/var_off */ -static void __reg_combine_min_max(struct bpf_reg_state *src_reg, - struct bpf_reg_state *dst_reg) -{ - src_reg->umin_value = dst_reg->umin_value = max(src_reg->umin_value, - dst_reg->umin_value); - src_reg->umax_value = dst_reg->umax_value = min(src_reg->umax_value, - dst_reg->umax_value); - src_reg->smin_value = dst_reg->smin_value = max(src_reg->smin_value, - dst_reg->smin_value); - src_reg->smax_value = dst_reg->smax_value = min(src_reg->smax_value, - dst_reg->smax_value); - src_reg->var_off = dst_reg->var_off = tnum_intersect(src_reg->var_off, - dst_reg->var_off); - reg_bounds_sync(src_reg); - reg_bounds_sync(dst_reg); -} + if (false_reg1->type != SCALAR_VALUE || false_reg2->type != SCALAR_VALUE) + return 0; -static void reg_combine_min_max(struct bpf_reg_state *true_src, - struct bpf_reg_state *true_dst, - struct bpf_reg_state *false_src, - struct bpf_reg_state *false_dst, - u8 opcode) -{ - switch (opcode) { - case BPF_JEQ: - __reg_combine_min_max(true_src, true_dst); - break; - case BPF_JNE: - __reg_combine_min_max(false_src, false_dst); - break; - } + /* fallthrough (FALSE) branch */ + regs_refine_cond_op(false_reg1, false_reg2, rev_opcode(opcode), is_jmp32); + reg_bounds_sync(false_reg1); + reg_bounds_sync(false_reg2); + + /* jump (TRUE) branch */ + regs_refine_cond_op(true_reg1, true_reg2, opcode, is_jmp32); + reg_bounds_sync(true_reg1); + reg_bounds_sync(true_reg2); + + err = reg_bounds_sanity_check(env, true_reg1, "true_reg1"); + err = err ?: reg_bounds_sanity_check(env, true_reg2, "true_reg2"); + err = err ?: reg_bounds_sanity_check(env, false_reg1, "false_reg1"); + err = err ?: reg_bounds_sanity_check(env, false_reg2, "false_reg2"); + return err; } static void mark_ptr_or_null_reg(struct bpf_func_state *state, @@ -14839,6 +14564,7 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env, struct bpf_reg_state *regs = this_branch->frame[this_branch->curframe]->regs; struct bpf_reg_state *dst_reg, *other_branch_regs, *src_reg = NULL; struct bpf_reg_state *eq_branch_regs; + struct bpf_reg_state fake_reg = {}; u8 opcode = BPF_OP(insn->code); bool is_jmp32; int pred = -1; @@ -14879,42 +14605,13 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env, verbose(env, "BPF_JMP/JMP32 uses reserved fields\n"); return -EINVAL; } + src_reg = &fake_reg; + src_reg->type = SCALAR_VALUE; + __mark_reg_known(src_reg, insn->imm); } is_jmp32 = BPF_CLASS(insn->code) == BPF_JMP32; - - if (BPF_SRC(insn->code) == BPF_K) { - pred = is_branch_taken(dst_reg, insn->imm, opcode, is_jmp32); - } else if (src_reg->type == SCALAR_VALUE && - is_jmp32 && tnum_is_const(tnum_subreg(src_reg->var_off))) { - pred = is_branch_taken(dst_reg, - tnum_subreg(src_reg->var_off).value, - opcode, - is_jmp32); - } else if (src_reg->type == SCALAR_VALUE && - !is_jmp32 && tnum_is_const(src_reg->var_off)) { - pred = is_branch_taken(dst_reg, - src_reg->var_off.value, - opcode, - is_jmp32); - } else if (dst_reg->type == SCALAR_VALUE && - is_jmp32 && tnum_is_const(tnum_subreg(dst_reg->var_off))) { - pred = is_branch_taken(src_reg, - tnum_subreg(dst_reg->var_off).value, - flip_opcode(opcode), - is_jmp32); - } else if (dst_reg->type == SCALAR_VALUE && - !is_jmp32 && tnum_is_const(dst_reg->var_off)) { - pred = is_branch_taken(src_reg, - dst_reg->var_off.value, - flip_opcode(opcode), - is_jmp32); - } else if (reg_is_pkt_pointer_any(dst_reg) && - reg_is_pkt_pointer_any(src_reg) && - !is_jmp32) { - pred = is_pkt_ptr_branch_taken(dst_reg, src_reg, opcode); - } - + pred = is_branch_taken(dst_reg, src_reg, opcode, is_jmp32); if (pred >= 0) { /* If we get here with a dst_reg pointer type it is because * above is_branch_taken() special cased the 0 comparison. @@ -14962,53 +14659,27 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env, return -EFAULT; other_branch_regs = other_branch->frame[other_branch->curframe]->regs; - /* detect if we are comparing against a constant value so we can adjust - * our min/max values for our dst register. - * this is only legit if both are scalars (or pointers to the same - * object, I suppose, see the PTR_MAYBE_NULL related if block below), - * because otherwise the different base pointers mean the offsets aren't - * comparable. - */ if (BPF_SRC(insn->code) == BPF_X) { - struct bpf_reg_state *src_reg = ®s[insn->src_reg]; - - if (dst_reg->type == SCALAR_VALUE && - src_reg->type == SCALAR_VALUE) { - if (tnum_is_const(src_reg->var_off) || - (is_jmp32 && - tnum_is_const(tnum_subreg(src_reg->var_off)))) - reg_set_min_max(&other_branch_regs[insn->dst_reg], - dst_reg, - src_reg->var_off.value, - tnum_subreg(src_reg->var_off).value, - opcode, is_jmp32); - else if (tnum_is_const(dst_reg->var_off) || - (is_jmp32 && - tnum_is_const(tnum_subreg(dst_reg->var_off)))) - reg_set_min_max_inv(&other_branch_regs[insn->src_reg], - src_reg, - dst_reg->var_off.value, - tnum_subreg(dst_reg->var_off).value, - opcode, is_jmp32); - else if (!is_jmp32 && - (opcode == BPF_JEQ || opcode == BPF_JNE)) - /* Comparing for equality, we can combine knowledge */ - reg_combine_min_max(&other_branch_regs[insn->src_reg], - &other_branch_regs[insn->dst_reg], - src_reg, dst_reg, opcode); - if (src_reg->id && - !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) { - find_equal_scalars(this_branch, src_reg); - find_equal_scalars(other_branch, &other_branch_regs[insn->src_reg]); - } - - } - } else if (dst_reg->type == SCALAR_VALUE) { - reg_set_min_max(&other_branch_regs[insn->dst_reg], - dst_reg, insn->imm, (u32)insn->imm, - opcode, is_jmp32); + err = reg_set_min_max(env, + &other_branch_regs[insn->dst_reg], + &other_branch_regs[insn->src_reg], + dst_reg, src_reg, opcode, is_jmp32); + } else /* BPF_SRC(insn->code) == BPF_K */ { + err = reg_set_min_max(env, + &other_branch_regs[insn->dst_reg], + src_reg /* fake one */, + dst_reg, src_reg /* same fake one */, + opcode, is_jmp32); } + if (err) + return err; + if (BPF_SRC(insn->code) == BPF_X && + src_reg->type == SCALAR_VALUE && src_reg->id && + !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) { + find_equal_scalars(this_branch, src_reg); + find_equal_scalars(other_branch, &other_branch_regs[insn->src_reg]); + } if (dst_reg->type == SCALAR_VALUE && dst_reg->id && !WARN_ON_ONCE(dst_reg->id != other_branch_regs[insn->dst_reg].id)) { find_equal_scalars(this_branch, dst_reg); @@ -17541,10 +17212,8 @@ static int do_check(struct bpf_verifier_env *env) insn->off, BPF_SIZE(insn->code), BPF_READ, insn->dst_reg, false, BPF_MODE(insn->code) == BPF_MEMSX); - if (err) - return err; - - err = save_aux_ptr_type(env, src_reg_type, true); + err = err ?: save_aux_ptr_type(env, src_reg_type, true); + err = err ?: reg_bounds_sanity_check(env, ®s[insn->dst_reg], "ldx"); if (err) return err; } else if (class == BPF_STX) { @@ -20831,6 +20500,7 @@ int bpf_check(struct bpf_prog **prog, union bpf_attr *attr, bpfptr_t uattr, __u3 if (is_priv) env->test_state_freq = attr->prog_flags & BPF_F_TEST_STATE_FREQ; + env->test_reg_invariants = attr->prog_flags & BPF_F_TEST_REG_INVARIANTS; env->explored_states = kvcalloc(state_htab_size(env), sizeof(struct bpf_verifier_state_list *), |