重学Java并发 - 手写并发工具类

重学Java并发 - 手写并发工具类,第1张

文章目录
      • 手写线程池
      • 手写阻塞队列

最近再重学Java多线程的内容,Java中线程的同步基本是靠两种方式,一是Object自带的Monitor机制,通过Object上的wait/notify实现的等待/通知模式;二是JUC并发包下的Lock系列API,底层是通过LockSupport的park/unpark来实现的等待通知(park和unpark是native方法,由JVM实现,在Linux下是借助pthread_cond_wait和pthread_cond_signal实现)

手写线程池

使用Object的wait/notify机制,手写一个简易的线程池,进行练习,以加深对并发的理解(包括线程中断机制等)

定义一个线程池接口

package experiment;

/**
 * @Author yogurtzzz
 * @Date 2022/4/20 15:08
 **/
public interface ThreadPool {

	// 提交一个任务到线程池
	void execute(Runnable runnable);

	// 增加线程池中的线程
	void addWorker(int n);

	// 移除线程池中的线程
	void removeWorker(int n);

	// 关闭线程池
	void shutdown();

	// 打印线程池状态
	void printStatus();
}

编写实现类

package experiment;

import java.util.HashSet;
import java.util.LinkedList;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

/**
 * @Author yogurtzzz
 * @Date 2022/4/20 15:09
 **/
public class DefaultThreadPool implements ThreadPool {

	enum State {
		INITILIAZING, RUNNING, TERMINATED
	}

	private final LinkedList<Runnable> jobList = new LinkedList<>();

	private final Set<Worker> workerSet = new HashSet<>(); // 总的 Worker

	private final Set<Worker> busyWorkers = new HashSet<>();

	private final Set<Worker> idleWorkers = new HashSet<>();

	private final AtomicLong completedJobCnt = new AtomicLong();

	private final AtomicLong processingJobCnt = new AtomicLong();

	private State state;

	private static final int MAX_SIZE = 100;

	private static final int MIN_SIZE = 1;

	private static final AtomicInteger threadNum = new AtomicInteger();

	public DefaultThreadPool(int initialSize) {
		state = State.INITILIAZING;
		int n = initialSize < MIN_SIZE ? MIN_SIZE : initialSize > MAX_SIZE ? MAX_SIZE : initialSize;
		addWorkers(n);
		state = State.RUNNING;
	}

	private void addWorkers(int n) {
		if (n <= 0) return;
		synchronized (workerSet) {
			int size = workerSet.size();
			int remain = MAX_SIZE - size; // 最多还能添加多少个Worker
			if (n > remain) n = remain;
			for (int i = 0; i < n; i++) {
				Worker worker = new Worker();
				Thread t = new Thread(worker, "Worker-" + threadNum.getAndIncrement());
				worker.t = t;
				t.start();
				workerSet.add(worker);
			}
		}
	}

	private void removeWorkers(int n) {
		if (n <= 0) return;
		synchronized (workerSet) {
			int size = workerSet.size();
			int remain = size - n; // 减掉后还剩多少个
			if (remain < MIN_SIZE) n = size - 1; // 保证最少还剩一个线程
			// 减掉n个线程
			int i = 0;
			for (Worker w : idleWorkers) {
				if (i >= n) break; // 已经减完
				w.shutdown();
				i++; // 已经减掉了一个
			}
			if (i < n) {
				for (Worker w : busyWorkers) {
					if (i >= n) break;
					w.shutdown();
					i++;
				}
			}
		}
	}

	@Override
	public void execute(Runnable runnable) {
		synchronized (jobList) {
			jobList.add(runnable);
			jobList.notify();
		}
	}

	@Override
	public void addWorker(int n) {
		addWorkers(n);
	}

	@Override
	public void removeWorker(int n) {
		removeWorkers(n);
	}

	@Override
	public void shutdown() {
		synchronized (workerSet) {
			for (Worker w : workerSet) w.shutdown();
		}
		state = State.TERMINATED;
	}

	@Override
	public void printStatus() {
		System.out.printf("Pool State: %s, WorkerCnt: %d, BusyWorker: %d, IdleWorker: %d, CompletedTask: %d, ProcessingTask: %d, WaitingTask: %d\n",
				state, workerSet.size(), busyWorkers.size(), idleWorkers.size(), completedJobCnt.get(), processingJobCnt.get(), jobList.size());
	}

	private class Worker implements Runnable {

		private Thread t;

