Write the Code. Change the World.

6月 15

虽然 dify 支持创建多个 api-token,也修改了 api_tokens 表,加上了 account_id 进行关联。但是,通过 api-token 转换成 current_user 的时候,这个 user 总是 tenant_account_joins 表里 role 字段是 owner 的用户。既然想给每个用户都有单独的 token,就要 current_user 能对应到每个用户身上。

api-key 到 account 的转换

从 post("/v1/datasets") 接口里,往回找,就可以找到用户绑定的地方:api\controllers\service_api\dataset\dataset.py

这里就有当前用户的获取: assert isinstance(current_user, Account)

assert isinstance(current_user, Account)

api\controllers\service_api\wraps.py 里,有用户查询相关的逻辑。

        tenant_account_join = db.session.execute(
            select(Tenant, TenantAccountJoin)
            .where(Tenant.id == api_token.tenant_id)
            .where(TenantAccountJoin.tenant_id == Tenant.id)
            .where(TenantAccountJoin.role.in_(["owner"]))
            .where(Tenant.status == TenantStatus.NORMAL)
        ).one_or_none()  # TODO: only owner information is required, so only one is returned.
        if tenant_account_join:
            tenant, ta = tenant_account_join
            account = db.session.get(Account, ta.account_id)
            # Login admin
            if account:
                account.current_tenant = tenant
                current_app.login_manager._update_request_context_with_user(account)  # type: ignore
                user_logged_in.send(current_app._get_current_object(), user=current_user)  # type: ignore
            else:
                raise Unauthorized("Tenant owner account does not exist.")

既然之前在 api_tokens 表中加了 account_id 字段,那么优先就应该从 api_tokens 表查询。如果这个表列没查出对应的 account_id, 再走原来的逻辑。

def validate_dataset_token[R](view: Callable[..., R]) -> Callable[..., R]:
    positional_parameters = [
        parameter
        for parameter in inspect.signature(view).parameters.values()
        if parameter.kind
        in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
    ]
    expects_bound_instance = bool(
        positional_parameters and positional_parameters[0].name in {"self", "cls"}
    )

    @wraps(view)
    def decorated(*args: object, **kwargs: object) -> R:
        api_token = validate_and_get_api_token("dataset")

        if api_token.account_id:
            account = db.session.get(Account, api_token.account_id)
            if not account:
                raise Unauthorized(
                    "By account_id. Token associated account does not exist."
                )

            tenant = db.session.get(Tenant, api_token.tenant_id)
            if not tenant or tenant.status == TenantStatus.ARCHIVE:
                raise Forbidden("Workspace is invalid or archived")

            account.current_tenant = tenant
            current_app.login_manager._update_request_context_with_user(account)  # type: ignore
            user_logged_in.send(current_app._get_current_object(), user=account)
        else:
            # Flask may pass URL path parameters positionally, so inspect both kwargs and args.
            dataset_id = kwargs.get("dataset_id")

            if not dataset_id and args:
                potential_id = args[0]
                try:
                    str_id = str(potential_id)
                    if len(str_id) == 36 and str_id.count("-") == 4:
                        dataset_id = str_id
                except Exception:
                    logger.exception("Failed to parse dataset_id from positional args")

            if dataset_id:
                dataset_id = str(dataset_id)
                dataset = db.session.scalar(
                    select(Dataset)
                    .where(
                        Dataset.id == dataset_id,
                        Dataset.tenant_id == api_token.tenant_id,
                    )
                    .limit(1)
                )
                if not dataset:
                    raise NotFound("Dataset not found.")
                if not dataset.enable_api:
                    raise Forbidden("Dataset api access is not enabled.")

            tenant_account_join = db.session.execute(
                select(Tenant, TenantAccountJoin)
                .where(Tenant.id == api_token.tenant_id)
                .where(TenantAccountJoin.tenant_id == Tenant.id)
                .where(TenantAccountJoin.role.in_(["owner"]))
                .where(Tenant.status == TenantStatus.NORMAL)
            ).one_or_none()  # TODO: only owner information is required, so only one is returned.
            if tenant_account_join:
                tenant, ta = tenant_account_join
                account = db.session.get(Account, ta.account_id)
                # Login admin
                if account:
                    account.current_tenant = tenant
                    current_app.login_manager._update_request_context_with_user(account)  # type: ignore
                    user_logged_in.send(current_app._get_current_object(), user=current_user)  # type: ignore
                else:
                    raise Unauthorized("Tenant owner account does not exist.")
            else:
                raise Unauthorized("Tenant does not exist.")

        if expects_bound_instance:
            if not args:
                raise TypeError(
                    "validate_dataset_token expected a bound resource instance."
                )
            return view(args[0], api_token.tenant_id, *args[1:], **kwargs)

        return view(api_token.tenant_id, *args, **kwargs)

    return decorated

只加了:

        if api_token.account_id:
            account = db.session.get(Account, api_token.account_id)
            if not account:
                raise Unauthorized(
                    "By account_id. Token associated account does not exist."
                )

            tenant = db.session.get(Tenant, api_token.tenant_id)
            if not tenant or tenant.status == TenantStatus.ARCHIVE:
                raise Forbidden("Workspace is invalid or archived")

            account.current_tenant = tenant
            current_app.login_manager._update_request_context_with_user(account)  # type: ignore
            user_logged_in.send(current_app._get_current_object(), user=account)

到这里看似是 ok 了。但是还是会报错。在 api_token = validate_and_get_api_token("dataset") 里,api_token 可能来自 ApiTokenCache。如果来自 ApiTokenCache 这里就会报错。在 ApiTokenCache 里根本就没有定要 account_id。如果不走缓存,之前在修改数据库表的时候,在模型中已经加入了 account_id 字段,自然不会报错。现在只需要在ApiTokenCache 中加入 account_id 即可。

# api\services\api_token_service.py

    id: str
    app_id: str | None
    tenant_id: str | None
    account_id: str | None
    type: str
    token: str
    last_used_at: datetime | None
    created_at: datetime | None

        cached = CachedApiToken(
            id=str(api_token.id),
            app_id=str(api_token.app_id) if api_token.app_id else None,
            tenant_id=str(api_token.tenant_id) if api_token.tenant_id else None,
            account_id=str(api_token.account_id) if api_token.account_id else None,
            type=api_token.type,
            token=api_token.token,
            last_used_at=api_token.last_used_at,
            created_at=api_token.created_at,
        )

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注