网络_文件传输

创建项目

新建项目FileServer

FileHeader

新建头文件file_header.h

需要设计为plain type,即简单结构,不能有虚函数表。
需要设计:

  1. 偏移量
  2. 大小

名字、标号、文件类型这些属性暂不考虑。

1
2
3
4
5
6
7
8
// file_header.h
#pragma once
class FileHeader
{
public:
unsigned long long offset;
unsigned long long size;
};

还需要考虑,问答的机制。

  1. 第一次请求文件时,需要得知文件总大小,
  2. 之后传输时则请求若干个不同的偏移量起始的文件片段。
  3. 传输完毕时,需要FINISH标志来提示结束。

所以,请求的阶段、内容不一样时,就需要用不同类型的文件头来区分。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
#pragma once
enum class HeaderType
{
FILE_SIZE = 0,
SEGMENT,
FINISH
};
class FileHeader
{
public:
unsigned long long offset;
unsigned long long size;
enum class HeaderType type;
};

FileServer

新建源文件file_server.cpp,内容拷贝stream_server_threadpool_coroutine.cpp
为了简洁,暂时不用线程池(仍保留协程)

打开文件

此处的代码位置位于服务端程序打印客户端信息和发送消息之间:

需要使用fstream,构造一个fstream对象。需要传入文件路径、打开方式。
打开方式见:fstream::open
在此例我们使用读取+二进制方式打开文件:std::fstream::in | std::fstream::binary
最后别忘了关闭fs。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
// file_server.cpp
#include <fstream>
agave::IAsyncAction worker_async(sockaddr_in client_addr, SOCKET work_sock)
{
co_await agave::resume_background();
// print client info ...
// ...

// open target file
std::fstream fs(L"d:/test", std::fstream::in | std::fstream::binary);

fs.close();
::closesocket(work_sock);
work_sock = INVALID_SOCKET;
}

读取FileHeader

这个FileHeader是客户端发送给服务端的请求。
服务端需要查看客户端的请求类型。才能做出下一步的响应。
写成一个函数。

  1. 需要两个参数:一个FileHeader的引用,供写入提前定义好的空对象。一个work_sock,将从此sock上收发FileHeader。
  2. 由于是TCP的流式传输,一次可能接收不完。所以每次接收后都需要计算剩余的大小,即remainder = 固定的FileHeader大小 - 已接收的大小,以及记录下一次接收的偏移量offset += 已接收的大小
  3. 如果recv返回值小于等于0则接收结束。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
bool read_file_header(FileHeader& file_header, SOCKET work_sock)
{
long long unsigned remainder = sizeof(file_header);
long long unsigned offset = 0;
while (remainder > 0)
{
auto bytes_received = ::recv(
work_sock,
reinterpret_cast<char*>(&file_header) + offset,
remainder,
0);
// 返回值为0时,代表连接已关闭
if (bytes_received == SOCKET_ERROR || bytes_received == 0)
{
return false;
}
else
{
remainder -= bytes_received;
offset += bytes_received;
}
}
}

在打开文件后,执行读取FileHeader。成功接受FileHeader后,便可以通过switch-case判断HeaderType以进行下一步操作(响应给客户端)。需要循环执行。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// file_server.cpp
#include <fstream>
agave::IAsyncAction worker_async(sockaddr_in client_addr, SOCKET work_sock)
{
co_await agave::resume_background();
// print client info ...
// ...

// open target file
// ...

// check headerType
FileHeader file_header;
while (true)
{
if (!read_file_header(file_header, work_sock))
{
break;
}
switch (file_header.type)
{
case HeaderType::FILE_SIZE:
// ...
break;
case HeaderType::SEGMENT:
// ...
break;
case HeaderType::FINISH:
// ...
break;
default:
// ...
break;
}
}
fs.close();
::closesocket(work_sock);
work_sock = INVALID_SOCKET;
}

获取文件大小后响应

根据fstream给出的方法,可以获取文件大小。
具体地,要使用的是istream中的方法。因为是读取(istream中的i代表in,是以内存为视角的,写入到内存中,即为读取)。
而对于读取,后缀都以g来区分。比如seek函数有seekgseekp(又比如tellgtellp),前者则是istream中的,后者是ostream中的,istream和ostream的缓冲区是两个独立的,所以要进行区分。
函数比较简短,可以考虑设置为inline。

