Source code for council.filters.basic_filter
from typing import List, Optional
from council.contexts import AgentContext, ScoredChatMessage
from council.filters import FilterBase
[docs]
class BasicFilter(FilterBase):
"""
a basic filter that filters messages based on a score threshold.
"""
[docs]
def __init__(self, score_threshold: Optional[float] = None, top_k: Optional[int] = None) -> None:
"""
Args:
score_threshold: minimum score value for a message to be kept
top_k: maximum number of messages to be kept
"""
super().__init__()
self._score_threshold = score_threshold
self._top_k = top_k
def _execute(self, context: AgentContext) -> List[ScoredChatMessage]:
filtered = self._filter(context)
if self._top_k is not None and self._top_k > 0:
return filtered[: self._top_k]
return filtered
def _filter(self, context: AgentContext) -> List[ScoredChatMessage]:
all_eval_results = context.evaluation
if all_eval_results is None:
return []
if self._score_threshold is not None:
filtered = [x for x in all_eval_results if x.score >= self._score_threshold]
return filtered
return list(all_eval_results)