Python 设计模式 —— 建造者模式

建造者模式的概念

建造者模式有点像工厂模式,它的主要面向的对象是参数较多的对象,例如统计学当中,我们使用回归模型,此时的参数可能相对较多:

Python
class Regressor:
    def __init__(self):
        self._model_type = "ols"  # 模型类型
        self._fit_intercept = True  # 截距是否拟合
        self._standardize = False  # 标准化
        self._alpha = 1.0  # 正则化力度
        self._l1_ratio = 0.5  # l1 比例
        self._solver = "auto"  # 求解器

此时的模型参数较多,可能更多时候会达到几十个,这种情况下,为了使用的简便我们往往需要固定其中某些参数给出子类.

简而言之,工厂模式和建造者模式的主要区别在于:

  • 工厂模式给出的是抽象基类和工厂类,抽象基类的实现可以很简便,子类可以额外定义很多新内容;
  • 建造者模式的具体产品(类)是创建比较复杂(参数多、部件多),我们想将构建过程和表示过程(给入参数的过程)分离;

举个例子就是,工厂模式主要解决的问题是:我们有“电子产品”这类基类。然后想派生出电脑、手机等具体物件;建造者模式则是已经有了电脑的概念,同时内部的各个部件都很复杂,需要单独分离出游戏电脑、办公电脑等细分场景.

下面是建造者模式的主要流程:

建造者模式图

Python 实现

具体产品

一个具体的回归器可以这么写:

Python
class Regressor:
    def __init__(self):
        self._model_type = "ols"  # 模型类型
        self._fit_intercept = True  # 截距是否拟合
        self._standardize = False  # 标准化
        self._alpha = 1.0  # 正则化力度
        self._l1_ratio = 0.5  # l1 比例
        self._solver = "auto"  # 求解器

    def __str__(self) -> str:
        specs: dict = {}
        if self._model_type:
            specs["模型类型"] = self._model_type
        if self._fit_intercept:
            specs["是否拟合截距"] = self._fit_intercept
        if self._standardize:
            specs["是否标准化"] = self._standardize
        if self._alpha:
            specs["正则化力度"] = self._alpha
        if self._l1_ratio:
            specs["L1 系数"] = self._l1_ratio
        if self._solver:
            specs["求解器"] = self._solver

        return "模型信息:\n" + "\n".join(
            [f"{key}: {specs[key]}" for key in specs.keys()]
        )

我们接下来的想法是:实现 LASSO 回归和一般线性回归,那么此时其实只要固定其中几个参数即可,因此使用的就是建造者模式.

抽象建造者

对于抽象建造者,其实主要的方法就是选定几个参数的方法,注意初始化的时候要给出 Regressor() 的实例.

Python
from abc import ABC, abstractmethod
class RegressorBuilder(ABC):
    def __init__(self):
        self.regressor = Regressor()

    @abstractmethod
    def select_type(self):
        pass

    @abstractmethod
    def select_fit_intercept(self):
        pass

    @abstractmethod
    def select_alpha(self):
        pass

    @abstractmethod
    def select_l1_ratio(self):
        pass

    @abstractmethod
    def select_solver(self):
        pass

    def get_regressor(self) -> Regressor:
        return self.regressor

具体建造者

我们以 LASSO 为例,此时固定了正则化力度和惩罚参数,因此可以写清楚每一个具体的参数:

Python
class LassoRegressorBuilder(RegressorBuilder):
    def select_type(self):
        self.regressor._model_type = "lasso"

    def select_fit_intercept(self):
        self.regressor._fit_intercept = True

    def select_alpha(self):
        self.regressor._alpha = 1.0

    def select_l1_ratio(self):
        self.regressor._l1_ratio = 1.0

    def select_solver(self):
        self.regressor._solver = "auto"

指挥者

如果直接使用上述的建造者,我们这些选定的方法就每次都要一行行执行,很麻烦,因此指挥者的主要作用就是不管具体建造者是谁,都依次调用相应的方法.

Python
class RegressorDirector:
    """回归器指挥者"""

    def __init__(self, builder: RegressorBuilder):
        self.builder: RegressorBuilder = builder

    def construct_regressor(self):
        """构建回归模型的完整过程"""
        self.builder.select_alpha()
        self.builder.select_fit_intercept()
        self.builder.select_l1_ratio()
        self.builder.select_solver()
        self.builder.select_type()

    def get_regressor(self) -> Regressor:
        return self.builder.get_regressor()

代码的全部实现

Python
# 建造者模式

# 1. 产品类 (回归模型)
from abc import ABC, abstractmethod


class Regressor:
    def __init__(self):
        self._model_type = "ols"  # 模型类型
        self._fit_intercept = True  # 截距是否拟合
        self._standardize = False  # 标准化
        self._alpha = 1.0  # 正则化力度
        self._l1_ratio = 0.5  # l1 比例
        self._solver = "auto"  # 求解器

    def __str__(self) -> str:
        specs: dict = {}
        if self._model_type:
            specs["模型类型"] = self._model_type
        if self._fit_intercept:
            specs["是否拟合截距"] = self._fit_intercept
        if self._standardize:
            specs["是否标准化"] = self._standardize
        if self._alpha:
            specs["正则化力度"] = self._alpha
        if self._l1_ratio:
            specs["L1 系数"] = self._l1_ratio
        if self._solver:
            specs["求解器"] = self._solver

        return "模型信息:\n" + "\n".join(
            [f"{key}: {specs[key]}" for key in specs.keys()]
        )

# 2. 抽象建造者
class RegressorBuilder(ABC):
    def __init__(self):
        self.regressor = Regressor()

    @abstractmethod
    def select_type(self):
        pass

    @abstractmethod
    def select_fit_intercept(self):
        pass

    @abstractmethod
    def select_alpha(self):
        pass

    @abstractmethod
    def select_l1_ratio(self):
        pass

    @abstractmethod
    def select_solver(self):
        pass

    def get_regressor(self) -> Regressor:
        return self.regressor

# 3. 具体建造者
class LinearRegressorBuilder(RegressorBuilder):
    def select_type(self):
        self.regressor._model_type = "ols"

    def select_fit_intercept(self):
        self.regressor._fit_intercept = True

    def select_alpha(self):
        self.regressor._alpha = 0.0

    def select_l1_ratio(self):
        self.regressor._l1_ratio = 0.0

    def select_solver(self):
        self.regressor._solver = "auto"


class LassoRegressorBuilder(RegressorBuilder):
    def select_type(self):
        self.regressor._model_type = "lasso"

    def select_fit_intercept(self):
        self.regressor._fit_intercept = True

    def select_alpha(self):
        self.regressor._alpha = 1.0

    def select_l1_ratio(self):
        self.regressor._l1_ratio = 1.0

    def select_solver(self):
        self.regressor._solver = "auto"

# 4. 指挥者
class RegressorDirector:
    """回归器指挥者"""

    def __init__(self, builder: RegressorBuilder):
        self.builder: RegressorBuilder = builder

    def construct_regressor(self):
        """构建回归模型的完整过程"""
        self.builder.select_alpha()
        self.builder.select_fit_intercept()
        self.builder.select_l1_ratio()
        self.builder.select_solver()
        self.builder.select_type()

    def get_regressor(self) -> Regressor:
        return self.builder.get_regressor()


# 构建 LASSO 模型
lasso_builder = LassoRegressorBuilder()
lasso_director = RegressorDirector(lasso_builder)
lasso_director.construct_regressor()
lasso: Regressor = lasso_director.get_regressor()
print(lasso)

Leave a Comment

您的邮箱地址不会被公开。 必填项已用 * 标注