1
2
3
4
5
6
7
inline long long unsigned get_file_size(std::fstream& fs)
{
fs.seekg(0, std::fstream::end);
long long unsigned length = fs.tellg();
fs.seekg(0, std::fstream::beg);
return length;
}

获取完文件大小后,填入要响应的FileHeader的信息,发送给客户端,需要封装一个send_file_header函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// file_server.cpp
// ...
switch (file_header.type)
{
case HeaderType::FILE_SIZE:
{
file_header.type = HeaderType::FILE_SIZE;
file_header.offset = 0;
file_header.size = get_file_size(fs);
send_file_header(file_header, work_sock);
}
break;
case HeaderType::SEGMENT:
// ...
break;
case HeaderType::FINISH:
// ...
break;
default:
// ...
break;
}
}
// ...
}

send_file_header

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
bool send_file_header(FileHeader& file_header, SOCKET work_sock)
{
long long unsigned remainder = sizeof(file_header);
long long unsigned offset = 0;
while (remainder > 0)
{
auto bytes_sent = ::send(
work_sock,
reinterpret_cast<const char*>(&file_header) + offset,
remainder,
0);
if (bytes_sent == SOCKET_ERROR)
{
return false;
}
else
{
remainder -= bytes_sent;
offset += bytes_sent;
}
}
return true;
}

响应片段

首先需要读取片段,再发送。
在此之前,定义一个固定的文件片段大小,在file_header.h中定义为常量,512字节。

1
2
3
4
// file_header.h
// fixed File Segment Size
constexpr unsigned SEGMENT_SIZE{ 512 };
// ...

然后,可以以此大小作为缓冲区大小。每次读取一个片段都把内容放到这个缓冲区中,再发送走。最后不要忘记释放缓冲区。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// file_server.cpp
// ...
// open target file
// ...

// check headerType
FileHeader file_header;
char* buf = new char(SEGMENT_SIZE);
while (true)
{
// ...
}

delete[]buf;
buf = nullptr;
fs.close();
::closesocket(work_sock);
work_sock = INVALID_SOCKET;
}

响应的代码:
read_segment_from_file用于读取一个文件片段,然后写入到一个buf缓冲区中。
send_segment用于把buf缓冲区从SOCKET发送到网络。

read_segmentsend_segment错误时,不能直接break,这回直接退出while循环,可以加一个标志位is_exit,在执行完该case后根据此标志来决定是否退出循环。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
// ...
bool is_exit{ false };
while (true)
{
if (!read_file_header(file_header, work_sock))
{
break;
}
switch (file_header.type)
{
case HeaderType::FILE_SIZE:
{
// ...
}
break;
case HeaderType::SEGMENT:
{
if (!read_segment_from_file(fs, buf, file_header.offset, file_header.size))
{
is_exit = true;
}
if (!send_segment(work_sock, buf, file_header.size))
{
is_exit = true;
}
}
break;
case HeaderType::FINISH:
is_exit = true;
break;
default:
break;
}
if (is_exit)
break;
}
// ...

读取片段

1
2
3
4
5
6
7
8
9
inline bool read_segment_from_file(std::fstream& fs, char * buf, long long unsigned offset, long long unsigned length)
{
fs.seekg(offset, std::fstream::beg);
fs.read(buf, length);
// read不会返回读取的量,需要用gcount来看读取了多少
if (fs.gcount() != length)
return false;
return true;
}

发送片段

send_file_header非常相似,只不过把FileHeader换成了buf,且有一个实际的length参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
bool send_segment(SOCKET work_sock, const char* buf, long long unsigned length)
{
long long unsigned remainder = length;
long long offset = 0;
while (remainder > 0)
{
auto bytes_sent = ::send(
work_sock,
reinterpret_cast<const char*>(buf) + offset,
remainder,
0);
if (bytes_sent == SOCKET_ERROR)
{
return false;
}
else
{
remainder -= bytes_sent;
offset += bytes_sent;
}
}
return true;
}

FileClient

