首页 > 解决方案 > 此代码正在运行,但现在我从 __getattribute__ 中的计数器收到递归错误

问题描述

我需要帮助使我的 Query 对象没有递归错误。为了检查正确数量的函数已被链调用,我添加了一个计数器来获取导致递归错误的属性。当我只有几个功能时没有发生这种情况,但现在一直在发生,我需要修复

改变功能。使用自我、对象等。我使用的是 python 3.5,这段代码在较小的时候可以工作,但后来出现了错误


import mysql.connector as sql

class Database:

    def __init__(self):

        # Attributes
        host = "127.0.0.1"
        database = "database"
        user = "root"
        password = None

        self.connection = sql.connect(
            host = host,
            database = database,
            user = user,
            password = password
        )

        self.cursor = self.connection.cursor()


class Query(object):

    """

        Object initialisers and attributes
        ----------------------------------

    """

    def __init__(self, table):

        super().__init__()

        self.query = None
        self._table = table
        self._count = 0
        self._countRequired = 1
        self._return = False
        self._leadMethods = ["select", "update", "insert", "join"]
        self._selectorMethods = ["where"]
        self._returnTypes = ["fetchAll", "get", "first", "chunk", "last", "row"]
        self._specialCases = ["count", "len", "desc", "asc", "orderBy"]

    def __call__(self):

        if self._count >= self._countRequired:
            self._execute()
        else:
            raise NotImplementedError("Check amount of methods for SQL statement first")

    def __getattr__(self, attr, *args):
        # Updatable
        return "Attributes not found!"

    def __getattribute__(*args):
        args[0]._count += 1
        return object.__getattribute__(*args)


    """

        Private methods and static methods
        ----------------------------------

    """

    @staticmethod
    def count(self, columns):

        chain = ""
        if columns:

            if type(columns) == type([]):

                for i in range(0, len(columns)):
                    if i == len(columns):
                        chain += "COUNT(" + column + ") as count_" + column
                        break
                    chain += "COUNT(" + column + ") as count_" + column + ", "

            elif type(columns) == type(""):

                chain = "COUNT(" + columns + ") as count_" + columns

        return chain

    @staticmethod
    def len(self, columns):

        chain = ""
        if columns:

            if type(columns) == type([]):

                for i in range(0, len(columns)):
                    if i == len(columns):
                        chain += "LEN(" + column + ") as len_" + column
                        break
                    chain += "LEN(" + column + ") as len_" + column + ", "

            elif type(columns) == type(""):

                chain = "LEN(" + columns + ") as len_" + columns

        return chain

    def _execute(self):
        query = self.query
        try:
            self.cursor.execute(query)
            if self._return == True:
                return self.cursor.fetchall()
            return True
        except Exception as e:
            return (False, e)

        self.connection.commit()


    """

        Public methods
        --------------

    """

    def select(self, columns=None, args=None):

        values = None
        if args:
            specialCases = []

            for arg in args:
                for case in self._specialCases[:2]:
                    if arg == case:
                        specialCases += args

            values = []

            for case in specialCases:
                try:
                    case = "" + case
                    x = getattr(self, case)(self, columns)
                    values += [x]
                except Exception as e:
                    print("Special case doesn't exist.")

        selector = ""
        if columns == []:

            for i in range(0, len(columns)):
                if i == len(columns):
                    selector += columns[i]
                    break
                selector += columns[i] + ","
        elif columns:

            selector = columns

        else:
            selector = "*"

        if self.query == None:

            sql = "SELECT " + selector + " FROM " + self._table

            self.query = sql

            self._countRequired += 2
            self._return = True

        return self

    def update(self, columns=None, values=None):

        statement = "UPDATE " + self._table + " SET "

        if len(columns) != len(values):
            raise Exception("Not equal amount of columns and values")

        temp = ""
        if type(columns) == type([]) and type(values) == type([]):
            for i in range(0, len(columns)):
                if i == len(columns):
                    temp += columns[i]  + " = " + values[i]
                    break

        elif type(columns) == type("") and type(values) == type(""):
            temp += str(columns) + " = " + str(values)

        statement += temp

        self._query = statement
        self._countRequired += 1

        return self


    def insert(self, columns=None, values=None):

        statment = "INSERT INTO " + self._table + " "

        temp = "("

        if type(columns) == type([]) and type(values) == type([]):
            for i in range(0, len(columns)):
                if i == len(columns):
                    temp += str(columns[i]) +  ") VALUES ("
                    break
                temp += str(columns[i]) + ", "

            for i in range(0, len(values)):
                if values[i] == None:
                    values[i] = "NULL"
                if i == len(VALUES):
                    temp += str(values[i]) + ")"
                    break
                temp += str(values[i]) + ", "


        return self

    def join (self, table, compareOne, operator, compareTwo, type="INNER"):

        statement = "{} JOIN {} ON {} {} {} ".format(str(type), str(table), str(compareOne), str(operator), str(compareTwo))

        self._query += statement

        return self

    def where(self, column=None, value=None, operator=None):

        statement = " WHERE "

        if not operator:
            operator = "="

        if column and value:
            temp = ""

            if type(column) == type([]) and type(value) == type([]):
                for i in range(0, len(column)):
                    if i == len(column):
                        temp += "{} {} {} ".format(column[i], operator, value[i])
                        break

                    temp += "{} {} {},".format(column[i], operator, value[i])


            elif type(column) == type([]):

                for i in range(0, len(column)):
                    if i == len(column):
                        temp += "{} {} {} ".format(column[i], operator, value)
                        break

                    temp += "{} {} {},".format(column[i], operator, value)


            elif type(value) == type([]):

                for i in range(0, len(value)):
                    if i == len(value):
                        temp += "{} {} {} ".format(column, operator, value[i])
                        break

                    temp += "{} {} {},".format(column, operator, value[i])

            else:

                temp = "{} {} {} ".format(column, operator, value)


            statement += temp

            self.query += statement

        return self

    def fetchAll(self):
        return self

    def get(self):
        return self

    def first(self):
        self.query += "LIMIT 1"
        return self

    def last(self):
        self.query += "DESC LIMIT 1"
        return self

    def row(self, n):
        self.query += "AND id = " + str(n)
        return self

    def chunk(self, n, start=0):
        sql = "AND id BETWEEN " + str(start) + " AND " + str(n)
        self.query += sql
        return self

    """

        Extra Functions

    """

    def orderBy(self, type):

        if type in possible:
            self._query += "ORDER BY " + str(type)
        else:
            raise Exception("Invalid orderBy value. Should be ASC or DESC")
        return self

    def override(self, password, sql):
        if password == "password":
            self._query = SQL
        return self

应该可以输入:

q = Query('table').select().where('name', 'test').get()()

并返回结果

"SELECT * FROM table WHERE name=test;"

标签: pythonrecursionfluent-interface

解决方案


修复 了我创建了一个函数数组,不包括 count 属性,这意味着我只计算它是否是流利界面可接受的方法

def __getattribute__(self, name):
        methods = ["select", "update", "insert", "join", "where", "get", "row", "last", "first", "fetchAll", "chunk", "override", "orderBy"]
        if name in methods:
            self._count += 1
        return object.__getattribute__(self, name)

推荐阅读