Python 设计模式 —— 工厂模式

工厂模式的概念

在编程中,我们经常需要创建对象。如果直接在代码中使用 new 关键字 (C++) 或类的构造函数,会导致:

  • 代码耦合度高:创建对象的代码与具体类紧密绑定
  • 维护困难:当需要修改或添加新的对象类型时,需要修改多处代码
  • 违反开闭原则:对扩展开放,对修改关闭的原则被破坏

工厂模式通过将对象的创建过程封装起来,解决了这些问题.

简单工厂模式

内容和具体实现

简单工厂模式比较简单,其主要的部分是:

  • 抽象基类
  • 基于基类实现的子类
  • 简单工厂类,直接调用这个工厂类的方法可以创建子类.

一个简单的实现如下,考虑多臂老虎机的基类,可以看到:

Python
from abc import ABC, abstractmethod
# 产品接口,或者说是抽象基类
class Bandit(ABC):
    @abstractmethod
    def pull():
        # 结果由子类实现
        pass

这里的 ABCabstractmethod 的主要作用是:ABC 指示的类是抽象基类,其不能被实例化,并且如果其子类有 abstractmethod 未实现,则子类也将无法实例化; abstractmethod 表示方法是抽象方法,子类必须实现,否则报错.

一个子类如果写为

Python
class BernoulliBandit(Bandit):
    pass

然后直接调用 BernoulliBandit.pull() ,则将会有如下报错:

Python
Traceback (most recent call last):
  File "e:\Projects\设计模式\工厂模式\简单工厂模式.py", line 23, in <module>
    bandit = BernoulliBandit()
             ^^^^^^^^^^^^^^^^^
TypeError: Can't instantiate abstract class BernoulliBandit without an implementation for abstract method 'pull'

因此我们这里给出一个子类的实现:

Python
import random
# 具体实现
class BernoulliBandit(Bandit):
    def pull(self):
        return random.randint(0, 1)

if __name__ == "__main__":
    bandit = BernoulliBandit()
    print(bandit.pull())

总体而言,这个 BernoulliBandit 实现了 pull 方法,所以 abstractmethod 此时不会报错,同理还能实现一些不同的老虎机,比如标准正态老虎机.

Python
import numpy as np
class StdNormalBandit(Bandit):
    def pull(self):
        return np.random.randn()

然后最重要的就是简单工厂类:

Python
class BanditFactory:
    @staticmethod
    def create_bandit(bandit_type):
        if bandit_type == "bernoulli":
            return BernoulliBandit()
        elif bandit_type == "normal":
            return StdNormalBandit()
        else:
            raise ValueError(f"未知的老虎机类型:{bandit_type}")

这个的作用就是简单工厂,只有一个实例化子类的方法,并且这个方法是静态方法(也就是无需实例化,可以直接调用,无法访问 self 的方法).

下面给出完整的方法:

Python
# 使用简单工厂模式来实现多臂老虎机
import numpy as np
from abc import ABC, abstractmethod


# 产品接口,或者说是抽象基类
class Bandit(ABC):
    @abstractmethod
    def pull(self):
        # 结果由子类实现
        pass


# 具体实现
class BernoulliBandit(Bandit):
    def pull(self):
        return np.random.randint(0, 2)


class StdNormalBandit(Bandit):
    def pull(self):
        return np.random.randn()


class BanditFactory:
    @staticmethod
    def create_bandit(bandit_type):
        if bandit_type == "bernoulli":
            return BernoulliBandit()
        elif bandit_type == "normal":
            return StdNormalBandit()
        else:
            raise ValueError(f"未知的老虎机类型:{bandit_type}")


if __name__ == "__main__":
    bandit = BanditFactory.create_bandit(bandit_type="bernoulli")
    print(bandit.pull())

优点与缺点

优点:

  • 客户端与具体产品类解耦
  • 职责分离,易于维护

缺点:

  • 添加新产品需要修改工厂类,违反开闭原则
  • 工厂类职责过重,不符合单一职责原则

工厂方法模式

内容和具体实现

工厂方法模式通过让子类决定创建什么对象来解决简单工厂模式的问题,因为之前的简单工厂类职责过重,所以我们现在应该考虑的问题是解耦.

还是多臂老虎机的例子,此时我们之前的老虎机工厂类有问题就是为了添加新的分布(老虎机)需要修改工厂类,过于麻烦,那么,我们何不直接将其设置抽象出来,每个都对应写一个 BanditConfig 呢?

同时,为了说明工厂方法模式的好处,我们不妨思考,之前实现的老虎机显然是不够格的,对于 Bernoulli 老虎机,我们应该给出 Bernoulli 分布的参数 p ;正态分布老虎机则需要给出均值以及方差,但是简单工厂类是很难包纳所有这些分布的稀奇古怪的参数的. 当然,一个实现就是使用:

Python
def create_bandit(bandit_type, **kwargs):
	if bandit_type == "bernoulli":
		p = kwargs.get('p', 0.5)   # 默认0.5

只是,这种实现不能给使用者或者编写者带来任何信息. 为了改进,我们每个老虎机设置都可以单独继承自一个配置文件基类. 这里我给出完整的代码:

Python
# 使用简单工厂模式来实现多臂老虎机
from numpy.ma import mean
import numpy as np
from abc import ABC, abstractmethod


