summaryrefslogtreecommitdiffstats
path: root/net/vmw_vsock/hyperv_transport.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/vmw_vsock/hyperv_transport.c')
-rw-r--r--net/vmw_vsock/hyperv_transport.c94
1 files changed, 39 insertions, 55 deletions
diff --git a/net/vmw_vsock/hyperv_transport.c b/net/vmw_vsock/hyperv_transport.c
index c443db7af8d4..3c7d07a99fc5 100644
--- a/net/vmw_vsock/hyperv_transport.c
+++ b/net/vmw_vsock/hyperv_transport.c
@@ -13,15 +13,16 @@
#include <linux/hyperv.h>
#include <net/sock.h>
#include <net/af_vsock.h>
+#include <asm/hyperv-tlfs.h>
/* Older (VMBUS version 'VERSION_WIN10' or before) Windows hosts have some
- * stricter requirements on the hv_sock ring buffer size of six 4K pages. Newer
- * hosts don't have this limitation; but, keep the defaults the same for compat.
+ * stricter requirements on the hv_sock ring buffer size of six 4K pages.
+ * hyperv-tlfs defines HV_HYP_PAGE_SIZE as 4K. Newer hosts don't have this
+ * limitation; but, keep the defaults the same for compat.
*/
-#define PAGE_SIZE_4K 4096
-#define RINGBUFFER_HVS_RCV_SIZE (PAGE_SIZE_4K * 6)
-#define RINGBUFFER_HVS_SND_SIZE (PAGE_SIZE_4K * 6)
-#define RINGBUFFER_HVS_MAX_SIZE (PAGE_SIZE_4K * 64)
+#define RINGBUFFER_HVS_RCV_SIZE (HV_HYP_PAGE_SIZE * 6)
+#define RINGBUFFER_HVS_SND_SIZE (HV_HYP_PAGE_SIZE * 6)
+#define RINGBUFFER_HVS_MAX_SIZE (HV_HYP_PAGE_SIZE * 64)
/* The MTU is 16KB per the host side's design */
#define HVS_MTU_SIZE (1024 * 16)
@@ -54,7 +55,8 @@ struct hvs_recv_buf {
* ringbuffer APIs that allow us to directly copy data from userspace buffer
* to VMBus ringbuffer.
*/
-#define HVS_SEND_BUF_SIZE (PAGE_SIZE_4K - sizeof(struct vmpipe_proto_header))
+#define HVS_SEND_BUF_SIZE \
+ (HV_HYP_PAGE_SIZE - sizeof(struct vmpipe_proto_header))
struct hvs_send_buf {
/* The header before the payload data */
@@ -163,6 +165,8 @@ static const guid_t srv_id_template =
GUID_INIT(0x00000000, 0xfacb, 0x11e6, 0xbd, 0x58,
0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3);
+static bool hvs_check_transport(struct vsock_sock *vsk);
+
static bool is_valid_srv_id(const guid_t *id)
{
return !memcmp(&id->b[4], &srv_id_template.b[4], sizeof(guid_t) - 4);
@@ -186,7 +190,8 @@ static void hvs_remote_addr_init(struct sockaddr_vm *remote,
static u32 host_ephemeral_port = MIN_HOST_EPHEMERAL_PORT;
struct sock *sk;
- vsock_addr_init(remote, VMADDR_CID_ANY, VMADDR_PORT_ANY);
+ /* Remote peer is always the host */
+ vsock_addr_init(remote, VMADDR_CID_HOST, VMADDR_PORT_ANY);
while (1) {
/* Wrap around ? */
@@ -358,13 +363,24 @@ static void hvs_open_connection(struct vmbus_channel *chan)
if (sk->sk_ack_backlog >= sk->sk_max_ack_backlog)
goto out;
- new = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
- sk->sk_type, 0);
+ new = vsock_create_connected(sk);
if (!new)
goto out;
new->sk_state = TCP_SYN_SENT;
vnew = vsock_sk(new);
+
+ hvs_addr_init(&vnew->local_addr, if_type);
+ hvs_remote_addr_init(&vnew->remote_addr, &vnew->local_addr);
+
+ ret = vsock_assign_transport(vnew, vsock_sk(sk));
+ /* Transport assigned (looking at remote_addr) must be the
+ * same where we received the request.
+ */
+ if (ret || !hvs_check_transport(vnew)) {
+ sock_put(new);
+ goto out;
+ }
hvs_new = vnew->trans;
hvs_new->chan = chan;
} else {
@@ -393,10 +409,10 @@ static void hvs_open_connection(struct vmbus_channel *chan)
} else {
sndbuf = max_t(int, sk->sk_sndbuf, RINGBUFFER_HVS_SND_SIZE);
sndbuf = min_t(int, sndbuf, RINGBUFFER_HVS_MAX_SIZE);
- sndbuf = ALIGN(sndbuf, PAGE_SIZE);
+ sndbuf = ALIGN(sndbuf, HV_HYP_PAGE_SIZE);
rcvbuf = max_t(int, sk->sk_rcvbuf, RINGBUFFER_HVS_RCV_SIZE);
rcvbuf = min_t(int, rcvbuf, RINGBUFFER_HVS_MAX_SIZE);
- rcvbuf = ALIGN(rcvbuf, PAGE_SIZE);
+ rcvbuf = ALIGN(rcvbuf, HV_HYP_PAGE_SIZE);
}
ret = vmbus_open(chan, sndbuf, rcvbuf, NULL, 0, hvs_channel_cb,
@@ -426,10 +442,7 @@ static void hvs_open_connection(struct vmbus_channel *chan)
if (conn_from_host) {
new->sk_state = TCP_ESTABLISHED;
- sk->sk_ack_backlog++;
-
- hvs_addr_init(&vnew->local_addr, if_type);
- hvs_remote_addr_init(&vnew->remote_addr, &vnew->local_addr);
+ sk_acceptq_added(sk);
hvs_new->vm_srv_id = *if_type;
hvs_new->host_srv_id = *if_instance;
@@ -670,7 +683,7 @@ static ssize_t hvs_stream_enqueue(struct vsock_sock *vsk, struct msghdr *msg,
ssize_t ret = 0;
ssize_t bytes_written = 0;
- BUILD_BUG_ON(sizeof(*send_buf) != PAGE_SIZE_4K);
+ BUILD_BUG_ON(sizeof(*send_buf) != HV_HYP_PAGE_SIZE);
send_buf = kmalloc(sizeof(*send_buf), GFP_KERNEL);
if (!send_buf)
@@ -843,37 +856,9 @@ int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written,
return 0;
}
-static void hvs_set_buffer_size(struct vsock_sock *vsk, u64 val)
-{
- /* Ignored. */
-}
-
-static void hvs_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
-{
- /* Ignored. */
-}
-
-static void hvs_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
-{
- /* Ignored. */
-}
-
-static u64 hvs_get_buffer_size(struct vsock_sock *vsk)
-{
- return -ENOPROTOOPT;
-}
-
-static u64 hvs_get_min_buffer_size(struct vsock_sock *vsk)
-{
- return -ENOPROTOOPT;
-}
-
-static u64 hvs_get_max_buffer_size(struct vsock_sock *vsk)
-{
- return -ENOPROTOOPT;
-}
-
static struct vsock_transport hvs_transport = {
+ .module = THIS_MODULE,
+
.get_local_cid = hvs_get_local_cid,
.init = hvs_sock_init,
@@ -906,14 +891,13 @@ static struct vsock_transport hvs_transport = {
.notify_send_pre_enqueue = hvs_notify_send_pre_enqueue,
.notify_send_post_enqueue = hvs_notify_send_post_enqueue,
- .set_buffer_size = hvs_set_buffer_size,
- .set_min_buffer_size = hvs_set_min_buffer_size,
- .set_max_buffer_size = hvs_set_max_buffer_size,
- .get_buffer_size = hvs_get_buffer_size,
- .get_min_buffer_size = hvs_get_min_buffer_size,
- .get_max_buffer_size = hvs_get_max_buffer_size,
};
+static bool hvs_check_transport(struct vsock_sock *vsk)
+{
+ return vsk->transport == &hvs_transport;
+}
+
static int hvs_probe(struct hv_device *hdev,
const struct hv_vmbus_device_id *dev_id)
{
@@ -962,7 +946,7 @@ static int __init hvs_init(void)
if (ret != 0)
return ret;
- ret = vsock_core_init(&hvs_transport);
+ ret = vsock_core_register(&hvs_transport, VSOCK_TRANSPORT_F_G2H);
if (ret) {
vmbus_driver_unregister(&hvs_drv);
return ret;
@@ -973,7 +957,7 @@ static int __init hvs_init(void)
static void __exit hvs_exit(void)
{
- vsock_core_exit();
+ vsock_core_unregister(&hvs_transport);
vmbus_driver_unregister(&hvs_drv);
}