		@Override
		public void run() {
			retry:
			while (!Thread.currentThread().isInterrupted()) {
				Runnable runnable = null;
				// 取任务
				synchronized (jobList) {
					while (jobList.isEmpty()) {
						try {
							busyWorkers.remove(this);
							idleWorkers.add(this);
							System.out.printf("%s is waiting for job...\n", Thread.currentThread().getName());
							jobList.wait();
						} catch (InterruptedException e) {
							Thread.currentThread().interrupt(); // 重置中断标志位
							break retry;
						}
					}
					runnable = jobList.poll();
				}
				idleWorkers.remove(this);
				busyWorkers.add(this);
				// 执行任务
				if (runnable != null) {
					processingJobCnt.incrementAndGet();
					runnable.run();
					processingJobCnt.decrementAndGet();
					completedJobCnt.incrementAndGet();
				}
			}
			idleWorkers.remove(this);
			busyWorkers.remove(this);
			workerSet.remove(this);
			System.out.printf("%s is going to terminate...\n", Thread.currentThread().getName());
		}

		public void shutdown() {
			t.interrupt(); // 中断 Worker 所属的线程
		}
	}
}

编写测试类

package experiment;

import java.util.Random;
import java.util.Scanner;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @Author yogurtzzz
 * @Date 2022/4/20 16:05
 **/
public class Test {

	private static Scanner scanner = new Scanner(System.in);

	private static final AtomicInteger taskCnt = new AtomicInteger();

	private static final Random random = new Random();

	public static void main(String[] args) {
		ThreadPool threadPool = new DefaultThreadPool(20);
		Runnable task = () -> {
			int i = random.nextInt(100);
			System.out.printf("task-%d is processing...will take %d seconds \n", taskCnt.getAndIncrement(), i);
			try {
				TimeUnit.SECONDS.sleep(i);
			} catch (InterruptedException e) {
				Thread.currentThread().interrupt();
			}
		};
		while (true) {
			int op = 1;
			try {
				op = Integer.parseInt(scanner.nextLine());
			} catch (Exception e) { }

			if (op == 1) threadPool.execute(task);
			else if (op == 2) threadPool.addWorker(1);
			else if (op == 3) threadPool.removeWorker(1);
			else if (op == 4) threadPool.printStatus();
			else if (op == 5) threadPool.shutdown();
		}
	}
}

效果:

手写阻塞队列

手写一版简单的 ArrayBlockingQueue,底层用普通数组模拟一个循环队列,阻塞机制仍然借助Object的wait/notify

package experiment;

/**
 * @Author yogurtzzz
 * @Date 2022/4/20 16:58
 *
 * 简单的阻塞队列
 **/
public class ArrayBlockingQueue<T> {

	private final Object notEmpty = new Object();

	private final Object notFull = new Object();

	private Object[] elements;

	// 双指针循环队列
	private int first;

	private int last;

	public ArrayBlockingQueue(int size) {
		elements = new Object[size + 1]; // 空出一个位置, 用来标识队列是满还是空
		first = 0; // last 用来指示队尾的下一个位置(插入时直接在last位置插入), first 留空
		last = 1; // 当 last 在 first 下一个位置时, 即 last = first + 1 时, 队列为空; 当 first = last时, 队列满
		// 每次插入, 直接在 last 位置插入, 并后移 last;
		// 每次取元素, 在 first + 1 的位置取
	}

	public void put(T e) {
		int c = -1;
		synchronized (notFull) {
			while (isFull()) { // 如果队列满, 则等待
				try {
					notFull.wait();
				} catch (InterruptedException e1) {
					Thread.currentThread().interrupt(); // 传递中断
				}
			}
			c = size();
			elements[last] = e;
			last = (last + 1) % elements.length;
		}
		// 当先前的size为0时, 添加后才进行通知
		if (c == 0) {
			synchronized (notEmpty) {
				notEmpty.notify();
			}
		}
	}

	public T get() {
		T res = null;
		int c = -1;
		synchronized (notEmpty) {
			while (isEmpty()) {
				try {
					notEmpty.wait();
				} catch (InterruptedException e) {
					Thread.currentThread().interrupt(); // 传递中断
				}
			}
			c = size();
			res = (T) elements[(first + 1) % elements.length];
			first = (first + 1) % elements.length;
		}
		if (c == elements.length - 1) {
			synchronized (notFull) {
				notFull.notify();
			}
		}
		return res;
	}

	public int size() {
		return (last - first - 1 + elements.length) % elements.length;
	}

	private boolean isFull() {
		return last == first;
	}

	private boolean isEmpty() {
		return ((first + 1) % elements.length) == last;
	}

}

欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/langs/721869.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-04-25
下一篇 2022-04-25

发表评论

登录后才能评论

评论列表(0条)

保存