# 产品接口,或者说是抽象基类
class Bandit(ABC):
    """老虎机抽象基类"""

    @abstractmethod
    def pull(self):
        # 结果由子类实现
        pass


# 老虎机具体实现
class BernoulliBandit(Bandit):
    """Bernoulli 老虎机"""

    def __init__(self, p: float = 0.5):
        self.p: float = p

    def pull(self):
        return np.random.binomial(n=1, p=self.p)


class NormalBandit(Bandit):
    """正态老虎机"""

    def __init__(self, mean: float = 0.0, variance: float = 1.0):
        self.mean: int | float = mean
        self.variance: int | float = variance

    def pull(self):
        return self.mean + np.random.randn() * np.sqrt(self.variance)


# 配置文件抽象类
class BanditConfig(ABC):
    """老虎机配置类基类"""

    @abstractmethod
    def create_bandit(self):
        pass


# 配置具体实现
class BernoulliConfig(BanditConfig):
    """Bernoulli 老虎机配置"""

    def __init__(self, p: float):
        self.p: float = p

    def create_bandit(self) -> Bandit:
        return BernoulliBandit(p=self.p)


class NormalConfig(BanditConfig):
    """正态老虎机配置"""

    def __init__(self, mean: float = 0.0, variance: float = 1.0):
        self.mean: int | float = mean
        self.variance: int | float = variance

    def create_bandit(self) -> Bandit:
        return NormalBandit(mean=self.mean, variance=self.variance)


if __name__ == "__main__":
    config = NormalConfig()
    bandit = config.create_bandit()
    print(bandit.pull())

抽象工厂模式

抽象工厂模式提供一个创建一系列相关或依赖对象的接口,而无需指定它们具体的类.

上面的定义看着很抽象,实际上很简单,就是考虑之前的 Config 如果能创建多个类,那么它实际上就变成了一个新的工厂,Config 基类也就变成了抽象工厂.

具体案例来看,假如你在写一个多臂老虎机的库,此时有人提议加新的功能,说应该反复采样并分析多臂老虎机拉杆的结果,因此应该加个 Analyzer ,同时负责分析结果. Bernoulli 老虎机只需要看均值,但正态老虎机必须考虑样本均值和方差.

下面直接给出完整的代码:

Python
# 使用工厂方法模式来实现多臂老虎机
import numpy as np
from abc import ABC, abstractmethod


# 产品接口,或者说是抽象基类
class Bandit(ABC):
    """老虎机抽象基类"""

    @abstractmethod
    def pull(self):
        # 结果由子类实现
        pass


# 老虎机具体实现
class BernoulliBandit(Bandit):
    """Bernoulli 老虎机"""

    def __init__(self, p: float = 0.5):
        self.p: float = p

    def pull(self):
        return np.random.binomial(n=1, p=self.p)


class NormalBandit(Bandit):
    """正态老虎机"""

    def __init__(self, mean: float = 0.0, variance: float = 1.0):
        self.mean: int | float = mean
        self.variance: int | float = variance

    def pull(self):
        return self.mean + np.random.randn() * np.sqrt(self.variance)


# 分析器基类
class Analyzer(ABC):
    """老虎机分析器基类"""

    @abstractmethod
    def analyze(self, n: int):
        pass


# 分析器具体实现
class BernoulliAnalyzer(Analyzer):
    def __init__(self, p: float):
        self.bandit = BernoulliBandit(p)

    def analyze(self, n: int) -> None:
        res = []
        for i in range(n):
            res.append(self.bandit.pull())

        print(np.mean(res))


class NormalAnalyzer(Analyzer):
    def __init__(self, mean: float, variance: float):
        self.bandit = NormalBandit(mean, variance)

    def analyze(self, n: int) -> None:
        res = []
        for i in range(n):
            res.append(self.bandit.pull())

        print(np.mean(res), np.var(res))


# 配置文件抽象类
class BanditFactory(ABC):
    """老虎机配置类基类"""

    @abstractmethod
    def create_bandit(self) -> Bandit:
        pass

    @abstractmethod
    def create_analyzer(self) -> Analyzer:
        pass


# 配置具体实现
class BernoulliFactory(BanditFactory):
    """Bernoulli 老虎机配置"""

    def __init__(self, p: float):
        self.p: float = p

    def create_bandit(self) -> Bandit:
        return BernoulliBandit(p=self.p)

    def create_analyzer(self) -> Analyzer:
        return BernoulliAnalyzer(p=self.p)


class NormalFactory(BanditFactory):
    """正态老虎机配置"""

    def __init__(self, mean: float = 0.0, variance: float = 1.0):
        self.mean: int | float = mean
        self.variance: int | float = variance

    def create_bandit(self) -> Bandit:
        return NormalBandit(mean=self.mean, variance=self.variance)

    def create_analyzer(self) -> Analyzer:
        return NormalAnalyzer(mean=self.mean, variance=self.variance)


if __name__ == "__main__":
    config = NormalFactory()
    bandit = config.create_bandit()
    print(bandit.pull())

    analyzer = config.create_analyzer()
    analyzer.analyze(n=20000)

可以看到,所谓的 Factory 其实也就是比 Config 多了一些东西而已. 因此抽象工厂模式可以看作工厂方法模式的一个延伸.

Leave a Comment

您的邮箱地址不会被公开。 必填项已用 * 标注