
"会用了工具不代表会写并发——就像有了锤子不代表会盖房子"
前面三篇我们学了:
现在你可能想:"太好了,我可以写并发程序了!"
等等,先别急。
给你一堆砖头、水泥、钢筋,你就能盖房子吗?显然不能。你还需要设计图纸和施工规范。
并发编程也是一样。知道工具怎么用只是第一步,更重要的是知道在什么场景下用什么工具,以及如何避免常见的并发陷阱。
今天我们就来学习并发编程的经典模式和最佳实践:线程池、死锁预防、并发设计模式,以及一些"血泪教训"总结出来的经验法则。
并发模式是解决常见并发问题的可复用方案。就像设计模式一样,它们是经过验证的最佳实践。
常见的并发模式:
想象你要处理 1000 个任务:
方案 1:每个任务创建一个线程
for task in tasks {
thread::spawn(|| process(task)); // ❌ 糟糕!
}
问题:
方案 2:线程池
let pool = ThreadPool::new(); // 只创建 4 个线程
for task in tasks {
pool.execute(|| process(task)); // 复用线程
}
优点:
死锁发生的四个条件(缺一不可):
预防死锁 = 破坏至少一个条件
use std::sync::{mpsc, Arc, Mutex};
use std::thread;
// 任务类型
type Job = Box<dyn FnOnce() + Send + 'static>;
// 线程池
pub struct ThreadPool {
workers: Vec<Worker>,
sender: Option<mpsc::Sender<Job>>,
}
// 工作线程
struct Worker {
id: usize,
thread: Option<thread::JoinHandle<()>>,
}
impl ThreadPool {
/// 创建新的线程池
pub fn new(size: usize) -> ThreadPool {
assert!(size > );
let (sender, receiver) = mpsc::channel();
let receiver = Arc::new(Mutex::new(receiver));
let mut workers = Vec::with_capacity(size);
for id in ..size {
workers.push(Worker::new(id, Arc::clone(&receiver)));
}
ThreadPool {
workers,
sender: Some(sender),
}
}
/// 执行任务
pub fn execute<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
let job = Box::new(f);
self.sender.as_ref().unwrap().send(job).unwrap();
}
}
impl Worker {
fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker {
let thread = thread::spawn(move || loop {
// 获取任务
let job = receiver.lock().unwrap().recv();
match job {
Ok(job) => {
println!("工人 {} 收到任务", id);
job();
}
Err(_) => {
println!("工人 {} 断开连接,退出", id);
break;
}
}
});
Worker {
id,
thread: Some(thread),
}
}
}
impl Drop for ThreadPool {
fn drop(&mut self) {
println!("发送关闭信号给所有工人");
drop(self.sender.take()); // 关闭 channel
println!("等待所有工人完成");
for worker in &mut self.workers {
println!("关闭工人 {}", worker.id);
if let Some(thread) = worker.thread.take() {
thread.join().unwrap();
}
}
}
}
fn main() {
let pool = ThreadPool::new();
for i in .. {
pool.execute(move || {
println!("任务 {} 执行中", i);
std::thread::sleep(std::time::Duration::from_millis());
println!("任务 {} 完成", i);
});
}
// 等待任务完成
std::thread::sleep(std::time::Duration::from_secs());
// pool 离开作用域,自动清理
}
use std::sync::{Arc, Mutex};
use std::thread;
// ❌ 错误示例:可能导致死锁
fn deadlock_example() {
let mutex1 = Arc::new(Mutex::new());
let mutex2 = Arc::new(Mutex::new());
let m1 = Arc::clone(&mutex1);
let m2 = Arc::clone(&mutex2);
let t1 = thread::spawn(move || {
let _lock1 = m1.lock().unwrap();
thread::sleep(std::time::Duration::from_millis());
let _lock2 = m2.lock().unwrap(); // 可能死锁
});
let t2 = thread::spawn(move || {
let _lock2 = m2.lock().unwrap();
thread::sleep(std::time::Duration::from_millis());
let _lock1 = m1.lock().unwrap(); // 可能死锁
});
t1.join().unwrap();
t2.join().unwrap();
}
// ✅ 正确示例:固定锁顺序
fn no_deadlock() {
let mutex1 = Arc::new(Mutex::new());
let mutex2 = Arc::new(Mutex::new());
let m1 = Arc::clone(&mutex1);
let m2 = Arc::clone(&mutex2);
// 两个线程都先锁 mutex1,再锁 mutex2
let t1 = thread::spawn(move || {
let _lock1 = m1.lock().unwrap();
thread::sleep(std::time::Duration::from_millis());
let _lock2 = m2.lock().unwrap();
});
let t2 = thread::spawn(move || {
let _lock1 = mutex1.lock().unwrap(); // 先锁同一个
thread::sleep(std::time::Duration::from_millis());
let _lock2 = mutex2.lock().unwrap(); // 再锁另一个
});
t1.join().unwrap();
t2.join().unwrap();
}
// ✅ 正确示例:使用 try_lock 避免无限等待
fn try_lock_example() {
let mutex1 = Arc::new(Mutex::new());
let mutex2 = Arc::new(Mutex::new());
let m1 = Arc::clone(&mutex1);
let m2 = Arc::clone(&mutex2);
let t1 = thread::spawn(move || {
if let Ok(_lock1) = m1.try_lock() {
thread::sleep(std::time::Duration::from_millis());
if let Ok(_lock2) = m2.try_lock() {
// 成功获取两个锁
}
}
});
t1.join().unwrap();
}
use std::sync::{mpsc, Arc, Mutex};
use std::thread;
use std::time::Duration;
fn producer_consumer_example() {
let (tx, rx) = mpsc::channel();
let rx = Arc::new(Mutex::new(rx));
// 多个生产者
let mut producers = vec![];
for i in .. {
let tx_clone = tx.clone();
let handle = thread::spawn(move || {
for j in .. {
tx_clone.send(format!("生产者 {} - 消息 {}", i, j)).unwrap();
thread::sleep(Duration::from_millis());
}
});
producers.push(handle);
}
drop(tx); // 关闭原始发送端
// 多个消费者
let mut consumers = vec![];
for i in .. {
let rx_clone = Arc::clone(&rx);
let handle = thread::spawn(move || {
loop {
let msg = {
let rx = rx_clone.lock().unwrap();
rx.recv()
};
match msg {
Ok(m) => println!("消费者 {} 收到:{}", i, m),
Err(_) => break, // channel 关闭
}
}
});
consumers.push(handle);
}
for handle in producers {
handle.join().unwrap();
}
for handle in consumers {
handle.join().unwrap();
}
}
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
// 每个线程有自己的任务队列
struct WorkStealingPool {
queues: Vec<Arc<Mutex<VecDeque<i32>>>>,
}
impl WorkStealingPool {
fn new(num_threads: usize) -> Self {
let queues = (..num_threads)
.map(|_| Arc::new(Mutex::new(VecDeque::new())))
.collect();
WorkStealingPool { queues }
}
fn add_task(&self, thread_id: usize, task: i32) {
let queue = &self.queues[thread_id];
queue.lock().unwrap().push_back(task);
}
// 从自己的队列获取任务,如果没有就从别人那里"偷"
fn steal_task(&self, thread_id: usize) -> Option<i32> {
// 先尝试从自己的队列获取
{
let mut queue = self.queues[thread_id].lock().unwrap();
if let Some(task) = queue.pop_front() {
return Some(task);
}
}
// 自己的队列为空,从别人那里偷
for (i, queue) in self.queues.iter().enumerate() {
if i != thread_id {
let mut q = queue.lock().unwrap();
if let Some(task) = q.pop_front() {
println!("线程 {} 从线程 {} 偷到任务", thread_id, i);
return Some(task);
}
}
}
None
}
}
use std::sync::Mutex;
struct BadCounter {
value: Mutex<i32>,
}
impl BadCounter {
fn increment(&self) {
// ❌ 锁的粒度过大
let mut value = self.value.lock().unwrap();
// 下面这些操作其实不需要锁
println!("准备增加");
std::thread::sleep(std::time::Duration::from_millis());
*value += ;
println!("增加完成");
// 锁持有时间过长
}
}
// ✅ 更好的做法
struct GoodCounter {
value: Mutex<i32>,
}
impl GoodCounter {
fn increment(&self) {
println!("准备增加");
std::thread::sleep(std::time::Duration::from_millis());
// 只在必要时持有锁
{
let mut value = self.value.lock().unwrap();
*value += ;
}
println!("增加完成");
}
}
use std::sync::mpsc;
use std::thread;
fn main() {
let (tx, rx) = mpsc::channel();
thread::spawn(move || {
tx.send("你好").unwrap();
// tx drop,channel 关闭
});
// ❌ 无限循环
// loop {
// let msg = rx.recv().unwrap(); // channel 关闭后会 panic
// println!("{}", msg);
// }
// ✅ 正确处理
for msg in rx {
println!("{}", msg);
}
// 或者
// loop {
// match rx.recv() {
// Ok(msg) => println!("{}", msg),
// Err(_) => break, // channel 关闭
// }
// }
}
use std::sync::{mpsc, Arc, Mutex};
use std::thread;
struct ThreadPool {
workers: Vec<Worker>,
sender: mpsc::Sender<Box<dyn FnOnce() + Send + 'static>>,
}
struct Worker {
thread: thread::JoinHandle<()>,
}
impl ThreadPool {
fn new(size: usize) -> Self {
let (tx, rx) = mpsc::channel();
let rx = Arc::new(Mutex::new(rx));
let workers = (..size)
.map(|id| {
let rx = Arc::clone(&rx);
let thread = thread::spawn(move || loop {
let job = rx.lock().unwrap().recv();
match job {
Ok(job) => {
// ❌ 如果 job() panic,整个线程就挂了
job();
}
Err(_) => break,
}
});
Worker { thread }
})
.collect();
ThreadPool { workers, sender: tx }
}
fn execute<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
self.sender.send(Box::new(f)).unwrap();
}
}
// ✅ 更好的做法:捕获 panic
fn better_worker(rx: Arc<Mutex<mpsc::Receiver<Box<dyn FnOnce() + Send + 'static>>>>) {
loop {
let job = rx.lock().unwrap().recv();
match job {
Ok(job) => {
// 用 catch_unwind 捕获 panic
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
job();
}));
if let Err(e) = result {
eprintln!("任务 panic 了:{:?}", e);
// 线程继续运行,处理下一个任务
}
}
Err(_) => break,
}
}
}
use std::collections::HashSet;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
// 简化的爬虫示例
struct WebCrawler {
visited: Arc<Mutex<HashSet<String>>>,
to_visit: Arc<Mutex<Vec<String>>>,
}
impl WebCrawler {
fn new() -> Self {
WebCrawler {
visited: Arc::new(Mutex::new(HashSet::new())),
to_visit: Arc::new(Mutex::new(vec!["https://example.com".to_string()])),
}
}
fn crawl(&self, num_threads: usize) {
let mut handles = vec![];
for i in ..num_threads {
let visited = Arc::clone(&self.visited);
let to_visit = Arc::clone(&self.to_visit);
let handle = thread::spawn(move || {
loop {
// 获取下一个 URL
let url = {
let mut queue = to_visit.lock().unwrap();
queue.pop()
};
match url {
Some(url) => {
// 检查是否访问过
let mut visited_set = visited.lock().unwrap();
if visited_set.contains(&url) {
continue;
}
visited_set.insert(url.clone());
// 模拟爬取
println!("线程 {} 爬取:{}", i, url);
thread::sleep(Duration::from_millis());
// 模拟发现新链接
let new_urls = vec![
format!("{}/page1", url),
format!("{}/page2", url),
];
let mut queue = to_visit.lock().unwrap();
for new_url in new_urls {
queue.push(new_url);
}
}
None => break, // 没有更多 URL
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
}
use std::sync::mpsc;
use std::thread;
// 数据处理管道:读取 -> 转换 -> 过滤 -> 输出
fn data_pipeline() {
let (tx1, rx1) = mpsc::channel();
let (tx2, rx2) = mpsc::channel();
let (tx3, rx3) = mpsc::channel();
// 阶段 1:读取数据
let reader = thread::spawn(move || {
for i in ..= {
tx1.send(i).unwrap();
}
});
// 阶段 2:转换数据(乘以 2)
let transformer = thread::spawn(move || {
for num in rx1 {
tx2.send(num * ).unwrap();
}
});
// 阶段 3:过滤数据(只保留大于 10 的)
let filter = thread::spawn(move || {
for num in rx2 {
if num > {
tx3.send(num).unwrap();
}
}
});
// 阶段 4:输出结果
let output = thread::spawn(move || {
for num in rx3 {
println!("输出:{}", num);
}
});
reader.join().unwrap();
transformer.join().unwrap();
filter.join().unwrap();
output.join().unwrap();
}

try_lock、缩小锁范围。catch_unwind 捕获,避免整个线程池挂掉。下篇预告: 线程并发学完了,但还有一种更高效的并发方式——异步编程!不用创建线程就能同时处理成千上万个任务,这是什么魔法?下篇我们学习 async/await、Future、Pin、Waker,揭开异步编程的神秘面纱!