创建项目FileClient,新建源文件file_client.cpp,内容拷贝basic_stream_client.cpp
添加头文件(Add Existing Item)file_header.h
在原先测试收发消息的代码处更换为:发送FileHeader。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#include <WinSock2.h>
#include <iostream>
#include <format>
#include <ws2tcpip.h>
#include "../FileServer/file_header.h"
#pragma comment (lib, "Ws2_32")
int main()
{
WORD wVersionRequested;
WSADATA wsaData;
int err;
wVersionRequested = MAKEWORD(2, 2);

err = ::WSAStartup(wVersionRequested, &wsaData);
if (err != 0)
{
std::wcout << std::format(L"WSAStartup failed with error : {}\n", err);
return 1;
}

SOCKET sock = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (sock == INVALID_SOCKET)
{
err = ::WSAGetLastError();
return 1;
}

sockaddr_in server_addr;
server_addr.sin_family = AF_INET;
if (1 != ::inet_pton(AF_INET, "127.0.0.1", &server_addr.sin_addr))
{
err = ::WSAGetLastError();
return 1;
}
server_addr.sin_port = htons(9008);

if (SOCKET_ERROR == ::connect(sock, reinterpret_cast<const sockaddr*>(&server_addr), sizeof(server_addr)))
{
err = ::WSAGetLastError();
return 1;
}

// adjust as follow:
FileHeader file_header;
file_header.type = HeaderType::FILE_SIZE;
// send FileHeader
// ...

::closesocket(sock);
sock = INVALID_SOCKET;

::WSACleanup();
return 0;
}

file_foundation

此处又会用到了FileServer中写过的send_file_header函数,可想而知后面同样会用到其他写过的函数。
所以我们单独在FileServer项目中再新建一个单独的头文件file_foundation.h,把之前在file_server.cpp单独声明的函数剪切到此处,以便两端都可以引用。
注意也需要剪切这些函数需要的头文件。相应地file_server.cpp可以只引用file_foundation.h

1
2
3
4
5
6
7
8
9
10
// file_foundation.h
#pragma once
#include <fstream>
#include "file_header.h"
#include <WinSock2.h>
bool read_file_header(FileHeader& file_header, SOCKET work_sock);
bool send_file_header(FileHeader& file_header, SOCKET work_sock);
bool send_segment(SOCKET work_sock, const char* buf, long long unsigned length);
long long unsigned get_file_size(std::fstream& fs);
bool read_segment_from_file(std::fstream& fs, char* buf, long long unsigned offset, long long unsigned length);

相应地,需要在FileServer项目中新建file_foundation.cpp,剪切实现部分。
此时编译测试,发现之前在file_server.cpp中声明的inline函数在分开的file_foundation时就不能通过编译了,需要去掉inline。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
// file_foundation.cpp
#include "file_foundation.h"
bool read_file_header(FileHeader& file_header, SOCKET work_sock)
{
long long unsigned remainder = sizeof(file_header);
long long unsigned offset = 0;
while (remainder > 0)
{
auto bytes_received = ::recv(
work_sock,
reinterpret_cast<char*>(&file_header) + offset,
remainder,
0);
// 返回值为0时,代表连接已关闭
if (bytes_received == SOCKET_ERROR || bytes_received == 0)
{
return false;
}
else
{
remainder -= bytes_received;
offset += bytes_received;
}
}
}
bool send_file_header(FileHeader& file_header, SOCKET work_sock)
{
long long unsigned remainder = sizeof(file_header);
long long unsigned offset = 0;
while (remainder > 0)
{
auto bytes_sent = ::send(
work_sock,
reinterpret_cast<const char*>(&file_header) + offset,
remainder,
0);
if (bytes_sent == SOCKET_ERROR)
{
return false;
}
else
{
remainder -= bytes_sent;
offset += bytes_sent;
}
}
return true;
}
bool send_segment(SOCKET work_sock, const char* buf, long long unsigned length)
{
long long unsigned remainder = length;
long long offset = 0;
while (remainder > 0)
{
auto bytes_sent = ::send(
work_sock,
buf + offset,
remainder,
0);
if (bytes_sent == SOCKET_ERROR)
{
return false;
}
else
{
remainder -= bytes_sent;
offset += bytes_sent;
}
}
return true;
}
long long unsigned get_file_size(std::fstream& fs)
{
fs.seekg(0, std::fstream::end);
long long unsigned length = fs.tellg();
fs.seekg(0, std::fstream::beg);
return length;
}
bool read_segment_from_file(std::fstream& fs, char* buf, long long unsigned offset, long long unsigned length)
{
fs.seekg(offset, std::fstream::beg);
fs.read(buf, length);
// read不会返回读取的量,需要用gcount来看读取了多少
if (fs.gcount() != length)
return false;
return true;
}

