summaryrefslogtreecommitdiffstats
path: root/net/vmw_vsock/vmci_transport.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/vmw_vsock/vmci_transport.c')
-rw-r--r--net/vmw_vsock/vmci_transport.c32
1 files changed, 29 insertions, 3 deletions
diff --git a/net/vmw_vsock/vmci_transport.c b/net/vmw_vsock/vmci_transport.c
index b6c8c9cc8d72..86030ecb53dd 100644
--- a/net/vmw_vsock/vmci_transport.c
+++ b/net/vmw_vsock/vmci_transport.c
@@ -57,6 +57,7 @@ static bool vmci_transport_old_proto_override(bool *old_pkt_proto);
static u16 vmci_transport_new_proto_supported_versions(void);
static bool vmci_transport_proto_to_notify_struct(struct sock *sk, u16 *proto,
bool old_pkt_proto);
+static bool vmci_check_transport(struct vsock_sock *vsk);
struct vmci_transport_recv_pkt_info {
struct work_struct work;
@@ -1017,6 +1018,16 @@ static int vmci_transport_recv_listen(struct sock *sk,
vsock_addr_init(&vpending->remote_addr, pkt->dg.src.context,
pkt->src_port);
+ err = vsock_assign_transport(vpending, vsock_sk(sk));
+ /* Transport assigned (looking at remote_addr) must be the same
+ * where we received the request.
+ */
+ if (err || !vmci_check_transport(vpending)) {
+ vmci_transport_send_reset(sk, pkt);
+ sock_put(pending);
+ return err;
+ }
+
/* If the proposed size fits within our min/max, accept it. Otherwise
* propose our own size.
*/
@@ -2008,7 +2019,7 @@ static u32 vmci_transport_get_local_cid(void)
return vmci_get_context_id();
}
-static const struct vsock_transport vmci_transport = {
+static struct vsock_transport vmci_transport = {
.init = vmci_transport_socket_init,
.destruct = vmci_transport_destruct,
.release = vmci_transport_release,
@@ -2038,10 +2049,25 @@ static const struct vsock_transport vmci_transport = {
.get_local_cid = vmci_transport_get_local_cid,
};
+static bool vmci_check_transport(struct vsock_sock *vsk)
+{
+ return vsk->transport == &vmci_transport;
+}
+
static int __init vmci_transport_init(void)
{
+ int features = VSOCK_TRANSPORT_F_DGRAM | VSOCK_TRANSPORT_F_H2G;
+ int cid;
int err;
+ cid = vmci_get_context_id();
+
+ if (cid == VMCI_INVALID_ID)
+ return -EINVAL;
+
+ if (cid != VMCI_HOST_CONTEXT_ID)
+ features |= VSOCK_TRANSPORT_F_G2H;
+
/* Create the datagram handle that we will use to send and receive all
* VSocket control messages for this context.
*/
@@ -2065,7 +2091,7 @@ static int __init vmci_transport_init(void)
goto err_destroy_stream_handle;
}
- err = vsock_core_init(&vmci_transport);
+ err = vsock_core_register(&vmci_transport, features);
if (err < 0)
goto err_unsubscribe;
@@ -2096,7 +2122,7 @@ static void __exit vmci_transport_exit(void)
vmci_transport_qp_resumed_sub_id = VMCI_INVALID_ID;
}
- vsock_core_exit();
+ vsock_core_unregister(&vmci_transport);
}
module_exit(vmci_transport_exit);