继续编写file_client

在FileClient项目下Add Existing Item:file_foundation.hfile_foundation.cpp
此时,file_client就可以引用file_foundation.h,复用send_file_header等函数:

接下来需要处理的就是:

  1. 发送FileHeader,请求获取文件总大小,读取到FileHeader中
  2. 文件总大小 / 缓冲区大小计算将要下载的片段数,以及计算最后一个片段大小
  3. fstream打开文件(需要指定接收文件到哪个位置),以out、binary、trunc(追加)的方式打开
  4. for循环
    1. 每次都发一个FileHeader,请求获取一个文件片段
    2. 读取Segment,先写入到buf。(此处对应的函数为read_segment,是客户端从网络读取下载。而read_segment_from_file是服务端从本地文件读取片段)
    3. fs.write拷贝buf内容到fs设置好的硬盘位置中。
    4. 循环完毕后按以上步骤特殊处理结尾片段。
  5. 发送FileHeader,表示FINISH。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
// ...
FileHeader file_header;
file_header.type = HeaderType::FILE_SIZE;
send_file_header(file_header, sock);
read_file_header(file_header, sock);
long long unsigned count = file_header.size / SEGMENT_SIZE;
long long unsigned last_size = file_header.size - count * SEGMENT_SIZE;
// open a file to write by downloading
char* buf = new char[SEGMENT_SIZE];
std::fstream fs(L"./testRecv", std::fstream::out | std::fstream::binary | std::fstream::trunc);

if (!fs)
{
goto __Target;
}
for (long long unsigned i = 0; i < count; ++i)
{
// request segment to server
file_header.type = HeaderType::SEGMENT;
file_header.offset = i * SEGMENT_SIZE;
file_header.size = SEGMENT_SIZE;
if (!send_file_header(file_header, sock))
{
break;
}
// download from server
if (!read_segment(sock, buf, SEGMENT_SIZE))
{
break;
}
// write to HardDisk
fs.write(buf, SEGMENT_SIZE);
if (!fs)
{
break;
}
}
if (last_size > 0)
{
// request segment to server
file_header.type = HeaderType::SEGMENT;
file_header.offset = count * SEGMENT_SIZE;
file_header.size = last_size;
if (send_file_header(file_header, sock))
{
// download from server
if (read_segment(sock, buf, last_size))
{
// write to HardDisk
fs.write(buf, last_size);
}
}
}
file_header.type = HeaderType::FINISH;
send_file_header(file_header, sock);

__Target:
fs.close();
delete[]buf;
::closesocket(sock);
sock = INVALID_SOCKET;

::WSACleanup();
return 0;
}

测试

启动项目设置为FileServer,直接运行(点击Local Windows Debugger)。
右键FileClient,选择Debug,Start New Instance,即可运行客户端。
发现,在缓冲区大小为512字节时(file_header.hSEGMENT_SIZE),传输时间稍长,可以设置为512 * 1024即512KB,传输速度即可翻倍。

线程池_bubo

创建项目

在已有或新建的解决方案里,新建项目,取名“ThreadPool”。
项目配置标准为C++20标准。

ITask

新建头文件ITask.h

1
2
3
4
5
6
7
8
namespace thpool
{
class ITask
{
public:
virtual void run_task(void) = 0;
};
}

基于信号量的ThreadPool

新建头文件“ThreadPool.h”

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// ThreadPool.h
#include <thread>
#include <mutex>
#include <semaphore>
#include <list>
namespace thpool
{
class ITask;
class ThreadPool
{
public:
ThreadPool(int max_num);
~ThreadPool();
void add_task(std::shared_ptr<ITask> task);
void add_task(std::list<std::shared_ptr<ITask>> task_lst);
private:
// capacity
int _max_num;
// current running task threads
int _alive_num;
std::counting_semaphore<100> _semaphore;
std::mutex _access_mx;
bool _is_exit{ false };
std::list<std::shared_ptr<ITask>> _task_queue;
};
}

实现:

  1. 初始化_max_num,规定线程池最大线程数
  2. 初始化_alive_num,线程池刚创建时,正在执行任务的线程为0。
  3. 初始化_semaphore为0。

共创建_max_num个线程,每个线程基于信号量获取任务,如果成功则_alive_num加1,如果此时_is_exit标志为真则意味着线程池即将析构,_alive_num减1,并退出循环。
如果_is_exit标志不为真,则尝试从任务队列中提取任务,需要用互斥锁同步,在外层先简单判断队列大小是否大于0,然后再用锁去再次获取真实值(这样外层先判断,内层再加锁判断,可以加大条件为真的概率,避免锁太急切地加,降低性能),再去执行提取出的任务。循环,直到任务队列大小为0,退出while循环后_alive_num减1。
退出大的while时(_is_exit标志为真时),_max_num减1。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
// ThreadPool.cpp
#include "ThreadPool.h"
#include "ITask.h"

thpool::ThreadPool::ThreadPool(int max_num)
: _max_num{ max_num }, _alive_num{ 0 }, _semaphore{ 0 }
{
for (int i = 0; i < _max_num; ++i)
{
std::jthread th([this](void)
{
while (true)
{
_semaphore.acquire();
++_alive_num;
if (_is_exit)
{
--_alive_num;
break;
}

while (_task_queue.size() > 0)
{
std::unique_lock lck{ _access_mx };
if (_task_queue.size() > 0)
{
auto task = _task_queue.front();
_task_queue.pop_front();
lck.unlock();
task->run_task();
}
}
--_alive_num;
}
--_max_num;
});
th.detach();
}
}

void thpool::ThreadPool::add_task(std::shared_ptr<ITask> task)
{
if (!task)
return;
std::unique_lock lck{ _access_mx };
_task_queue.push_back(task);
lck.unlock();
// 判断是否有空余的线程,如果有则唤醒一个
if (_max_num - _alive_num > 0)
{
_semaphore.release(1);
}
}

但是以上程序存在一些问题:

  1. _semaphore.acquire()++_alive_num不是同步的。
  2. _is_exit的判断和--_alive_num也不是同步的。
  3. 在每个线程中都进行while循环反复判断任务队列是否为空,可能会导致一个线程一直独占_access_mx互斥量。
  4. 在执行完任务后的--_alive_num步骤,是和执行结束是不同步的,有可能_max_num个线程同时卡在这一步,导致add_task函数中的_max_num - _alive_num > 0的条件不成立,导致_semaphore.release(1)不会执行。这将导致系统实际增加了任务,却没有增加任务的信号量。有小概率死锁(死锁在_semaphore.acquire())。
  5. 综上所述,最好把_semaphore_alive_num融为一体,不要割裂两者。(也就是说,能不能让信号量的量和_alive_num无缝保持一致)
    1. 操作系统的具体实现,如Windows、Linux是可以随时获取信号量的量的,但是我们现在使用的是跨平台C++信号量,没有提供获取量大小的方法,因此必须有一个_alive_num记录。
    2. 可以在_semaphore.acquire()++_alive_num这两个动作整体加锁吗?不可以,因为_semaphore本身就是会阻塞的东西,如果加了锁后,_semaphore也阻塞了,那么锁就不能解开了。
    3. 上面提到的,加了锁后,里面的东西阻塞了,想要把锁解开,有一样东西可以实现:条件变量。但是,条件变量没有像信号量记录数目的功能(要么是notify_one,要么是notify_all),因此不行。
  6. 最简单的彻底解决4死锁的问题的方法是,add_task方法中不再判断_max_num - _alive_num > 0的条件,即无论如何,在添加任务时,都要_semaphore.release(1)。这样做的副作用就是要把std::counting_semaphore _semaphore在声明时,定义其为一个最大值为无限大(最大数)的信号量,可以用std::counting_semaphore<> _semaphore表示。

修改后的版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
// ThreadPool.h
#include <thread>
#include <mutex>
#include <semaphore>
#include <list>
namespace thpool
{
class ITask;
class ThreadPool
{
public:
ThreadPool(int max_num);
~ThreadPool();
void add_task(std::shared_ptr<ITask> task);
void add_task(std::list<std::shared_ptr<ITask>> task_lst);
private:
// capacity
int _max_num;
// current running task threads
int _alive_num;
std::counting_semaphore<> _semaphore;
std::mutex _access_mx;
bool _is_exit{ false };
std::list<std::shared_ptr<ITask>> _task_queue;
};
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
// ThreadPool.cpp
#include "ThreadPool.h"
#include "ITask.h"

thpool::ThreadPool::ThreadPool(int max_num)
: _max_num{ max_num }, _alive_num{ 0 }, _semaphore{ 0 }
{
for (int i = 0; i < _max_num; ++i)
{
std::jthread th([this](void)
{
while (true)
{
_semaphore.acquire();
++_alive_num;
if (_is_exit)
{
--_alive_num;
break;
}

std::unique_lock lck{ _access_mx };
if (_task_queue.size() > 0)
{
auto task = _task_queue.front();
_task_queue.pop_front();
lck.unlock();
task->run_task();
}
--_alive_num;
}
--_max_num;
});
th.detach();
}
}

void thpool::ThreadPool::add_task(std::shared_ptr<ITask> task)
{
if (!task)
return;
std::unique_lock lck{ _access_mx };
_task_queue.push_back(task);
lck.unlock();

_semaphore.release(1);
}

线程池的析构——latch的运用

latch是C++20引入的标准。
实际上是对操作系统同步量的操作的封装,比如在Windows下,latch就是对事件的封装,或者是对WaitFor Single/Mutiple Object的封装。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
// ThreadPool.h
#include <thread>
#include <mutex>
#include <semaphore>
#include <list>
#include <latch>
namespace thpool
{
class ITask;
class ThreadPool
{
public:
// ...
private:
// ...
std::latch _latch;
// ...
};
}

在ThreadPool构造时初始化_latchmax_num,表示需要等待max_num个线程的结束,latch才放行。
同时,要在退出大的while循环时,_max_num减1之后,进行_latch.count_down(),参数默认为1,意为对latch值减1。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
// ThreadPool.cpp
#include "ThreadPool.h"
#include "ITask.h"

thpool::ThreadPool::ThreadPool(int max_num)
: _max_num{ max_num }, _alive_num{ 0 }, _semaphore{ 0 }, _latch{ max_num }
{
for (int i = 0; i < _max_num; ++i)
{
std::jthread th([this](void)
{
while (true)
{
_semaphore.acquire();
++_alive_num;
if (_is_exit)
{
--_alive_num;
break;
}

// ...
}
--_max_num;
_latch.count_down();
});
th.detach();
}
}

析构函数,则可以利用latch,让其在析构函数设置_is_exit标志为true且释放_max_num信号量后,wait直到latch值为0时,代表所有线程都结束了,就可以返回了。

1
2
3
4
5
6
7
8
// ThreadPool.cpp
thpool::ThreadPool::~ThreadPool()
{
_is_exit = true;
_semaphore.release(_max_num);
// 等待线程结束,对应latch其值为0时
_latch.wait();
}

测试

新建main_entry.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// main_entry.cpp
#include "ThreadPool.h"
#include "ITask.h"
#include <iostream>
using namespace std::chrono_literals;
class Task : public thpool::ITask
{
public:
virtual void run_task(void) override
{
std::wcout << L"task" << std::endl;
}
};
int main(void)
{
thpool::ThreadPool thread_pool{ 10 };
thread_pool.add_task(std::shared_ptr<thpool::ITask>(new Task));
thread_pool.add_task(std::shared_ptr<thpool::ITask>(new Task));
std::this_thread::sleep_for(10s);
return 0;
}

测试结果:

两个task黏在一起,说明多线程输出的。

在Debug-Windows-Threads中,可以看到目前程序中的线程状况:

可以看到有10个子线程在其中: